397 lines
11 KiB
Python
397 lines
11 KiB
Python
"""Tests for task endpoints with authentication."""
|
|
|
|
import uuid
|
|
|
|
import pytest
|
|
from fastapi import HTTPException
|
|
from sqlalchemy import select
|
|
from sqlalchemy.ext.asyncio import AsyncSession
|
|
|
|
from app.models.agent import Agent
|
|
from app.models.task import Task, TaskStatus
|
|
from app.models.tenant import Tenant
|
|
from app.routes.tasks import (
|
|
get_next_pending_task,
|
|
get_next_task_endpoint,
|
|
update_task_endpoint,
|
|
)
|
|
from app.schemas.task import TaskUpdate
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
class TestGetNextPendingTask:
|
|
"""Tests for the get_next_pending_task helper function."""
|
|
|
|
async def test_returns_pending_task(
|
|
self,
|
|
db: AsyncSession,
|
|
test_tenant: Tenant,
|
|
test_agent_with_secret: tuple[Agent, str],
|
|
):
|
|
"""Returns oldest pending task for agent's tenant."""
|
|
agent, _ = test_agent_with_secret
|
|
|
|
# Create tasks
|
|
task1 = Task(
|
|
tenant_id=test_tenant.id,
|
|
type="TEST",
|
|
payload={"order": 1},
|
|
status=TaskStatus.PENDING.value,
|
|
)
|
|
task2 = Task(
|
|
tenant_id=test_tenant.id,
|
|
type="TEST",
|
|
payload={"order": 2},
|
|
status=TaskStatus.PENDING.value,
|
|
)
|
|
db.add_all([task1, task2])
|
|
await db.commit()
|
|
|
|
result = await get_next_pending_task(db, agent)
|
|
|
|
# Should return oldest (task1)
|
|
assert result is not None
|
|
assert result.payload["order"] == 1
|
|
|
|
async def test_filters_by_tenant(
|
|
self,
|
|
db: AsyncSession,
|
|
test_tenant: Tenant,
|
|
test_agent_with_secret: tuple[Agent, str],
|
|
):
|
|
"""Only returns tasks for agent's tenant."""
|
|
agent, _ = test_agent_with_secret
|
|
|
|
# Create tenant for another tenant
|
|
other_tenant = Tenant(id=uuid.uuid4(), name="Other Tenant")
|
|
db.add(other_tenant)
|
|
|
|
# Create task for other tenant
|
|
other_task = Task(
|
|
tenant_id=other_tenant.id,
|
|
type="TEST",
|
|
payload={},
|
|
status=TaskStatus.PENDING.value,
|
|
)
|
|
db.add(other_task)
|
|
await db.commit()
|
|
|
|
result = await get_next_pending_task(db, agent)
|
|
|
|
# Should not see other tenant's task
|
|
assert result is None
|
|
|
|
async def test_returns_none_when_empty(
|
|
self,
|
|
db: AsyncSession,
|
|
test_agent_with_secret: tuple[Agent, str],
|
|
):
|
|
"""Returns None when no pending tasks exist."""
|
|
agent, _ = test_agent_with_secret
|
|
|
|
result = await get_next_pending_task(db, agent)
|
|
|
|
assert result is None
|
|
|
|
async def test_skips_non_pending_tasks(
|
|
self,
|
|
db: AsyncSession,
|
|
test_tenant: Tenant,
|
|
test_agent_with_secret: tuple[Agent, str],
|
|
):
|
|
"""Only returns tasks with PENDING status."""
|
|
agent, _ = test_agent_with_secret
|
|
|
|
# Create running task
|
|
running_task = Task(
|
|
tenant_id=test_tenant.id,
|
|
type="TEST",
|
|
payload={},
|
|
status=TaskStatus.RUNNING.value,
|
|
)
|
|
db.add(running_task)
|
|
await db.commit()
|
|
|
|
result = await get_next_pending_task(db, agent)
|
|
|
|
assert result is None
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
class TestGetNextTaskEndpoint:
|
|
"""Tests for the GET /tasks/next endpoint."""
|
|
|
|
async def test_claims_task(
|
|
self,
|
|
db: AsyncSession,
|
|
test_tenant: Tenant,
|
|
test_agent_with_secret: tuple[Agent, str],
|
|
):
|
|
"""Successfully claims a pending task."""
|
|
agent, _ = test_agent_with_secret
|
|
|
|
task = Task(
|
|
tenant_id=test_tenant.id,
|
|
type="TEST",
|
|
payload={},
|
|
status=TaskStatus.PENDING.value,
|
|
)
|
|
db.add(task)
|
|
await db.commit()
|
|
|
|
result = await get_next_task_endpoint(db=db, current_agent=agent)
|
|
|
|
assert result is not None
|
|
assert result.status == TaskStatus.RUNNING.value
|
|
assert result.agent_id == agent.id
|
|
|
|
async def test_returns_none_when_empty(
|
|
self,
|
|
db: AsyncSession,
|
|
test_agent_with_secret: tuple[Agent, str],
|
|
):
|
|
"""Returns None when no tasks available."""
|
|
agent, _ = test_agent_with_secret
|
|
|
|
result = await get_next_task_endpoint(db=db, current_agent=agent)
|
|
|
|
assert result is None
|
|
|
|
async def test_tenant_isolation(
|
|
self,
|
|
db: AsyncSession,
|
|
test_tenant: Tenant,
|
|
test_agent_with_secret: tuple[Agent, str],
|
|
):
|
|
"""Agent can only claim tasks from its tenant."""
|
|
agent, _ = test_agent_with_secret
|
|
|
|
# Create task for different tenant
|
|
other_tenant = Tenant(id=uuid.uuid4(), name="Other")
|
|
db.add(other_tenant)
|
|
|
|
other_task = Task(
|
|
tenant_id=other_tenant.id,
|
|
type="TEST",
|
|
payload={},
|
|
status=TaskStatus.PENDING.value,
|
|
)
|
|
db.add(other_task)
|
|
await db.commit()
|
|
|
|
result = await get_next_task_endpoint(db=db, current_agent=agent)
|
|
|
|
# Should not see other tenant's task
|
|
assert result is None
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
class TestUpdateTaskEndpoint:
|
|
"""Tests for the PATCH /tasks/{task_id} endpoint."""
|
|
|
|
async def test_update_task_success(
|
|
self,
|
|
db: AsyncSession,
|
|
test_tenant: Tenant,
|
|
test_agent_with_secret: tuple[Agent, str],
|
|
):
|
|
"""Successfully update an assigned task."""
|
|
agent, _ = test_agent_with_secret
|
|
|
|
task = Task(
|
|
tenant_id=test_tenant.id,
|
|
type="TEST",
|
|
payload={},
|
|
status=TaskStatus.RUNNING.value,
|
|
agent_id=agent.id, # Assigned to this agent
|
|
)
|
|
db.add(task)
|
|
await db.commit()
|
|
|
|
update = TaskUpdate(status=TaskStatus.COMPLETED, result={"success": True})
|
|
|
|
result = await update_task_endpoint(
|
|
task_id=task.id, task_update=update, db=db, current_agent=agent
|
|
)
|
|
|
|
assert result.status == TaskStatus.COMPLETED.value
|
|
assert result.result == {"success": True}
|
|
|
|
async def test_update_task_not_found(
|
|
self,
|
|
db: AsyncSession,
|
|
test_agent_with_secret: tuple[Agent, str],
|
|
):
|
|
"""Returns 404 for non-existent task."""
|
|
agent, _ = test_agent_with_secret
|
|
fake_task_id = uuid.uuid4()
|
|
|
|
update = TaskUpdate(status=TaskStatus.COMPLETED)
|
|
|
|
with pytest.raises(HTTPException) as exc_info:
|
|
await update_task_endpoint(
|
|
task_id=fake_task_id, task_update=update, db=db, current_agent=agent
|
|
)
|
|
|
|
assert exc_info.value.status_code == 404
|
|
|
|
async def test_update_task_wrong_tenant(
|
|
self,
|
|
db: AsyncSession,
|
|
test_tenant: Tenant,
|
|
test_agent_with_secret: tuple[Agent, str],
|
|
):
|
|
"""Returns 403 when task belongs to different tenant."""
|
|
agent, _ = test_agent_with_secret
|
|
|
|
# Create task for different tenant
|
|
other_tenant = Tenant(id=uuid.uuid4(), name="Other")
|
|
db.add(other_tenant)
|
|
|
|
other_task = Task(
|
|
tenant_id=other_tenant.id,
|
|
type="TEST",
|
|
payload={},
|
|
status=TaskStatus.RUNNING.value,
|
|
agent_id=agent.id, # Even if assigned to agent
|
|
)
|
|
db.add(other_task)
|
|
await db.commit()
|
|
|
|
update = TaskUpdate(status=TaskStatus.COMPLETED)
|
|
|
|
with pytest.raises(HTTPException) as exc_info:
|
|
await update_task_endpoint(
|
|
task_id=other_task.id, task_update=update, db=db, current_agent=agent
|
|
)
|
|
|
|
assert exc_info.value.status_code == 403
|
|
assert "does not belong to this tenant" in str(exc_info.value.detail)
|
|
|
|
async def test_update_task_not_assigned(
|
|
self,
|
|
db: AsyncSession,
|
|
test_tenant: Tenant,
|
|
test_agent_with_secret: tuple[Agent, str],
|
|
):
|
|
"""Returns 403 when task is not assigned to requesting agent."""
|
|
agent, _ = test_agent_with_secret
|
|
|
|
# Create task assigned to different agent
|
|
other_agent_id = uuid.uuid4()
|
|
task = Task(
|
|
tenant_id=test_tenant.id,
|
|
type="TEST",
|
|
payload={},
|
|
status=TaskStatus.RUNNING.value,
|
|
agent_id=other_agent_id, # Different agent
|
|
)
|
|
db.add(task)
|
|
await db.commit()
|
|
|
|
update = TaskUpdate(status=TaskStatus.COMPLETED)
|
|
|
|
with pytest.raises(HTTPException) as exc_info:
|
|
await update_task_endpoint(
|
|
task_id=task.id, task_update=update, db=db, current_agent=agent
|
|
)
|
|
|
|
assert exc_info.value.status_code == 403
|
|
assert "not assigned to this agent" in str(exc_info.value.detail)
|
|
|
|
async def test_update_task_unassigned_forbidden(
|
|
self,
|
|
db: AsyncSession,
|
|
test_tenant: Tenant,
|
|
test_agent_with_secret: tuple[Agent, str],
|
|
):
|
|
"""Returns 403 when task has no agent_id assigned."""
|
|
agent, _ = test_agent_with_secret
|
|
|
|
task = Task(
|
|
tenant_id=test_tenant.id,
|
|
type="TEST",
|
|
payload={},
|
|
status=TaskStatus.PENDING.value,
|
|
agent_id=None, # Not assigned
|
|
)
|
|
db.add(task)
|
|
await db.commit()
|
|
|
|
update = TaskUpdate(status=TaskStatus.COMPLETED)
|
|
|
|
with pytest.raises(HTTPException) as exc_info:
|
|
await update_task_endpoint(
|
|
task_id=task.id, task_update=update, db=db, current_agent=agent
|
|
)
|
|
|
|
assert exc_info.value.status_code == 403
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
class TestSharedAgentBehavior:
|
|
"""Tests for agents without tenant_id (shared agents)."""
|
|
|
|
async def test_shared_agent_can_claim_any_task(
|
|
self, db: AsyncSession, test_tenant: Tenant
|
|
):
|
|
"""Shared agent (no tenant_id) can claim tasks from any tenant."""
|
|
# Create shared agent (no tenant_id)
|
|
shared_agent = Agent(
|
|
id=uuid.uuid4(),
|
|
name="shared-agent",
|
|
version="1.0.0",
|
|
status="online",
|
|
token="shared-token",
|
|
secret_hash="dummy-hash",
|
|
tenant_id=None,
|
|
)
|
|
db.add(shared_agent)
|
|
|
|
task = Task(
|
|
tenant_id=test_tenant.id,
|
|
type="TEST",
|
|
payload={},
|
|
status=TaskStatus.PENDING.value,
|
|
)
|
|
db.add(task)
|
|
await db.commit()
|
|
|
|
result = await get_next_task_endpoint(db=db, current_agent=shared_agent)
|
|
|
|
assert result is not None
|
|
assert result.agent_id == shared_agent.id
|
|
|
|
async def test_shared_agent_can_update_any_tenant_task(
|
|
self, db: AsyncSession, test_tenant: Tenant
|
|
):
|
|
"""Shared agent can update tasks from any tenant if assigned."""
|
|
shared_agent = Agent(
|
|
id=uuid.uuid4(),
|
|
name="shared-agent",
|
|
version="1.0.0",
|
|
status="online",
|
|
token="shared-token",
|
|
secret_hash="dummy-hash",
|
|
tenant_id=None,
|
|
)
|
|
db.add(shared_agent)
|
|
|
|
task = Task(
|
|
tenant_id=test_tenant.id,
|
|
type="TEST",
|
|
payload={},
|
|
status=TaskStatus.RUNNING.value,
|
|
agent_id=shared_agent.id,
|
|
)
|
|
db.add(task)
|
|
await db.commit()
|
|
|
|
update = TaskUpdate(status=TaskStatus.COMPLETED)
|
|
|
|
result = await update_task_endpoint(
|
|
task_id=task.id, task_update=update, db=db, current_agent=shared_agent
|
|
)
|
|
|
|
assert result.status == TaskStatus.COMPLETED.value
|