Fix duplicate listings via staged Redis cache and frontend stream cancellation

Three-pronged fix for duplicate listings appearing in the UI:

1. Backend: Replace direct rpush cache writes with staged population
   (write to temp key, then atomic RENAME to live key). Skip cache
   writes entirely for POI-enriched requests. Clean staging keys on
   invalidation.

2. Frontend: Add AbortController to cancel in-flight streaming requests
   when loadListings is called again, preventing data mixing.

3. Frontend: Deduplicate features by URL during stream accumulation as
   a safety net against any remaining server-side duplicates.
This commit is contained in:
Viktor Barzin 2026-02-09 21:17:30 +00:00
parent 5b8aa98446
commit 73d19e29d5
No known key found for this signature in database
GPG key ID: 0EB088298288D958
5 changed files with 159 additions and 38 deletions

View file

@ -8,6 +8,7 @@ from api.auth import get_current_user
from api.config import DEV_TIER_ORIGINS, PROD_TIER_ORIGINS, APP_ENV from api.config import DEV_TIER_ORIGINS, PROD_TIER_ORIGINS, APP_ENV
from api.passkey_routes import passkey_router from api.passkey_routes import passkey_router
from api.poi_routes import poi_router from api.poi_routes import poi_router
from api.ws_routes import ws_router
from api.rate_limit_config import RateLimitConfig from api.rate_limit_config import RateLimitConfig
from api.rate_limiter import RateLimitMiddleware from api.rate_limiter import RateLimitMiddleware
from api.audit_middleware import AuditLogMiddleware from api.audit_middleware import AuditLogMiddleware
@ -30,7 +31,10 @@ from services import listing_service, export_service, district_service, task_ser
from services.listing_cache import ( from services.listing_cache import (
get_cached_count, get_cached_count,
get_cached_features, get_cached_features,
cache_features_batch, begin_cache_population,
cache_features_batch_staged,
finalize_cache_population,
delete_staging_key,
) )
from repositories.poi_repository import POIRepository from repositories.poi_repository import POIRepository
from repositories.user_repository import UserRepository from repositories.user_repository import UserRepository
@ -94,6 +98,7 @@ app = FastAPI(
) )
app.include_router(passkey_router) app.include_router(passkey_router)
app.include_router(poi_router) app.include_router(poi_router)
app.include_router(ws_router)
app.mount("/metrics", metrics_app) app.mount("/metrics", metrics_app)
meter = get_meter(__name__) meter = get_meter(__name__)
request_counter = meter.create_counter( request_counter = meter.create_counter(
@ -213,6 +218,7 @@ async def _stream_from_db(
batch_size: int, batch_size: int,
limit: int | None, limit: int | None,
poi_distances_lookup: dict[int, list[dict[str, str | int]]] | None = None, poi_distances_lookup: dict[int, list[dict[str, str | int]]] | None = None,
skip_cache: bool = False,
) -> AsyncGenerator[str, None]: ) -> AsyncGenerator[str, None]:
"""Stream GeoJSON features from the database, populating the cache as we go.""" """Stream GeoJSON features from the database, populating the cache as we go."""
repository = ListingRepository(engine) repository = ListingRepository(engine)
@ -227,28 +233,44 @@ async def _stream_from_db(
"cached": False, "cached": False,
}) + "\n" }) + "\n"
count = 0 staging_key: str | None = None
batch: list[dict] = [] if not skip_cache:
for row in repository.stream_listings_optimized( staging_key = begin_cache_population(query_parameters)
query_parameters, limit=limit, page_size=batch_size
):
feature = convert_row_to_geojson(row, query_parameters.listing_type.value)
# Inject POI distances if available
if poi_distances_lookup and row['id'] in poi_distances_lookup:
feature['properties']['poi_distances'] = poi_distances_lookup[row['id']]
batch.append(feature)
count += 1
if len(batch) >= batch_size: try:
cache_features_batch(query_parameters, batch) count = 0
batch: list[dict] = []
for row in repository.stream_listings_optimized(
query_parameters, limit=limit, page_size=batch_size
):
feature = convert_row_to_geojson(row, query_parameters.listing_type.value)
# Inject POI distances if available
if poi_distances_lookup and row['id'] in poi_distances_lookup:
feature['properties']['poi_distances'] = poi_distances_lookup[row['id']]
batch.append(feature)
count += 1
if len(batch) >= batch_size:
if staging_key:
cache_features_batch_staged(staging_key, batch)
yield json.dumps({"type": "batch", "features": batch}) + "\n"
batch = []
if batch:
if staging_key:
cache_features_batch_staged(staging_key, batch)
yield json.dumps({"type": "batch", "features": batch}) + "\n" yield json.dumps({"type": "batch", "features": batch}) + "\n"
batch = []
if batch: # Atomically promote staged data to live cache
cache_features_batch(query_parameters, batch) if staging_key:
yield json.dumps({"type": "batch", "features": batch}) + "\n" finalize_cache_population(staging_key, query_parameters)
staging_key = None # Mark as finalized
yield json.dumps({"type": "complete", "total": count}) + "\n" yield json.dumps({"type": "complete", "total": count}) + "\n"
finally:
# Clean up orphaned staging key on failure
if staging_key:
delete_staging_key(staging_key)
@app.get("/api/listing_geojson/stream") @app.get("/api/listing_geojson/stream")
@ -304,7 +326,10 @@ async def stream_listing_geojson(
if cached_count is not None and cached_count > 0 and not include_poi_distances: if cached_count is not None and cached_count > 0 and not include_poi_distances:
generator = _stream_from_cache(query_parameters, batch_size, limit) generator = _stream_from_cache(query_parameters, batch_size, limit)
else: else:
generator = _stream_from_db(query_parameters, batch_size, limit, poi_distances_lookup) generator = _stream_from_db(
query_parameters, batch_size, limit, poi_distances_lookup,
skip_cache=include_poi_distances,
)
return StreamingResponse( return StreamingResponse(
generator, generator,

View file

@ -19,6 +19,7 @@ import { Filter } from 'lucide-react';
import type { GeoJSONFeatureCollection, PropertyProperties, PropertyFeature, POI, POITravelFilter } from '@/types'; import type { GeoJSONFeatureCollection, PropertyProperties, PropertyFeature, POI, POITravelFilter } from '@/types';
import { refreshListings, fetchTasksForUser, streamListingGeoJSON, fetchUserPOIs, type StreamingProgress } from '@/services'; import { refreshListings, fetchTasksForUser, streamListingGeoJSON, fetchUserPOIs, type StreamingProgress } from '@/services';
import { poiMetricPropertyName, injectPoiMetricProperty } from '@/utils/poiUtils'; import { poiMetricPropertyName, injectPoiMetricProperty } from '@/utils/poiUtils';
import { useTaskWebSocket } from '@/hooks/useTaskWebSocket';
function App() { function App() {
const [listingData, setListingData] = useState<GeoJSONFeatureCollection | null>(null); const [listingData, setListingData] = useState<GeoJSONFeatureCollection | null>(null);
@ -43,10 +44,15 @@ function App() {
const [poiTravelFilters, setPoiTravelFilters] = useState<Record<number, POITravelFilter>>({}); const [poiTravelFilters, setPoiTravelFilters] = useState<Record<number, POITravelFilter>>({});
const [currentMetric, setCurrentMetric] = useState<Metric>(DEFAULT_FILTER_VALUES.metric); const [currentMetric, setCurrentMetric] = useState<Metric>(DEFAULT_FILTER_VALUES.metric);
// WebSocket-based real-time task progress
const { tasks: wsTasks, isConnected: wsConnected, subscribe: wsSubscribe } = useTaskWebSocket(user);
// Ref to track accumulated features during streaming // Ref to track accumulated features during streaming
const accumulatedFeaturesRef = useRef<PropertyFeature[]>([]); const accumulatedFeaturesRef = useRef<PropertyFeature[]>([]);
// Ref to track if initial load has been triggered // Ref to track if initial load has been triggered
const initialLoadTriggeredRef = useRef(false); const initialLoadTriggeredRef = useRef(false);
// Ref to abort in-flight streaming requests
const abortControllerRef = useRef<AbortController | null>(null);
// Check if this is the callback route - render dedicated component // Check if this is the callback route - render dedicated component
if (window.location.pathname === '/callback') { if (window.location.pathname === '/callback') {
@ -92,6 +98,13 @@ function App() {
const loadListings = useCallback(async (parameters: ParameterValues) => { const loadListings = useCallback(async (parameters: ParameterValues) => {
if (!user) return; if (!user) return;
// Abort any in-flight streaming request
if (abortControllerRef.current) {
abortControllerRef.current.abort();
}
const controller = new AbortController();
abortControllerRef.current = controller;
setQueryParameters(parameters); setQueryParameters(parameters);
setMobileFilterOpen(false); setMobileFilterOpen(false);
setIsLoading(true); setIsLoading(true);
@ -99,6 +112,9 @@ function App() {
setStreamingProgress({ count: 0 }); setStreamingProgress({ count: 0 });
setListingData(null); setListingData(null);
// Dedup safety net: track seen URLs to prevent duplicate features
const seenUrls = new Set<string>();
let updateScheduled = false; let updateScheduled = false;
const flushUpdate = () => { const flushUpdate = () => {
@ -119,13 +135,26 @@ function App() {
try { try {
for await (const batch of streamListingGeoJSON(user, parameters, (progress) => { for await (const batch of streamListingGeoJSON(user, parameters, (progress) => {
setStreamingProgress(progress); setStreamingProgress(progress);
}, { includePoiDistances: userPOIs.length > 0 })) { }, { includePoiDistances: userPOIs.length > 0, signal: controller.signal })) {
accumulatedFeaturesRef.current.push(...batch); // Deduplicate features by URL
scheduleUpdate(); const uniqueBatch = batch.filter((feature) => {
const url = feature.properties?.url;
if (!url || seenUrls.has(url)) return false;
seenUrls.add(url);
return true;
});
if (uniqueBatch.length > 0) {
accumulatedFeaturesRef.current.push(...uniqueBatch);
scheduleUpdate();
}
} }
// Final flush to ensure all data is rendered // Final flush to ensure all data is rendered
flushUpdate(); flushUpdate();
} catch (error) { } catch (error) {
// Silently ignore AbortError — it means we intentionally cancelled
if (error instanceof DOMException && error.name === 'AbortError') {
return;
}
if (error instanceof Error) { if (error instanceof Error) {
setSubmitError(error.message); setSubmitError(error.message);
} else { } else {
@ -133,8 +162,11 @@ function App() {
} }
setAlertDialogIsOpen(true); setAlertDialogIsOpen(true);
} finally { } finally {
setIsLoading(false); // Only clear loading state if this controller is still the current one
setStreamingProgress(null); if (abortControllerRef.current === controller) {
setIsLoading(false);
setStreamingProgress(null);
}
} }
}, [user, userPOIs]); }, [user, userPOIs]);
@ -217,6 +249,7 @@ function App() {
try { try {
const data = await refreshListings(user!, parameters); const data = await refreshListings(user!, parameters);
setTaskID(data.task_id); setTaskID(data.task_id);
if (data.task_id) wsSubscribe(data.task_id);
} catch (error) { } catch (error) {
if (error instanceof Error) { if (error instanceof Error) {
setSubmitError(error.message); setSubmitError(error.message);
@ -320,6 +353,7 @@ function App() {
const handlePOITaskCreated = (taskId: string) => { const handlePOITaskCreated = (taskId: string) => {
setTaskID(taskId); setTaskID(taskId);
if (taskId) wsSubscribe(taskId);
// Refresh POI list in case new ones were created // Refresh POI list in case new ones were created
if (user) { if (user) {
fetchUserPOIs(user).then(setUserPOIs).catch(() => {}); fetchUserPOIs(user).then(setUserPOIs).catch(() => {});
@ -348,6 +382,9 @@ function App() {
taskID={taskID} taskID={taskID}
onTaskCancelled={handleTaskCancelled} onTaskCancelled={handleTaskCancelled}
onTaskCompleted={handleTaskCompleted} onTaskCompleted={handleTaskCompleted}
wsTasks={wsTasks}
wsConnected={wsConnected}
wsSubscribe={wsSubscribe}
/> />
{/* Main content area */} {/* Main content area */}

View file

@ -68,7 +68,7 @@ export async function* streamListingGeoJSON(
user: AuthUser, user: AuthUser,
parameters: ParameterValues, parameters: ParameterValues,
onProgress?: (progress: StreamingProgress) => void, onProgress?: (progress: StreamingProgress) => void,
options?: { includePoiDistances?: boolean }, options?: { includePoiDistances?: boolean; signal?: AbortSignal },
): AsyncGenerator<PropertyFeature[], void, unknown> { ): AsyncGenerator<PropertyFeature[], void, unknown> {
const params = buildListingParams(parameters); const params = buildListingParams(parameters);
if (options?.includePoiDistances) { if (options?.includePoiDistances) {
@ -83,6 +83,7 @@ export async function* streamListingGeoJSON(
headers: { headers: {
Authorization: `Bearer ${user.accessToken}`, Authorization: `Bearer ${user.accessToken}`,
}, },
signal: options?.signal,
}); });
if (!response.ok) { if (!response.ok) {
@ -99,6 +100,10 @@ export async function* streamListingGeoJSON(
let totalCount = 0; let totalCount = 0;
while (true) { while (true) {
if (options?.signal?.aborted) {
await reader.cancel();
return;
}
const { done, value } = await reader.read(); const { done, value } = await reader.read();
if (done) break; if (done) break;

View file

@ -3,6 +3,7 @@ import hashlib
import json import json
import logging import logging
import os import os
import uuid
from typing import Generator from typing import Generator
from urllib.parse import urlparse, urlunparse from urllib.parse import urlparse, urlunparse
@ -13,7 +14,9 @@ from models.listing import QueryParameters
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
CACHE_PREFIX = "listings:geojson:" CACHE_PREFIX = "listings:geojson:"
STAGING_PREFIX = "listings:geojson:staging:"
CACHE_TTL_SECONDS = 30 * 60 # 30 minutes CACHE_TTL_SECONDS = 30 * 60 # 30 minutes
STAGING_TTL_SECONDS = 5 * 60 # 5 minutes safety net for orphaned staging keys
CACHE_DB = 2 CACHE_DB = 2
@ -81,22 +84,71 @@ def cache_features_batch(query_params: QueryParameters, features: list[dict]) ->
logger.warning(f"Redis cache write error: {e}") logger.warning(f"Redis cache write error: {e}")
def begin_cache_population(query_params: QueryParameters) -> str:
"""Begin staged cache population. Returns a unique staging key.
The staging key gets its TTL set by cache_features_batch_staged on the
first rpush, so no pre-creation is needed here.
"""
return f"{STAGING_PREFIX}{uuid.uuid4().hex}"
def cache_features_batch_staged(staging_key: str, features: list[dict]) -> None:
"""Append a batch of features to a staging key."""
if not features:
return
try:
client = _get_redis_client()
pipeline = client.pipeline()
for feature in features:
pipeline.rpush(staging_key, json.dumps(feature))
pipeline.expire(staging_key, STAGING_TTL_SECONDS)
pipeline.execute()
except redis.RedisError as e:
logger.warning(f"Redis staged cache write error: {e}")
def finalize_cache_population(staging_key: str, query_params: QueryParameters) -> None:
"""Atomically rename the staging key to the live cache key and set TTL."""
try:
client = _get_redis_client()
live_key = make_cache_key(query_params)
# RENAME is atomic — replaces the live key in one operation
client.rename(staging_key, live_key)
client.expire(live_key, CACHE_TTL_SECONDS)
logger.debug(f"Finalized cache population for {live_key}")
except redis.RedisError as e:
logger.warning(f"Redis cache finalize error: {e}")
def delete_staging_key(staging_key: str) -> None:
"""Delete an orphaned staging key (used in error cleanup)."""
try:
client = _get_redis_client()
client.delete(staging_key)
except redis.RedisError as e:
logger.warning(f"Redis staging key cleanup error: {e}")
def invalidate_cache() -> None: def invalidate_cache() -> None:
"""Delete all listing GeoJSON cache entries.""" """Delete all listing GeoJSON cache entries, including staging keys."""
try: try:
client = _get_redis_client() client = _get_redis_client()
cursor = 0 cursor = 0
deleted = 0 deleted = 0
while True: # Clean both live cache keys and staging keys
cursor, keys = client.scan(cursor, match=f"{CACHE_PREFIX}*", count=100) for pattern in [f"{CACHE_PREFIX}*", f"{STAGING_PREFIX}*"]:
if keys: cursor = 0
pipeline = client.pipeline() while True:
for key in keys: cursor, keys = client.scan(cursor, match=pattern, count=100)
pipeline.delete(key) if keys:
pipeline.execute() pipeline = client.pipeline()
deleted += len(keys) for key in keys:
if cursor == 0: pipeline.delete(key)
break pipeline.execute()
deleted += len(keys)
if cursor == 0:
break
if deleted: if deleted:
logger.info(f"Invalidated {deleted} listing cache entries") logger.info(f"Invalidated {deleted} listing cache entries")
except redis.RedisError as e: except redis.RedisError as e:

View file

@ -268,7 +268,9 @@ class TestStreamListingGeoJsonEndpoint:
with patch("api.app.get_cached_count", return_value=None), \ with patch("api.app.get_cached_count", return_value=None), \
patch("api.app.ListingRepository", return_value=mock_repo), \ patch("api.app.ListingRepository", return_value=mock_repo), \
patch("api.app.cache_features_batch"): patch("api.app.begin_cache_population", return_value="staging:test"), \
patch("api.app.cache_features_batch_staged"), \
patch("api.app.finalize_cache_population"):
response = await async_client.get( response = await async_client.get(
"/api/listing_geojson/stream", "/api/listing_geojson/stream",
params={"listing_type": "RENT"}, params={"listing_type": "RENT"},