feat: initial MCP Browser Sidecar implementation
Playwright browser automation service for LLM-driven UI interaction.
Features:
- Session-based browser management with domain allowlisting
- HTTP API endpoints for browser actions (navigate, click, type, wait, screenshot, snapshot)
- Session lifecycle management (create, close, status)
- Automatic session cleanup (idle timeout, max lifetime)
- Resource limits (max sessions, max actions per session)
- Domain filtering via route interception
API Surface:
- POST /sessions - Create session with allowed_domains
- DELETE /sessions/{id} - Close session
- GET /sessions/{id}/status - Get session info
- POST /sessions/{id}/navigate - Navigate to URL
- POST /sessions/{id}/click - Click element
- POST /sessions/{id}/type - Type into element
- POST /sessions/{id}/wait - Wait for condition
- POST /sessions/{id}/screenshot - Capture screenshot
- POST /sessions/{id}/snapshot - Get accessibility tree
Security:
- Mandatory domain allowlist per session
- Network request filtering
- Session isolation via browser contexts
🤖 Generated with [Claude Code](https://claude.com/claude-code)
Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
This commit is contained in:
7
app/__init__.py
Normal file
7
app/__init__.py
Normal file
@@ -0,0 +1,7 @@
|
||||
"""LetsBe MCP Browser Sidecar.
|
||||
|
||||
A Playwright browser automation service that provides HTTP API for
|
||||
LLM-driven exploratory browser control with domain allowlisting.
|
||||
"""
|
||||
|
||||
__version__ = "0.1.0"
|
||||
35
app/config.py
Normal file
35
app/config.py
Normal file
@@ -0,0 +1,35 @@
|
||||
"""Configuration settings for MCP Browser Sidecar."""
|
||||
|
||||
from pydantic_settings import BaseSettings
|
||||
|
||||
|
||||
class Settings(BaseSettings):
|
||||
"""Application settings loaded from environment variables."""
|
||||
|
||||
# Session limits
|
||||
max_sessions: int = 3
|
||||
idle_timeout_seconds: int = 300 # 5 minutes
|
||||
max_session_lifetime_seconds: int = 1800 # 30 minutes
|
||||
max_actions_per_session: int = 50
|
||||
|
||||
# Playwright settings
|
||||
browser_headless: bool = True
|
||||
default_timeout_ms: int = 30000
|
||||
navigation_timeout_ms: int = 60000
|
||||
|
||||
# Cleanup interval
|
||||
cleanup_interval_seconds: int = 60
|
||||
|
||||
# Logging
|
||||
log_level: str = "INFO"
|
||||
log_json: bool = True
|
||||
|
||||
# Screenshots directory
|
||||
screenshots_dir: str = "/screenshots"
|
||||
|
||||
class Config:
|
||||
env_prefix = ""
|
||||
case_sensitive = False
|
||||
|
||||
|
||||
settings = Settings()
|
||||
82
app/domain_filter.py
Normal file
82
app/domain_filter.py
Normal file
@@ -0,0 +1,82 @@
|
||||
"""Domain filtering and allowlist validation."""
|
||||
|
||||
import fnmatch
|
||||
import re
|
||||
from urllib.parse import urlparse
|
||||
|
||||
|
||||
class DomainFilter:
|
||||
"""
|
||||
Validates URLs against a domain allowlist.
|
||||
|
||||
Supports:
|
||||
- Exact domain matching: "example.com"
|
||||
- Wildcard subdomains: "*.example.com"
|
||||
- Domains with ports: "example.com:8443"
|
||||
"""
|
||||
|
||||
def __init__(self, allowed_domains: list[str]):
|
||||
"""
|
||||
Initialize the domain filter.
|
||||
|
||||
Args:
|
||||
allowed_domains: List of allowed domain patterns
|
||||
"""
|
||||
if not allowed_domains:
|
||||
raise ValueError("allowed_domains cannot be empty")
|
||||
|
||||
self.allowed_domains = allowed_domains
|
||||
self._patterns = self._compile_patterns(allowed_domains)
|
||||
|
||||
def _compile_patterns(self, domains: list[str]) -> list[re.Pattern]:
|
||||
"""Compile domain patterns into regex for efficient matching."""
|
||||
patterns = []
|
||||
for domain in domains:
|
||||
# Convert wildcard pattern to regex
|
||||
# *.example.com -> matches any subdomain of example.com
|
||||
if domain.startswith("*."):
|
||||
# Match the exact domain or any subdomain
|
||||
base = re.escape(domain[2:])
|
||||
pattern = rf"^([a-zA-Z0-9-]+\.)*{base}$"
|
||||
else:
|
||||
# Exact match
|
||||
pattern = rf"^{re.escape(domain)}$"
|
||||
patterns.append(re.compile(pattern, re.IGNORECASE))
|
||||
return patterns
|
||||
|
||||
def is_allowed(self, url: str) -> bool:
|
||||
"""
|
||||
Check if a URL's domain is in the allowlist.
|
||||
|
||||
Args:
|
||||
url: The URL to check
|
||||
|
||||
Returns:
|
||||
True if the domain is allowed, False otherwise
|
||||
"""
|
||||
try:
|
||||
parsed = urlparse(url)
|
||||
host = parsed.netloc
|
||||
|
||||
# Include port if present
|
||||
if not host:
|
||||
return False
|
||||
|
||||
# Check against all patterns
|
||||
for pattern in self._patterns:
|
||||
if pattern.match(host):
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
except Exception:
|
||||
return False
|
||||
|
||||
def get_blocked_reason(self, url: str) -> str:
|
||||
"""Get a human-readable reason for why a URL was blocked."""
|
||||
try:
|
||||
parsed = urlparse(url)
|
||||
host = parsed.netloc
|
||||
return f"Domain '{host}' not in allowlist: {self.allowed_domains}"
|
||||
except Exception:
|
||||
return f"Invalid URL: {url}"
|
||||
87
app/playwright_client.py
Normal file
87
app/playwright_client.py
Normal file
@@ -0,0 +1,87 @@
|
||||
"""Playwright browser management."""
|
||||
|
||||
import asyncio
|
||||
from typing import Optional
|
||||
|
||||
from playwright.async_api import Browser, Playwright, async_playwright
|
||||
|
||||
from app.config import settings
|
||||
|
||||
|
||||
class PlaywrightClient:
|
||||
"""
|
||||
Manages the Playwright browser instance.
|
||||
|
||||
Provides lifecycle management for a single Chromium browser
|
||||
that serves all sessions.
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
self._playwright: Optional[Playwright] = None
|
||||
self._browser: Optional[Browser] = None
|
||||
|
||||
@property
|
||||
def browser(self) -> Browser:
|
||||
"""Get the browser instance."""
|
||||
if self._browser is None:
|
||||
raise RuntimeError("Browser not started. Call start() first.")
|
||||
return self._browser
|
||||
|
||||
@property
|
||||
def is_running(self) -> bool:
|
||||
"""Check if browser is running."""
|
||||
return self._browser is not None
|
||||
|
||||
async def start(self) -> Browser:
|
||||
"""
|
||||
Start the Playwright browser.
|
||||
|
||||
Returns:
|
||||
The Browser instance
|
||||
"""
|
||||
if self._browser is not None:
|
||||
return self._browser
|
||||
|
||||
self._playwright = await async_playwright().start()
|
||||
|
||||
# Launch Chromium with Docker-compatible settings
|
||||
self._browser = await self._playwright.chromium.launch(
|
||||
headless=settings.browser_headless,
|
||||
args=[
|
||||
"--no-sandbox",
|
||||
"--disable-setuid-sandbox",
|
||||
"--disable-dev-shm-usage",
|
||||
"--disable-gpu",
|
||||
"--single-process",
|
||||
],
|
||||
)
|
||||
|
||||
return self._browser
|
||||
|
||||
async def stop(self) -> None:
|
||||
"""Stop the Playwright browser and cleanup."""
|
||||
if self._browser:
|
||||
try:
|
||||
await self._browser.close()
|
||||
except Exception:
|
||||
pass
|
||||
self._browser = None
|
||||
|
||||
if self._playwright:
|
||||
try:
|
||||
await self._playwright.stop()
|
||||
except Exception:
|
||||
pass
|
||||
self._playwright = None
|
||||
|
||||
|
||||
# Global singleton instance
|
||||
_playwright_client: Optional[PlaywrightClient] = None
|
||||
|
||||
|
||||
def get_playwright_client() -> PlaywrightClient:
|
||||
"""Get the global PlaywrightClient instance."""
|
||||
global _playwright_client
|
||||
if _playwright_client is None:
|
||||
_playwright_client = PlaywrightClient()
|
||||
return _playwright_client
|
||||
441
app/server.py
Normal file
441
app/server.py
Normal file
@@ -0,0 +1,441 @@
|
||||
"""FastAPI HTTP server for MCP Browser Sidecar."""
|
||||
|
||||
import os
|
||||
import uuid
|
||||
from contextlib import asynccontextmanager
|
||||
from datetime import datetime, timezone
|
||||
from typing import Any, Optional
|
||||
|
||||
from fastapi import FastAPI, HTTPException, status
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from app.config import settings
|
||||
from app.playwright_client import get_playwright_client
|
||||
from app.session_manager import BrowserSession, SessionManager
|
||||
|
||||
# Global session manager
|
||||
_session_manager: Optional[SessionManager] = None
|
||||
|
||||
|
||||
def get_session_manager() -> SessionManager:
|
||||
"""Get the global SessionManager instance."""
|
||||
if _session_manager is None:
|
||||
raise RuntimeError("SessionManager not initialized")
|
||||
return _session_manager
|
||||
|
||||
|
||||
@asynccontextmanager
|
||||
async def lifespan(app: FastAPI):
|
||||
"""Application lifespan manager."""
|
||||
global _session_manager
|
||||
|
||||
# Startup
|
||||
client = get_playwright_client()
|
||||
browser = await client.start()
|
||||
_session_manager = SessionManager(browser)
|
||||
await _session_manager.start_cleanup_task()
|
||||
|
||||
yield
|
||||
|
||||
# Shutdown
|
||||
if _session_manager:
|
||||
await _session_manager.stop_cleanup_task()
|
||||
await _session_manager.close_all_sessions()
|
||||
await client.stop()
|
||||
|
||||
|
||||
app = FastAPI(
|
||||
title="MCP Browser Sidecar",
|
||||
description="Playwright browser automation service for LLM-driven UI interaction",
|
||||
version="0.1.0",
|
||||
lifespan=lifespan,
|
||||
)
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Request/Response Models
|
||||
# =============================================================================
|
||||
|
||||
|
||||
class CreateSessionRequest(BaseModel):
|
||||
"""Request to create a new browser session."""
|
||||
|
||||
allowed_domains: list[str] = Field(
|
||||
...,
|
||||
min_length=1,
|
||||
description="List of allowed domain patterns (required)",
|
||||
)
|
||||
|
||||
|
||||
class SessionResponse(BaseModel):
|
||||
"""Response containing session information."""
|
||||
|
||||
session_id: str
|
||||
created_at: datetime
|
||||
last_activity: datetime
|
||||
actions_used: int
|
||||
actions_remaining: int
|
||||
expires_at: datetime
|
||||
allowed_domains: list[str]
|
||||
|
||||
|
||||
class NavigateRequest(BaseModel):
|
||||
"""Request to navigate to a URL."""
|
||||
|
||||
url: str = Field(..., min_length=1, description="URL to navigate to")
|
||||
|
||||
|
||||
class NavigateResponse(BaseModel):
|
||||
"""Response from navigation."""
|
||||
|
||||
url: str
|
||||
status: str # "success" or "blocked"
|
||||
title: Optional[str] = None
|
||||
current_url: Optional[str] = None
|
||||
blocked_reason: Optional[str] = None
|
||||
|
||||
|
||||
class ClickRequest(BaseModel):
|
||||
"""Request to click an element."""
|
||||
|
||||
selector: str = Field(..., description="CSS selector or element ref")
|
||||
|
||||
|
||||
class TypeRequest(BaseModel):
|
||||
"""Request to type into an element."""
|
||||
|
||||
selector: str = Field(..., description="CSS selector or element ref")
|
||||
text: str = Field(..., description="Text to type")
|
||||
press_enter: bool = Field(False, description="Press Enter after typing")
|
||||
|
||||
|
||||
class WaitRequest(BaseModel):
|
||||
"""Request to wait for a condition."""
|
||||
|
||||
wait_for: str = Field(..., alias="for", description="'selector', 'text', or 'timeout'")
|
||||
value: str = Field(..., description="Selector, text, or milliseconds")
|
||||
|
||||
|
||||
class ScreenshotRequest(BaseModel):
|
||||
"""Request to take a screenshot."""
|
||||
|
||||
full_page: bool = Field(False, description="Capture full scrollable page")
|
||||
|
||||
|
||||
class ScreenshotResponse(BaseModel):
|
||||
"""Response from screenshot."""
|
||||
|
||||
path: str
|
||||
|
||||
|
||||
class SnapshotNode(BaseModel):
|
||||
"""A node in the accessibility snapshot."""
|
||||
|
||||
ref: str
|
||||
role: str
|
||||
name: Optional[str] = None
|
||||
text: Optional[str] = None
|
||||
|
||||
|
||||
class SnapshotResponse(BaseModel):
|
||||
"""Response from accessibility snapshot."""
|
||||
|
||||
nodes: list[SnapshotNode]
|
||||
|
||||
|
||||
class ActionResponse(BaseModel):
|
||||
"""Generic response for actions."""
|
||||
|
||||
status: str
|
||||
message: Optional[str] = None
|
||||
|
||||
|
||||
class ErrorResponse(BaseModel):
|
||||
"""Error response."""
|
||||
|
||||
detail: str
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Helper Functions
|
||||
# =============================================================================
|
||||
|
||||
|
||||
async def get_session_or_404(session_id: str) -> BrowserSession:
|
||||
"""Get a session or raise 404."""
|
||||
manager = get_session_manager()
|
||||
session = await manager.get_session(session_id)
|
||||
if session is None:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail=f"Session {session_id} not found or expired",
|
||||
)
|
||||
return session
|
||||
|
||||
|
||||
def session_to_response(session: BrowserSession) -> SessionResponse:
|
||||
"""Convert a BrowserSession to SessionResponse."""
|
||||
return SessionResponse(
|
||||
session_id=session.session_id,
|
||||
created_at=session.created_at,
|
||||
last_activity=session.last_activity,
|
||||
actions_used=session.actions_used,
|
||||
actions_remaining=session.actions_remaining,
|
||||
expires_at=session.expires_at,
|
||||
allowed_domains=session.allowed_domains,
|
||||
)
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Health Check
|
||||
# =============================================================================
|
||||
|
||||
|
||||
@app.get("/health")
|
||||
async def health_check() -> dict[str, Any]:
|
||||
"""Health check endpoint."""
|
||||
manager = get_session_manager()
|
||||
return {
|
||||
"status": "healthy",
|
||||
"active_sessions": manager.active_session_count,
|
||||
"max_sessions": settings.max_sessions,
|
||||
}
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Session Management
|
||||
# =============================================================================
|
||||
|
||||
|
||||
@app.post(
|
||||
"/sessions",
|
||||
response_model=SessionResponse,
|
||||
status_code=status.HTTP_201_CREATED,
|
||||
)
|
||||
async def create_session(request: CreateSessionRequest) -> SessionResponse:
|
||||
"""
|
||||
Create a new browser session.
|
||||
|
||||
The session will have a domain allowlist that restricts all navigation
|
||||
and network requests to the specified domains.
|
||||
"""
|
||||
manager = get_session_manager()
|
||||
|
||||
try:
|
||||
session = await manager.create_session(request.allowed_domains)
|
||||
return session_to_response(session)
|
||||
except ValueError as e:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail=str(e),
|
||||
)
|
||||
|
||||
|
||||
@app.get("/sessions/{session_id}/status", response_model=SessionResponse)
|
||||
async def get_session_status(session_id: str) -> SessionResponse:
|
||||
"""Get the status of a session."""
|
||||
session = await get_session_or_404(session_id)
|
||||
return session_to_response(session)
|
||||
|
||||
|
||||
@app.delete("/sessions/{session_id}", status_code=status.HTTP_204_NO_CONTENT)
|
||||
async def close_session(session_id: str) -> None:
|
||||
"""
|
||||
Close a browser session.
|
||||
|
||||
This is idempotent - calling it multiple times is safe.
|
||||
"""
|
||||
manager = get_session_manager()
|
||||
await manager.close_session(session_id)
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Browser Actions
|
||||
# =============================================================================
|
||||
|
||||
|
||||
@app.post("/sessions/{session_id}/navigate", response_model=NavigateResponse)
|
||||
async def navigate(session_id: str, request: NavigateRequest) -> NavigateResponse:
|
||||
"""
|
||||
Navigate to a URL.
|
||||
|
||||
The URL must be in the session's allowed domains.
|
||||
"""
|
||||
session = await get_session_or_404(session_id)
|
||||
|
||||
# Check if URL is allowed
|
||||
if not session.domain_filter.is_allowed(request.url):
|
||||
return NavigateResponse(
|
||||
url=request.url,
|
||||
status="blocked",
|
||||
blocked_reason=session.domain_filter.get_blocked_reason(request.url),
|
||||
)
|
||||
|
||||
try:
|
||||
session.increment_actions()
|
||||
response = await session.page.goto(request.url, wait_until="domcontentloaded")
|
||||
|
||||
return NavigateResponse(
|
||||
url=request.url,
|
||||
status="success",
|
||||
title=await session.page.title(),
|
||||
current_url=session.page.url,
|
||||
)
|
||||
except Exception as e:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_504_GATEWAY_TIMEOUT,
|
||||
detail=f"Navigation failed: {str(e)}",
|
||||
)
|
||||
|
||||
|
||||
@app.post("/sessions/{session_id}/click", response_model=ActionResponse)
|
||||
async def click(session_id: str, request: ClickRequest) -> ActionResponse:
|
||||
"""Click an element."""
|
||||
session = await get_session_or_404(session_id)
|
||||
|
||||
try:
|
||||
session.increment_actions()
|
||||
await session.page.click(request.selector)
|
||||
return ActionResponse(status="success")
|
||||
except Exception as e:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail=f"Click failed: {str(e)}",
|
||||
)
|
||||
|
||||
|
||||
@app.post("/sessions/{session_id}/type", response_model=ActionResponse)
|
||||
async def type_text(session_id: str, request: TypeRequest) -> ActionResponse:
|
||||
"""Type text into an element."""
|
||||
session = await get_session_or_404(session_id)
|
||||
|
||||
try:
|
||||
session.increment_actions()
|
||||
await session.page.fill(request.selector, request.text)
|
||||
if request.press_enter:
|
||||
await session.page.press(request.selector, "Enter")
|
||||
return ActionResponse(status="success")
|
||||
except Exception as e:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail=f"Type failed: {str(e)}",
|
||||
)
|
||||
|
||||
|
||||
@app.post("/sessions/{session_id}/wait", response_model=ActionResponse)
|
||||
async def wait_for(session_id: str, request: WaitRequest) -> ActionResponse:
|
||||
"""Wait for a condition."""
|
||||
session = await get_session_or_404(session_id)
|
||||
|
||||
try:
|
||||
session.increment_actions()
|
||||
|
||||
if request.wait_for == "selector":
|
||||
await session.page.wait_for_selector(request.value)
|
||||
elif request.wait_for == "text":
|
||||
await session.page.wait_for_selector(f"text={request.value}")
|
||||
elif request.wait_for == "timeout":
|
||||
await session.page.wait_for_timeout(int(request.value))
|
||||
else:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail=f"Invalid wait_for type: {request.wait_for}",
|
||||
)
|
||||
|
||||
return ActionResponse(status="success")
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_408_REQUEST_TIMEOUT,
|
||||
detail=f"Wait failed: {str(e)}",
|
||||
)
|
||||
|
||||
|
||||
@app.post("/sessions/{session_id}/screenshot", response_model=ScreenshotResponse)
|
||||
async def screenshot(session_id: str, request: ScreenshotRequest) -> ScreenshotResponse:
|
||||
"""Take a screenshot."""
|
||||
session = await get_session_or_404(session_id)
|
||||
|
||||
try:
|
||||
session.increment_actions()
|
||||
|
||||
# Create unique filename
|
||||
filename = f"session-{session_id[:8]}-{uuid.uuid4().hex[:8]}.png"
|
||||
filepath = os.path.join(settings.screenshots_dir, filename)
|
||||
|
||||
# Ensure directory exists
|
||||
os.makedirs(settings.screenshots_dir, exist_ok=True)
|
||||
|
||||
await session.page.screenshot(path=filepath, full_page=request.full_page)
|
||||
|
||||
return ScreenshotResponse(path=filepath)
|
||||
except Exception as e:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail=f"Screenshot failed: {str(e)}",
|
||||
)
|
||||
|
||||
|
||||
@app.post("/sessions/{session_id}/snapshot", response_model=SnapshotResponse)
|
||||
async def snapshot(session_id: str) -> SnapshotResponse:
|
||||
"""
|
||||
Get an accessibility snapshot of the page.
|
||||
|
||||
Returns a structured tree of elements suitable for LLM consumption.
|
||||
"""
|
||||
session = await get_session_or_404(session_id)
|
||||
|
||||
try:
|
||||
session.increment_actions()
|
||||
|
||||
# Get accessibility tree
|
||||
snapshot = await session.page.accessibility.snapshot()
|
||||
|
||||
nodes = []
|
||||
node_counter = [0] # Use list for mutable reference in closure
|
||||
|
||||
def extract_nodes(node: dict, nodes_list: list):
|
||||
"""Recursively extract nodes from accessibility tree."""
|
||||
if not node:
|
||||
return
|
||||
|
||||
node_counter[0] += 1
|
||||
ref = f"n{node_counter[0]}"
|
||||
|
||||
nodes_list.append(
|
||||
SnapshotNode(
|
||||
ref=ref,
|
||||
role=node.get("role", "unknown"),
|
||||
name=node.get("name"),
|
||||
text=node.get("value") or node.get("description"),
|
||||
)
|
||||
)
|
||||
|
||||
for child in node.get("children", []):
|
||||
extract_nodes(child, nodes_list)
|
||||
|
||||
if snapshot:
|
||||
extract_nodes(snapshot, nodes)
|
||||
|
||||
return SnapshotResponse(nodes=nodes)
|
||||
except Exception as e:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail=f"Snapshot failed: {str(e)}",
|
||||
)
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Metrics (basic)
|
||||
# =============================================================================
|
||||
|
||||
|
||||
@app.get("/metrics")
|
||||
async def metrics() -> dict[str, Any]:
|
||||
"""Basic metrics endpoint."""
|
||||
manager = get_session_manager()
|
||||
return {
|
||||
"mcp_sessions_active": manager.active_session_count,
|
||||
"mcp_max_sessions": settings.max_sessions,
|
||||
}
|
||||
259
app/session_manager.py
Normal file
259
app/session_manager.py
Normal file
@@ -0,0 +1,259 @@
|
||||
"""Browser session management."""
|
||||
|
||||
import asyncio
|
||||
import uuid
|
||||
from datetime import datetime, timezone
|
||||
from typing import Optional
|
||||
|
||||
from playwright.async_api import Browser, BrowserContext, Page
|
||||
|
||||
from app.config import settings
|
||||
from app.domain_filter import DomainFilter
|
||||
|
||||
|
||||
def utc_now() -> datetime:
|
||||
"""Get current UTC time."""
|
||||
return datetime.now(timezone.utc)
|
||||
|
||||
|
||||
class BrowserSession:
|
||||
"""
|
||||
Represents a single browser session with domain restrictions.
|
||||
|
||||
Each session has its own BrowserContext and Page, isolated from
|
||||
other sessions. Navigation and network requests are filtered
|
||||
against the allowed_domains list.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
session_id: str,
|
||||
allowed_domains: list[str],
|
||||
context: BrowserContext,
|
||||
page: Page,
|
||||
):
|
||||
self.session_id = session_id
|
||||
self.allowed_domains = allowed_domains
|
||||
self.domain_filter = DomainFilter(allowed_domains)
|
||||
self.context = context
|
||||
self.page = page
|
||||
self.created_at = utc_now()
|
||||
self.last_activity = utc_now()
|
||||
self.actions_used = 0
|
||||
self.max_actions = settings.max_actions_per_session
|
||||
|
||||
def touch(self) -> None:
|
||||
"""Update last activity timestamp."""
|
||||
self.last_activity = utc_now()
|
||||
|
||||
def increment_actions(self) -> None:
|
||||
"""Increment action counter."""
|
||||
self.actions_used += 1
|
||||
self.touch()
|
||||
|
||||
@property
|
||||
def actions_remaining(self) -> int:
|
||||
"""Get remaining actions for this session."""
|
||||
return max(0, self.max_actions - self.actions_used)
|
||||
|
||||
@property
|
||||
def is_expired(self) -> bool:
|
||||
"""Check if session has expired due to timeout or max lifetime."""
|
||||
now = utc_now()
|
||||
|
||||
# Check idle timeout
|
||||
idle_seconds = (now - self.last_activity).total_seconds()
|
||||
if idle_seconds > settings.idle_timeout_seconds:
|
||||
return True
|
||||
|
||||
# Check max lifetime
|
||||
lifetime_seconds = (now - self.created_at).total_seconds()
|
||||
if lifetime_seconds > settings.max_session_lifetime_seconds:
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
@property
|
||||
def expires_at(self) -> datetime:
|
||||
"""Calculate when this session will expire."""
|
||||
from datetime import timedelta
|
||||
|
||||
idle_expiry = self.last_activity + timedelta(
|
||||
seconds=settings.idle_timeout_seconds
|
||||
)
|
||||
lifetime_expiry = self.created_at + timedelta(
|
||||
seconds=settings.max_session_lifetime_seconds
|
||||
)
|
||||
return min(idle_expiry, lifetime_expiry)
|
||||
|
||||
async def close(self) -> None:
|
||||
"""Close the browser context and page."""
|
||||
try:
|
||||
await self.context.close()
|
||||
except Exception:
|
||||
pass # Best effort cleanup
|
||||
|
||||
|
||||
class SessionManager:
|
||||
"""
|
||||
Manages browser sessions with lifecycle, limits, and cleanup.
|
||||
|
||||
Responsible for:
|
||||
- Creating new sessions with domain restrictions
|
||||
- Tracking active sessions
|
||||
- Enforcing session limits
|
||||
- Background cleanup of expired sessions
|
||||
"""
|
||||
|
||||
def __init__(self, browser: Browser):
|
||||
self.browser = browser
|
||||
self._sessions: dict[str, BrowserSession] = {}
|
||||
self._lock = asyncio.Lock()
|
||||
self._cleanup_task: Optional[asyncio.Task] = None
|
||||
|
||||
@property
|
||||
def active_session_count(self) -> int:
|
||||
"""Get count of active sessions."""
|
||||
return len(self._sessions)
|
||||
|
||||
async def start_cleanup_task(self) -> None:
|
||||
"""Start the background cleanup task."""
|
||||
if self._cleanup_task is None:
|
||||
self._cleanup_task = asyncio.create_task(self._cleanup_loop())
|
||||
|
||||
async def stop_cleanup_task(self) -> None:
|
||||
"""Stop the background cleanup task."""
|
||||
if self._cleanup_task:
|
||||
self._cleanup_task.cancel()
|
||||
try:
|
||||
await self._cleanup_task
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
self._cleanup_task = None
|
||||
|
||||
async def _cleanup_loop(self) -> None:
|
||||
"""Background loop that cleans up expired sessions."""
|
||||
while True:
|
||||
try:
|
||||
await asyncio.sleep(settings.cleanup_interval_seconds)
|
||||
await self._cleanup_expired_sessions()
|
||||
except asyncio.CancelledError:
|
||||
break
|
||||
except Exception:
|
||||
# Log error but keep running
|
||||
pass
|
||||
|
||||
async def _cleanup_expired_sessions(self) -> None:
|
||||
"""Remove and close expired sessions."""
|
||||
async with self._lock:
|
||||
expired_ids = [
|
||||
sid for sid, session in self._sessions.items() if session.is_expired
|
||||
]
|
||||
|
||||
for session_id in expired_ids:
|
||||
session = self._sessions.pop(session_id, None)
|
||||
if session:
|
||||
await session.close()
|
||||
|
||||
async def create_session(self, allowed_domains: list[str]) -> BrowserSession:
|
||||
"""
|
||||
Create a new browser session.
|
||||
|
||||
Args:
|
||||
allowed_domains: List of allowed domain patterns
|
||||
|
||||
Returns:
|
||||
The created BrowserSession
|
||||
|
||||
Raises:
|
||||
ValueError: If max sessions reached or invalid domains
|
||||
"""
|
||||
if not allowed_domains:
|
||||
raise ValueError("allowed_domains is required and cannot be empty")
|
||||
|
||||
async with self._lock:
|
||||
# Check session limit
|
||||
if len(self._sessions) >= settings.max_sessions:
|
||||
raise ValueError(
|
||||
f"Maximum sessions ({settings.max_sessions}) reached. "
|
||||
"Close an existing session first."
|
||||
)
|
||||
|
||||
session_id = str(uuid.uuid4())
|
||||
|
||||
# Create browser context with domain filtering
|
||||
context = await self.browser.new_context(
|
||||
viewport={"width": 1280, "height": 720},
|
||||
user_agent="Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36",
|
||||
)
|
||||
|
||||
# Create domain filter for route interception
|
||||
domain_filter = DomainFilter(allowed_domains)
|
||||
|
||||
# Set up route handler to block requests to non-allowed domains
|
||||
async def handle_route(route):
|
||||
url = route.request.url
|
||||
if domain_filter.is_allowed(url):
|
||||
await route.continue_()
|
||||
else:
|
||||
# Block the request
|
||||
await route.abort("blockedbyclient")
|
||||
|
||||
await context.route("**/*", handle_route)
|
||||
|
||||
# Create page
|
||||
page = await context.new_page()
|
||||
page.set_default_timeout(settings.default_timeout_ms)
|
||||
page.set_default_navigation_timeout(settings.navigation_timeout_ms)
|
||||
|
||||
# Create session
|
||||
session = BrowserSession(
|
||||
session_id=session_id,
|
||||
allowed_domains=allowed_domains,
|
||||
context=context,
|
||||
page=page,
|
||||
)
|
||||
|
||||
self._sessions[session_id] = session
|
||||
return session
|
||||
|
||||
async def get_session(self, session_id: str) -> Optional[BrowserSession]:
|
||||
"""
|
||||
Get a session by ID.
|
||||
|
||||
Args:
|
||||
session_id: The session UUID
|
||||
|
||||
Returns:
|
||||
The BrowserSession or None if not found
|
||||
"""
|
||||
session = self._sessions.get(session_id)
|
||||
if session and session.is_expired:
|
||||
# Clean up expired session
|
||||
await self.close_session(session_id)
|
||||
return None
|
||||
return session
|
||||
|
||||
async def close_session(self, session_id: str) -> bool:
|
||||
"""
|
||||
Close and remove a session.
|
||||
|
||||
Args:
|
||||
session_id: The session UUID
|
||||
|
||||
Returns:
|
||||
True if session was found and closed, False otherwise
|
||||
"""
|
||||
async with self._lock:
|
||||
session = self._sessions.pop(session_id, None)
|
||||
if session:
|
||||
await session.close()
|
||||
return True
|
||||
return False
|
||||
|
||||
async def close_all_sessions(self) -> None:
|
||||
"""Close all active sessions."""
|
||||
async with self._lock:
|
||||
for session in list(self._sessions.values()):
|
||||
await session.close()
|
||||
self._sessions.clear()
|
||||
Reference in New Issue
Block a user