feat: initial MCP Browser Sidecar implementation
Playwright browser automation service for LLM-driven UI interaction.
Features:
- Session-based browser management with domain allowlisting
- HTTP API endpoints for browser actions (navigate, click, type, wait, screenshot, snapshot)
- Session lifecycle management (create, close, status)
- Automatic session cleanup (idle timeout, max lifetime)
- Resource limits (max sessions, max actions per session)
- Domain filtering via route interception
API Surface:
- POST /sessions - Create session with allowed_domains
- DELETE /sessions/{id} - Close session
- GET /sessions/{id}/status - Get session info
- POST /sessions/{id}/navigate - Navigate to URL
- POST /sessions/{id}/click - Click element
- POST /sessions/{id}/type - Type into element
- POST /sessions/{id}/wait - Wait for condition
- POST /sessions/{id}/screenshot - Capture screenshot
- POST /sessions/{id}/snapshot - Get accessibility tree
Security:
- Mandatory domain allowlist per session
- Network request filtering
- Session isolation via browser contexts
🤖 Generated with [Claude Code](https://claude.com/claude-code)
Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
This commit is contained in:
commit
5851cb39f4
|
|
@ -0,0 +1,86 @@
|
|||
# CLAUDE.md - LetsBe MCP Browser Sidecar
|
||||
|
||||
## Purpose
|
||||
|
||||
This is the MCP Browser Sidecar - a Playwright browser automation service that provides an HTTP API for LLM-driven exploratory browser control.
|
||||
|
||||
Key features:
|
||||
- Session-based browser management with domain allowlisting
|
||||
- HTTP API for browser actions (navigate, click, type, screenshot, snapshot)
|
||||
- Automatic session cleanup and resource limits
|
||||
- Security via mandatory domain restrictions
|
||||
|
||||
## Tech Stack
|
||||
|
||||
- Python 3.11
|
||||
- FastAPI for HTTP API
|
||||
- Playwright for browser automation
|
||||
- Pydantic for validation
|
||||
- Docker with Playwright base image
|
||||
|
||||
## Project Structure
|
||||
|
||||
```
|
||||
letsbe-mcp-browser/
|
||||
app/
|
||||
__init__.py
|
||||
config.py # Settings from environment
|
||||
domain_filter.py # URL/domain allowlist validation
|
||||
session_manager.py # Browser session lifecycle
|
||||
playwright_client.py # Playwright browser management
|
||||
server.py # FastAPI HTTP endpoints
|
||||
tests/
|
||||
test_domain_filter.py
|
||||
test_session_manager.py
|
||||
Dockerfile
|
||||
docker-compose.yml
|
||||
requirements.txt
|
||||
pytest.ini
|
||||
```
|
||||
|
||||
## Development Commands
|
||||
|
||||
```bash
|
||||
# Start the service
|
||||
docker compose up --build
|
||||
|
||||
# Run tests locally
|
||||
pip install -r requirements.txt
|
||||
pytest -v
|
||||
|
||||
# API available at http://localhost:8931
|
||||
```
|
||||
|
||||
## API Endpoints
|
||||
|
||||
| Method | Endpoint | Description |
|
||||
|--------|----------|-------------|
|
||||
| POST | /sessions | Create new browser session |
|
||||
| GET | /sessions/{id}/status | Get session status |
|
||||
| DELETE | /sessions/{id} | Close session |
|
||||
| POST | /sessions/{id}/navigate | Navigate to URL |
|
||||
| POST | /sessions/{id}/click | Click element |
|
||||
| POST | /sessions/{id}/type | Type into element |
|
||||
| POST | /sessions/{id}/wait | Wait for condition |
|
||||
| POST | /sessions/{id}/screenshot | Take screenshot |
|
||||
| POST | /sessions/{id}/snapshot | Get accessibility tree |
|
||||
| GET | /health | Health check |
|
||||
| GET | /metrics | Basic metrics |
|
||||
|
||||
## Environment Variables
|
||||
|
||||
| Variable | Default | Description |
|
||||
|----------|---------|-------------|
|
||||
| MAX_SESSIONS | 3 | Max concurrent sessions |
|
||||
| IDLE_TIMEOUT_SECONDS | 300 | Session idle timeout |
|
||||
| MAX_SESSION_LIFETIME_SECONDS | 1800 | Max session lifetime |
|
||||
| MAX_ACTIONS_PER_SESSION | 50 | Max actions per session |
|
||||
| LOG_LEVEL | INFO | Logging level |
|
||||
| SCREENSHOTS_DIR | /screenshots | Screenshots directory |
|
||||
|
||||
## Security
|
||||
|
||||
- **Domain allowlist**: Every session requires `allowed_domains` at creation
|
||||
- **Network filtering**: All requests blocked unless domain is in allowlist
|
||||
- **Resource limits**: Max sessions, actions, and timeouts enforced
|
||||
- **Session isolation**: Each session has its own browser context
|
||||
|
|
@ -0,0 +1,28 @@
|
|||
# MCP Browser Sidecar Dockerfile
|
||||
# Based on Playwright's official Python image with Chromium pre-installed
|
||||
|
||||
FROM mcr.microsoft.com/playwright/python:v1.40.0-jammy
|
||||
|
||||
WORKDIR /app
|
||||
|
||||
# Copy requirements first for layer caching
|
||||
COPY requirements.txt .
|
||||
|
||||
# Install Python dependencies
|
||||
RUN pip install --no-cache-dir -r requirements.txt
|
||||
|
||||
# Copy application code
|
||||
COPY app/ ./app/
|
||||
|
||||
# Create screenshots directory
|
||||
RUN mkdir -p /screenshots && chmod 777 /screenshots
|
||||
|
||||
# Expose port
|
||||
EXPOSE 8931
|
||||
|
||||
# Health check
|
||||
HEALTHCHECK --interval=30s --timeout=10s --start-period=5s --retries=3 \
|
||||
CMD curl -f http://localhost:8931/health || exit 1
|
||||
|
||||
# Run the server
|
||||
CMD ["uvicorn", "app.server:app", "--host", "0.0.0.0", "--port", "8931"]
|
||||
|
|
@ -0,0 +1,7 @@
|
|||
"""LetsBe MCP Browser Sidecar.
|
||||
|
||||
A Playwright browser automation service that provides HTTP API for
|
||||
LLM-driven exploratory browser control with domain allowlisting.
|
||||
"""
|
||||
|
||||
__version__ = "0.1.0"
|
||||
|
|
@ -0,0 +1,35 @@
|
|||
"""Configuration settings for MCP Browser Sidecar."""
|
||||
|
||||
from pydantic_settings import BaseSettings
|
||||
|
||||
|
||||
class Settings(BaseSettings):
|
||||
"""Application settings loaded from environment variables."""
|
||||
|
||||
# Session limits
|
||||
max_sessions: int = 3
|
||||
idle_timeout_seconds: int = 300 # 5 minutes
|
||||
max_session_lifetime_seconds: int = 1800 # 30 minutes
|
||||
max_actions_per_session: int = 50
|
||||
|
||||
# Playwright settings
|
||||
browser_headless: bool = True
|
||||
default_timeout_ms: int = 30000
|
||||
navigation_timeout_ms: int = 60000
|
||||
|
||||
# Cleanup interval
|
||||
cleanup_interval_seconds: int = 60
|
||||
|
||||
# Logging
|
||||
log_level: str = "INFO"
|
||||
log_json: bool = True
|
||||
|
||||
# Screenshots directory
|
||||
screenshots_dir: str = "/screenshots"
|
||||
|
||||
class Config:
|
||||
env_prefix = ""
|
||||
case_sensitive = False
|
||||
|
||||
|
||||
settings = Settings()
|
||||
|
|
@ -0,0 +1,82 @@
|
|||
"""Domain filtering and allowlist validation."""
|
||||
|
||||
import fnmatch
|
||||
import re
|
||||
from urllib.parse import urlparse
|
||||
|
||||
|
||||
class DomainFilter:
|
||||
"""
|
||||
Validates URLs against a domain allowlist.
|
||||
|
||||
Supports:
|
||||
- Exact domain matching: "example.com"
|
||||
- Wildcard subdomains: "*.example.com"
|
||||
- Domains with ports: "example.com:8443"
|
||||
"""
|
||||
|
||||
def __init__(self, allowed_domains: list[str]):
|
||||
"""
|
||||
Initialize the domain filter.
|
||||
|
||||
Args:
|
||||
allowed_domains: List of allowed domain patterns
|
||||
"""
|
||||
if not allowed_domains:
|
||||
raise ValueError("allowed_domains cannot be empty")
|
||||
|
||||
self.allowed_domains = allowed_domains
|
||||
self._patterns = self._compile_patterns(allowed_domains)
|
||||
|
||||
def _compile_patterns(self, domains: list[str]) -> list[re.Pattern]:
|
||||
"""Compile domain patterns into regex for efficient matching."""
|
||||
patterns = []
|
||||
for domain in domains:
|
||||
# Convert wildcard pattern to regex
|
||||
# *.example.com -> matches any subdomain of example.com
|
||||
if domain.startswith("*."):
|
||||
# Match the exact domain or any subdomain
|
||||
base = re.escape(domain[2:])
|
||||
pattern = rf"^([a-zA-Z0-9-]+\.)*{base}$"
|
||||
else:
|
||||
# Exact match
|
||||
pattern = rf"^{re.escape(domain)}$"
|
||||
patterns.append(re.compile(pattern, re.IGNORECASE))
|
||||
return patterns
|
||||
|
||||
def is_allowed(self, url: str) -> bool:
|
||||
"""
|
||||
Check if a URL's domain is in the allowlist.
|
||||
|
||||
Args:
|
||||
url: The URL to check
|
||||
|
||||
Returns:
|
||||
True if the domain is allowed, False otherwise
|
||||
"""
|
||||
try:
|
||||
parsed = urlparse(url)
|
||||
host = parsed.netloc
|
||||
|
||||
# Include port if present
|
||||
if not host:
|
||||
return False
|
||||
|
||||
# Check against all patterns
|
||||
for pattern in self._patterns:
|
||||
if pattern.match(host):
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
except Exception:
|
||||
return False
|
||||
|
||||
def get_blocked_reason(self, url: str) -> str:
|
||||
"""Get a human-readable reason for why a URL was blocked."""
|
||||
try:
|
||||
parsed = urlparse(url)
|
||||
host = parsed.netloc
|
||||
return f"Domain '{host}' not in allowlist: {self.allowed_domains}"
|
||||
except Exception:
|
||||
return f"Invalid URL: {url}"
|
||||
|
|
@ -0,0 +1,87 @@
|
|||
"""Playwright browser management."""
|
||||
|
||||
import asyncio
|
||||
from typing import Optional
|
||||
|
||||
from playwright.async_api import Browser, Playwright, async_playwright
|
||||
|
||||
from app.config import settings
|
||||
|
||||
|
||||
class PlaywrightClient:
|
||||
"""
|
||||
Manages the Playwright browser instance.
|
||||
|
||||
Provides lifecycle management for a single Chromium browser
|
||||
that serves all sessions.
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
self._playwright: Optional[Playwright] = None
|
||||
self._browser: Optional[Browser] = None
|
||||
|
||||
@property
|
||||
def browser(self) -> Browser:
|
||||
"""Get the browser instance."""
|
||||
if self._browser is None:
|
||||
raise RuntimeError("Browser not started. Call start() first.")
|
||||
return self._browser
|
||||
|
||||
@property
|
||||
def is_running(self) -> bool:
|
||||
"""Check if browser is running."""
|
||||
return self._browser is not None
|
||||
|
||||
async def start(self) -> Browser:
|
||||
"""
|
||||
Start the Playwright browser.
|
||||
|
||||
Returns:
|
||||
The Browser instance
|
||||
"""
|
||||
if self._browser is not None:
|
||||
return self._browser
|
||||
|
||||
self._playwright = await async_playwright().start()
|
||||
|
||||
# Launch Chromium with Docker-compatible settings
|
||||
self._browser = await self._playwright.chromium.launch(
|
||||
headless=settings.browser_headless,
|
||||
args=[
|
||||
"--no-sandbox",
|
||||
"--disable-setuid-sandbox",
|
||||
"--disable-dev-shm-usage",
|
||||
"--disable-gpu",
|
||||
"--single-process",
|
||||
],
|
||||
)
|
||||
|
||||
return self._browser
|
||||
|
||||
async def stop(self) -> None:
|
||||
"""Stop the Playwright browser and cleanup."""
|
||||
if self._browser:
|
||||
try:
|
||||
await self._browser.close()
|
||||
except Exception:
|
||||
pass
|
||||
self._browser = None
|
||||
|
||||
if self._playwright:
|
||||
try:
|
||||
await self._playwright.stop()
|
||||
except Exception:
|
||||
pass
|
||||
self._playwright = None
|
||||
|
||||
|
||||
# Global singleton instance
|
||||
_playwright_client: Optional[PlaywrightClient] = None
|
||||
|
||||
|
||||
def get_playwright_client() -> PlaywrightClient:
|
||||
"""Get the global PlaywrightClient instance."""
|
||||
global _playwright_client
|
||||
if _playwright_client is None:
|
||||
_playwright_client = PlaywrightClient()
|
||||
return _playwright_client
|
||||
|
|
@ -0,0 +1,441 @@
|
|||
"""FastAPI HTTP server for MCP Browser Sidecar."""
|
||||
|
||||
import os
|
||||
import uuid
|
||||
from contextlib import asynccontextmanager
|
||||
from datetime import datetime, timezone
|
||||
from typing import Any, Optional
|
||||
|
||||
from fastapi import FastAPI, HTTPException, status
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from app.config import settings
|
||||
from app.playwright_client import get_playwright_client
|
||||
from app.session_manager import BrowserSession, SessionManager
|
||||
|
||||
# Global session manager
|
||||
_session_manager: Optional[SessionManager] = None
|
||||
|
||||
|
||||
def get_session_manager() -> SessionManager:
|
||||
"""Get the global SessionManager instance."""
|
||||
if _session_manager is None:
|
||||
raise RuntimeError("SessionManager not initialized")
|
||||
return _session_manager
|
||||
|
||||
|
||||
@asynccontextmanager
|
||||
async def lifespan(app: FastAPI):
|
||||
"""Application lifespan manager."""
|
||||
global _session_manager
|
||||
|
||||
# Startup
|
||||
client = get_playwright_client()
|
||||
browser = await client.start()
|
||||
_session_manager = SessionManager(browser)
|
||||
await _session_manager.start_cleanup_task()
|
||||
|
||||
yield
|
||||
|
||||
# Shutdown
|
||||
if _session_manager:
|
||||
await _session_manager.stop_cleanup_task()
|
||||
await _session_manager.close_all_sessions()
|
||||
await client.stop()
|
||||
|
||||
|
||||
app = FastAPI(
|
||||
title="MCP Browser Sidecar",
|
||||
description="Playwright browser automation service for LLM-driven UI interaction",
|
||||
version="0.1.0",
|
||||
lifespan=lifespan,
|
||||
)
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Request/Response Models
|
||||
# =============================================================================
|
||||
|
||||
|
||||
class CreateSessionRequest(BaseModel):
|
||||
"""Request to create a new browser session."""
|
||||
|
||||
allowed_domains: list[str] = Field(
|
||||
...,
|
||||
min_length=1,
|
||||
description="List of allowed domain patterns (required)",
|
||||
)
|
||||
|
||||
|
||||
class SessionResponse(BaseModel):
|
||||
"""Response containing session information."""
|
||||
|
||||
session_id: str
|
||||
created_at: datetime
|
||||
last_activity: datetime
|
||||
actions_used: int
|
||||
actions_remaining: int
|
||||
expires_at: datetime
|
||||
allowed_domains: list[str]
|
||||
|
||||
|
||||
class NavigateRequest(BaseModel):
|
||||
"""Request to navigate to a URL."""
|
||||
|
||||
url: str = Field(..., min_length=1, description="URL to navigate to")
|
||||
|
||||
|
||||
class NavigateResponse(BaseModel):
|
||||
"""Response from navigation."""
|
||||
|
||||
url: str
|
||||
status: str # "success" or "blocked"
|
||||
title: Optional[str] = None
|
||||
current_url: Optional[str] = None
|
||||
blocked_reason: Optional[str] = None
|
||||
|
||||
|
||||
class ClickRequest(BaseModel):
|
||||
"""Request to click an element."""
|
||||
|
||||
selector: str = Field(..., description="CSS selector or element ref")
|
||||
|
||||
|
||||
class TypeRequest(BaseModel):
|
||||
"""Request to type into an element."""
|
||||
|
||||
selector: str = Field(..., description="CSS selector or element ref")
|
||||
text: str = Field(..., description="Text to type")
|
||||
press_enter: bool = Field(False, description="Press Enter after typing")
|
||||
|
||||
|
||||
class WaitRequest(BaseModel):
|
||||
"""Request to wait for a condition."""
|
||||
|
||||
wait_for: str = Field(..., alias="for", description="'selector', 'text', or 'timeout'")
|
||||
value: str = Field(..., description="Selector, text, or milliseconds")
|
||||
|
||||
|
||||
class ScreenshotRequest(BaseModel):
|
||||
"""Request to take a screenshot."""
|
||||
|
||||
full_page: bool = Field(False, description="Capture full scrollable page")
|
||||
|
||||
|
||||
class ScreenshotResponse(BaseModel):
|
||||
"""Response from screenshot."""
|
||||
|
||||
path: str
|
||||
|
||||
|
||||
class SnapshotNode(BaseModel):
|
||||
"""A node in the accessibility snapshot."""
|
||||
|
||||
ref: str
|
||||
role: str
|
||||
name: Optional[str] = None
|
||||
text: Optional[str] = None
|
||||
|
||||
|
||||
class SnapshotResponse(BaseModel):
|
||||
"""Response from accessibility snapshot."""
|
||||
|
||||
nodes: list[SnapshotNode]
|
||||
|
||||
|
||||
class ActionResponse(BaseModel):
|
||||
"""Generic response for actions."""
|
||||
|
||||
status: str
|
||||
message: Optional[str] = None
|
||||
|
||||
|
||||
class ErrorResponse(BaseModel):
|
||||
"""Error response."""
|
||||
|
||||
detail: str
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Helper Functions
|
||||
# =============================================================================
|
||||
|
||||
|
||||
async def get_session_or_404(session_id: str) -> BrowserSession:
|
||||
"""Get a session or raise 404."""
|
||||
manager = get_session_manager()
|
||||
session = await manager.get_session(session_id)
|
||||
if session is None:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail=f"Session {session_id} not found or expired",
|
||||
)
|
||||
return session
|
||||
|
||||
|
||||
def session_to_response(session: BrowserSession) -> SessionResponse:
|
||||
"""Convert a BrowserSession to SessionResponse."""
|
||||
return SessionResponse(
|
||||
session_id=session.session_id,
|
||||
created_at=session.created_at,
|
||||
last_activity=session.last_activity,
|
||||
actions_used=session.actions_used,
|
||||
actions_remaining=session.actions_remaining,
|
||||
expires_at=session.expires_at,
|
||||
allowed_domains=session.allowed_domains,
|
||||
)
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Health Check
|
||||
# =============================================================================
|
||||
|
||||
|
||||
@app.get("/health")
|
||||
async def health_check() -> dict[str, Any]:
|
||||
"""Health check endpoint."""
|
||||
manager = get_session_manager()
|
||||
return {
|
||||
"status": "healthy",
|
||||
"active_sessions": manager.active_session_count,
|
||||
"max_sessions": settings.max_sessions,
|
||||
}
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Session Management
|
||||
# =============================================================================
|
||||
|
||||
|
||||
@app.post(
|
||||
"/sessions",
|
||||
response_model=SessionResponse,
|
||||
status_code=status.HTTP_201_CREATED,
|
||||
)
|
||||
async def create_session(request: CreateSessionRequest) -> SessionResponse:
|
||||
"""
|
||||
Create a new browser session.
|
||||
|
||||
The session will have a domain allowlist that restricts all navigation
|
||||
and network requests to the specified domains.
|
||||
"""
|
||||
manager = get_session_manager()
|
||||
|
||||
try:
|
||||
session = await manager.create_session(request.allowed_domains)
|
||||
return session_to_response(session)
|
||||
except ValueError as e:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail=str(e),
|
||||
)
|
||||
|
||||
|
||||
@app.get("/sessions/{session_id}/status", response_model=SessionResponse)
|
||||
async def get_session_status(session_id: str) -> SessionResponse:
|
||||
"""Get the status of a session."""
|
||||
session = await get_session_or_404(session_id)
|
||||
return session_to_response(session)
|
||||
|
||||
|
||||
@app.delete("/sessions/{session_id}", status_code=status.HTTP_204_NO_CONTENT)
|
||||
async def close_session(session_id: str) -> None:
|
||||
"""
|
||||
Close a browser session.
|
||||
|
||||
This is idempotent - calling it multiple times is safe.
|
||||
"""
|
||||
manager = get_session_manager()
|
||||
await manager.close_session(session_id)
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Browser Actions
|
||||
# =============================================================================
|
||||
|
||||
|
||||
@app.post("/sessions/{session_id}/navigate", response_model=NavigateResponse)
|
||||
async def navigate(session_id: str, request: NavigateRequest) -> NavigateResponse:
|
||||
"""
|
||||
Navigate to a URL.
|
||||
|
||||
The URL must be in the session's allowed domains.
|
||||
"""
|
||||
session = await get_session_or_404(session_id)
|
||||
|
||||
# Check if URL is allowed
|
||||
if not session.domain_filter.is_allowed(request.url):
|
||||
return NavigateResponse(
|
||||
url=request.url,
|
||||
status="blocked",
|
||||
blocked_reason=session.domain_filter.get_blocked_reason(request.url),
|
||||
)
|
||||
|
||||
try:
|
||||
session.increment_actions()
|
||||
response = await session.page.goto(request.url, wait_until="domcontentloaded")
|
||||
|
||||
return NavigateResponse(
|
||||
url=request.url,
|
||||
status="success",
|
||||
title=await session.page.title(),
|
||||
current_url=session.page.url,
|
||||
)
|
||||
except Exception as e:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_504_GATEWAY_TIMEOUT,
|
||||
detail=f"Navigation failed: {str(e)}",
|
||||
)
|
||||
|
||||
|
||||
@app.post("/sessions/{session_id}/click", response_model=ActionResponse)
|
||||
async def click(session_id: str, request: ClickRequest) -> ActionResponse:
|
||||
"""Click an element."""
|
||||
session = await get_session_or_404(session_id)
|
||||
|
||||
try:
|
||||
session.increment_actions()
|
||||
await session.page.click(request.selector)
|
||||
return ActionResponse(status="success")
|
||||
except Exception as e:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail=f"Click failed: {str(e)}",
|
||||
)
|
||||
|
||||
|
||||
@app.post("/sessions/{session_id}/type", response_model=ActionResponse)
|
||||
async def type_text(session_id: str, request: TypeRequest) -> ActionResponse:
|
||||
"""Type text into an element."""
|
||||
session = await get_session_or_404(session_id)
|
||||
|
||||
try:
|
||||
session.increment_actions()
|
||||
await session.page.fill(request.selector, request.text)
|
||||
if request.press_enter:
|
||||
await session.page.press(request.selector, "Enter")
|
||||
return ActionResponse(status="success")
|
||||
except Exception as e:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail=f"Type failed: {str(e)}",
|
||||
)
|
||||
|
||||
|
||||
@app.post("/sessions/{session_id}/wait", response_model=ActionResponse)
|
||||
async def wait_for(session_id: str, request: WaitRequest) -> ActionResponse:
|
||||
"""Wait for a condition."""
|
||||
session = await get_session_or_404(session_id)
|
||||
|
||||
try:
|
||||
session.increment_actions()
|
||||
|
||||
if request.wait_for == "selector":
|
||||
await session.page.wait_for_selector(request.value)
|
||||
elif request.wait_for == "text":
|
||||
await session.page.wait_for_selector(f"text={request.value}")
|
||||
elif request.wait_for == "timeout":
|
||||
await session.page.wait_for_timeout(int(request.value))
|
||||
else:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail=f"Invalid wait_for type: {request.wait_for}",
|
||||
)
|
||||
|
||||
return ActionResponse(status="success")
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_408_REQUEST_TIMEOUT,
|
||||
detail=f"Wait failed: {str(e)}",
|
||||
)
|
||||
|
||||
|
||||
@app.post("/sessions/{session_id}/screenshot", response_model=ScreenshotResponse)
|
||||
async def screenshot(session_id: str, request: ScreenshotRequest) -> ScreenshotResponse:
|
||||
"""Take a screenshot."""
|
||||
session = await get_session_or_404(session_id)
|
||||
|
||||
try:
|
||||
session.increment_actions()
|
||||
|
||||
# Create unique filename
|
||||
filename = f"session-{session_id[:8]}-{uuid.uuid4().hex[:8]}.png"
|
||||
filepath = os.path.join(settings.screenshots_dir, filename)
|
||||
|
||||
# Ensure directory exists
|
||||
os.makedirs(settings.screenshots_dir, exist_ok=True)
|
||||
|
||||
await session.page.screenshot(path=filepath, full_page=request.full_page)
|
||||
|
||||
return ScreenshotResponse(path=filepath)
|
||||
except Exception as e:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail=f"Screenshot failed: {str(e)}",
|
||||
)
|
||||
|
||||
|
||||
@app.post("/sessions/{session_id}/snapshot", response_model=SnapshotResponse)
|
||||
async def snapshot(session_id: str) -> SnapshotResponse:
|
||||
"""
|
||||
Get an accessibility snapshot of the page.
|
||||
|
||||
Returns a structured tree of elements suitable for LLM consumption.
|
||||
"""
|
||||
session = await get_session_or_404(session_id)
|
||||
|
||||
try:
|
||||
session.increment_actions()
|
||||
|
||||
# Get accessibility tree
|
||||
snapshot = await session.page.accessibility.snapshot()
|
||||
|
||||
nodes = []
|
||||
node_counter = [0] # Use list for mutable reference in closure
|
||||
|
||||
def extract_nodes(node: dict, nodes_list: list):
|
||||
"""Recursively extract nodes from accessibility tree."""
|
||||
if not node:
|
||||
return
|
||||
|
||||
node_counter[0] += 1
|
||||
ref = f"n{node_counter[0]}"
|
||||
|
||||
nodes_list.append(
|
||||
SnapshotNode(
|
||||
ref=ref,
|
||||
role=node.get("role", "unknown"),
|
||||
name=node.get("name"),
|
||||
text=node.get("value") or node.get("description"),
|
||||
)
|
||||
)
|
||||
|
||||
for child in node.get("children", []):
|
||||
extract_nodes(child, nodes_list)
|
||||
|
||||
if snapshot:
|
||||
extract_nodes(snapshot, nodes)
|
||||
|
||||
return SnapshotResponse(nodes=nodes)
|
||||
except Exception as e:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail=f"Snapshot failed: {str(e)}",
|
||||
)
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Metrics (basic)
|
||||
# =============================================================================
|
||||
|
||||
|
||||
@app.get("/metrics")
|
||||
async def metrics() -> dict[str, Any]:
|
||||
"""Basic metrics endpoint."""
|
||||
manager = get_session_manager()
|
||||
return {
|
||||
"mcp_sessions_active": manager.active_session_count,
|
||||
"mcp_max_sessions": settings.max_sessions,
|
||||
}
|
||||
|
|
@ -0,0 +1,259 @@
|
|||
"""Browser session management."""
|
||||
|
||||
import asyncio
|
||||
import uuid
|
||||
from datetime import datetime, timezone
|
||||
from typing import Optional
|
||||
|
||||
from playwright.async_api import Browser, BrowserContext, Page
|
||||
|
||||
from app.config import settings
|
||||
from app.domain_filter import DomainFilter
|
||||
|
||||
|
||||
def utc_now() -> datetime:
|
||||
"""Get current UTC time."""
|
||||
return datetime.now(timezone.utc)
|
||||
|
||||
|
||||
class BrowserSession:
|
||||
"""
|
||||
Represents a single browser session with domain restrictions.
|
||||
|
||||
Each session has its own BrowserContext and Page, isolated from
|
||||
other sessions. Navigation and network requests are filtered
|
||||
against the allowed_domains list.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
session_id: str,
|
||||
allowed_domains: list[str],
|
||||
context: BrowserContext,
|
||||
page: Page,
|
||||
):
|
||||
self.session_id = session_id
|
||||
self.allowed_domains = allowed_domains
|
||||
self.domain_filter = DomainFilter(allowed_domains)
|
||||
self.context = context
|
||||
self.page = page
|
||||
self.created_at = utc_now()
|
||||
self.last_activity = utc_now()
|
||||
self.actions_used = 0
|
||||
self.max_actions = settings.max_actions_per_session
|
||||
|
||||
def touch(self) -> None:
|
||||
"""Update last activity timestamp."""
|
||||
self.last_activity = utc_now()
|
||||
|
||||
def increment_actions(self) -> None:
|
||||
"""Increment action counter."""
|
||||
self.actions_used += 1
|
||||
self.touch()
|
||||
|
||||
@property
|
||||
def actions_remaining(self) -> int:
|
||||
"""Get remaining actions for this session."""
|
||||
return max(0, self.max_actions - self.actions_used)
|
||||
|
||||
@property
|
||||
def is_expired(self) -> bool:
|
||||
"""Check if session has expired due to timeout or max lifetime."""
|
||||
now = utc_now()
|
||||
|
||||
# Check idle timeout
|
||||
idle_seconds = (now - self.last_activity).total_seconds()
|
||||
if idle_seconds > settings.idle_timeout_seconds:
|
||||
return True
|
||||
|
||||
# Check max lifetime
|
||||
lifetime_seconds = (now - self.created_at).total_seconds()
|
||||
if lifetime_seconds > settings.max_session_lifetime_seconds:
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
@property
|
||||
def expires_at(self) -> datetime:
|
||||
"""Calculate when this session will expire."""
|
||||
from datetime import timedelta
|
||||
|
||||
idle_expiry = self.last_activity + timedelta(
|
||||
seconds=settings.idle_timeout_seconds
|
||||
)
|
||||
lifetime_expiry = self.created_at + timedelta(
|
||||
seconds=settings.max_session_lifetime_seconds
|
||||
)
|
||||
return min(idle_expiry, lifetime_expiry)
|
||||
|
||||
async def close(self) -> None:
|
||||
"""Close the browser context and page."""
|
||||
try:
|
||||
await self.context.close()
|
||||
except Exception:
|
||||
pass # Best effort cleanup
|
||||
|
||||
|
||||
class SessionManager:
|
||||
"""
|
||||
Manages browser sessions with lifecycle, limits, and cleanup.
|
||||
|
||||
Responsible for:
|
||||
- Creating new sessions with domain restrictions
|
||||
- Tracking active sessions
|
||||
- Enforcing session limits
|
||||
- Background cleanup of expired sessions
|
||||
"""
|
||||
|
||||
def __init__(self, browser: Browser):
|
||||
self.browser = browser
|
||||
self._sessions: dict[str, BrowserSession] = {}
|
||||
self._lock = asyncio.Lock()
|
||||
self._cleanup_task: Optional[asyncio.Task] = None
|
||||
|
||||
@property
|
||||
def active_session_count(self) -> int:
|
||||
"""Get count of active sessions."""
|
||||
return len(self._sessions)
|
||||
|
||||
async def start_cleanup_task(self) -> None:
|
||||
"""Start the background cleanup task."""
|
||||
if self._cleanup_task is None:
|
||||
self._cleanup_task = asyncio.create_task(self._cleanup_loop())
|
||||
|
||||
async def stop_cleanup_task(self) -> None:
|
||||
"""Stop the background cleanup task."""
|
||||
if self._cleanup_task:
|
||||
self._cleanup_task.cancel()
|
||||
try:
|
||||
await self._cleanup_task
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
self._cleanup_task = None
|
||||
|
||||
async def _cleanup_loop(self) -> None:
|
||||
"""Background loop that cleans up expired sessions."""
|
||||
while True:
|
||||
try:
|
||||
await asyncio.sleep(settings.cleanup_interval_seconds)
|
||||
await self._cleanup_expired_sessions()
|
||||
except asyncio.CancelledError:
|
||||
break
|
||||
except Exception:
|
||||
# Log error but keep running
|
||||
pass
|
||||
|
||||
async def _cleanup_expired_sessions(self) -> None:
|
||||
"""Remove and close expired sessions."""
|
||||
async with self._lock:
|
||||
expired_ids = [
|
||||
sid for sid, session in self._sessions.items() if session.is_expired
|
||||
]
|
||||
|
||||
for session_id in expired_ids:
|
||||
session = self._sessions.pop(session_id, None)
|
||||
if session:
|
||||
await session.close()
|
||||
|
||||
async def create_session(self, allowed_domains: list[str]) -> BrowserSession:
|
||||
"""
|
||||
Create a new browser session.
|
||||
|
||||
Args:
|
||||
allowed_domains: List of allowed domain patterns
|
||||
|
||||
Returns:
|
||||
The created BrowserSession
|
||||
|
||||
Raises:
|
||||
ValueError: If max sessions reached or invalid domains
|
||||
"""
|
||||
if not allowed_domains:
|
||||
raise ValueError("allowed_domains is required and cannot be empty")
|
||||
|
||||
async with self._lock:
|
||||
# Check session limit
|
||||
if len(self._sessions) >= settings.max_sessions:
|
||||
raise ValueError(
|
||||
f"Maximum sessions ({settings.max_sessions}) reached. "
|
||||
"Close an existing session first."
|
||||
)
|
||||
|
||||
session_id = str(uuid.uuid4())
|
||||
|
||||
# Create browser context with domain filtering
|
||||
context = await self.browser.new_context(
|
||||
viewport={"width": 1280, "height": 720},
|
||||
user_agent="Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36",
|
||||
)
|
||||
|
||||
# Create domain filter for route interception
|
||||
domain_filter = DomainFilter(allowed_domains)
|
||||
|
||||
# Set up route handler to block requests to non-allowed domains
|
||||
async def handle_route(route):
|
||||
url = route.request.url
|
||||
if domain_filter.is_allowed(url):
|
||||
await route.continue_()
|
||||
else:
|
||||
# Block the request
|
||||
await route.abort("blockedbyclient")
|
||||
|
||||
await context.route("**/*", handle_route)
|
||||
|
||||
# Create page
|
||||
page = await context.new_page()
|
||||
page.set_default_timeout(settings.default_timeout_ms)
|
||||
page.set_default_navigation_timeout(settings.navigation_timeout_ms)
|
||||
|
||||
# Create session
|
||||
session = BrowserSession(
|
||||
session_id=session_id,
|
||||
allowed_domains=allowed_domains,
|
||||
context=context,
|
||||
page=page,
|
||||
)
|
||||
|
||||
self._sessions[session_id] = session
|
||||
return session
|
||||
|
||||
async def get_session(self, session_id: str) -> Optional[BrowserSession]:
|
||||
"""
|
||||
Get a session by ID.
|
||||
|
||||
Args:
|
||||
session_id: The session UUID
|
||||
|
||||
Returns:
|
||||
The BrowserSession or None if not found
|
||||
"""
|
||||
session = self._sessions.get(session_id)
|
||||
if session and session.is_expired:
|
||||
# Clean up expired session
|
||||
await self.close_session(session_id)
|
||||
return None
|
||||
return session
|
||||
|
||||
async def close_session(self, session_id: str) -> bool:
|
||||
"""
|
||||
Close and remove a session.
|
||||
|
||||
Args:
|
||||
session_id: The session UUID
|
||||
|
||||
Returns:
|
||||
True if session was found and closed, False otherwise
|
||||
"""
|
||||
async with self._lock:
|
||||
session = self._sessions.pop(session_id, None)
|
||||
if session:
|
||||
await session.close()
|
||||
return True
|
||||
return False
|
||||
|
||||
async def close_all_sessions(self) -> None:
|
||||
"""Close all active sessions."""
|
||||
async with self._lock:
|
||||
for session in list(self._sessions.values()):
|
||||
await session.close()
|
||||
self._sessions.clear()
|
||||
|
|
@ -0,0 +1,42 @@
|
|||
version: "3.8"
|
||||
|
||||
services:
|
||||
mcp-browser:
|
||||
build:
|
||||
context: .
|
||||
dockerfile: Dockerfile
|
||||
container_name: mcp-browser-dev
|
||||
|
||||
environment:
|
||||
- MAX_SESSIONS=3
|
||||
- IDLE_TIMEOUT_SECONDS=300
|
||||
- MAX_SESSION_LIFETIME_SECONDS=1800
|
||||
- MAX_ACTIONS_PER_SESSION=50
|
||||
- LOG_LEVEL=DEBUG
|
||||
- LOG_JSON=false
|
||||
- SCREENSHOTS_DIR=/screenshots
|
||||
|
||||
ports:
|
||||
- "8931:8931"
|
||||
|
||||
volumes:
|
||||
# Hot reload in development
|
||||
- ./app:/app/app:ro
|
||||
# Screenshots persistence
|
||||
- mcp_screenshots:/screenshots
|
||||
|
||||
security_opt:
|
||||
- seccomp=unconfined
|
||||
|
||||
deploy:
|
||||
resources:
|
||||
limits:
|
||||
cpus: '1.5'
|
||||
memory: 1G
|
||||
reservations:
|
||||
cpus: '0.25'
|
||||
memory: 256M
|
||||
|
||||
volumes:
|
||||
mcp_screenshots:
|
||||
name: mcp-browser-screenshots
|
||||
|
|
@ -0,0 +1,6 @@
|
|||
[pytest]
|
||||
asyncio_mode = auto
|
||||
testpaths = tests
|
||||
python_files = test_*.py
|
||||
python_classes = Test*
|
||||
python_functions = test_*
|
||||
|
|
@ -0,0 +1,15 @@
|
|||
# Web Framework
|
||||
fastapi>=0.109.0
|
||||
uvicorn[standard]>=0.27.0
|
||||
|
||||
# Validation
|
||||
pydantic>=2.5.0
|
||||
pydantic-settings>=2.1.0
|
||||
|
||||
# Browser Automation
|
||||
playwright>=1.40.0
|
||||
|
||||
# Testing
|
||||
pytest>=8.0.0
|
||||
pytest-asyncio>=0.23.0
|
||||
httpx>=0.26.0
|
||||
|
|
@ -0,0 +1 @@
|
|||
"""Tests for MCP Browser Sidecar."""
|
||||
|
|
@ -0,0 +1,70 @@
|
|||
"""Tests for domain filtering."""
|
||||
|
||||
import pytest
|
||||
|
||||
from app.domain_filter import DomainFilter
|
||||
|
||||
|
||||
class TestDomainFilter:
|
||||
"""Tests for the DomainFilter class."""
|
||||
|
||||
def test_exact_domain_match(self):
|
||||
"""Exact domain should match."""
|
||||
df = DomainFilter(["example.com"])
|
||||
assert df.is_allowed("https://example.com/path")
|
||||
assert df.is_allowed("http://example.com")
|
||||
|
||||
def test_exact_domain_no_match(self):
|
||||
"""Non-matching domain should be blocked."""
|
||||
df = DomainFilter(["example.com"])
|
||||
assert not df.is_allowed("https://other.com")
|
||||
assert not df.is_allowed("https://sub.example.com")
|
||||
|
||||
def test_wildcard_subdomain_match(self):
|
||||
"""Wildcard should match subdomains."""
|
||||
df = DomainFilter(["*.example.com"])
|
||||
assert df.is_allowed("https://sub.example.com")
|
||||
assert df.is_allowed("https://deep.sub.example.com")
|
||||
|
||||
def test_wildcard_does_not_match_root(self):
|
||||
"""Wildcard *.example.com should still match example.com."""
|
||||
df = DomainFilter(["*.example.com"])
|
||||
# The pattern matches zero or more subdomains
|
||||
assert df.is_allowed("https://example.com")
|
||||
|
||||
def test_domain_with_port(self):
|
||||
"""Domain with port should match."""
|
||||
df = DomainFilter(["example.com:8443"])
|
||||
assert df.is_allowed("https://example.com:8443/path")
|
||||
assert not df.is_allowed("https://example.com/path")
|
||||
|
||||
def test_multiple_domains(self):
|
||||
"""Multiple domains in allowlist."""
|
||||
df = DomainFilter(["example.com", "other.com"])
|
||||
assert df.is_allowed("https://example.com")
|
||||
assert df.is_allowed("https://other.com")
|
||||
assert not df.is_allowed("https://blocked.com")
|
||||
|
||||
def test_case_insensitive(self):
|
||||
"""Domain matching should be case-insensitive."""
|
||||
df = DomainFilter(["Example.Com"])
|
||||
assert df.is_allowed("https://example.com")
|
||||
assert df.is_allowed("https://EXAMPLE.COM")
|
||||
|
||||
def test_invalid_url_blocked(self):
|
||||
"""Invalid URLs should be blocked."""
|
||||
df = DomainFilter(["example.com"])
|
||||
assert not df.is_allowed("not-a-url")
|
||||
assert not df.is_allowed("")
|
||||
|
||||
def test_empty_domains_raises(self):
|
||||
"""Empty allowed_domains should raise ValueError."""
|
||||
with pytest.raises(ValueError):
|
||||
DomainFilter([])
|
||||
|
||||
def test_get_blocked_reason(self):
|
||||
"""get_blocked_reason should return informative message."""
|
||||
df = DomainFilter(["example.com"])
|
||||
reason = df.get_blocked_reason("https://blocked.com/path")
|
||||
assert "blocked.com" in reason
|
||||
assert "example.com" in reason
|
||||
|
|
@ -0,0 +1,247 @@
|
|||
"""Tests for session management."""
|
||||
|
||||
import asyncio
|
||||
from datetime import datetime, timedelta, timezone
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from app.session_manager import BrowserSession, SessionManager, utc_now
|
||||
|
||||
|
||||
class TestBrowserSession:
|
||||
"""Tests for the BrowserSession class."""
|
||||
|
||||
def test_touch_updates_last_activity(self):
|
||||
"""touch() should update last_activity timestamp."""
|
||||
mock_context = MagicMock()
|
||||
mock_page = MagicMock()
|
||||
|
||||
session = BrowserSession(
|
||||
session_id="test-123",
|
||||
allowed_domains=["example.com"],
|
||||
context=mock_context,
|
||||
page=mock_page,
|
||||
)
|
||||
|
||||
original_activity = session.last_activity
|
||||
session.touch()
|
||||
|
||||
assert session.last_activity >= original_activity
|
||||
|
||||
def test_increment_actions(self):
|
||||
"""increment_actions() should update counter and touch."""
|
||||
mock_context = MagicMock()
|
||||
mock_page = MagicMock()
|
||||
|
||||
session = BrowserSession(
|
||||
session_id="test-123",
|
||||
allowed_domains=["example.com"],
|
||||
context=mock_context,
|
||||
page=mock_page,
|
||||
)
|
||||
|
||||
assert session.actions_used == 0
|
||||
session.increment_actions()
|
||||
assert session.actions_used == 1
|
||||
|
||||
def test_actions_remaining(self):
|
||||
"""actions_remaining should calculate correctly."""
|
||||
mock_context = MagicMock()
|
||||
mock_page = MagicMock()
|
||||
|
||||
session = BrowserSession(
|
||||
session_id="test-123",
|
||||
allowed_domains=["example.com"],
|
||||
context=mock_context,
|
||||
page=mock_page,
|
||||
)
|
||||
|
||||
initial_remaining = session.actions_remaining
|
||||
session.increment_actions()
|
||||
assert session.actions_remaining == initial_remaining - 1
|
||||
|
||||
def test_is_expired_idle_timeout(self):
|
||||
"""Session should expire after idle timeout."""
|
||||
mock_context = MagicMock()
|
||||
mock_page = MagicMock()
|
||||
|
||||
session = BrowserSession(
|
||||
session_id="test-123",
|
||||
allowed_domains=["example.com"],
|
||||
context=mock_context,
|
||||
page=mock_page,
|
||||
)
|
||||
|
||||
# Set last_activity to past
|
||||
session.last_activity = utc_now() - timedelta(seconds=400)
|
||||
assert session.is_expired
|
||||
|
||||
def test_is_expired_max_lifetime(self):
|
||||
"""Session should expire after max lifetime."""
|
||||
mock_context = MagicMock()
|
||||
mock_page = MagicMock()
|
||||
|
||||
session = BrowserSession(
|
||||
session_id="test-123",
|
||||
allowed_domains=["example.com"],
|
||||
context=mock_context,
|
||||
page=mock_page,
|
||||
)
|
||||
|
||||
# Set created_at to past
|
||||
session.created_at = utc_now() - timedelta(seconds=2000)
|
||||
session.last_activity = utc_now() # Recent activity
|
||||
assert session.is_expired
|
||||
|
||||
def test_not_expired_fresh_session(self):
|
||||
"""Fresh session should not be expired."""
|
||||
mock_context = MagicMock()
|
||||
mock_page = MagicMock()
|
||||
|
||||
session = BrowserSession(
|
||||
session_id="test-123",
|
||||
allowed_domains=["example.com"],
|
||||
context=mock_context,
|
||||
page=mock_page,
|
||||
)
|
||||
|
||||
assert not session.is_expired
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
class TestSessionManager:
|
||||
"""Tests for the SessionManager class."""
|
||||
|
||||
async def test_create_session_success(self):
|
||||
"""create_session should create a new session."""
|
||||
mock_browser = AsyncMock()
|
||||
mock_context = AsyncMock()
|
||||
mock_page = AsyncMock()
|
||||
|
||||
mock_browser.new_context.return_value = mock_context
|
||||
mock_context.new_page.return_value = mock_page
|
||||
|
||||
manager = SessionManager(mock_browser)
|
||||
session = await manager.create_session(["example.com"])
|
||||
|
||||
assert session is not None
|
||||
assert session.session_id is not None
|
||||
assert session.allowed_domains == ["example.com"]
|
||||
assert manager.active_session_count == 1
|
||||
|
||||
async def test_create_session_empty_domains_raises(self):
|
||||
"""create_session with empty domains should raise."""
|
||||
mock_browser = AsyncMock()
|
||||
manager = SessionManager(mock_browser)
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
await manager.create_session([])
|
||||
|
||||
async def test_create_session_max_limit(self):
|
||||
"""create_session should fail when max sessions reached."""
|
||||
mock_browser = AsyncMock()
|
||||
mock_context = AsyncMock()
|
||||
mock_page = AsyncMock()
|
||||
|
||||
mock_browser.new_context.return_value = mock_context
|
||||
mock_context.new_page.return_value = mock_page
|
||||
|
||||
manager = SessionManager(mock_browser)
|
||||
|
||||
# Create max sessions
|
||||
with patch("app.session_manager.settings") as mock_settings:
|
||||
mock_settings.max_sessions = 2
|
||||
mock_settings.max_actions_per_session = 50
|
||||
mock_settings.idle_timeout_seconds = 300
|
||||
mock_settings.max_session_lifetime_seconds = 1800
|
||||
mock_settings.default_timeout_ms = 30000
|
||||
mock_settings.navigation_timeout_ms = 60000
|
||||
|
||||
await manager.create_session(["example.com"])
|
||||
await manager.create_session(["other.com"])
|
||||
|
||||
with pytest.raises(ValueError) as exc_info:
|
||||
await manager.create_session(["blocked.com"])
|
||||
|
||||
assert "Maximum sessions" in str(exc_info.value)
|
||||
|
||||
async def test_get_session_returns_session(self):
|
||||
"""get_session should return existing session."""
|
||||
mock_browser = AsyncMock()
|
||||
mock_context = AsyncMock()
|
||||
mock_page = AsyncMock()
|
||||
|
||||
mock_browser.new_context.return_value = mock_context
|
||||
mock_context.new_page.return_value = mock_page
|
||||
|
||||
manager = SessionManager(mock_browser)
|
||||
session = await manager.create_session(["example.com"])
|
||||
|
||||
retrieved = await manager.get_session(session.session_id)
|
||||
assert retrieved is session
|
||||
|
||||
async def test_get_session_returns_none_for_unknown(self):
|
||||
"""get_session should return None for unknown session."""
|
||||
mock_browser = AsyncMock()
|
||||
manager = SessionManager(mock_browser)
|
||||
|
||||
retrieved = await manager.get_session("unknown-id")
|
||||
assert retrieved is None
|
||||
|
||||
async def test_close_session_removes_session(self):
|
||||
"""close_session should remove and close session."""
|
||||
mock_browser = AsyncMock()
|
||||
mock_context = AsyncMock()
|
||||
mock_page = AsyncMock()
|
||||
|
||||
mock_browser.new_context.return_value = mock_context
|
||||
mock_context.new_page.return_value = mock_page
|
||||
|
||||
manager = SessionManager(mock_browser)
|
||||
session = await manager.create_session(["example.com"])
|
||||
|
||||
assert manager.active_session_count == 1
|
||||
|
||||
result = await manager.close_session(session.session_id)
|
||||
|
||||
assert result is True
|
||||
assert manager.active_session_count == 0
|
||||
mock_context.close.assert_called_once()
|
||||
|
||||
async def test_close_session_idempotent(self):
|
||||
"""close_session should be idempotent."""
|
||||
mock_browser = AsyncMock()
|
||||
mock_context = AsyncMock()
|
||||
mock_page = AsyncMock()
|
||||
|
||||
mock_browser.new_context.return_value = mock_context
|
||||
mock_context.new_page.return_value = mock_page
|
||||
|
||||
manager = SessionManager(mock_browser)
|
||||
session = await manager.create_session(["example.com"])
|
||||
|
||||
await manager.close_session(session.session_id)
|
||||
result = await manager.close_session(session.session_id)
|
||||
|
||||
assert result is False # Already closed
|
||||
|
||||
async def test_close_all_sessions(self):
|
||||
"""close_all_sessions should close all sessions."""
|
||||
mock_browser = AsyncMock()
|
||||
mock_context = AsyncMock()
|
||||
mock_page = AsyncMock()
|
||||
|
||||
mock_browser.new_context.return_value = mock_context
|
||||
mock_context.new_page.return_value = mock_page
|
||||
|
||||
manager = SessionManager(mock_browser)
|
||||
|
||||
await manager.create_session(["example.com"])
|
||||
await manager.create_session(["other.com"])
|
||||
|
||||
assert manager.active_session_count == 2
|
||||
|
||||
await manager.close_all_sessions()
|
||||
|
||||
assert manager.active_session_count == 0
|
||||
Loading…
Reference in New Issue