262 lines
7.8 KiB
Python
262 lines
7.8 KiB
Python
|
|
"""Task polling and execution management."""
|
||
|
|
|
||
|
|
import asyncio
|
||
|
|
import random
|
||
|
|
import time
|
||
|
|
import traceback
|
||
|
|
from typing import Optional
|
||
|
|
|
||
|
|
from app.clients.orchestrator_client import (
|
||
|
|
CircuitBreakerOpen,
|
||
|
|
EventLevel,
|
||
|
|
OrchestratorClient,
|
||
|
|
Task,
|
||
|
|
TaskStatus,
|
||
|
|
)
|
||
|
|
from app.config import Settings, get_settings
|
||
|
|
from app.executors import ExecutionResult, get_executor
|
||
|
|
from app.utils.logger import get_logger
|
||
|
|
|
||
|
|
logger = get_logger("task_manager")
|
||
|
|
|
||
|
|
|
||
|
|
class TaskManager:
|
||
|
|
"""Manage task polling, execution, and result submission.
|
||
|
|
|
||
|
|
Features:
|
||
|
|
- Concurrent task execution with semaphore
|
||
|
|
- Circuit breaker integration
|
||
|
|
- Event logging for each task
|
||
|
|
- Error handling and result persistence
|
||
|
|
"""
|
||
|
|
|
||
|
|
def __init__(
|
||
|
|
self,
|
||
|
|
client: OrchestratorClient,
|
||
|
|
settings: Optional[Settings] = None,
|
||
|
|
):
|
||
|
|
self.client = client
|
||
|
|
self.settings = settings or get_settings()
|
||
|
|
self._shutdown_event = asyncio.Event()
|
||
|
|
self._semaphore = asyncio.Semaphore(self.settings.max_concurrent_tasks)
|
||
|
|
self._active_tasks: set[str] = set()
|
||
|
|
|
||
|
|
async def poll_loop(self) -> None:
|
||
|
|
"""Run the task polling loop until shutdown.
|
||
|
|
|
||
|
|
Continuously polls for new tasks and dispatches them for execution.
|
||
|
|
"""
|
||
|
|
if not self.client.agent_id:
|
||
|
|
logger.warning("poll_loop_not_registered")
|
||
|
|
return
|
||
|
|
|
||
|
|
logger.info(
|
||
|
|
"poll_loop_started",
|
||
|
|
interval=self.settings.poll_interval,
|
||
|
|
max_concurrent=self.settings.max_concurrent_tasks,
|
||
|
|
)
|
||
|
|
|
||
|
|
consecutive_failures = 0
|
||
|
|
backoff_multiplier = 1.0
|
||
|
|
|
||
|
|
while not self._shutdown_event.is_set():
|
||
|
|
try:
|
||
|
|
# Check circuit breaker
|
||
|
|
task = await self.client.fetch_next_task()
|
||
|
|
|
||
|
|
if task:
|
||
|
|
# Reset backoff on successful fetch
|
||
|
|
consecutive_failures = 0
|
||
|
|
backoff_multiplier = 1.0
|
||
|
|
|
||
|
|
# Dispatch task (non-blocking)
|
||
|
|
asyncio.create_task(self._execute_task(task))
|
||
|
|
else:
|
||
|
|
logger.debug("no_tasks_available")
|
||
|
|
|
||
|
|
except CircuitBreakerOpen:
|
||
|
|
logger.warning("poll_circuit_breaker_open")
|
||
|
|
backoff_multiplier = min(backoff_multiplier * 2, 8.0)
|
||
|
|
|
||
|
|
except Exception as e:
|
||
|
|
consecutive_failures += 1
|
||
|
|
backoff_multiplier = min(backoff_multiplier * 1.5, 8.0)
|
||
|
|
logger.error(
|
||
|
|
"poll_error",
|
||
|
|
error=str(e),
|
||
|
|
consecutive_failures=consecutive_failures,
|
||
|
|
)
|
||
|
|
|
||
|
|
# Calculate next poll interval
|
||
|
|
interval = self.settings.poll_interval * backoff_multiplier
|
||
|
|
# Add jitter (0-25% of interval)
|
||
|
|
interval += random.uniform(0, interval * 0.25)
|
||
|
|
|
||
|
|
# Wait for next poll or shutdown
|
||
|
|
try:
|
||
|
|
await asyncio.wait_for(
|
||
|
|
self._shutdown_event.wait(),
|
||
|
|
timeout=interval,
|
||
|
|
)
|
||
|
|
break # Shutdown requested
|
||
|
|
except asyncio.TimeoutError:
|
||
|
|
pass # Normal timeout, continue polling
|
||
|
|
|
||
|
|
# Wait for active tasks to complete
|
||
|
|
if self._active_tasks:
|
||
|
|
logger.info("waiting_for_active_tasks", count=len(self._active_tasks))
|
||
|
|
# Give tasks a grace period
|
||
|
|
await asyncio.sleep(5)
|
||
|
|
|
||
|
|
logger.info("poll_loop_stopped")
|
||
|
|
|
||
|
|
async def _execute_task(self, task: Task) -> None:
|
||
|
|
"""Execute a single task with concurrency control.
|
||
|
|
|
||
|
|
Args:
|
||
|
|
task: Task to execute
|
||
|
|
"""
|
||
|
|
# Acquire semaphore for concurrency control
|
||
|
|
async with self._semaphore:
|
||
|
|
self._active_tasks.add(task.id)
|
||
|
|
|
||
|
|
try:
|
||
|
|
await self._run_task(task)
|
||
|
|
finally:
|
||
|
|
self._active_tasks.discard(task.id)
|
||
|
|
|
||
|
|
async def _run_task(self, task: Task) -> None:
|
||
|
|
"""Run task execution and handle results.
|
||
|
|
|
||
|
|
Args:
|
||
|
|
task: Task to execute
|
||
|
|
"""
|
||
|
|
start_time = time.time()
|
||
|
|
|
||
|
|
logger.info(
|
||
|
|
"task_started",
|
||
|
|
task_id=task.id,
|
||
|
|
task_type=task.type,
|
||
|
|
tenant_id=task.tenant_id,
|
||
|
|
)
|
||
|
|
|
||
|
|
# Send start event
|
||
|
|
await self.client.send_event(
|
||
|
|
EventLevel.INFO,
|
||
|
|
f"Task started: {task.type}",
|
||
|
|
task_id=task.id,
|
||
|
|
metadata={"payload_keys": list(task.payload.keys())},
|
||
|
|
)
|
||
|
|
|
||
|
|
# Mark task as in progress
|
||
|
|
await self.client.update_task(task.id, TaskStatus.RUNNING)
|
||
|
|
|
||
|
|
try:
|
||
|
|
# Get executor for task type
|
||
|
|
executor = get_executor(task.type)
|
||
|
|
|
||
|
|
# Execute task
|
||
|
|
result = await executor.execute(task.payload)
|
||
|
|
|
||
|
|
duration_ms = (time.time() - start_time) * 1000
|
||
|
|
|
||
|
|
if result.success:
|
||
|
|
logger.info(
|
||
|
|
"task_completed",
|
||
|
|
task_id=task.id,
|
||
|
|
task_type=task.type,
|
||
|
|
duration_ms=duration_ms,
|
||
|
|
)
|
||
|
|
|
||
|
|
await self.client.update_task(
|
||
|
|
task.id,
|
||
|
|
TaskStatus.COMPLETED,
|
||
|
|
result=result.data,
|
||
|
|
)
|
||
|
|
|
||
|
|
await self.client.send_event(
|
||
|
|
EventLevel.INFO,
|
||
|
|
f"Task completed: {task.type}",
|
||
|
|
task_id=task.id,
|
||
|
|
metadata={"duration_ms": duration_ms},
|
||
|
|
)
|
||
|
|
else:
|
||
|
|
logger.warning(
|
||
|
|
"task_failed",
|
||
|
|
task_id=task.id,
|
||
|
|
task_type=task.type,
|
||
|
|
error=result.error,
|
||
|
|
duration_ms=duration_ms,
|
||
|
|
)
|
||
|
|
|
||
|
|
await self.client.update_task(
|
||
|
|
task.id,
|
||
|
|
TaskStatus.FAILED,
|
||
|
|
result=result.data,
|
||
|
|
error=result.error,
|
||
|
|
)
|
||
|
|
|
||
|
|
await self.client.send_event(
|
||
|
|
EventLevel.ERROR,
|
||
|
|
f"Task failed: {task.type}",
|
||
|
|
task_id=task.id,
|
||
|
|
metadata={"error": result.error, "duration_ms": duration_ms},
|
||
|
|
)
|
||
|
|
|
||
|
|
except ValueError as e:
|
||
|
|
# Unknown task type or validation error
|
||
|
|
duration_ms = (time.time() - start_time) * 1000
|
||
|
|
error_msg = str(e)
|
||
|
|
|
||
|
|
logger.error(
|
||
|
|
"task_validation_error",
|
||
|
|
task_id=task.id,
|
||
|
|
task_type=task.type,
|
||
|
|
error=error_msg,
|
||
|
|
)
|
||
|
|
|
||
|
|
await self.client.update_task(
|
||
|
|
task.id,
|
||
|
|
TaskStatus.FAILED,
|
||
|
|
error=error_msg,
|
||
|
|
)
|
||
|
|
|
||
|
|
await self.client.send_event(
|
||
|
|
EventLevel.ERROR,
|
||
|
|
f"Task validation failed: {task.type}",
|
||
|
|
task_id=task.id,
|
||
|
|
metadata={"error": error_msg},
|
||
|
|
)
|
||
|
|
|
||
|
|
except Exception as e:
|
||
|
|
# Unexpected error
|
||
|
|
duration_ms = (time.time() - start_time) * 1000
|
||
|
|
error_msg = str(e)
|
||
|
|
tb = traceback.format_exc()
|
||
|
|
|
||
|
|
logger.error(
|
||
|
|
"task_exception",
|
||
|
|
task_id=task.id,
|
||
|
|
task_type=task.type,
|
||
|
|
error=error_msg,
|
||
|
|
traceback=tb,
|
||
|
|
)
|
||
|
|
|
||
|
|
await self.client.update_task(
|
||
|
|
task.id,
|
||
|
|
TaskStatus.FAILED,
|
||
|
|
error=error_msg,
|
||
|
|
)
|
||
|
|
|
||
|
|
await self.client.send_event(
|
||
|
|
EventLevel.ERROR,
|
||
|
|
f"Task exception: {task.type}",
|
||
|
|
task_id=task.id,
|
||
|
|
metadata={"error": error_msg, "traceback": tb[:500]},
|
||
|
|
)
|
||
|
|
|
||
|
|
async def shutdown(self) -> None:
|
||
|
|
"""Initiate graceful shutdown."""
|
||
|
|
logger.info("task_manager_shutdown_initiated")
|
||
|
|
self._shutdown_event.set()
|