Refactor codebase following Clean Code principles and add 229 tests

- Extract helpers to reduce function sizes (listing_tasks, app.py, query.py, listing_fetcher)
  - Replace nonlocal mutations with _PipelineState dataclass in listing_tasks
  - Fix bugs: isinstance→equality check in repository, verify_exp for OIDC tokens
  - Consolidate duplicate filter methods in listing_repository
  - Move hardcoded config to env vars with backward-compatible defaults
  - Simplify CLI decorator to auto-build QueryParameters
  - Add deprecation docstring to data_access.py
  - Test count: 158 → 387 (all passing)
This commit is contained in:
Viktor Barzin 2026-02-07 20:19:57 +00:00
parent 7e05b3c971
commit 150342bb9e
No known key found for this signature in database
GPG key ID: 0EB088298288D958
48 changed files with 5029 additions and 990 deletions

View file

@ -3,7 +3,7 @@ from datetime import datetime, timedelta
import json
import logging
import logging.config
from typing import Annotated, Optional
from typing import Annotated, AsyncGenerator, Optional
from api.auth import get_current_user
from api.config import DEV_TIER_ORIGINS, PROD_TIER_ORIGINS
from api.passkey_routes import passkey_router
@ -32,6 +32,8 @@ from opentelemetry.metrics import get_meter
load_dotenv()
logger = logging.getLogger("uvicorn")
DEFAULT_BATCH_SIZE = 50
def get_query_parameters(
listing_type: ListingType,
@ -120,11 +122,79 @@ async def get_listing_geojson(
return result.data
async def _stream_from_cache(
query_parameters: QueryParameters,
batch_size: int,
limit: int | None,
) -> AsyncGenerator[str, None]:
"""Stream GeoJSON features from the Redis cache (cache-hit path)."""
cached_count = get_cached_count(query_parameters)
effective_total = min(limit, cached_count) if limit and cached_count else cached_count
yield json.dumps({
"type": "metadata",
"batch_size": batch_size,
"total_expected": effective_total,
"cached": True,
}) + "\n"
count = 0
for feature_batch in get_cached_features(query_parameters, batch_size=batch_size):
if limit and count + len(feature_batch) > limit:
feature_batch = feature_batch[:limit - count]
count += len(feature_batch)
yield json.dumps({"type": "batch", "features": feature_batch}) + "\n"
if limit and count >= limit:
break
yield json.dumps({"type": "complete", "total": count}) + "\n"
async def _stream_from_db(
query_parameters: QueryParameters,
batch_size: int,
limit: int | None,
) -> AsyncGenerator[str, None]:
"""Stream GeoJSON features from the database, populating the cache as we go."""
repository = ListingRepository(engine)
total = repository.count_listings(query_parameters)
effective_total = min(limit, total) if limit else total
yield json.dumps({
"type": "metadata",
"batch_size": batch_size,
"total_expected": effective_total,
"cached": False,
}) + "\n"
count = 0
batch: list[dict] = []
for row in repository.stream_listings_optimized(
query_parameters, limit=limit, page_size=batch_size
):
feature = convert_row_to_geojson(row, query_parameters.listing_type.value)
batch.append(feature)
count += 1
if len(batch) >= batch_size:
cache_features_batch(query_parameters, batch)
yield json.dumps({"type": "batch", "features": batch}) + "\n"
batch = []
if batch:
cache_features_batch(query_parameters, batch)
yield json.dumps({"type": "batch", "features": batch}) + "\n"
yield json.dumps({"type": "complete", "total": count}) + "\n"
@app.get("/api/listing_geojson/stream")
async def stream_listing_geojson(
user: Annotated[User, Depends(get_current_user)],
query_parameters: Annotated[QueryParameters, Depends(get_query_parameters)],
batch_size: int = 50,
batch_size: int = DEFAULT_BATCH_SIZE,
limit: int | None = None,
) -> StreamingResponse:
"""Stream listings as NDJSON for progressive map loading.
@ -134,71 +204,14 @@ async def stream_listing_geojson(
- batch: Array of GeoJSON features
- complete: Final message with total count
"""
async def generate():
# Check cache first
cached_count = get_cached_count(query_parameters)
if cached_count is not None and cached_count > 0:
# Cache HIT
effective_total = min(limit, cached_count) if limit else cached_count
yield json.dumps({
"type": "metadata",
"batch_size": batch_size,
"total_expected": effective_total,
"cached": True,
}) + "\n"
count = 0
for feature_batch in get_cached_features(query_parameters, batch_size=batch_size):
if limit and count + len(feature_batch) > limit:
feature_batch = feature_batch[:limit - count]
count += len(feature_batch)
yield json.dumps({"type": "batch", "features": feature_batch}) + "\n"
if limit and count >= limit:
break
yield json.dumps({"type": "complete", "total": count}) + "\n"
else:
# Cache MISS - query DB and populate cache
repository = ListingRepository(engine)
# Phase 1: Fast count for progress estimation
total = repository.count_listings(query_parameters)
effective_total = min(limit, total) if limit else total
yield json.dumps({
"type": "metadata",
"batch_size": batch_size,
"total_expected": effective_total,
"cached": False,
}) + "\n"
# Phase 2: Stream with column projection and keyset pagination
count = 0
batch = []
for row in repository.stream_listings_optimized(
query_parameters, limit=limit, page_size=batch_size
):
feature = convert_row_to_geojson(row, query_parameters.listing_type.value)
batch.append(feature)
count += 1
if len(batch) >= batch_size:
cache_features_batch(query_parameters, batch)
yield json.dumps({"type": "batch", "features": batch}) + "\n"
batch = []
# Send remaining
if batch:
cache_features_batch(query_parameters, batch)
yield json.dumps({"type": "batch", "features": batch}) + "\n"
# Final message
yield json.dumps({"type": "complete", "total": count}) + "\n"
cached_count = get_cached_count(query_parameters)
if cached_count is not None and cached_count > 0:
generator = _stream_from_cache(query_parameters, batch_size, limit)
else:
generator = _stream_from_db(query_parameters, batch_size, limit)
return StreamingResponse(
generate(),
generator,
media_type="application/x-ndjson",
headers={
"Cache-Control": "no-cache",

View file

@ -59,7 +59,6 @@ async def _verify_authentik_token(token: str) -> User:
algorithms=["RS256"],
audience=OIDC_CLIENT_ID,
issuer=metadata["issuer"],
options={"verify_exp": False},
)
return User(**payload)
@ -84,7 +83,9 @@ async def get_current_user(
) -> User:
token = credentials.credentials
try:
# Peek at unverified issuer to route verification
# Decode WITHOUT verification just to read the "iss" claim for routing.
# This is safe: we only use the issuer to decide which verified decode
# path to take next; the actual security check happens in the branch below.
unverified = jwt.decode(
token, options={"verify_signature": False, "verify_exp": False}
)

View file

@ -1,10 +1,13 @@
from datetime import timedelta
import logging
import os
_logger = logging.getLogger(__name__)
# Authentik OIDC Configuration
AUTHENTIK_URL = "https://authentik.viktorbarzin.me"
OIDC_CLIENT_ID = "5AJKRgcdgVm1OyApBzFkadDFfStW9a555zwv2MOe"
AUTHENTIK_URL = os.getenv("AUTHENTIK_URL", "https://authentik.viktorbarzin.me")
OIDC_CLIENT_ID = os.getenv("OIDC_CLIENT_ID", "5AJKRgcdgVm1OyApBzFkadDFfStW9a555zwv2MOe")
OIDC_METADATA_URL = (
f"{AUTHENTIK_URL}/application/o/wrongmove/.well-known/openid-configuration"
)
@ -23,6 +26,8 @@ WEBAUTHN_ORIGIN = os.getenv("WEBAUTHN_ORIGIN", "https://localhost")
# JWT Configuration (for passkey-issued tokens)
JWT_SECRET = os.getenv("JWT_SECRET", "change-me-in-production")
if JWT_SECRET == "change-me-in-production":
_logger.warning("JWT_SECRET is using the default value. Set JWT_SECRET env var in production.")
JWT_ALGORITHM = os.getenv("JWT_ALGORITHM", "HS256")
JWT_EXPIRATION_HOURS = int(os.getenv("JWT_EXPIRATION_HOURS", "24"))
JWT_ISSUER = os.getenv("JWT_ISSUER", "wrongmove")