442 lines
13 KiB
Python
442 lines
13 KiB
Python
|
|
"""FastAPI HTTP server for MCP Browser Sidecar."""
|
||
|
|
|
||
|
|
import os
|
||
|
|
import uuid
|
||
|
|
from contextlib import asynccontextmanager
|
||
|
|
from datetime import datetime, timezone
|
||
|
|
from typing import Any, Optional
|
||
|
|
|
||
|
|
from fastapi import FastAPI, HTTPException, status
|
||
|
|
from pydantic import BaseModel, Field
|
||
|
|
|
||
|
|
from app.config import settings
|
||
|
|
from app.playwright_client import get_playwright_client
|
||
|
|
from app.session_manager import BrowserSession, SessionManager
|
||
|
|
|
||
|
|
# Global session manager
|
||
|
|
_session_manager: Optional[SessionManager] = None
|
||
|
|
|
||
|
|
|
||
|
|
def get_session_manager() -> SessionManager:
|
||
|
|
"""Get the global SessionManager instance."""
|
||
|
|
if _session_manager is None:
|
||
|
|
raise RuntimeError("SessionManager not initialized")
|
||
|
|
return _session_manager
|
||
|
|
|
||
|
|
|
||
|
|
@asynccontextmanager
|
||
|
|
async def lifespan(app: FastAPI):
|
||
|
|
"""Application lifespan manager."""
|
||
|
|
global _session_manager
|
||
|
|
|
||
|
|
# Startup
|
||
|
|
client = get_playwright_client()
|
||
|
|
browser = await client.start()
|
||
|
|
_session_manager = SessionManager(browser)
|
||
|
|
await _session_manager.start_cleanup_task()
|
||
|
|
|
||
|
|
yield
|
||
|
|
|
||
|
|
# Shutdown
|
||
|
|
if _session_manager:
|
||
|
|
await _session_manager.stop_cleanup_task()
|
||
|
|
await _session_manager.close_all_sessions()
|
||
|
|
await client.stop()
|
||
|
|
|
||
|
|
|
||
|
|
app = FastAPI(
|
||
|
|
title="MCP Browser Sidecar",
|
||
|
|
description="Playwright browser automation service for LLM-driven UI interaction",
|
||
|
|
version="0.1.0",
|
||
|
|
lifespan=lifespan,
|
||
|
|
)
|
||
|
|
|
||
|
|
|
||
|
|
# =============================================================================
|
||
|
|
# Request/Response Models
|
||
|
|
# =============================================================================
|
||
|
|
|
||
|
|
|
||
|
|
class CreateSessionRequest(BaseModel):
|
||
|
|
"""Request to create a new browser session."""
|
||
|
|
|
||
|
|
allowed_domains: list[str] = Field(
|
||
|
|
...,
|
||
|
|
min_length=1,
|
||
|
|
description="List of allowed domain patterns (required)",
|
||
|
|
)
|
||
|
|
|
||
|
|
|
||
|
|
class SessionResponse(BaseModel):
|
||
|
|
"""Response containing session information."""
|
||
|
|
|
||
|
|
session_id: str
|
||
|
|
created_at: datetime
|
||
|
|
last_activity: datetime
|
||
|
|
actions_used: int
|
||
|
|
actions_remaining: int
|
||
|
|
expires_at: datetime
|
||
|
|
allowed_domains: list[str]
|
||
|
|
|
||
|
|
|
||
|
|
class NavigateRequest(BaseModel):
|
||
|
|
"""Request to navigate to a URL."""
|
||
|
|
|
||
|
|
url: str = Field(..., min_length=1, description="URL to navigate to")
|
||
|
|
|
||
|
|
|
||
|
|
class NavigateResponse(BaseModel):
|
||
|
|
"""Response from navigation."""
|
||
|
|
|
||
|
|
url: str
|
||
|
|
status: str # "success" or "blocked"
|
||
|
|
title: Optional[str] = None
|
||
|
|
current_url: Optional[str] = None
|
||
|
|
blocked_reason: Optional[str] = None
|
||
|
|
|
||
|
|
|
||
|
|
class ClickRequest(BaseModel):
|
||
|
|
"""Request to click an element."""
|
||
|
|
|
||
|
|
selector: str = Field(..., description="CSS selector or element ref")
|
||
|
|
|
||
|
|
|
||
|
|
class TypeRequest(BaseModel):
|
||
|
|
"""Request to type into an element."""
|
||
|
|
|
||
|
|
selector: str = Field(..., description="CSS selector or element ref")
|
||
|
|
text: str = Field(..., description="Text to type")
|
||
|
|
press_enter: bool = Field(False, description="Press Enter after typing")
|
||
|
|
|
||
|
|
|
||
|
|
class WaitRequest(BaseModel):
|
||
|
|
"""Request to wait for a condition."""
|
||
|
|
|
||
|
|
wait_for: str = Field(..., alias="for", description="'selector', 'text', or 'timeout'")
|
||
|
|
value: str = Field(..., description="Selector, text, or milliseconds")
|
||
|
|
|
||
|
|
|
||
|
|
class ScreenshotRequest(BaseModel):
|
||
|
|
"""Request to take a screenshot."""
|
||
|
|
|
||
|
|
full_page: bool = Field(False, description="Capture full scrollable page")
|
||
|
|
|
||
|
|
|
||
|
|
class ScreenshotResponse(BaseModel):
|
||
|
|
"""Response from screenshot."""
|
||
|
|
|
||
|
|
path: str
|
||
|
|
|
||
|
|
|
||
|
|
class SnapshotNode(BaseModel):
|
||
|
|
"""A node in the accessibility snapshot."""
|
||
|
|
|
||
|
|
ref: str
|
||
|
|
role: str
|
||
|
|
name: Optional[str] = None
|
||
|
|
text: Optional[str] = None
|
||
|
|
|
||
|
|
|
||
|
|
class SnapshotResponse(BaseModel):
|
||
|
|
"""Response from accessibility snapshot."""
|
||
|
|
|
||
|
|
nodes: list[SnapshotNode]
|
||
|
|
|
||
|
|
|
||
|
|
class ActionResponse(BaseModel):
|
||
|
|
"""Generic response for actions."""
|
||
|
|
|
||
|
|
status: str
|
||
|
|
message: Optional[str] = None
|
||
|
|
|
||
|
|
|
||
|
|
class ErrorResponse(BaseModel):
|
||
|
|
"""Error response."""
|
||
|
|
|
||
|
|
detail: str
|
||
|
|
|
||
|
|
|
||
|
|
# =============================================================================
|
||
|
|
# Helper Functions
|
||
|
|
# =============================================================================
|
||
|
|
|
||
|
|
|
||
|
|
async def get_session_or_404(session_id: str) -> BrowserSession:
|
||
|
|
"""Get a session or raise 404."""
|
||
|
|
manager = get_session_manager()
|
||
|
|
session = await manager.get_session(session_id)
|
||
|
|
if session is None:
|
||
|
|
raise HTTPException(
|
||
|
|
status_code=status.HTTP_404_NOT_FOUND,
|
||
|
|
detail=f"Session {session_id} not found or expired",
|
||
|
|
)
|
||
|
|
return session
|
||
|
|
|
||
|
|
|
||
|
|
def session_to_response(session: BrowserSession) -> SessionResponse:
|
||
|
|
"""Convert a BrowserSession to SessionResponse."""
|
||
|
|
return SessionResponse(
|
||
|
|
session_id=session.session_id,
|
||
|
|
created_at=session.created_at,
|
||
|
|
last_activity=session.last_activity,
|
||
|
|
actions_used=session.actions_used,
|
||
|
|
actions_remaining=session.actions_remaining,
|
||
|
|
expires_at=session.expires_at,
|
||
|
|
allowed_domains=session.allowed_domains,
|
||
|
|
)
|
||
|
|
|
||
|
|
|
||
|
|
# =============================================================================
|
||
|
|
# Health Check
|
||
|
|
# =============================================================================
|
||
|
|
|
||
|
|
|
||
|
|
@app.get("/health")
|
||
|
|
async def health_check() -> dict[str, Any]:
|
||
|
|
"""Health check endpoint."""
|
||
|
|
manager = get_session_manager()
|
||
|
|
return {
|
||
|
|
"status": "healthy",
|
||
|
|
"active_sessions": manager.active_session_count,
|
||
|
|
"max_sessions": settings.max_sessions,
|
||
|
|
}
|
||
|
|
|
||
|
|
|
||
|
|
# =============================================================================
|
||
|
|
# Session Management
|
||
|
|
# =============================================================================
|
||
|
|
|
||
|
|
|
||
|
|
@app.post(
|
||
|
|
"/sessions",
|
||
|
|
response_model=SessionResponse,
|
||
|
|
status_code=status.HTTP_201_CREATED,
|
||
|
|
)
|
||
|
|
async def create_session(request: CreateSessionRequest) -> SessionResponse:
|
||
|
|
"""
|
||
|
|
Create a new browser session.
|
||
|
|
|
||
|
|
The session will have a domain allowlist that restricts all navigation
|
||
|
|
and network requests to the specified domains.
|
||
|
|
"""
|
||
|
|
manager = get_session_manager()
|
||
|
|
|
||
|
|
try:
|
||
|
|
session = await manager.create_session(request.allowed_domains)
|
||
|
|
return session_to_response(session)
|
||
|
|
except ValueError as e:
|
||
|
|
raise HTTPException(
|
||
|
|
status_code=status.HTTP_400_BAD_REQUEST,
|
||
|
|
detail=str(e),
|
||
|
|
)
|
||
|
|
|
||
|
|
|
||
|
|
@app.get("/sessions/{session_id}/status", response_model=SessionResponse)
|
||
|
|
async def get_session_status(session_id: str) -> SessionResponse:
|
||
|
|
"""Get the status of a session."""
|
||
|
|
session = await get_session_or_404(session_id)
|
||
|
|
return session_to_response(session)
|
||
|
|
|
||
|
|
|
||
|
|
@app.delete("/sessions/{session_id}", status_code=status.HTTP_204_NO_CONTENT)
|
||
|
|
async def close_session(session_id: str) -> None:
|
||
|
|
"""
|
||
|
|
Close a browser session.
|
||
|
|
|
||
|
|
This is idempotent - calling it multiple times is safe.
|
||
|
|
"""
|
||
|
|
manager = get_session_manager()
|
||
|
|
await manager.close_session(session_id)
|
||
|
|
|
||
|
|
|
||
|
|
# =============================================================================
|
||
|
|
# Browser Actions
|
||
|
|
# =============================================================================
|
||
|
|
|
||
|
|
|
||
|
|
@app.post("/sessions/{session_id}/navigate", response_model=NavigateResponse)
|
||
|
|
async def navigate(session_id: str, request: NavigateRequest) -> NavigateResponse:
|
||
|
|
"""
|
||
|
|
Navigate to a URL.
|
||
|
|
|
||
|
|
The URL must be in the session's allowed domains.
|
||
|
|
"""
|
||
|
|
session = await get_session_or_404(session_id)
|
||
|
|
|
||
|
|
# Check if URL is allowed
|
||
|
|
if not session.domain_filter.is_allowed(request.url):
|
||
|
|
return NavigateResponse(
|
||
|
|
url=request.url,
|
||
|
|
status="blocked",
|
||
|
|
blocked_reason=session.domain_filter.get_blocked_reason(request.url),
|
||
|
|
)
|
||
|
|
|
||
|
|
try:
|
||
|
|
session.increment_actions()
|
||
|
|
response = await session.page.goto(request.url, wait_until="domcontentloaded")
|
||
|
|
|
||
|
|
return NavigateResponse(
|
||
|
|
url=request.url,
|
||
|
|
status="success",
|
||
|
|
title=await session.page.title(),
|
||
|
|
current_url=session.page.url,
|
||
|
|
)
|
||
|
|
except Exception as e:
|
||
|
|
raise HTTPException(
|
||
|
|
status_code=status.HTTP_504_GATEWAY_TIMEOUT,
|
||
|
|
detail=f"Navigation failed: {str(e)}",
|
||
|
|
)
|
||
|
|
|
||
|
|
|
||
|
|
@app.post("/sessions/{session_id}/click", response_model=ActionResponse)
|
||
|
|
async def click(session_id: str, request: ClickRequest) -> ActionResponse:
|
||
|
|
"""Click an element."""
|
||
|
|
session = await get_session_or_404(session_id)
|
||
|
|
|
||
|
|
try:
|
||
|
|
session.increment_actions()
|
||
|
|
await session.page.click(request.selector)
|
||
|
|
return ActionResponse(status="success")
|
||
|
|
except Exception as e:
|
||
|
|
raise HTTPException(
|
||
|
|
status_code=status.HTTP_400_BAD_REQUEST,
|
||
|
|
detail=f"Click failed: {str(e)}",
|
||
|
|
)
|
||
|
|
|
||
|
|
|
||
|
|
@app.post("/sessions/{session_id}/type", response_model=ActionResponse)
|
||
|
|
async def type_text(session_id: str, request: TypeRequest) -> ActionResponse:
|
||
|
|
"""Type text into an element."""
|
||
|
|
session = await get_session_or_404(session_id)
|
||
|
|
|
||
|
|
try:
|
||
|
|
session.increment_actions()
|
||
|
|
await session.page.fill(request.selector, request.text)
|
||
|
|
if request.press_enter:
|
||
|
|
await session.page.press(request.selector, "Enter")
|
||
|
|
return ActionResponse(status="success")
|
||
|
|
except Exception as e:
|
||
|
|
raise HTTPException(
|
||
|
|
status_code=status.HTTP_400_BAD_REQUEST,
|
||
|
|
detail=f"Type failed: {str(e)}",
|
||
|
|
)
|
||
|
|
|
||
|
|
|
||
|
|
@app.post("/sessions/{session_id}/wait", response_model=ActionResponse)
|
||
|
|
async def wait_for(session_id: str, request: WaitRequest) -> ActionResponse:
|
||
|
|
"""Wait for a condition."""
|
||
|
|
session = await get_session_or_404(session_id)
|
||
|
|
|
||
|
|
try:
|
||
|
|
session.increment_actions()
|
||
|
|
|
||
|
|
if request.wait_for == "selector":
|
||
|
|
await session.page.wait_for_selector(request.value)
|
||
|
|
elif request.wait_for == "text":
|
||
|
|
await session.page.wait_for_selector(f"text={request.value}")
|
||
|
|
elif request.wait_for == "timeout":
|
||
|
|
await session.page.wait_for_timeout(int(request.value))
|
||
|
|
else:
|
||
|
|
raise HTTPException(
|
||
|
|
status_code=status.HTTP_400_BAD_REQUEST,
|
||
|
|
detail=f"Invalid wait_for type: {request.wait_for}",
|
||
|
|
)
|
||
|
|
|
||
|
|
return ActionResponse(status="success")
|
||
|
|
except HTTPException:
|
||
|
|
raise
|
||
|
|
except Exception as e:
|
||
|
|
raise HTTPException(
|
||
|
|
status_code=status.HTTP_408_REQUEST_TIMEOUT,
|
||
|
|
detail=f"Wait failed: {str(e)}",
|
||
|
|
)
|
||
|
|
|
||
|
|
|
||
|
|
@app.post("/sessions/{session_id}/screenshot", response_model=ScreenshotResponse)
|
||
|
|
async def screenshot(session_id: str, request: ScreenshotRequest) -> ScreenshotResponse:
|
||
|
|
"""Take a screenshot."""
|
||
|
|
session = await get_session_or_404(session_id)
|
||
|
|
|
||
|
|
try:
|
||
|
|
session.increment_actions()
|
||
|
|
|
||
|
|
# Create unique filename
|
||
|
|
filename = f"session-{session_id[:8]}-{uuid.uuid4().hex[:8]}.png"
|
||
|
|
filepath = os.path.join(settings.screenshots_dir, filename)
|
||
|
|
|
||
|
|
# Ensure directory exists
|
||
|
|
os.makedirs(settings.screenshots_dir, exist_ok=True)
|
||
|
|
|
||
|
|
await session.page.screenshot(path=filepath, full_page=request.full_page)
|
||
|
|
|
||
|
|
return ScreenshotResponse(path=filepath)
|
||
|
|
except Exception as e:
|
||
|
|
raise HTTPException(
|
||
|
|
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||
|
|
detail=f"Screenshot failed: {str(e)}",
|
||
|
|
)
|
||
|
|
|
||
|
|
|
||
|
|
@app.post("/sessions/{session_id}/snapshot", response_model=SnapshotResponse)
|
||
|
|
async def snapshot(session_id: str) -> SnapshotResponse:
|
||
|
|
"""
|
||
|
|
Get an accessibility snapshot of the page.
|
||
|
|
|
||
|
|
Returns a structured tree of elements suitable for LLM consumption.
|
||
|
|
"""
|
||
|
|
session = await get_session_or_404(session_id)
|
||
|
|
|
||
|
|
try:
|
||
|
|
session.increment_actions()
|
||
|
|
|
||
|
|
# Get accessibility tree
|
||
|
|
snapshot = await session.page.accessibility.snapshot()
|
||
|
|
|
||
|
|
nodes = []
|
||
|
|
node_counter = [0] # Use list for mutable reference in closure
|
||
|
|
|
||
|
|
def extract_nodes(node: dict, nodes_list: list):
|
||
|
|
"""Recursively extract nodes from accessibility tree."""
|
||
|
|
if not node:
|
||
|
|
return
|
||
|
|
|
||
|
|
node_counter[0] += 1
|
||
|
|
ref = f"n{node_counter[0]}"
|
||
|
|
|
||
|
|
nodes_list.append(
|
||
|
|
SnapshotNode(
|
||
|
|
ref=ref,
|
||
|
|
role=node.get("role", "unknown"),
|
||
|
|
name=node.get("name"),
|
||
|
|
text=node.get("value") or node.get("description"),
|
||
|
|
)
|
||
|
|
)
|
||
|
|
|
||
|
|
for child in node.get("children", []):
|
||
|
|
extract_nodes(child, nodes_list)
|
||
|
|
|
||
|
|
if snapshot:
|
||
|
|
extract_nodes(snapshot, nodes)
|
||
|
|
|
||
|
|
return SnapshotResponse(nodes=nodes)
|
||
|
|
except Exception as e:
|
||
|
|
raise HTTPException(
|
||
|
|
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||
|
|
detail=f"Snapshot failed: {str(e)}",
|
||
|
|
)
|
||
|
|
|
||
|
|
|
||
|
|
# =============================================================================
|
||
|
|
# Metrics (basic)
|
||
|
|
# =============================================================================
|
||
|
|
|
||
|
|
|
||
|
|
@app.get("/metrics")
|
||
|
|
async def metrics() -> dict[str, Any]:
|
||
|
|
"""Basic metrics endpoint."""
|
||
|
|
manager = get_session_manager()
|
||
|
|
return {
|
||
|
|
"mcp_sessions_active": manager.active_session_count,
|
||
|
|
"mcp_max_sessions": settings.max_sessions,
|
||
|
|
}
|