letsbe-sysadmin/app/task_manager.py

262 lines
7.8 KiB
Python

"""Task polling and execution management."""
import asyncio
import random
import time
import traceback
from typing import Optional
from app.clients.orchestrator_client import (
CircuitBreakerOpen,
EventLevel,
OrchestratorClient,
Task,
TaskStatus,
)
from app.config import Settings, get_settings
from app.executors import ExecutionResult, get_executor
from app.utils.logger import get_logger
logger = get_logger("task_manager")
class TaskManager:
"""Manage task polling, execution, and result submission.
Features:
- Concurrent task execution with semaphore
- Circuit breaker integration
- Event logging for each task
- Error handling and result persistence
"""
def __init__(
self,
client: OrchestratorClient,
settings: Optional[Settings] = None,
):
self.client = client
self.settings = settings or get_settings()
self._shutdown_event = asyncio.Event()
self._semaphore = asyncio.Semaphore(self.settings.max_concurrent_tasks)
self._active_tasks: set[str] = set()
async def poll_loop(self) -> None:
"""Run the task polling loop until shutdown.
Continuously polls for new tasks and dispatches them for execution.
"""
if not self.client.agent_id:
logger.warning("poll_loop_not_registered")
return
logger.info(
"poll_loop_started",
interval=self.settings.poll_interval,
max_concurrent=self.settings.max_concurrent_tasks,
)
consecutive_failures = 0
backoff_multiplier = 1.0
while not self._shutdown_event.is_set():
try:
# Check circuit breaker
task = await self.client.fetch_next_task()
if task:
# Reset backoff on successful fetch
consecutive_failures = 0
backoff_multiplier = 1.0
# Dispatch task (non-blocking)
asyncio.create_task(self._execute_task(task))
else:
logger.debug("no_tasks_available")
except CircuitBreakerOpen:
logger.warning("poll_circuit_breaker_open")
backoff_multiplier = min(backoff_multiplier * 2, 8.0)
except Exception as e:
consecutive_failures += 1
backoff_multiplier = min(backoff_multiplier * 1.5, 8.0)
logger.error(
"poll_error",
error=str(e),
consecutive_failures=consecutive_failures,
)
# Calculate next poll interval
interval = self.settings.poll_interval * backoff_multiplier
# Add jitter (0-25% of interval)
interval += random.uniform(0, interval * 0.25)
# Wait for next poll or shutdown
try:
await asyncio.wait_for(
self._shutdown_event.wait(),
timeout=interval,
)
break # Shutdown requested
except asyncio.TimeoutError:
pass # Normal timeout, continue polling
# Wait for active tasks to complete
if self._active_tasks:
logger.info("waiting_for_active_tasks", count=len(self._active_tasks))
# Give tasks a grace period
await asyncio.sleep(5)
logger.info("poll_loop_stopped")
async def _execute_task(self, task: Task) -> None:
"""Execute a single task with concurrency control.
Args:
task: Task to execute
"""
# Acquire semaphore for concurrency control
async with self._semaphore:
self._active_tasks.add(task.id)
try:
await self._run_task(task)
finally:
self._active_tasks.discard(task.id)
async def _run_task(self, task: Task) -> None:
"""Run task execution and handle results.
Args:
task: Task to execute
"""
start_time = time.time()
logger.info(
"task_started",
task_id=task.id,
task_type=task.type,
tenant_id=task.tenant_id,
)
# Send start event
await self.client.send_event(
EventLevel.INFO,
f"Task started: {task.type}",
task_id=task.id,
metadata={"payload_keys": list(task.payload.keys())},
)
# Mark task as in progress
await self.client.update_task(task.id, TaskStatus.RUNNING)
try:
# Get executor for task type
executor = get_executor(task.type)
# Execute task
result = await executor.execute(task.payload)
duration_ms = (time.time() - start_time) * 1000
if result.success:
logger.info(
"task_completed",
task_id=task.id,
task_type=task.type,
duration_ms=duration_ms,
)
await self.client.update_task(
task.id,
TaskStatus.COMPLETED,
result=result.data,
)
await self.client.send_event(
EventLevel.INFO,
f"Task completed: {task.type}",
task_id=task.id,
metadata={"duration_ms": duration_ms},
)
else:
logger.warning(
"task_failed",
task_id=task.id,
task_type=task.type,
error=result.error,
duration_ms=duration_ms,
)
await self.client.update_task(
task.id,
TaskStatus.FAILED,
result=result.data,
error=result.error,
)
await self.client.send_event(
EventLevel.ERROR,
f"Task failed: {task.type}",
task_id=task.id,
metadata={"error": result.error, "duration_ms": duration_ms},
)
except ValueError as e:
# Unknown task type or validation error
duration_ms = (time.time() - start_time) * 1000
error_msg = str(e)
logger.error(
"task_validation_error",
task_id=task.id,
task_type=task.type,
error=error_msg,
)
await self.client.update_task(
task.id,
TaskStatus.FAILED,
error=error_msg,
)
await self.client.send_event(
EventLevel.ERROR,
f"Task validation failed: {task.type}",
task_id=task.id,
metadata={"error": error_msg},
)
except Exception as e:
# Unexpected error
duration_ms = (time.time() - start_time) * 1000
error_msg = str(e)
tb = traceback.format_exc()
logger.error(
"task_exception",
task_id=task.id,
task_type=task.type,
error=error_msg,
traceback=tb,
)
await self.client.update_task(
task.id,
TaskStatus.FAILED,
error=error_msg,
)
await self.client.send_event(
EventLevel.ERROR,
f"Task exception: {task.type}",
task_id=task.id,
metadata={"error": error_msg, "traceback": tb[:500]},
)
async def shutdown(self) -> None:
"""Initiate graceful shutdown."""
logger.info("task_manager_shutdown_initiated")
self._shutdown_event.set()