diff --git a/tests/unit/test_models.py b/tests/unit/test_models.py index d71d12d..9a4f891 100644 --- a/tests/unit/test_models.py +++ b/tests/unit/test_models.py @@ -445,6 +445,15 @@ class TestQueryParametersValidation: max_bedrooms=3, ) + def test_invalid_sqm_range_raises(self) -> None: + """min_sqm > max_sqm should raise ValidationError.""" + with pytest.raises(ValidationError, match="min_sqm.*must be <= max_sqm"): + QueryParameters( + listing_type=ListingType.RENT, + min_sqm=100, + max_sqm=50, + ) + class TestDestinationMode: """Tests for DestinationMode.""" diff --git a/tests/unit/test_passkey_error_handling.py b/tests/unit/test_passkey_error_handling.py new file mode 100644 index 0000000..a991a5a --- /dev/null +++ b/tests/unit/test_passkey_error_handling.py @@ -0,0 +1,71 @@ +"""Unit tests for passkey route error handling.""" +from unittest.mock import patch, MagicMock + +import pytest +from starlette.testclient import TestClient + +# We need to test through the FastAPI app or build a minimal test client +from api.passkey_routes import passkey_router +from fastapi import FastAPI + + +def _build_app() -> FastAPI: + app = FastAPI() + app.include_router(passkey_router) + return app + + +class TestPasskeyErrorHandling: + """Tests that passkey routes return generic error messages for internal exceptions.""" + + @patch("api.passkey_routes.passkey_service") + @patch("api.passkey_routes.UserRepository") + def test_register_begin_internal_error_returns_generic_message( + self, mock_user_repo: MagicMock, mock_service: MagicMock + ) -> None: + mock_service.begin_registration.side_effect = RuntimeError("DB connection lost") + client = TestClient(_build_app()) + resp = client.post("/api/passkey/register/begin", json={"email": "test@example.com"}) + assert resp.status_code == 400 + assert resp.json()["detail"] == "Registration failed. Please try again." + assert "DB connection lost" not in resp.json()["detail"] + + @patch("api.passkey_routes.passkey_service") + @patch("api.passkey_routes.UserRepository") + def test_register_begin_value_error_returns_user_message( + self, mock_user_repo: MagicMock, mock_service: MagicMock + ) -> None: + mock_service.begin_registration.side_effect = ValueError("Email already registered") + client = TestClient(_build_app()) + resp = client.post("/api/passkey/register/begin", json={"email": "test@example.com"}) + assert resp.status_code == 400 + assert resp.json()["detail"] == "Email already registered" + + @patch("api.passkey_routes.passkey_service") + @patch("api.passkey_routes.UserRepository") + def test_login_complete_internal_error_returns_generic_message( + self, mock_user_repo: MagicMock, mock_service: MagicMock + ) -> None: + mock_service.complete_authentication.side_effect = RuntimeError("Crypto failure") + client = TestClient(_build_app()) + resp = client.post( + "/api/passkey/login/complete", + json={"session_id": "abc", "credential": {"id": "x"}}, + ) + assert resp.status_code == 400 + assert resp.json()["detail"] == "Login could not be completed." + assert "Crypto failure" not in resp.json()["detail"] + + @patch("api.passkey_routes.passkey_service") + @patch("api.passkey_routes.UserRepository") + def test_login_complete_value_error_returns_user_message( + self, mock_user_repo: MagicMock, mock_service: MagicMock + ) -> None: + mock_service.complete_authentication.side_effect = ValueError("Invalid credential") + client = TestClient(_build_app()) + resp = client.post( + "/api/passkey/login/complete", + json={"session_id": "abc", "credential": {"id": "x"}}, + ) + assert resp.status_code == 400 + assert resp.json()["detail"] == "Invalid credential" diff --git a/tests/unit/test_poi_validation.py b/tests/unit/test_poi_validation.py new file mode 100644 index 0000000..e1cfe62 --- /dev/null +++ b/tests/unit/test_poi_validation.py @@ -0,0 +1,58 @@ +"""Unit tests for POI request validation.""" +import pytest +from pydantic import ValidationError + +from api.poi_routes import CreatePOIRequest, UpdatePOIRequest + + +class TestCreatePOIValidation: + """Tests for CreatePOIRequest field validation.""" + + def test_valid_request(self) -> None: + req = CreatePOIRequest(name="Office", address="123 Main St", latitude=51.5, longitude=-0.1) + assert req.name == "Office" + + def test_name_too_long(self) -> None: + with pytest.raises(ValidationError): + CreatePOIRequest(name="A" * 201, address="addr", latitude=0, longitude=0) + + def test_address_too_long(self) -> None: + with pytest.raises(ValidationError): + CreatePOIRequest(name="ok", address="A" * 501, latitude=0, longitude=0) + + def test_latitude_too_high(self) -> None: + with pytest.raises(ValidationError): + CreatePOIRequest(name="ok", address="addr", latitude=91.0, longitude=0) + + def test_latitude_too_low(self) -> None: + with pytest.raises(ValidationError): + CreatePOIRequest(name="ok", address="addr", latitude=-91.0, longitude=0) + + def test_longitude_too_high(self) -> None: + with pytest.raises(ValidationError): + CreatePOIRequest(name="ok", address="addr", latitude=0, longitude=181.0) + + def test_longitude_too_low(self) -> None: + with pytest.raises(ValidationError): + CreatePOIRequest(name="ok", address="addr", latitude=0, longitude=-181.0) + + +class TestUpdatePOIValidation: + """Tests for UpdatePOIRequest field validation.""" + + def test_valid_partial_update(self) -> None: + req = UpdatePOIRequest(name="New Name") + assert req.name == "New Name" + assert req.latitude is None + + def test_name_too_long(self) -> None: + with pytest.raises(ValidationError): + UpdatePOIRequest(name="A" * 201) + + def test_latitude_out_of_range(self) -> None: + with pytest.raises(ValidationError): + UpdatePOIRequest(latitude=91.0) + + def test_longitude_out_of_range(self) -> None: + with pytest.raises(ValidationError): + UpdatePOIRequest(longitude=181.0) diff --git a/tests/unit/test_rate_limiter.py b/tests/unit/test_rate_limiter.py index 68b0b57..d3a8011 100644 --- a/tests/unit/test_rate_limiter.py +++ b/tests/unit/test_rate_limiter.py @@ -30,6 +30,7 @@ def _make_config(**overrides: object) -> RateLimitConfig: "geojson_stream_batch_size_cap": 200, "rate_limit_redis_db": 3, "metrics_allowed_ips": "127.0.0.1", + "trusted_proxy_depth": 1, } defaults.update(overrides) return RateLimitConfig(**defaults) # type: ignore[arg-type] @@ -236,3 +237,92 @@ class TestRateLimitConfig: # Should fall back to default assert config.endpoint_limits["/api/listing"].max_requests == 30 assert config.endpoint_limits["/api/listing"].window_seconds == 60 + + +from api.rate_limiter import _client_ip + + +class TestClientIp: + """Tests for _client_ip with trusted proxy depth.""" + + def test_uses_rightmost_ip_with_depth_1(self) -> None: + scope = { + "type": "http", + "headers": [(b"x-forwarded-for", b"spoofed.ip, proxy1, real.ip")], + } + request = Request(scope) + assert _client_ip(request, depth=1) == "real.ip" + + def test_uses_second_from_right_with_depth_2(self) -> None: + scope = { + "type": "http", + "headers": [(b"x-forwarded-for", b"spoofed.ip, real.ip, proxy1")], + } + request = Request(scope) + assert _client_ip(request, depth=2) == "real.ip" + + def test_falls_back_to_connection_ip(self) -> None: + scope = { + "type": "http", + "headers": [], + "client": ("192.168.1.1", 12345), + } + request = Request(scope) + assert _client_ip(request) == "192.168.1.1" + + def test_no_client_returns_unknown(self) -> None: + scope = {"type": "http", "headers": []} + request = Request(scope) + assert _client_ip(request) == "unknown" + + def test_single_ip_with_depth_1(self) -> None: + scope = { + "type": "http", + "headers": [(b"x-forwarded-for", b"single.ip")], + } + request = Request(scope) + assert _client_ip(request, depth=1) == "single.ip" + + +class TestInMemoryFallback: + """Tests for in-memory rate limit fallback when Redis is unavailable.""" + + @mock.patch("api.rate_limiter._get_rate_limit_redis") + def test_fallback_enforces_limits_when_redis_unavailable(self, mock_get_redis: mock.MagicMock) -> None: + """When Redis is unavailable at startup, in-memory fallback should enforce limits.""" + import redis as redis_lib + mock_get_redis.side_effect = redis_lib.RedisError("connection refused") + + config = _make_config() + app = _build_app(config) + client = TestClient(app) + + # Should allow first 3 requests (limit=3) + for i in range(3): + resp = client.get("/api/listing") + assert resp.status_code == 200, f"Request {i+1} should be allowed" + + # 4th request should be rate limited + resp = client.get("/api/listing") + assert resp.status_code == 429 + + @mock.patch("api.rate_limiter._get_rate_limit_redis") + def test_fallback_activates_on_redis_error_during_request(self, mock_get_redis: mock.MagicMock) -> None: + """When Redis errors during request handling, fallback should activate.""" + import redis as redis_lib + + mock_redis = mock.MagicMock() + mock_pipe = mock.MagicMock() + mock_pipe.execute.side_effect = redis_lib.RedisError("connection lost") + mock_redis.pipeline.return_value = mock_pipe + mock_redis.ping.return_value = True + mock_get_redis.return_value = mock_redis + + config = _make_config() + app = _build_app(config) + client = TestClient(app) + + # First request should succeed via fallback + resp = client.get("/api/listing") + assert resp.status_code == 200 + assert "X-RateLimit-Limit" in resp.headers diff --git a/tests/unit/test_security_headers.py b/tests/unit/test_security_headers.py new file mode 100644 index 0000000..1e61423 --- /dev/null +++ b/tests/unit/test_security_headers.py @@ -0,0 +1,56 @@ +"""Unit tests for api/security_headers.py.""" +from starlette.testclient import TestClient +from starlette.applications import Starlette +from starlette.requests import Request +from starlette.responses import JSONResponse +from starlette.routing import Route + +from api.security_headers import SecurityHeadersMiddleware + + +async def _ok_endpoint(request: Request) -> JSONResponse: + return JSONResponse({"ok": True}) + + +def _build_app() -> Starlette: + app = Starlette(routes=[Route("/test", _ok_endpoint)]) + app.add_middleware(SecurityHeadersMiddleware) + return app + + +class TestSecurityHeaders: + """Tests for SecurityHeadersMiddleware.""" + + def test_x_content_type_options(self) -> None: + client = TestClient(_build_app()) + resp = client.get("/test") + assert resp.headers["X-Content-Type-Options"] == "nosniff" + + def test_x_frame_options(self) -> None: + client = TestClient(_build_app()) + resp = client.get("/test") + assert resp.headers["X-Frame-Options"] == "DENY" + + def test_referrer_policy(self) -> None: + client = TestClient(_build_app()) + resp = client.get("/test") + assert resp.headers["Referrer-Policy"] == "strict-origin-when-cross-origin" + + def test_content_security_policy(self) -> None: + client = TestClient(_build_app()) + resp = client.get("/test") + assert "Content-Security-Policy" in resp.headers + csp = resp.headers["Content-Security-Policy"] + assert "default-src 'self'" in csp + assert "frame-ancestors 'none'" in csp + + def test_hsts_set_for_https(self) -> None: + client = TestClient(_build_app()) + resp = client.get("/test", headers={"x-forwarded-proto": "https"}) + assert "Strict-Transport-Security" in resp.headers + assert "max-age=" in resp.headers["Strict-Transport-Security"] + + def test_hsts_not_set_for_http(self) -> None: + client = TestClient(_build_app()) + resp = client.get("/test") + assert "Strict-Transport-Security" not in resp.headers diff --git a/tests/unit/test_task_service.py b/tests/unit/test_task_service.py index e71d311..e121cb5 100644 --- a/tests/unit/test_task_service.py +++ b/tests/unit/test_task_service.py @@ -304,3 +304,84 @@ class TestClearAllTasks: # Should not raise despite revoke failure count = clear_all_tasks("test@example.com", revoke=True) assert count == 1 + + +from fastapi.testclient import TestClient + + +class TestTaskStatusSecurity: + """Tests for task status endpoint security (IDOR, traceback suppression).""" + + def _get_client(self) -> TestClient: + """Create test client with mocked auth.""" + from api.app import app + from api.auth import get_current_user, User + + mock_user = User(sub="test", email="test@example.com", name="Test") + app.dependency_overrides[get_current_user] = lambda: mock_user + client = TestClient(app, raise_server_exceptions=False) + return client + + @patch("api.app.task_service") + def test_returns_404_for_unowned_task(self, mock_task_service: MagicMock) -> None: + mock_task_service.get_user_tasks.return_value = ["task-abc"] + client = self._get_client() + try: + resp = client.get("/api/task_status", params={"task_id": "task-xyz"}) + assert resp.status_code == 404 + finally: + from api.app import app + from api.auth import get_current_user + app.dependency_overrides.pop(get_current_user, None) + + @patch("api.app.task_service") + def test_returns_200_for_owned_task(self, mock_task_service: MagicMock) -> None: + mock_status = MagicMock() + mock_status.task_id = "task-abc" + mock_status.status = "SUCCESS" + mock_status.result = None + mock_status.progress = None + mock_status.processed = None + mock_status.total = None + mock_status.message = None + mock_status.error = None + mock_status.traceback = None + + mock_task_service.get_user_tasks.return_value = ["task-abc"] + mock_task_service.get_task_status.return_value = mock_status + client = self._get_client() + try: + resp = client.get("/api/task_status", params={"task_id": "task-abc"}) + assert resp.status_code == 200 + finally: + from api.app import app + from api.auth import get_current_user + app.dependency_overrides.pop(get_current_user, None) + + @patch("api.app.APP_ENV", "production") + @patch("api.app.task_service") + def test_traceback_suppressed_in_production(self, mock_task_service: MagicMock) -> None: + mock_status = MagicMock() + mock_status.task_id = "task-abc" + mock_status.status = "FAILURE" + mock_status.result = None + mock_status.progress = None + mock_status.processed = None + mock_status.total = None + mock_status.message = None + mock_status.error = "Internal error" + mock_status.traceback = "Traceback (most recent call last)..." + + mock_task_service.get_user_tasks.return_value = ["task-abc"] + mock_task_service.get_task_status.return_value = mock_status + client = self._get_client() + try: + resp = client.get("/api/task_status", params={"task_id": "task-abc"}) + assert resp.status_code == 200 + data = resp.json() + assert data["traceback"] is None + assert data["error"] is None + finally: + from api.app import app + from api.auth import get_current_user + app.dependency_overrides.pop(get_current_user, None)