260 lines
8.0 KiB
Python
260 lines
8.0 KiB
Python
|
|
"""Browser session management."""
|
||
|
|
|
||
|
|
import asyncio
|
||
|
|
import uuid
|
||
|
|
from datetime import datetime, timezone
|
||
|
|
from typing import Optional
|
||
|
|
|
||
|
|
from playwright.async_api import Browser, BrowserContext, Page
|
||
|
|
|
||
|
|
from app.config import settings
|
||
|
|
from app.domain_filter import DomainFilter
|
||
|
|
|
||
|
|
|
||
|
|
def utc_now() -> datetime:
|
||
|
|
"""Get current UTC time."""
|
||
|
|
return datetime.now(timezone.utc)
|
||
|
|
|
||
|
|
|
||
|
|
class BrowserSession:
|
||
|
|
"""
|
||
|
|
Represents a single browser session with domain restrictions.
|
||
|
|
|
||
|
|
Each session has its own BrowserContext and Page, isolated from
|
||
|
|
other sessions. Navigation and network requests are filtered
|
||
|
|
against the allowed_domains list.
|
||
|
|
"""
|
||
|
|
|
||
|
|
def __init__(
|
||
|
|
self,
|
||
|
|
session_id: str,
|
||
|
|
allowed_domains: list[str],
|
||
|
|
context: BrowserContext,
|
||
|
|
page: Page,
|
||
|
|
):
|
||
|
|
self.session_id = session_id
|
||
|
|
self.allowed_domains = allowed_domains
|
||
|
|
self.domain_filter = DomainFilter(allowed_domains)
|
||
|
|
self.context = context
|
||
|
|
self.page = page
|
||
|
|
self.created_at = utc_now()
|
||
|
|
self.last_activity = utc_now()
|
||
|
|
self.actions_used = 0
|
||
|
|
self.max_actions = settings.max_actions_per_session
|
||
|
|
|
||
|
|
def touch(self) -> None:
|
||
|
|
"""Update last activity timestamp."""
|
||
|
|
self.last_activity = utc_now()
|
||
|
|
|
||
|
|
def increment_actions(self) -> None:
|
||
|
|
"""Increment action counter."""
|
||
|
|
self.actions_used += 1
|
||
|
|
self.touch()
|
||
|
|
|
||
|
|
@property
|
||
|
|
def actions_remaining(self) -> int:
|
||
|
|
"""Get remaining actions for this session."""
|
||
|
|
return max(0, self.max_actions - self.actions_used)
|
||
|
|
|
||
|
|
@property
|
||
|
|
def is_expired(self) -> bool:
|
||
|
|
"""Check if session has expired due to timeout or max lifetime."""
|
||
|
|
now = utc_now()
|
||
|
|
|
||
|
|
# Check idle timeout
|
||
|
|
idle_seconds = (now - self.last_activity).total_seconds()
|
||
|
|
if idle_seconds > settings.idle_timeout_seconds:
|
||
|
|
return True
|
||
|
|
|
||
|
|
# Check max lifetime
|
||
|
|
lifetime_seconds = (now - self.created_at).total_seconds()
|
||
|
|
if lifetime_seconds > settings.max_session_lifetime_seconds:
|
||
|
|
return True
|
||
|
|
|
||
|
|
return False
|
||
|
|
|
||
|
|
@property
|
||
|
|
def expires_at(self) -> datetime:
|
||
|
|
"""Calculate when this session will expire."""
|
||
|
|
from datetime import timedelta
|
||
|
|
|
||
|
|
idle_expiry = self.last_activity + timedelta(
|
||
|
|
seconds=settings.idle_timeout_seconds
|
||
|
|
)
|
||
|
|
lifetime_expiry = self.created_at + timedelta(
|
||
|
|
seconds=settings.max_session_lifetime_seconds
|
||
|
|
)
|
||
|
|
return min(idle_expiry, lifetime_expiry)
|
||
|
|
|
||
|
|
async def close(self) -> None:
|
||
|
|
"""Close the browser context and page."""
|
||
|
|
try:
|
||
|
|
await self.context.close()
|
||
|
|
except Exception:
|
||
|
|
pass # Best effort cleanup
|
||
|
|
|
||
|
|
|
||
|
|
class SessionManager:
|
||
|
|
"""
|
||
|
|
Manages browser sessions with lifecycle, limits, and cleanup.
|
||
|
|
|
||
|
|
Responsible for:
|
||
|
|
- Creating new sessions with domain restrictions
|
||
|
|
- Tracking active sessions
|
||
|
|
- Enforcing session limits
|
||
|
|
- Background cleanup of expired sessions
|
||
|
|
"""
|
||
|
|
|
||
|
|
def __init__(self, browser: Browser):
|
||
|
|
self.browser = browser
|
||
|
|
self._sessions: dict[str, BrowserSession] = {}
|
||
|
|
self._lock = asyncio.Lock()
|
||
|
|
self._cleanup_task: Optional[asyncio.Task] = None
|
||
|
|
|
||
|
|
@property
|
||
|
|
def active_session_count(self) -> int:
|
||
|
|
"""Get count of active sessions."""
|
||
|
|
return len(self._sessions)
|
||
|
|
|
||
|
|
async def start_cleanup_task(self) -> None:
|
||
|
|
"""Start the background cleanup task."""
|
||
|
|
if self._cleanup_task is None:
|
||
|
|
self._cleanup_task = asyncio.create_task(self._cleanup_loop())
|
||
|
|
|
||
|
|
async def stop_cleanup_task(self) -> None:
|
||
|
|
"""Stop the background cleanup task."""
|
||
|
|
if self._cleanup_task:
|
||
|
|
self._cleanup_task.cancel()
|
||
|
|
try:
|
||
|
|
await self._cleanup_task
|
||
|
|
except asyncio.CancelledError:
|
||
|
|
pass
|
||
|
|
self._cleanup_task = None
|
||
|
|
|
||
|
|
async def _cleanup_loop(self) -> None:
|
||
|
|
"""Background loop that cleans up expired sessions."""
|
||
|
|
while True:
|
||
|
|
try:
|
||
|
|
await asyncio.sleep(settings.cleanup_interval_seconds)
|
||
|
|
await self._cleanup_expired_sessions()
|
||
|
|
except asyncio.CancelledError:
|
||
|
|
break
|
||
|
|
except Exception:
|
||
|
|
# Log error but keep running
|
||
|
|
pass
|
||
|
|
|
||
|
|
async def _cleanup_expired_sessions(self) -> None:
|
||
|
|
"""Remove and close expired sessions."""
|
||
|
|
async with self._lock:
|
||
|
|
expired_ids = [
|
||
|
|
sid for sid, session in self._sessions.items() if session.is_expired
|
||
|
|
]
|
||
|
|
|
||
|
|
for session_id in expired_ids:
|
||
|
|
session = self._sessions.pop(session_id, None)
|
||
|
|
if session:
|
||
|
|
await session.close()
|
||
|
|
|
||
|
|
async def create_session(self, allowed_domains: list[str]) -> BrowserSession:
|
||
|
|
"""
|
||
|
|
Create a new browser session.
|
||
|
|
|
||
|
|
Args:
|
||
|
|
allowed_domains: List of allowed domain patterns
|
||
|
|
|
||
|
|
Returns:
|
||
|
|
The created BrowserSession
|
||
|
|
|
||
|
|
Raises:
|
||
|
|
ValueError: If max sessions reached or invalid domains
|
||
|
|
"""
|
||
|
|
if not allowed_domains:
|
||
|
|
raise ValueError("allowed_domains is required and cannot be empty")
|
||
|
|
|
||
|
|
async with self._lock:
|
||
|
|
# Check session limit
|
||
|
|
if len(self._sessions) >= settings.max_sessions:
|
||
|
|
raise ValueError(
|
||
|
|
f"Maximum sessions ({settings.max_sessions}) reached. "
|
||
|
|
"Close an existing session first."
|
||
|
|
)
|
||
|
|
|
||
|
|
session_id = str(uuid.uuid4())
|
||
|
|
|
||
|
|
# Create browser context with domain filtering
|
||
|
|
context = await self.browser.new_context(
|
||
|
|
viewport={"width": 1280, "height": 720},
|
||
|
|
user_agent="Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36",
|
||
|
|
)
|
||
|
|
|
||
|
|
# Create domain filter for route interception
|
||
|
|
domain_filter = DomainFilter(allowed_domains)
|
||
|
|
|
||
|
|
# Set up route handler to block requests to non-allowed domains
|
||
|
|
async def handle_route(route):
|
||
|
|
url = route.request.url
|
||
|
|
if domain_filter.is_allowed(url):
|
||
|
|
await route.continue_()
|
||
|
|
else:
|
||
|
|
# Block the request
|
||
|
|
await route.abort("blockedbyclient")
|
||
|
|
|
||
|
|
await context.route("**/*", handle_route)
|
||
|
|
|
||
|
|
# Create page
|
||
|
|
page = await context.new_page()
|
||
|
|
page.set_default_timeout(settings.default_timeout_ms)
|
||
|
|
page.set_default_navigation_timeout(settings.navigation_timeout_ms)
|
||
|
|
|
||
|
|
# Create session
|
||
|
|
session = BrowserSession(
|
||
|
|
session_id=session_id,
|
||
|
|
allowed_domains=allowed_domains,
|
||
|
|
context=context,
|
||
|
|
page=page,
|
||
|
|
)
|
||
|
|
|
||
|
|
self._sessions[session_id] = session
|
||
|
|
return session
|
||
|
|
|
||
|
|
async def get_session(self, session_id: str) -> Optional[BrowserSession]:
|
||
|
|
"""
|
||
|
|
Get a session by ID.
|
||
|
|
|
||
|
|
Args:
|
||
|
|
session_id: The session UUID
|
||
|
|
|
||
|
|
Returns:
|
||
|
|
The BrowserSession or None if not found
|
||
|
|
"""
|
||
|
|
session = self._sessions.get(session_id)
|
||
|
|
if session and session.is_expired:
|
||
|
|
# Clean up expired session
|
||
|
|
await self.close_session(session_id)
|
||
|
|
return None
|
||
|
|
return session
|
||
|
|
|
||
|
|
async def close_session(self, session_id: str) -> bool:
|
||
|
|
"""
|
||
|
|
Close and remove a session.
|
||
|
|
|
||
|
|
Args:
|
||
|
|
session_id: The session UUID
|
||
|
|
|
||
|
|
Returns:
|
||
|
|
True if session was found and closed, False otherwise
|
||
|
|
"""
|
||
|
|
async with self._lock:
|
||
|
|
session = self._sessions.pop(session_id, None)
|
||
|
|
if session:
|
||
|
|
await session.close()
|
||
|
|
return True
|
||
|
|
return False
|
||
|
|
|
||
|
|
async def close_all_sessions(self) -> None:
|
||
|
|
"""Close all active sessions."""
|
||
|
|
async with self._lock:
|
||
|
|
for session in list(self._sessions.values()):
|
||
|
|
await session.close()
|
||
|
|
self._sessions.clear()
|