83 lines
2.4 KiB
Python
83 lines
2.4 KiB
Python
"""Domain filtering and allowlist validation."""
|
|
|
|
import fnmatch
|
|
import re
|
|
from urllib.parse import urlparse
|
|
|
|
|
|
class DomainFilter:
|
|
"""
|
|
Validates URLs against a domain allowlist.
|
|
|
|
Supports:
|
|
- Exact domain matching: "example.com"
|
|
- Wildcard subdomains: "*.example.com"
|
|
- Domains with ports: "example.com:8443"
|
|
"""
|
|
|
|
def __init__(self, allowed_domains: list[str]):
|
|
"""
|
|
Initialize the domain filter.
|
|
|
|
Args:
|
|
allowed_domains: List of allowed domain patterns
|
|
"""
|
|
if not allowed_domains:
|
|
raise ValueError("allowed_domains cannot be empty")
|
|
|
|
self.allowed_domains = allowed_domains
|
|
self._patterns = self._compile_patterns(allowed_domains)
|
|
|
|
def _compile_patterns(self, domains: list[str]) -> list[re.Pattern]:
|
|
"""Compile domain patterns into regex for efficient matching."""
|
|
patterns = []
|
|
for domain in domains:
|
|
# Convert wildcard pattern to regex
|
|
# *.example.com -> matches any subdomain of example.com
|
|
if domain.startswith("*."):
|
|
# Match the exact domain or any subdomain
|
|
base = re.escape(domain[2:])
|
|
pattern = rf"^([a-zA-Z0-9-]+\.)*{base}$"
|
|
else:
|
|
# Exact match
|
|
pattern = rf"^{re.escape(domain)}$"
|
|
patterns.append(re.compile(pattern, re.IGNORECASE))
|
|
return patterns
|
|
|
|
def is_allowed(self, url: str) -> bool:
|
|
"""
|
|
Check if a URL's domain is in the allowlist.
|
|
|
|
Args:
|
|
url: The URL to check
|
|
|
|
Returns:
|
|
True if the domain is allowed, False otherwise
|
|
"""
|
|
try:
|
|
parsed = urlparse(url)
|
|
host = parsed.netloc
|
|
|
|
# Include port if present
|
|
if not host:
|
|
return False
|
|
|
|
# Check against all patterns
|
|
for pattern in self._patterns:
|
|
if pattern.match(host):
|
|
return True
|
|
|
|
return False
|
|
|
|
except Exception:
|
|
return False
|
|
|
|
def get_blocked_reason(self, url: str) -> str:
|
|
"""Get a human-readable reason for why a URL was blocked."""
|
|
try:
|
|
parsed = urlparse(url)
|
|
host = parsed.netloc
|
|
return f"Domain '{host}' not in allowlist: {self.allowed_domains}"
|
|
except Exception:
|
|
return f"Invalid URL: {url}"
|