"""Task polling and execution management.""" import asyncio import random import time import traceback from typing import Optional from app.clients.orchestrator_client import ( CircuitBreakerOpen, EventLevel, OrchestratorClient, Task, TaskStatus, ) from app.config import Settings, get_settings from app.executors import ExecutionResult, get_executor from app.utils.logger import get_logger logger = get_logger("task_manager") class TaskManager: """Manage task polling, execution, and result submission. Features: - Concurrent task execution with semaphore - Circuit breaker integration - Event logging for each task - Error handling and result persistence """ def __init__( self, client: OrchestratorClient, settings: Optional[Settings] = None, ): self.client = client self.settings = settings or get_settings() self._shutdown_event = asyncio.Event() self._semaphore = asyncio.Semaphore(self.settings.max_concurrent_tasks) self._active_tasks: set[str] = set() async def poll_loop(self) -> None: """Run the task polling loop until shutdown. Continuously polls for new tasks and dispatches them for execution. """ if not self.client.agent_id: logger.warning("poll_loop_not_registered") return logger.info( "poll_loop_started", interval=self.settings.poll_interval, max_concurrent=self.settings.max_concurrent_tasks, ) consecutive_failures = 0 backoff_multiplier = 1.0 while not self._shutdown_event.is_set(): try: # Check circuit breaker task = await self.client.fetch_next_task() if task: # Reset backoff on successful fetch consecutive_failures = 0 backoff_multiplier = 1.0 # Dispatch task (non-blocking) asyncio.create_task(self._execute_task(task)) else: logger.debug("no_tasks_available") except CircuitBreakerOpen: logger.warning("poll_circuit_breaker_open") backoff_multiplier = min(backoff_multiplier * 2, 8.0) except Exception as e: consecutive_failures += 1 backoff_multiplier = min(backoff_multiplier * 1.5, 8.0) logger.error( "poll_error", error=str(e), consecutive_failures=consecutive_failures, ) # Calculate next poll interval interval = self.settings.poll_interval * backoff_multiplier # Add jitter (0-25% of interval) interval += random.uniform(0, interval * 0.25) # Wait for next poll or shutdown try: await asyncio.wait_for( self._shutdown_event.wait(), timeout=interval, ) break # Shutdown requested except asyncio.TimeoutError: pass # Normal timeout, continue polling # Wait for active tasks to complete if self._active_tasks: logger.info("waiting_for_active_tasks", count=len(self._active_tasks)) # Give tasks a grace period await asyncio.sleep(5) logger.info("poll_loop_stopped") async def _execute_task(self, task: Task) -> None: """Execute a single task with concurrency control. Args: task: Task to execute """ # Acquire semaphore for concurrency control async with self._semaphore: self._active_tasks.add(task.id) try: await self._run_task(task) finally: self._active_tasks.discard(task.id) async def _run_task(self, task: Task) -> None: """Run task execution and handle results. Args: task: Task to execute """ start_time = time.time() logger.info( "task_started", task_id=task.id, task_type=task.type, tenant_id=task.tenant_id, ) # Send start event await self.client.send_event( EventLevel.INFO, f"Task started: {task.type}", task_id=task.id, metadata={"payload_keys": list(task.payload.keys())}, ) # Mark task as in progress await self.client.update_task(task.id, TaskStatus.RUNNING) try: # Get executor for task type executor = get_executor(task.type) # Execute task result = await executor.execute(task.payload) duration_ms = (time.time() - start_time) * 1000 if result.success: logger.info( "task_completed", task_id=task.id, task_type=task.type, duration_ms=duration_ms, ) await self.client.update_task( task.id, TaskStatus.COMPLETED, result=result.data, ) await self.client.send_event( EventLevel.INFO, f"Task completed: {task.type}", task_id=task.id, metadata={"duration_ms": duration_ms}, ) else: logger.warning( "task_failed", task_id=task.id, task_type=task.type, error=result.error, duration_ms=duration_ms, ) await self.client.update_task( task.id, TaskStatus.FAILED, result=result.data, error=result.error, ) await self.client.send_event( EventLevel.ERROR, f"Task failed: {task.type}", task_id=task.id, metadata={"error": result.error, "duration_ms": duration_ms}, ) except ValueError as e: # Unknown task type or validation error duration_ms = (time.time() - start_time) * 1000 error_msg = str(e) logger.error( "task_validation_error", task_id=task.id, task_type=task.type, error=error_msg, ) await self.client.update_task( task.id, TaskStatus.FAILED, error=error_msg, ) await self.client.send_event( EventLevel.ERROR, f"Task validation failed: {task.type}", task_id=task.id, metadata={"error": error_msg}, ) except Exception as e: # Unexpected error duration_ms = (time.time() - start_time) * 1000 error_msg = str(e) tb = traceback.format_exc() logger.error( "task_exception", task_id=task.id, task_type=task.type, error=error_msg, traceback=tb, ) await self.client.update_task( task.id, TaskStatus.FAILED, error=error_msg, ) await self.client.send_event( EventLevel.ERROR, f"Task exception: {task.type}", task_id=task.id, metadata={"error": error_msg, "traceback": tb[:500]}, ) async def shutdown(self) -> None: """Initiate graceful shutdown.""" logger.info("task_manager_shutdown_initiated") self._shutdown_event.set()