diff --git a/api/app.py b/api/app.py index 7bdf5ff..d36bc1d 100644 --- a/api/app.py +++ b/api/app.py @@ -28,7 +28,7 @@ from database import engine from fastapi.middleware.cors import CORSMiddleware from ui_exporter import convert_to_geojson_feature, convert_row_to_geojson -from services import listing_service, export_service, district_service, task_service +from services import listing_service, export_service, district_service, task_service, decision_service from services.listing_cache import ( get_cached_count, get_cached_features, @@ -37,6 +37,7 @@ from services.listing_cache import ( finalize_cache_population, delete_staging_key, ) +from repositories.decision_repository import DecisionRepository from repositories.poi_repository import POIRepository from repositories.user_repository import UserRepository from opentelemetry.instrumentation.fastapi import FastAPIInstrumentor @@ -168,8 +169,13 @@ async def get_listing_geojson( user: Annotated[User, Depends(get_current_user)], query_parameters: Annotated[QueryParameters, Depends(get_query_parameters)], limit: int | None = None, + decision_filter: str = "all", ) -> dict: - """Get listings as GeoJSON for map display.""" + """Get listings as GeoJSON for map display. + + decision_filter: 'all' (hide disliked, default), 'liked', 'disliked', + 'undecided', 'everything' (no filtering). + """ if limit is not None: limit = min(limit, _rate_limit_config.geojson_limit_cap) else: @@ -180,10 +186,36 @@ async def get_listing_geojson( query_parameters=query_parameters, limit=limit, ) + + # Filter features based on decision_filter + if decision_filter != "everything": + disliked_ids = _get_disliked_ids( + user.email, query_parameters.listing_type.value + ) + if disliked_ids: + str_disliked = {str(lid) for lid in disliked_ids} + result.data["features"] = [ + f for f in result.data["features"] + if f.get("properties", {}).get("url", "").split("/")[-1] + not in str_disliked + ] + return result.data +def _get_disliked_ids(user_email: str, listing_type: str) -> set[int]: + """Get the set of disliked listing IDs for a user.""" + user_repo = UserRepository(engine) + db_user = user_repo.get_user_by_email(user_email) + if not db_user or db_user.id is None: + return set() + decision_repo = DecisionRepository(engine) + return decision_service.get_disliked_listing_ids( + decision_repo, user_id=db_user.id, listing_type=listing_type + ) + + def _build_poi_distances_lookup( user_email: str, listing_type: ListingType, @@ -252,6 +284,7 @@ async def _stream_from_db( limit: int | None, poi_distances_lookup: dict[int, list[dict[str, str | int]]] | None = None, skip_cache: bool = False, + disliked_ids: set[int] | None = None, ) -> AsyncGenerator[str, None]: """Stream GeoJSON features from the database, populating the cache as we go.""" repository = ListingRepository(engine) @@ -276,6 +309,9 @@ async def _stream_from_db( for row in repository.stream_listings_optimized( query_parameters, limit=limit, page_size=batch_size ): + # Skip disliked listings + if disliked_ids and row['id'] in disliked_ids: + continue feature = convert_row_to_geojson(row, query_parameters.listing_type.value) # Inject POI distances if available if poi_distances_lookup and row['id'] in poi_distances_lookup: @@ -330,6 +366,11 @@ async def stream_listing_geojson( # Build POI distances lookup if requested poi_distances_lookup = _build_poi_distances_lookup(user.email, query_parameters.listing_type) if include_poi_distances else None + # Get disliked listing IDs to exclude from stream + disliked_ids = _get_disliked_ids( + user.email, query_parameters.listing_type.value + ) + cached_count = get_cached_count(query_parameters) if cached_count is not None and cached_count > 0 and not include_poi_distances: app_metrics.geojson_cache_operations.add(1, {"result": "hit"}) @@ -339,6 +380,7 @@ async def stream_listing_geojson( generator = _stream_from_db( query_parameters, batch_size, limit, poi_distances_lookup, skip_cache=include_poi_distances, + disliked_ids=disliked_ids if disliked_ids else None, ) return StreamingResponse( diff --git a/tests/unit/test_decision_filtering.py b/tests/unit/test_decision_filtering.py new file mode 100644 index 0000000..7953071 --- /dev/null +++ b/tests/unit/test_decision_filtering.py @@ -0,0 +1,121 @@ +"""Test that disliked listings are filtered from the GeoJSON endpoint.""" +import pytest +from datetime import datetime +from httpx import ASGITransport, AsyncClient +from sqlalchemy import Engine +from sqlmodel import SQLModel, Session, create_engine + +from models.user import User +from models.listing import RentListing, ListingSite, FurnishType +from models.decision import ListingDecision +from api.auth import get_current_user, User as AuthUser + + +@pytest.fixture +def filter_engine() -> Engine: + engine = create_engine( + "sqlite:///:memory:", + echo=False, + connect_args={"check_same_thread": False}, + ) + SQLModel.metadata.create_all(engine) + with Session(engine) as session: + session.add(User(id=1, email="test@example.com")) + # Add two listings + now = datetime.now() + for lid in [100, 200]: + session.add(RentListing( + id=lid, + price=2000.0, + number_of_bedrooms=2, + square_meters=50.0, + longitude=-0.1, + latitude=51.5, + price_history_json="[]", + listing_site=ListingSite.RIGHTMOVE, + last_seen=now, + floorplan_image_paths=[], + additional_info={"property": {"visible": True}}, + furnish_type=FurnishType.FURNISHED, + )) + # Dislike listing 200 + session.add(ListingDecision( + user_id=1, + listing_id=200, + listing_type="RENT", + decision="disliked", + )) + session.commit() + yield engine # type: ignore[misc] + SQLModel.metadata.drop_all(engine) + + +@pytest.fixture +async def filter_client(filter_engine: Engine) -> AsyncClient: + import database + import api.app as api_app + import api.decision_routes as decision_routes_mod + import api.poi_routes as poi_routes_mod + + app = api_app.app + mock_user = AuthUser( + sub="test", email="test@example.com", name="Test" + ) + app.dependency_overrides[get_current_user] = lambda: mock_user + + original_db = database.engine + original_app = api_app.engine + original_decision = decision_routes_mod.engine + original_poi = poi_routes_mod.engine + database.engine = filter_engine + api_app.engine = filter_engine + decision_routes_mod.engine = filter_engine + poi_routes_mod.engine = filter_engine + + transport = ASGITransport(app=app) + async with AsyncClient(transport=transport, base_url="http://test") as c: + yield c # type: ignore[misc] + + database.engine = original_db + api_app.engine = original_app + decision_routes_mod.engine = original_decision + poi_routes_mod.engine = original_poi + app.dependency_overrides.clear() + + +class TestDecisionFiltering: + @pytest.mark.asyncio + async def test_disliked_excluded_by_default( + self, filter_client: AsyncClient + ) -> None: + """Default decision_filter should exclude disliked listings.""" + resp = await filter_client.get( + "/api/listing_geojson", + params={"listing_type": "RENT"}, + ) + assert resp.status_code == 200 + data = resp.json() + listing_ids = [ + f["properties"]["url"].split("/")[-1] + for f in data["features"] + ] + assert "100" in listing_ids + assert "200" not in listing_ids + + @pytest.mark.asyncio + async def test_everything_filter_includes_all( + self, filter_client: AsyncClient + ) -> None: + """decision_filter='everything' should include disliked listings.""" + resp = await filter_client.get( + "/api/listing_geojson", + params={"listing_type": "RENT", "decision_filter": "everything"}, + ) + assert resp.status_code == 200 + data = resp.json() + listing_ids = [ + f["properties"]["url"].split("/")[-1] + for f in data["features"] + ] + assert "100" in listing_ids + assert "200" in listing_ids