letsbe-mcp-browser/app/server.py

442 lines
13 KiB
Python
Raw Normal View History

"""FastAPI HTTP server for MCP Browser Sidecar."""
import os
import uuid
from contextlib import asynccontextmanager
from datetime import datetime, timezone
from typing import Any, Optional
from fastapi import FastAPI, HTTPException, status
from pydantic import BaseModel, Field
from app.config import settings
from app.playwright_client import get_playwright_client
from app.session_manager import BrowserSession, SessionManager
# Global session manager
_session_manager: Optional[SessionManager] = None
def get_session_manager() -> SessionManager:
"""Get the global SessionManager instance."""
if _session_manager is None:
raise RuntimeError("SessionManager not initialized")
return _session_manager
@asynccontextmanager
async def lifespan(app: FastAPI):
"""Application lifespan manager."""
global _session_manager
# Startup
client = get_playwright_client()
browser = await client.start()
_session_manager = SessionManager(browser)
await _session_manager.start_cleanup_task()
yield
# Shutdown
if _session_manager:
await _session_manager.stop_cleanup_task()
await _session_manager.close_all_sessions()
await client.stop()
app = FastAPI(
title="MCP Browser Sidecar",
description="Playwright browser automation service for LLM-driven UI interaction",
version="0.1.0",
lifespan=lifespan,
)
# =============================================================================
# Request/Response Models
# =============================================================================
class CreateSessionRequest(BaseModel):
"""Request to create a new browser session."""
allowed_domains: list[str] = Field(
...,
min_length=1,
description="List of allowed domain patterns (required)",
)
class SessionResponse(BaseModel):
"""Response containing session information."""
session_id: str
created_at: datetime
last_activity: datetime
actions_used: int
actions_remaining: int
expires_at: datetime
allowed_domains: list[str]
class NavigateRequest(BaseModel):
"""Request to navigate to a URL."""
url: str = Field(..., min_length=1, description="URL to navigate to")
class NavigateResponse(BaseModel):
"""Response from navigation."""
url: str
status: str # "success" or "blocked"
title: Optional[str] = None
current_url: Optional[str] = None
blocked_reason: Optional[str] = None
class ClickRequest(BaseModel):
"""Request to click an element."""
selector: str = Field(..., description="CSS selector or element ref")
class TypeRequest(BaseModel):
"""Request to type into an element."""
selector: str = Field(..., description="CSS selector or element ref")
text: str = Field(..., description="Text to type")
press_enter: bool = Field(False, description="Press Enter after typing")
class WaitRequest(BaseModel):
"""Request to wait for a condition."""
wait_for: str = Field(..., alias="for", description="'selector', 'text', or 'timeout'")
value: str = Field(..., description="Selector, text, or milliseconds")
class ScreenshotRequest(BaseModel):
"""Request to take a screenshot."""
full_page: bool = Field(False, description="Capture full scrollable page")
class ScreenshotResponse(BaseModel):
"""Response from screenshot."""
path: str
class SnapshotNode(BaseModel):
"""A node in the accessibility snapshot."""
ref: str
role: str
name: Optional[str] = None
text: Optional[str] = None
class SnapshotResponse(BaseModel):
"""Response from accessibility snapshot."""
nodes: list[SnapshotNode]
class ActionResponse(BaseModel):
"""Generic response for actions."""
status: str
message: Optional[str] = None
class ErrorResponse(BaseModel):
"""Error response."""
detail: str
# =============================================================================
# Helper Functions
# =============================================================================
async def get_session_or_404(session_id: str) -> BrowserSession:
"""Get a session or raise 404."""
manager = get_session_manager()
session = await manager.get_session(session_id)
if session is None:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail=f"Session {session_id} not found or expired",
)
return session
def session_to_response(session: BrowserSession) -> SessionResponse:
"""Convert a BrowserSession to SessionResponse."""
return SessionResponse(
session_id=session.session_id,
created_at=session.created_at,
last_activity=session.last_activity,
actions_used=session.actions_used,
actions_remaining=session.actions_remaining,
expires_at=session.expires_at,
allowed_domains=session.allowed_domains,
)
# =============================================================================
# Health Check
# =============================================================================
@app.get("/health")
async def health_check() -> dict[str, Any]:
"""Health check endpoint."""
manager = get_session_manager()
return {
"status": "healthy",
"active_sessions": manager.active_session_count,
"max_sessions": settings.max_sessions,
}
# =============================================================================
# Session Management
# =============================================================================
@app.post(
"/sessions",
response_model=SessionResponse,
status_code=status.HTTP_201_CREATED,
)
async def create_session(request: CreateSessionRequest) -> SessionResponse:
"""
Create a new browser session.
The session will have a domain allowlist that restricts all navigation
and network requests to the specified domains.
"""
manager = get_session_manager()
try:
session = await manager.create_session(request.allowed_domains)
return session_to_response(session)
except ValueError as e:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail=str(e),
)
@app.get("/sessions/{session_id}/status", response_model=SessionResponse)
async def get_session_status(session_id: str) -> SessionResponse:
"""Get the status of a session."""
session = await get_session_or_404(session_id)
return session_to_response(session)
@app.delete("/sessions/{session_id}", status_code=status.HTTP_204_NO_CONTENT)
async def close_session(session_id: str) -> None:
"""
Close a browser session.
This is idempotent - calling it multiple times is safe.
"""
manager = get_session_manager()
await manager.close_session(session_id)
# =============================================================================
# Browser Actions
# =============================================================================
@app.post("/sessions/{session_id}/navigate", response_model=NavigateResponse)
async def navigate(session_id: str, request: NavigateRequest) -> NavigateResponse:
"""
Navigate to a URL.
The URL must be in the session's allowed domains.
"""
session = await get_session_or_404(session_id)
# Check if URL is allowed
if not session.domain_filter.is_allowed(request.url):
return NavigateResponse(
url=request.url,
status="blocked",
blocked_reason=session.domain_filter.get_blocked_reason(request.url),
)
try:
session.increment_actions()
response = await session.page.goto(request.url, wait_until="domcontentloaded")
return NavigateResponse(
url=request.url,
status="success",
title=await session.page.title(),
current_url=session.page.url,
)
except Exception as e:
raise HTTPException(
status_code=status.HTTP_504_GATEWAY_TIMEOUT,
detail=f"Navigation failed: {str(e)}",
)
@app.post("/sessions/{session_id}/click", response_model=ActionResponse)
async def click(session_id: str, request: ClickRequest) -> ActionResponse:
"""Click an element."""
session = await get_session_or_404(session_id)
try:
session.increment_actions()
await session.page.click(request.selector)
return ActionResponse(status="success")
except Exception as e:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail=f"Click failed: {str(e)}",
)
@app.post("/sessions/{session_id}/type", response_model=ActionResponse)
async def type_text(session_id: str, request: TypeRequest) -> ActionResponse:
"""Type text into an element."""
session = await get_session_or_404(session_id)
try:
session.increment_actions()
await session.page.fill(request.selector, request.text)
if request.press_enter:
await session.page.press(request.selector, "Enter")
return ActionResponse(status="success")
except Exception as e:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail=f"Type failed: {str(e)}",
)
@app.post("/sessions/{session_id}/wait", response_model=ActionResponse)
async def wait_for(session_id: str, request: WaitRequest) -> ActionResponse:
"""Wait for a condition."""
session = await get_session_or_404(session_id)
try:
session.increment_actions()
if request.wait_for == "selector":
await session.page.wait_for_selector(request.value)
elif request.wait_for == "text":
await session.page.wait_for_selector(f"text={request.value}")
elif request.wait_for == "timeout":
await session.page.wait_for_timeout(int(request.value))
else:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail=f"Invalid wait_for type: {request.wait_for}",
)
return ActionResponse(status="success")
except HTTPException:
raise
except Exception as e:
raise HTTPException(
status_code=status.HTTP_408_REQUEST_TIMEOUT,
detail=f"Wait failed: {str(e)}",
)
@app.post("/sessions/{session_id}/screenshot", response_model=ScreenshotResponse)
async def screenshot(session_id: str, request: ScreenshotRequest) -> ScreenshotResponse:
"""Take a screenshot."""
session = await get_session_or_404(session_id)
try:
session.increment_actions()
# Create unique filename
filename = f"session-{session_id[:8]}-{uuid.uuid4().hex[:8]}.png"
filepath = os.path.join(settings.screenshots_dir, filename)
# Ensure directory exists
os.makedirs(settings.screenshots_dir, exist_ok=True)
await session.page.screenshot(path=filepath, full_page=request.full_page)
return ScreenshotResponse(path=filepath)
except Exception as e:
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail=f"Screenshot failed: {str(e)}",
)
@app.post("/sessions/{session_id}/snapshot", response_model=SnapshotResponse)
async def snapshot(session_id: str) -> SnapshotResponse:
"""
Get an accessibility snapshot of the page.
Returns a structured tree of elements suitable for LLM consumption.
"""
session = await get_session_or_404(session_id)
try:
session.increment_actions()
# Get accessibility tree
snapshot = await session.page.accessibility.snapshot()
nodes = []
node_counter = [0] # Use list for mutable reference in closure
def extract_nodes(node: dict, nodes_list: list):
"""Recursively extract nodes from accessibility tree."""
if not node:
return
node_counter[0] += 1
ref = f"n{node_counter[0]}"
nodes_list.append(
SnapshotNode(
ref=ref,
role=node.get("role", "unknown"),
name=node.get("name"),
text=node.get("value") or node.get("description"),
)
)
for child in node.get("children", []):
extract_nodes(child, nodes_list)
if snapshot:
extract_nodes(snapshot, nodes)
return SnapshotResponse(nodes=nodes)
except Exception as e:
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail=f"Snapshot failed: {str(e)}",
)
# =============================================================================
# Metrics (basic)
# =============================================================================
@app.get("/metrics")
async def metrics() -> dict[str, Any]:
"""Basic metrics endpoint."""
manager = get_session_manager()
return {
"mcp_sessions_active": manager.active_session_count,
"mcp_max_sessions": settings.max_sessions,
}