451 lines
18 KiB
Python
451 lines
18 KiB
Python
"""Unit tests for PlaywrightExecutor.
|
|
|
|
These tests focus on validation logic without launching browsers.
|
|
Browser-based integration tests are skipped by default (SKIP_BROWSER_TESTS=true).
|
|
"""
|
|
|
|
import os
|
|
import sys
|
|
from pathlib import Path
|
|
from unittest.mock import AsyncMock, MagicMock, patch
|
|
|
|
import pytest
|
|
|
|
# Mock playwright module before any imports that might use it
|
|
sys.modules["playwright"] = MagicMock()
|
|
sys.modules["playwright.async_api"] = MagicMock()
|
|
|
|
# Patch the logger before importing the executor
|
|
with patch("app.utils.logger.get_logger", return_value=MagicMock()):
|
|
from app.utils.validation import is_domain_allowed, validate_allowed_domains, ValidationError
|
|
|
|
|
|
class TestDomainValidation:
|
|
"""Test domain allowlist validation functions."""
|
|
|
|
# ==================== is_domain_allowed Tests ====================
|
|
|
|
def test_exact_domain_match(self):
|
|
"""Test exact domain matching."""
|
|
assert is_domain_allowed("https://cloud.example.com/path", ["cloud.example.com"]) is True
|
|
assert is_domain_allowed("https://cloud.example.com", ["cloud.example.com"]) is True
|
|
assert is_domain_allowed("http://cloud.example.com", ["cloud.example.com"]) is True
|
|
|
|
def test_exact_domain_no_match(self):
|
|
"""Test exact domain non-matching."""
|
|
assert is_domain_allowed("https://evil.com/path", ["cloud.example.com"]) is False
|
|
assert is_domain_allowed("https://sub.cloud.example.com", ["cloud.example.com"]) is False
|
|
|
|
def test_wildcard_subdomain_match(self):
|
|
"""Test wildcard subdomain matching."""
|
|
assert is_domain_allowed("https://sub.example.com", ["*.example.com"]) is True
|
|
assert is_domain_allowed("https://deep.sub.example.com", ["*.example.com"]) is True
|
|
assert is_domain_allowed("https://example.com", ["*.example.com"]) is True
|
|
|
|
def test_wildcard_subdomain_no_match(self):
|
|
"""Test wildcard subdomain non-matching."""
|
|
assert is_domain_allowed("https://evil.com", ["*.example.com"]) is False
|
|
assert is_domain_allowed("https://example.org", ["*.example.com"]) is False
|
|
|
|
def test_domain_with_port(self):
|
|
"""Test domain matching with port specification."""
|
|
assert is_domain_allowed("https://cloud.example.com:8443/path", ["cloud.example.com:8443"]) is True
|
|
assert is_domain_allowed("https://cloud.example.com:8443", ["cloud.example.com:8443"]) is True
|
|
# Wrong port should not match
|
|
assert is_domain_allowed("https://cloud.example.com:9000", ["cloud.example.com:8443"]) is False
|
|
# No port in URL should not match port-specific pattern
|
|
assert is_domain_allowed("https://cloud.example.com", ["cloud.example.com:8443"]) is False
|
|
|
|
def test_multiple_allowed_domains(self):
|
|
"""Test with multiple allowed domains."""
|
|
allowed = ["cloud.example.com", "mail.example.com", "*.internal.com"]
|
|
assert is_domain_allowed("https://cloud.example.com", allowed) is True
|
|
assert is_domain_allowed("https://mail.example.com", allowed) is True
|
|
assert is_domain_allowed("https://app.internal.com", allowed) is True
|
|
assert is_domain_allowed("https://evil.com", allowed) is False
|
|
|
|
def test_empty_inputs(self):
|
|
"""Test with empty inputs."""
|
|
assert is_domain_allowed("", ["example.com"]) is False
|
|
assert is_domain_allowed("https://example.com", []) is False
|
|
assert is_domain_allowed("", []) is False
|
|
|
|
def test_case_insensitive(self):
|
|
"""Test case-insensitive matching."""
|
|
assert is_domain_allowed("https://Cloud.Example.COM", ["cloud.example.com"]) is True
|
|
assert is_domain_allowed("https://cloud.example.com", ["Cloud.Example.COM"]) is True
|
|
|
|
# ==================== validate_allowed_domains Tests ====================
|
|
|
|
def test_validate_valid_domains(self):
|
|
"""Test validation of valid domain patterns."""
|
|
result = validate_allowed_domains(["example.com", "cloud.example.com"])
|
|
assert result == ["example.com", "cloud.example.com"]
|
|
|
|
def test_validate_wildcard_domains(self):
|
|
"""Test validation of wildcard domain patterns."""
|
|
result = validate_allowed_domains(["*.example.com", "*.internal.org"])
|
|
assert result == ["*.example.com", "*.internal.org"]
|
|
|
|
def test_validate_with_ports(self):
|
|
"""Test validation of domains with ports."""
|
|
result = validate_allowed_domains(["example.com:8080", "cloud.example.com:8443"])
|
|
assert result == ["example.com:8080", "cloud.example.com:8443"]
|
|
|
|
def test_validate_empty_list_raises(self):
|
|
"""Test that empty list raises ValidationError."""
|
|
with pytest.raises(ValidationError, match="cannot be empty"):
|
|
validate_allowed_domains([])
|
|
|
|
def test_validate_protocol_raises(self):
|
|
"""Test that domains with protocol raise ValidationError."""
|
|
with pytest.raises(ValidationError, match="should not include protocol"):
|
|
validate_allowed_domains(["https://example.com"])
|
|
|
|
def test_validate_invalid_wildcard_raises(self):
|
|
"""Test that invalid wildcards raise ValidationError."""
|
|
with pytest.raises(ValidationError, match="Wildcards must be at the start"):
|
|
validate_allowed_domains(["example.*.com"])
|
|
|
|
with pytest.raises(ValidationError, match="Wildcards must be at the start"):
|
|
validate_allowed_domains(["*"])
|
|
|
|
def test_validate_normalizes_case(self):
|
|
"""Test that validation normalizes to lowercase."""
|
|
result = validate_allowed_domains(["Example.COM", "CLOUD.Example.com"])
|
|
assert result == ["example.com", "cloud.example.com"]
|
|
|
|
|
|
class TestPlaywrightExecutor:
|
|
"""Test suite for PlaywrightExecutor."""
|
|
|
|
@pytest.fixture
|
|
def executor(self):
|
|
"""Create executor instance with mocked logger."""
|
|
with patch("app.executors.base.get_logger", return_value=MagicMock()):
|
|
from app.executors.playwright_executor import PlaywrightExecutor
|
|
return PlaywrightExecutor()
|
|
|
|
@pytest.fixture
|
|
def mock_settings(self, tmp_path):
|
|
"""Mock settings with temporary paths."""
|
|
settings = MagicMock()
|
|
settings.playwright_artifacts_dir = str(tmp_path / "playwright-artifacts")
|
|
settings.playwright_default_timeout_ms = 60000
|
|
settings.playwright_navigation_timeout_ms = 120000
|
|
return settings
|
|
|
|
# ==================== Validation Error Tests ====================
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_missing_scenario_field(self, executor, mock_settings):
|
|
"""Test that missing scenario field returns error."""
|
|
with patch("app.executors.playwright_executor.get_settings", return_value=mock_settings):
|
|
result = await executor.execute({
|
|
"inputs": {"base_url": "https://example.com"},
|
|
"options": {"allowed_domains": ["example.com"]}
|
|
})
|
|
|
|
assert result.success is False
|
|
assert "Missing required fields: scenario" in result.error
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_missing_inputs_field(self, executor, mock_settings):
|
|
"""Test that missing inputs field returns error."""
|
|
with patch("app.executors.playwright_executor.get_settings", return_value=mock_settings):
|
|
result = await executor.execute({
|
|
"scenario": "test_scenario",
|
|
"options": {"allowed_domains": ["example.com"]}
|
|
})
|
|
|
|
assert result.success is False
|
|
assert "Missing required fields: inputs" in result.error
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_missing_allowed_domains(self, executor, mock_settings):
|
|
"""Test that missing allowed_domains returns security error."""
|
|
with patch("app.executors.playwright_executor.get_settings", return_value=mock_settings):
|
|
result = await executor.execute({
|
|
"scenario": "test_scenario",
|
|
"inputs": {"base_url": "https://example.com"},
|
|
"options": {}
|
|
})
|
|
|
|
assert result.success is False
|
|
assert "allowed_domains" in result.error
|
|
assert "required" in result.error.lower()
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_missing_options_means_no_domains(self, executor, mock_settings):
|
|
"""Test that missing options dict means no allowed_domains."""
|
|
with patch("app.executors.playwright_executor.get_settings", return_value=mock_settings):
|
|
result = await executor.execute({
|
|
"scenario": "test_scenario",
|
|
"inputs": {"base_url": "https://example.com"},
|
|
})
|
|
|
|
assert result.success is False
|
|
assert "allowed_domains" in result.error
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_invalid_allowed_domains_format(self, executor, mock_settings):
|
|
"""Test that invalid domain patterns return error."""
|
|
with patch("app.executors.playwright_executor.get_settings", return_value=mock_settings):
|
|
result = await executor.execute({
|
|
"scenario": "test_scenario",
|
|
"inputs": {"base_url": "https://example.com"},
|
|
"options": {"allowed_domains": ["https://example.com"]} # Protocol not allowed
|
|
})
|
|
|
|
assert result.success is False
|
|
assert "Invalid allowed_domains" in result.error
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_unknown_scenario(self, executor, mock_settings):
|
|
"""Test that unknown scenario returns error with available list."""
|
|
with patch("app.executors.playwright_executor.get_settings", return_value=mock_settings):
|
|
result = await executor.execute({
|
|
"scenario": "nonexistent_scenario",
|
|
"inputs": {"base_url": "https://example.com"},
|
|
"options": {"allowed_domains": ["example.com"]}
|
|
})
|
|
|
|
assert result.success is False
|
|
assert "Unknown scenario" in result.error
|
|
assert "nonexistent_scenario" in result.error
|
|
assert "available_scenarios" in result.data
|
|
|
|
# ==================== Task Type Tests ====================
|
|
|
|
def test_task_type_is_playwright(self, executor):
|
|
"""Test that executor reports correct task type."""
|
|
assert executor.task_type == "PLAYWRIGHT"
|
|
|
|
|
|
class TestScenarioRegistry:
|
|
"""Test scenario registration and lookup."""
|
|
|
|
def test_register_and_get_scenario(self):
|
|
"""Test registering and retrieving a scenario."""
|
|
from app.playwright_scenarios import get_scenario, get_scenario_names, _SCENARIO_REGISTRY
|
|
from app.playwright_scenarios import register_scenario, BaseScenario, ScenarioResult
|
|
|
|
# Clear registry for clean test
|
|
original_registry = _SCENARIO_REGISTRY.copy()
|
|
_SCENARIO_REGISTRY.clear()
|
|
|
|
try:
|
|
@register_scenario
|
|
class TestScenario(BaseScenario):
|
|
@property
|
|
def name(self) -> str:
|
|
return "test_scenario"
|
|
|
|
@property
|
|
def required_inputs(self) -> list[str]:
|
|
return ["base_url"]
|
|
|
|
async def execute(self, page, inputs, options) -> ScenarioResult:
|
|
return ScenarioResult(success=True, data={})
|
|
|
|
# Should find the registered scenario
|
|
scenario = get_scenario("test_scenario")
|
|
assert scenario is not None
|
|
assert scenario.name == "test_scenario"
|
|
|
|
# Should be in the list
|
|
names = get_scenario_names()
|
|
assert "test_scenario" in names
|
|
|
|
finally:
|
|
# Restore original registry
|
|
_SCENARIO_REGISTRY.clear()
|
|
_SCENARIO_REGISTRY.update(original_registry)
|
|
|
|
def test_get_unknown_scenario_returns_none(self):
|
|
"""Test that unknown scenario lookup returns None."""
|
|
from app.playwright_scenarios import get_scenario
|
|
|
|
scenario = get_scenario("definitely_does_not_exist_xyz123")
|
|
assert scenario is None
|
|
|
|
|
|
class TestScenarioOptions:
|
|
"""Test ScenarioOptions dataclass."""
|
|
|
|
def test_default_values(self):
|
|
"""Test default option values."""
|
|
from app.playwright_scenarios import ScenarioOptions
|
|
|
|
options = ScenarioOptions()
|
|
assert options.timeout_ms == 60000
|
|
assert options.screenshot_on_failure is True
|
|
assert options.screenshot_on_success is False
|
|
assert options.save_trace is False
|
|
assert options.allowed_domains == []
|
|
assert options.artifacts_dir is None
|
|
|
|
def test_custom_values(self):
|
|
"""Test custom option values."""
|
|
from app.playwright_scenarios import ScenarioOptions
|
|
|
|
options = ScenarioOptions(
|
|
timeout_ms=30000,
|
|
screenshot_on_failure=False,
|
|
screenshot_on_success=True,
|
|
save_trace=True,
|
|
allowed_domains=["example.com"],
|
|
artifacts_dir=Path("/tmp/artifacts"),
|
|
)
|
|
assert options.timeout_ms == 30000
|
|
assert options.screenshot_on_failure is False
|
|
assert options.screenshot_on_success is True
|
|
assert options.save_trace is True
|
|
assert options.allowed_domains == ["example.com"]
|
|
assert options.artifacts_dir == Path("/tmp/artifacts")
|
|
|
|
def test_string_artifacts_dir_converted(self):
|
|
"""Test that string artifacts_dir is converted to Path."""
|
|
from app.playwright_scenarios import ScenarioOptions
|
|
|
|
options = ScenarioOptions(artifacts_dir="/tmp/artifacts")
|
|
assert isinstance(options.artifacts_dir, Path)
|
|
# Path separators differ by OS, just check it's a valid Path
|
|
assert options.artifacts_dir == Path("/tmp/artifacts")
|
|
|
|
|
|
class TestScenarioResult:
|
|
"""Test ScenarioResult dataclass."""
|
|
|
|
def test_success_result(self):
|
|
"""Test successful result creation."""
|
|
from app.playwright_scenarios import ScenarioResult
|
|
|
|
result = ScenarioResult(
|
|
success=True,
|
|
data={"setup": "complete"},
|
|
screenshots=["/tmp/success.png"],
|
|
)
|
|
assert result.success is True
|
|
assert result.data == {"setup": "complete"}
|
|
assert result.screenshots == ["/tmp/success.png"]
|
|
assert result.error is None
|
|
|
|
def test_failure_result(self):
|
|
"""Test failure result creation."""
|
|
from app.playwright_scenarios import ScenarioResult
|
|
|
|
result = ScenarioResult(
|
|
success=False,
|
|
data={},
|
|
error="Element not found",
|
|
)
|
|
assert result.success is False
|
|
assert result.error == "Element not found"
|
|
|
|
|
|
class TestBaseScenario:
|
|
"""Test BaseScenario ABC."""
|
|
|
|
def test_validate_inputs_missing(self):
|
|
"""Test input validation returns missing keys."""
|
|
from app.playwright_scenarios import BaseScenario, ScenarioResult
|
|
|
|
class TestScenario(BaseScenario):
|
|
@property
|
|
def name(self) -> str:
|
|
return "test"
|
|
|
|
@property
|
|
def required_inputs(self) -> list[str]:
|
|
return ["base_url", "username", "password"]
|
|
|
|
async def execute(self, page, inputs, options) -> ScenarioResult:
|
|
return ScenarioResult(success=True, data={})
|
|
|
|
scenario = TestScenario()
|
|
|
|
# Missing all inputs
|
|
missing = scenario.validate_inputs({})
|
|
assert "base_url" in missing
|
|
assert "username" in missing
|
|
assert "password" in missing
|
|
|
|
# Missing some inputs
|
|
missing = scenario.validate_inputs({"base_url": "https://example.com"})
|
|
assert "base_url" not in missing
|
|
assert "username" in missing
|
|
assert "password" in missing
|
|
|
|
# All inputs present
|
|
missing = scenario.validate_inputs({
|
|
"base_url": "https://example.com",
|
|
"username": "admin",
|
|
"password": "secret",
|
|
})
|
|
assert missing == []
|
|
|
|
def test_default_optional_inputs(self):
|
|
"""Test default optional inputs is empty."""
|
|
from app.playwright_scenarios import BaseScenario, ScenarioResult
|
|
|
|
class TestScenario(BaseScenario):
|
|
@property
|
|
def name(self) -> str:
|
|
return "test"
|
|
|
|
@property
|
|
def required_inputs(self) -> list[str]:
|
|
return ["base_url"]
|
|
|
|
async def execute(self, page, inputs, options) -> ScenarioResult:
|
|
return ScenarioResult(success=True, data={})
|
|
|
|
scenario = TestScenario()
|
|
assert scenario.optional_inputs == []
|
|
|
|
def test_default_description(self):
|
|
"""Test default description uses name."""
|
|
from app.playwright_scenarios import BaseScenario, ScenarioResult
|
|
|
|
class TestScenario(BaseScenario):
|
|
@property
|
|
def name(self) -> str:
|
|
return "my_test_scenario"
|
|
|
|
@property
|
|
def required_inputs(self) -> list[str]:
|
|
return []
|
|
|
|
async def execute(self, page, inputs, options) -> ScenarioResult:
|
|
return ScenarioResult(success=True, data={})
|
|
|
|
scenario = TestScenario()
|
|
assert "my_test_scenario" in scenario.description
|
|
|
|
|
|
# Skip browser tests by default
|
|
SKIP_BROWSER_TESTS = os.environ.get("SKIP_BROWSER_TESTS", "true").lower() == "true"
|
|
|
|
|
|
@pytest.mark.skipif(SKIP_BROWSER_TESTS, reason="Browser tests skipped (set SKIP_BROWSER_TESTS=false to run)")
|
|
class TestPlaywrightExecutorIntegration:
|
|
"""Integration tests that require a real browser.
|
|
|
|
These tests are skipped by default. Set SKIP_BROWSER_TESTS=false to run.
|
|
"""
|
|
|
|
@pytest.fixture
|
|
def executor(self):
|
|
"""Create executor instance."""
|
|
with patch("app.executors.base.get_logger", return_value=MagicMock()):
|
|
from app.executors.playwright_executor import PlaywrightExecutor
|
|
return PlaywrightExecutor()
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_domain_blocking_in_browser(self, executor, tmp_path):
|
|
"""Test that blocked domains are actually blocked in browser."""
|
|
# This would require a mock HTTP server and real browser
|
|
# Implementation deferred to manual testing
|
|
pass
|