Add security regression tests for all hardening fixes
- New: test_security_headers.py — verify all headers present, HSTS conditional on HTTPS - New: test_passkey_error_handling.py — generic vs user-facing error messages - New: test_poi_validation.py — field length and coordinate range constraints - Extend test_rate_limiter.py — client IP depth selection, in-memory fallback enforcement - Extend test_models.py — sqm range validation - Extend test_task_service.py — IDOR 404, ownership 200, traceback suppression in production
This commit is contained in:
parent
727dd537ef
commit
492921424e
6 changed files with 365 additions and 0 deletions
|
|
@ -445,6 +445,15 @@ class TestQueryParametersValidation:
|
|||
max_bedrooms=3,
|
||||
)
|
||||
|
||||
def test_invalid_sqm_range_raises(self) -> None:
|
||||
"""min_sqm > max_sqm should raise ValidationError."""
|
||||
with pytest.raises(ValidationError, match="min_sqm.*must be <= max_sqm"):
|
||||
QueryParameters(
|
||||
listing_type=ListingType.RENT,
|
||||
min_sqm=100,
|
||||
max_sqm=50,
|
||||
)
|
||||
|
||||
|
||||
class TestDestinationMode:
|
||||
"""Tests for DestinationMode."""
|
||||
|
|
|
|||
71
tests/unit/test_passkey_error_handling.py
Normal file
71
tests/unit/test_passkey_error_handling.py
Normal file
|
|
@ -0,0 +1,71 @@
|
|||
"""Unit tests for passkey route error handling."""
|
||||
from unittest.mock import patch, MagicMock
|
||||
|
||||
import pytest
|
||||
from starlette.testclient import TestClient
|
||||
|
||||
# We need to test through the FastAPI app or build a minimal test client
|
||||
from api.passkey_routes import passkey_router
|
||||
from fastapi import FastAPI
|
||||
|
||||
|
||||
def _build_app() -> FastAPI:
|
||||
app = FastAPI()
|
||||
app.include_router(passkey_router)
|
||||
return app
|
||||
|
||||
|
||||
class TestPasskeyErrorHandling:
|
||||
"""Tests that passkey routes return generic error messages for internal exceptions."""
|
||||
|
||||
@patch("api.passkey_routes.passkey_service")
|
||||
@patch("api.passkey_routes.UserRepository")
|
||||
def test_register_begin_internal_error_returns_generic_message(
|
||||
self, mock_user_repo: MagicMock, mock_service: MagicMock
|
||||
) -> None:
|
||||
mock_service.begin_registration.side_effect = RuntimeError("DB connection lost")
|
||||
client = TestClient(_build_app())
|
||||
resp = client.post("/api/passkey/register/begin", json={"email": "test@example.com"})
|
||||
assert resp.status_code == 400
|
||||
assert resp.json()["detail"] == "Registration failed. Please try again."
|
||||
assert "DB connection lost" not in resp.json()["detail"]
|
||||
|
||||
@patch("api.passkey_routes.passkey_service")
|
||||
@patch("api.passkey_routes.UserRepository")
|
||||
def test_register_begin_value_error_returns_user_message(
|
||||
self, mock_user_repo: MagicMock, mock_service: MagicMock
|
||||
) -> None:
|
||||
mock_service.begin_registration.side_effect = ValueError("Email already registered")
|
||||
client = TestClient(_build_app())
|
||||
resp = client.post("/api/passkey/register/begin", json={"email": "test@example.com"})
|
||||
assert resp.status_code == 400
|
||||
assert resp.json()["detail"] == "Email already registered"
|
||||
|
||||
@patch("api.passkey_routes.passkey_service")
|
||||
@patch("api.passkey_routes.UserRepository")
|
||||
def test_login_complete_internal_error_returns_generic_message(
|
||||
self, mock_user_repo: MagicMock, mock_service: MagicMock
|
||||
) -> None:
|
||||
mock_service.complete_authentication.side_effect = RuntimeError("Crypto failure")
|
||||
client = TestClient(_build_app())
|
||||
resp = client.post(
|
||||
"/api/passkey/login/complete",
|
||||
json={"session_id": "abc", "credential": {"id": "x"}},
|
||||
)
|
||||
assert resp.status_code == 400
|
||||
assert resp.json()["detail"] == "Login could not be completed."
|
||||
assert "Crypto failure" not in resp.json()["detail"]
|
||||
|
||||
@patch("api.passkey_routes.passkey_service")
|
||||
@patch("api.passkey_routes.UserRepository")
|
||||
def test_login_complete_value_error_returns_user_message(
|
||||
self, mock_user_repo: MagicMock, mock_service: MagicMock
|
||||
) -> None:
|
||||
mock_service.complete_authentication.side_effect = ValueError("Invalid credential")
|
||||
client = TestClient(_build_app())
|
||||
resp = client.post(
|
||||
"/api/passkey/login/complete",
|
||||
json={"session_id": "abc", "credential": {"id": "x"}},
|
||||
)
|
||||
assert resp.status_code == 400
|
||||
assert resp.json()["detail"] == "Invalid credential"
|
||||
58
tests/unit/test_poi_validation.py
Normal file
58
tests/unit/test_poi_validation.py
Normal file
|
|
@ -0,0 +1,58 @@
|
|||
"""Unit tests for POI request validation."""
|
||||
import pytest
|
||||
from pydantic import ValidationError
|
||||
|
||||
from api.poi_routes import CreatePOIRequest, UpdatePOIRequest
|
||||
|
||||
|
||||
class TestCreatePOIValidation:
|
||||
"""Tests for CreatePOIRequest field validation."""
|
||||
|
||||
def test_valid_request(self) -> None:
|
||||
req = CreatePOIRequest(name="Office", address="123 Main St", latitude=51.5, longitude=-0.1)
|
||||
assert req.name == "Office"
|
||||
|
||||
def test_name_too_long(self) -> None:
|
||||
with pytest.raises(ValidationError):
|
||||
CreatePOIRequest(name="A" * 201, address="addr", latitude=0, longitude=0)
|
||||
|
||||
def test_address_too_long(self) -> None:
|
||||
with pytest.raises(ValidationError):
|
||||
CreatePOIRequest(name="ok", address="A" * 501, latitude=0, longitude=0)
|
||||
|
||||
def test_latitude_too_high(self) -> None:
|
||||
with pytest.raises(ValidationError):
|
||||
CreatePOIRequest(name="ok", address="addr", latitude=91.0, longitude=0)
|
||||
|
||||
def test_latitude_too_low(self) -> None:
|
||||
with pytest.raises(ValidationError):
|
||||
CreatePOIRequest(name="ok", address="addr", latitude=-91.0, longitude=0)
|
||||
|
||||
def test_longitude_too_high(self) -> None:
|
||||
with pytest.raises(ValidationError):
|
||||
CreatePOIRequest(name="ok", address="addr", latitude=0, longitude=181.0)
|
||||
|
||||
def test_longitude_too_low(self) -> None:
|
||||
with pytest.raises(ValidationError):
|
||||
CreatePOIRequest(name="ok", address="addr", latitude=0, longitude=-181.0)
|
||||
|
||||
|
||||
class TestUpdatePOIValidation:
|
||||
"""Tests for UpdatePOIRequest field validation."""
|
||||
|
||||
def test_valid_partial_update(self) -> None:
|
||||
req = UpdatePOIRequest(name="New Name")
|
||||
assert req.name == "New Name"
|
||||
assert req.latitude is None
|
||||
|
||||
def test_name_too_long(self) -> None:
|
||||
with pytest.raises(ValidationError):
|
||||
UpdatePOIRequest(name="A" * 201)
|
||||
|
||||
def test_latitude_out_of_range(self) -> None:
|
||||
with pytest.raises(ValidationError):
|
||||
UpdatePOIRequest(latitude=91.0)
|
||||
|
||||
def test_longitude_out_of_range(self) -> None:
|
||||
with pytest.raises(ValidationError):
|
||||
UpdatePOIRequest(longitude=181.0)
|
||||
|
|
@ -30,6 +30,7 @@ def _make_config(**overrides: object) -> RateLimitConfig:
|
|||
"geojson_stream_batch_size_cap": 200,
|
||||
"rate_limit_redis_db": 3,
|
||||
"metrics_allowed_ips": "127.0.0.1",
|
||||
"trusted_proxy_depth": 1,
|
||||
}
|
||||
defaults.update(overrides)
|
||||
return RateLimitConfig(**defaults) # type: ignore[arg-type]
|
||||
|
|
@ -236,3 +237,92 @@ class TestRateLimitConfig:
|
|||
# Should fall back to default
|
||||
assert config.endpoint_limits["/api/listing"].max_requests == 30
|
||||
assert config.endpoint_limits["/api/listing"].window_seconds == 60
|
||||
|
||||
|
||||
from api.rate_limiter import _client_ip
|
||||
|
||||
|
||||
class TestClientIp:
|
||||
"""Tests for _client_ip with trusted proxy depth."""
|
||||
|
||||
def test_uses_rightmost_ip_with_depth_1(self) -> None:
|
||||
scope = {
|
||||
"type": "http",
|
||||
"headers": [(b"x-forwarded-for", b"spoofed.ip, proxy1, real.ip")],
|
||||
}
|
||||
request = Request(scope)
|
||||
assert _client_ip(request, depth=1) == "real.ip"
|
||||
|
||||
def test_uses_second_from_right_with_depth_2(self) -> None:
|
||||
scope = {
|
||||
"type": "http",
|
||||
"headers": [(b"x-forwarded-for", b"spoofed.ip, real.ip, proxy1")],
|
||||
}
|
||||
request = Request(scope)
|
||||
assert _client_ip(request, depth=2) == "real.ip"
|
||||
|
||||
def test_falls_back_to_connection_ip(self) -> None:
|
||||
scope = {
|
||||
"type": "http",
|
||||
"headers": [],
|
||||
"client": ("192.168.1.1", 12345),
|
||||
}
|
||||
request = Request(scope)
|
||||
assert _client_ip(request) == "192.168.1.1"
|
||||
|
||||
def test_no_client_returns_unknown(self) -> None:
|
||||
scope = {"type": "http", "headers": []}
|
||||
request = Request(scope)
|
||||
assert _client_ip(request) == "unknown"
|
||||
|
||||
def test_single_ip_with_depth_1(self) -> None:
|
||||
scope = {
|
||||
"type": "http",
|
||||
"headers": [(b"x-forwarded-for", b"single.ip")],
|
||||
}
|
||||
request = Request(scope)
|
||||
assert _client_ip(request, depth=1) == "single.ip"
|
||||
|
||||
|
||||
class TestInMemoryFallback:
|
||||
"""Tests for in-memory rate limit fallback when Redis is unavailable."""
|
||||
|
||||
@mock.patch("api.rate_limiter._get_rate_limit_redis")
|
||||
def test_fallback_enforces_limits_when_redis_unavailable(self, mock_get_redis: mock.MagicMock) -> None:
|
||||
"""When Redis is unavailable at startup, in-memory fallback should enforce limits."""
|
||||
import redis as redis_lib
|
||||
mock_get_redis.side_effect = redis_lib.RedisError("connection refused")
|
||||
|
||||
config = _make_config()
|
||||
app = _build_app(config)
|
||||
client = TestClient(app)
|
||||
|
||||
# Should allow first 3 requests (limit=3)
|
||||
for i in range(3):
|
||||
resp = client.get("/api/listing")
|
||||
assert resp.status_code == 200, f"Request {i+1} should be allowed"
|
||||
|
||||
# 4th request should be rate limited
|
||||
resp = client.get("/api/listing")
|
||||
assert resp.status_code == 429
|
||||
|
||||
@mock.patch("api.rate_limiter._get_rate_limit_redis")
|
||||
def test_fallback_activates_on_redis_error_during_request(self, mock_get_redis: mock.MagicMock) -> None:
|
||||
"""When Redis errors during request handling, fallback should activate."""
|
||||
import redis as redis_lib
|
||||
|
||||
mock_redis = mock.MagicMock()
|
||||
mock_pipe = mock.MagicMock()
|
||||
mock_pipe.execute.side_effect = redis_lib.RedisError("connection lost")
|
||||
mock_redis.pipeline.return_value = mock_pipe
|
||||
mock_redis.ping.return_value = True
|
||||
mock_get_redis.return_value = mock_redis
|
||||
|
||||
config = _make_config()
|
||||
app = _build_app(config)
|
||||
client = TestClient(app)
|
||||
|
||||
# First request should succeed via fallback
|
||||
resp = client.get("/api/listing")
|
||||
assert resp.status_code == 200
|
||||
assert "X-RateLimit-Limit" in resp.headers
|
||||
|
|
|
|||
56
tests/unit/test_security_headers.py
Normal file
56
tests/unit/test_security_headers.py
Normal file
|
|
@ -0,0 +1,56 @@
|
|||
"""Unit tests for api/security_headers.py."""
|
||||
from starlette.testclient import TestClient
|
||||
from starlette.applications import Starlette
|
||||
from starlette.requests import Request
|
||||
from starlette.responses import JSONResponse
|
||||
from starlette.routing import Route
|
||||
|
||||
from api.security_headers import SecurityHeadersMiddleware
|
||||
|
||||
|
||||
async def _ok_endpoint(request: Request) -> JSONResponse:
|
||||
return JSONResponse({"ok": True})
|
||||
|
||||
|
||||
def _build_app() -> Starlette:
|
||||
app = Starlette(routes=[Route("/test", _ok_endpoint)])
|
||||
app.add_middleware(SecurityHeadersMiddleware)
|
||||
return app
|
||||
|
||||
|
||||
class TestSecurityHeaders:
|
||||
"""Tests for SecurityHeadersMiddleware."""
|
||||
|
||||
def test_x_content_type_options(self) -> None:
|
||||
client = TestClient(_build_app())
|
||||
resp = client.get("/test")
|
||||
assert resp.headers["X-Content-Type-Options"] == "nosniff"
|
||||
|
||||
def test_x_frame_options(self) -> None:
|
||||
client = TestClient(_build_app())
|
||||
resp = client.get("/test")
|
||||
assert resp.headers["X-Frame-Options"] == "DENY"
|
||||
|
||||
def test_referrer_policy(self) -> None:
|
||||
client = TestClient(_build_app())
|
||||
resp = client.get("/test")
|
||||
assert resp.headers["Referrer-Policy"] == "strict-origin-when-cross-origin"
|
||||
|
||||
def test_content_security_policy(self) -> None:
|
||||
client = TestClient(_build_app())
|
||||
resp = client.get("/test")
|
||||
assert "Content-Security-Policy" in resp.headers
|
||||
csp = resp.headers["Content-Security-Policy"]
|
||||
assert "default-src 'self'" in csp
|
||||
assert "frame-ancestors 'none'" in csp
|
||||
|
||||
def test_hsts_set_for_https(self) -> None:
|
||||
client = TestClient(_build_app())
|
||||
resp = client.get("/test", headers={"x-forwarded-proto": "https"})
|
||||
assert "Strict-Transport-Security" in resp.headers
|
||||
assert "max-age=" in resp.headers["Strict-Transport-Security"]
|
||||
|
||||
def test_hsts_not_set_for_http(self) -> None:
|
||||
client = TestClient(_build_app())
|
||||
resp = client.get("/test")
|
||||
assert "Strict-Transport-Security" not in resp.headers
|
||||
|
|
@ -304,3 +304,84 @@ class TestClearAllTasks:
|
|||
# Should not raise despite revoke failure
|
||||
count = clear_all_tasks("test@example.com", revoke=True)
|
||||
assert count == 1
|
||||
|
||||
|
||||
from fastapi.testclient import TestClient
|
||||
|
||||
|
||||
class TestTaskStatusSecurity:
|
||||
"""Tests for task status endpoint security (IDOR, traceback suppression)."""
|
||||
|
||||
def _get_client(self) -> TestClient:
|
||||
"""Create test client with mocked auth."""
|
||||
from api.app import app
|
||||
from api.auth import get_current_user, User
|
||||
|
||||
mock_user = User(sub="test", email="test@example.com", name="Test")
|
||||
app.dependency_overrides[get_current_user] = lambda: mock_user
|
||||
client = TestClient(app, raise_server_exceptions=False)
|
||||
return client
|
||||
|
||||
@patch("api.app.task_service")
|
||||
def test_returns_404_for_unowned_task(self, mock_task_service: MagicMock) -> None:
|
||||
mock_task_service.get_user_tasks.return_value = ["task-abc"]
|
||||
client = self._get_client()
|
||||
try:
|
||||
resp = client.get("/api/task_status", params={"task_id": "task-xyz"})
|
||||
assert resp.status_code == 404
|
||||
finally:
|
||||
from api.app import app
|
||||
from api.auth import get_current_user
|
||||
app.dependency_overrides.pop(get_current_user, None)
|
||||
|
||||
@patch("api.app.task_service")
|
||||
def test_returns_200_for_owned_task(self, mock_task_service: MagicMock) -> None:
|
||||
mock_status = MagicMock()
|
||||
mock_status.task_id = "task-abc"
|
||||
mock_status.status = "SUCCESS"
|
||||
mock_status.result = None
|
||||
mock_status.progress = None
|
||||
mock_status.processed = None
|
||||
mock_status.total = None
|
||||
mock_status.message = None
|
||||
mock_status.error = None
|
||||
mock_status.traceback = None
|
||||
|
||||
mock_task_service.get_user_tasks.return_value = ["task-abc"]
|
||||
mock_task_service.get_task_status.return_value = mock_status
|
||||
client = self._get_client()
|
||||
try:
|
||||
resp = client.get("/api/task_status", params={"task_id": "task-abc"})
|
||||
assert resp.status_code == 200
|
||||
finally:
|
||||
from api.app import app
|
||||
from api.auth import get_current_user
|
||||
app.dependency_overrides.pop(get_current_user, None)
|
||||
|
||||
@patch("api.app.APP_ENV", "production")
|
||||
@patch("api.app.task_service")
|
||||
def test_traceback_suppressed_in_production(self, mock_task_service: MagicMock) -> None:
|
||||
mock_status = MagicMock()
|
||||
mock_status.task_id = "task-abc"
|
||||
mock_status.status = "FAILURE"
|
||||
mock_status.result = None
|
||||
mock_status.progress = None
|
||||
mock_status.processed = None
|
||||
mock_status.total = None
|
||||
mock_status.message = None
|
||||
mock_status.error = "Internal error"
|
||||
mock_status.traceback = "Traceback (most recent call last)..."
|
||||
|
||||
mock_task_service.get_user_tasks.return_value = ["task-abc"]
|
||||
mock_task_service.get_task_status.return_value = mock_status
|
||||
client = self._get_client()
|
||||
try:
|
||||
resp = client.get("/api/task_status", params={"task_id": "task-abc"})
|
||||
assert resp.status_code == 200
|
||||
data = resp.json()
|
||||
assert data["traceback"] is None
|
||||
assert data["error"] is None
|
||||
finally:
|
||||
from api.app import app
|
||||
from api.auth import get_current_user
|
||||
app.dependency_overrides.pop(get_current_user, None)
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue