Refactor codebase following Clean Code principles and add 229 tests

- Extract helpers to reduce function sizes (listing_tasks, app.py, query.py, listing_fetcher)
  - Replace nonlocal mutations with _PipelineState dataclass in listing_tasks
  - Fix bugs: isinstance→equality check in repository, verify_exp for OIDC tokens
  - Consolidate duplicate filter methods in listing_repository
  - Move hardcoded config to env vars with backward-compatible defaults
  - Simplify CLI decorator to auto-build QueryParameters
  - Add deprecation docstring to data_access.py
  - Test count: 158 → 387 (all passing)
This commit is contained in:
Viktor Barzin 2026-02-07 20:19:57 +00:00
parent 7e05b3c971
commit 150342bb9e
No known key found for this signature in database
GPG key ID: 0EB088298288D958
48 changed files with 5029 additions and 990 deletions

95
.claude/CLAUDE.md Normal file
View file

@ -0,0 +1,95 @@
# Realestate Crawler
## Project Overview
A real estate listing aggregation platform that scrapes Rightmove UK listings, extracts square meter data from floorplan images via OCR, calculates transit routes, and serves an interactive map-based web UI. The repo contains three sub-projects:
- **`crawler/`** — Main application (Python backend + React frontend). Has its own `CLAUDE.md` with detailed architecture docs.
- **`immoweb/`** — Separate scraper (Node.js, legacy/reference).
- **`vqa/`** — Visual QA / testing tooling.
## Command Execution
**All commands run inside Docker containers.** The dev environment uses Docker Compose — start it first, then exec into containers for any operations.
- **Infrastructure commands** (docker compose, kubectl) — Run locally on the Mac
- **All project commands** (pytest, poetry, alembic, python, mypy, ruff, etc.) — Run inside the `app` container via `docker compose exec app <command>`
See `.claude/skills/` for detailed skills on dev environment, building, and deploying.
## Quick Reference
| Action | Command | Where |
|--------|---------|-------|
| Start all services | `./start.sh` | `crawler/` |
| Rebuild & start | `./start.sh --build` | `crawler/` |
| Stop services | `./start.sh --down` | `crawler/` |
| Run tests | `pytest tests/ -v --cov=. --cov-report=term-missing` | `crawler/` |
| Type check | `mypy .` | `crawler/` |
| Format code | `yapf --style .style.yapf --recursive .` | `crawler/` |
| Lint (CI runs Ruff) | `ruff check .` | `crawler/` |
| DB migration | `alembic upgrade head` | `crawler/` |
| New migration | `alembic revision -m "description"` | `crawler/` |
| Frontend dev | `cd frontend && ./start.sh` | `crawler/` |
## Tech Stack
- **Backend:** Python 3.13, FastAPI, SQLModel (SQLAlchemy 2), Celery + Redis, pytesseract/OpenCV
- **Frontend:** React 19, TypeScript, Vite, Tailwind CSS, Radix UI, Mapbox GL
- **Database:** MySQL 9 (prod) / SQLite (local dev), Alembic migrations
- **Infrastructure:** Docker Compose (dev), Kubernetes (prod), Drone CI
## Code Conventions
- **Type checking:** Strict mypy (`disallow_untyped_defs=true`). All functions need type annotations.
- **Formatting:** YAPF (`.style.yapf` config) + Ruff linter (CI on PRs).
- **Async:** Scraper layer uses `async/await` throughout with `aiohttp`.
- **Models:** SQLModel for DB entities, Pydantic for request/response validation.
- **Service layer:** `services/` contains unified handlers used by both CLI and API — add new business logic here, not in API routes or CLI commands directly.
- **Repository pattern:** `repositories/` for database queries. Don't put raw SQL in services or API.
- **Tests:** pytest with `pytest-asyncio` (auto mode). Unit tests in `tests/unit/`, integration in `tests/integration/`.
## Architecture Layers (in `crawler/`)
```
API (api/app.py) ←→ Services (services/) ←→ Repositories (repositories/)
↑ ↑ ↑
Auth (OIDC) Core Logic (rec/) SQLModel (models/)
Celery Tasks (tasks/)
```
- `rec/` — Core scraping, OCR, routing logic. Contains circuit breaker and throttle detection.
- `services/` — Orchestration. Query splitting, listing fetching, caching.
- `api/` — FastAPI routes, auth middleware, metrics.
- `models/``RentListing`, `BuyListing` SQLModel entities.
- `tasks/` — Celery background tasks with Redis broker.
- `config/` — Env-var-based configuration (scraper settings, schedules).
## Key Design Decisions
- **Query splitting** works around Rightmove's ~1500-result API cap by adaptively splitting queries by district, bedrooms, and price bands (binary search). See `services/query_splitter.py`.
- **Circuit breaker** (`rec/circuit_breaker.py`) and **throttle detection** (`rec/throttle_detector.py`) protect against rate limiting.
- **Streaming API** uses NDJSON for progressive loading of large result sets.
- **Redis** serves dual duty as Celery broker and GeoJSON cache.
## Environment Variables
See `crawler/.env.sample` for the full list. Key ones:
- `DB_CONNECTION_STRING` — Database URL
- `CELERY_BROKER_URL` / `CELERY_RESULT_BACKEND` — Redis URLs
- `ROUTING_API_KEY` — Google Maps API key
- `RIGHTMOVE_*` — Scraper tuning (concurrency, delays, thresholds, proxy)
- `SCRAPE_SCHEDULES` — JSON array of periodic scrape configs
## Git Workflow
- CI: Drone CI builds Docker images on push to `master`, deploys to K8s.
- Linting: GitHub Actions runs Ruff on PR diffs.
- Keep commits focused — one logical change per commit.
- Group related files (e.g., code + its tests) in the same commit.
## Directories to Ignore
- `node_modules/`, `__pycache__/`, `.idea/`, `crawler/data/`, `venv/`, `_cache/`

View file

@ -0,0 +1,113 @@
---
name: build-and-push
description: |
Build Docker images for the API and frontend, and push them to Docker Hub.
Use when: (1) user wants to build new Docker images locally, (2) push images
to the registry before deploying, (3) tag images for a release. Covers both
the Python/FastAPI backend and the React/Nginx frontend.
author: Claude Code
version: 1.0.0
date: 2026-02-06
---
# Build and Push Docker Images
All commands run locally. Images are pushed to Docker Hub under the `viktorbarzin` namespace.
## Image Registries
| Component | Docker Hub repo | Dockerfile location |
|-----------|------------------------------------|------------------------------|
| API | `viktorbarzin/realestatecrawler` | `crawler/Dockerfile` |
| Frontend | `viktorbarzin/immoweb` | `crawler/frontend/Dockerfile`|
## Building Images
### Build API image
```bash
docker build -t viktorbarzin/realestatecrawler:latest crawler/
```
### Build Frontend image
```bash
docker build -t viktorbarzin/immoweb:latest crawler/frontend/
```
### Build both
```bash
docker build -t viktorbarzin/realestatecrawler:latest crawler/ && \
docker build -t viktorbarzin/immoweb:latest crawler/frontend/
```
### Build with a specific tag (recommended for production)
```bash
# Use git commit SHA
GIT_SHA=$(git rev-parse --short HEAD)
docker build -t viktorbarzin/realestatecrawler:${GIT_SHA} -t viktorbarzin/realestatecrawler:latest crawler/
docker build -t viktorbarzin/immoweb:${GIT_SHA} -t viktorbarzin/immoweb:latest crawler/frontend/
```
## Pushing Images
### Login to Docker Hub (if not already)
```bash
docker login -u viktorbarzin
```
### Push API image
```bash
docker push viktorbarzin/realestatecrawler:latest
```
### Push Frontend image
```bash
docker push viktorbarzin/immoweb:latest
```
### Push with specific tag
```bash
GIT_SHA=$(git rev-parse --short HEAD)
docker push viktorbarzin/realestatecrawler:${GIT_SHA}
docker push viktorbarzin/realestatecrawler:latest
docker push viktorbarzin/immoweb:${GIT_SHA}
docker push viktorbarzin/immoweb:latest
```
## Build and Push Everything (Full Release)
```bash
GIT_SHA=$(git rev-parse --short HEAD)
# Build
docker build -t viktorbarzin/realestatecrawler:${GIT_SHA} -t viktorbarzin/realestatecrawler:latest crawler/
docker build -t viktorbarzin/immoweb:${GIT_SHA} -t viktorbarzin/immoweb:latest crawler/frontend/
# Push
docker push viktorbarzin/realestatecrawler:${GIT_SHA}
docker push viktorbarzin/realestatecrawler:latest
docker push viktorbarzin/immoweb:${GIT_SHA}
docker push viktorbarzin/immoweb:latest
```
## CI/CD Note
Drone CI automatically builds and pushes images on push to `master` (see `.drone.yml`).
The manual process above is for when you need to build/push outside of CI, such as:
- Hotfix deployments
- Testing image builds locally before pushing
- Deploying from a non-master branch
## Notes
- The API Dockerfile installs system deps (OpenCV, Tesseract, MariaDB client) and Python deps via Poetry
- The Frontend Dockerfile is a multi-stage build: Node builder -> Nginx runtime
- Always tag with both `:latest` and a specific tag (git SHA or version) for traceability
- Use `docker buildx` for cross-platform builds if deploying to ARM nodes

View file

@ -0,0 +1,158 @@
---
name: deploy-to-kubernetes
description: |
Deploy the realestate-crawler to the Kubernetes cluster. Use when: (1) user wants
to deploy after building new images, (2) rollout restart to pick up new images,
(3) check deployment status, pod health, or logs in production, (4) scale
deployments up or down, (5) debug production issues.
author: Claude Code
version: 1.0.0
date: 2026-02-06
---
# Deploy to Kubernetes
All kubectl commands run locally against the K8s cluster at `10.0.20.100:6443`.
Namespace: `realestate-crawler`.
## Deployments
| Deployment | Image | Component |
|--------------------------|------------------------------------|-----------|
| realestate-crawler-api | viktorbarzin/realestatecrawler | API |
| realestate-crawler-ui | viktorbarzin/immoweb | Frontend |
## Deploying New Images
### Rolling restart (picks up :latest after push)
```bash
# Restart API deployment
kubectl rollout restart deployment/realestate-crawler-api -n realestate-crawler
# Restart Frontend deployment
kubectl rollout restart deployment/realestate-crawler-ui -n realestate-crawler
# Restart both
kubectl rollout restart deployment/realestate-crawler-api deployment/realestate-crawler-ui -n realestate-crawler
```
### Full deploy workflow (build, push, restart)
```bash
GIT_SHA=$(git rev-parse --short HEAD)
# Build and push API
docker build -t viktorbarzin/realestatecrawler:${GIT_SHA} -t viktorbarzin/realestatecrawler:latest crawler/
docker push viktorbarzin/realestatecrawler:${GIT_SHA}
docker push viktorbarzin/realestatecrawler:latest
# Build and push Frontend
docker build -t viktorbarzin/immoweb:${GIT_SHA} -t viktorbarzin/immoweb:latest crawler/frontend/
docker push viktorbarzin/immoweb:${GIT_SHA}
docker push viktorbarzin/immoweb:latest
# Restart deployments to pick up new images
kubectl rollout restart deployment/realestate-crawler-api -n realestate-crawler
kubectl rollout restart deployment/realestate-crawler-ui -n realestate-crawler
```
### Deploy a specific image tag
```bash
# Set API to a specific image version
kubectl set image deployment/realestate-crawler-api \
realestate-crawler-api=viktorbarzin/realestatecrawler:abc1234 \
-n realestate-crawler
# Set Frontend to a specific version
kubectl set image deployment/realestate-crawler-ui \
realestate-crawler-ui=viktorbarzin/immoweb:abc1234 \
-n realestate-crawler
```
## Checking Deployment Status
```bash
# List all resources in namespace
kubectl get all -n realestate-crawler
# Check deployment status
kubectl get deployments -n realestate-crawler
# Check rollout status (waits for completion)
kubectl rollout status deployment/realestate-crawler-api -n realestate-crawler
kubectl rollout status deployment/realestate-crawler-ui -n realestate-crawler
# Check pods
kubectl get pods -n realestate-crawler
# Describe a specific pod (events, conditions, image)
kubectl describe pod <pod-name> -n realestate-crawler
# Check which image a pod is running
kubectl get pods -n realestate-crawler -o jsonpath='{range .items[*]}{.metadata.name}{"\t"}{.spec.containers[*].image}{"\n"}{end}'
```
## Viewing Logs
```bash
# API logs
kubectl logs deployment/realestate-crawler-api -n realestate-crawler --tail=100 -f
# Frontend logs
kubectl logs deployment/realestate-crawler-ui -n realestate-crawler --tail=100 -f
# Logs from a specific pod
kubectl logs <pod-name> -n realestate-crawler --tail=100 -f
# Previous container logs (if pod crashed/restarted)
kubectl logs <pod-name> -n realestate-crawler --previous
```
## Scaling
```bash
# Scale API
kubectl scale deployment/realestate-crawler-api --replicas=2 -n realestate-crawler
# Scale down (e.g., for maintenance)
kubectl scale deployment/realestate-crawler-api --replicas=0 -n realestate-crawler
```
## Rollback
```bash
# View rollout history
kubectl rollout history deployment/realestate-crawler-api -n realestate-crawler
# Rollback to previous version
kubectl rollout undo deployment/realestate-crawler-api -n realestate-crawler
# Rollback to specific revision
kubectl rollout undo deployment/realestate-crawler-api --to-revision=3 -n realestate-crawler
```
## Debugging Production Issues
```bash
# Exec into a running API pod
kubectl exec -it deployment/realestate-crawler-api -n realestate-crawler -- bash
# Run a one-off command
kubectl exec deployment/realestate-crawler-api -n realestate-crawler -- python -c "print('hello')"
# Check pod events (useful for crash loops, image pull errors)
kubectl get events -n realestate-crawler --sort-by='.lastTimestamp' | tail -20
# Port-forward to a pod for local debugging
kubectl port-forward deployment/realestate-crawler-api 5001:5001 -n realestate-crawler
```
## Notes
- Drone CI handles automated deployments on push to `master` (see `.drone.yml`)
- Use manual deployment for hotfixes, testing, or deploying from non-master branches
- The K8s cluster is at `10.0.20.100:6443` (context: `kubernetes-admin@kubernetes`)
- If pods aren't picking up new `:latest` images, check the `kubernetes-latest-tag-image-pull` skill
- Always verify the rollout completed with `kubectl rollout status` after deploying

View file

@ -0,0 +1,144 @@
---
name: dev-environment
description: |
Start, stop, rebuild, and manage the local Docker Compose development environment
for the realestate-crawler project. Use when: (1) user wants to start/stop the dev
environment, (2) needs to rebuild after code changes, (3) wants to check service
status or view logs, (4) needs to run database migrations.
author: Claude Code
version: 1.0.0
date: 2026-02-06
---
# Dev Environment Management
Docker Compose orchestrates the dev environment locally from `crawler/`. All project
commands (pytest, alembic, mypy, python, etc.) must run inside the `app` container
via `docker compose exec app <command>`. Only docker/kubectl commands run on the host.
## Starting the Dev Environment
```bash
# Start all services (Redis, MySQL, API, Celery worker, Celery beat)
cd crawler && docker compose up
# Start in detached mode (background)
cd crawler && docker compose up -d
# Rebuild images and start (after Dockerfile or dependency changes)
cd crawler && docker compose up --build
# Or use the start.sh helper
cd crawler && ./start.sh # foreground
cd crawler && ./start.sh --build # rebuild first
```
## Stopping the Dev Environment
```bash
cd crawler && docker compose down
# Also remove volumes (fresh database, fresh Redis)
cd crawler && docker compose down -v
# Or use the helper
cd crawler && ./start.sh --down
```
## Checking Status
```bash
# List running containers
cd crawler && docker compose ps
# Check health status
cd crawler && docker compose ps --format "table {{.Name}}\t{{.Status}}"
```
## Viewing Logs
```bash
# Follow all service logs
cd crawler && docker compose logs -f
# Follow specific service logs
cd crawler && docker compose logs -f app
cd crawler && docker compose logs -f celery
cd crawler && docker compose logs -f celery-beat
cd crawler && docker compose logs -f mysql
cd crawler && docker compose logs -f redis
# Or use the helper
cd crawler && ./start.sh --logs
```
## Restarting Individual Services
```bash
# Restart just the API (e.g., after config change)
cd crawler && docker compose restart app
# Restart Celery worker
cd crawler && docker compose restart celery
# Rebuild and restart a single service
cd crawler && docker compose up --build app
```
## Running Database Migrations
```bash
# Apply pending migrations
cd crawler && docker compose exec app alembic upgrade head
# Create a new migration
cd crawler && docker compose exec app alembic revision -m "description"
```
## Running Tests Inside Container
```bash
cd crawler && docker compose exec app pytest tests/ -v --cov=. --cov-report=term-missing
```
## Running Any Command Inside Container
All project commands must be run inside the `app` container:
```bash
# General pattern
cd crawler && docker compose exec app <command>
# Examples
cd crawler && docker compose exec app python main.py dump-listings --type rent
cd crawler && docker compose exec app mypy .
cd crawler && docker compose exec app ruff check .
cd crawler && docker compose exec app poetry install
cd crawler && docker compose exec app bash # interactive shell
```
## Services and Ports
| Service | Container | Port | Description |
|-------------|-----------------|-------|--------------------------------|
| redis | rec-redis | 6379 | Celery broker + GeoJSON cache |
| mysql | rec-mysql | 3306 | Primary database |
| app | rec-app | 5001 | FastAPI server (hot-reload) |
| celery | rec-celery | - | Background task worker |
| celery-beat | rec-celery-beat | - | Periodic task scheduler |
## Environment Variables
Key env vars are set in `docker-compose.yml`. To override locally, create a `.env` file
in `crawler/` (see `.env.sample`). Key overrides:
- `ROUTING_API_KEY` - Google Maps API key (passed from host env)
- `SCRAPE_SCHEDULES` - JSON array of periodic scrape configs (passed from host env)
## Notes
- The API server has hot-reload enabled for `api/`, `services/`, `repositories/`, and `models/` directories
- Source code is bind-mounted into containers, so local edits are reflected immediately
- Python virtualenv is stored in a named Docker volume (`app_venv`) shared across app, celery, and celery-beat
- MySQL data persists in the `mysql_data` volume; Redis data in `redis_data`
- Use `docker compose down -v` to reset all data (volumes)

View file

@ -3,7 +3,7 @@ from datetime import datetime, timedelta
import json
import logging
import logging.config
from typing import Annotated, Optional
from typing import Annotated, AsyncGenerator, Optional
from api.auth import get_current_user
from api.config import DEV_TIER_ORIGINS, PROD_TIER_ORIGINS
from api.passkey_routes import passkey_router
@ -32,6 +32,8 @@ from opentelemetry.metrics import get_meter
load_dotenv()
logger = logging.getLogger("uvicorn")
DEFAULT_BATCH_SIZE = 50
def get_query_parameters(
listing_type: ListingType,
@ -120,11 +122,79 @@ async def get_listing_geojson(
return result.data
async def _stream_from_cache(
query_parameters: QueryParameters,
batch_size: int,
limit: int | None,
) -> AsyncGenerator[str, None]:
"""Stream GeoJSON features from the Redis cache (cache-hit path)."""
cached_count = get_cached_count(query_parameters)
effective_total = min(limit, cached_count) if limit and cached_count else cached_count
yield json.dumps({
"type": "metadata",
"batch_size": batch_size,
"total_expected": effective_total,
"cached": True,
}) + "\n"
count = 0
for feature_batch in get_cached_features(query_parameters, batch_size=batch_size):
if limit and count + len(feature_batch) > limit:
feature_batch = feature_batch[:limit - count]
count += len(feature_batch)
yield json.dumps({"type": "batch", "features": feature_batch}) + "\n"
if limit and count >= limit:
break
yield json.dumps({"type": "complete", "total": count}) + "\n"
async def _stream_from_db(
query_parameters: QueryParameters,
batch_size: int,
limit: int | None,
) -> AsyncGenerator[str, None]:
"""Stream GeoJSON features from the database, populating the cache as we go."""
repository = ListingRepository(engine)
total = repository.count_listings(query_parameters)
effective_total = min(limit, total) if limit else total
yield json.dumps({
"type": "metadata",
"batch_size": batch_size,
"total_expected": effective_total,
"cached": False,
}) + "\n"
count = 0
batch: list[dict] = []
for row in repository.stream_listings_optimized(
query_parameters, limit=limit, page_size=batch_size
):
feature = convert_row_to_geojson(row, query_parameters.listing_type.value)
batch.append(feature)
count += 1
if len(batch) >= batch_size:
cache_features_batch(query_parameters, batch)
yield json.dumps({"type": "batch", "features": batch}) + "\n"
batch = []
if batch:
cache_features_batch(query_parameters, batch)
yield json.dumps({"type": "batch", "features": batch}) + "\n"
yield json.dumps({"type": "complete", "total": count}) + "\n"
@app.get("/api/listing_geojson/stream")
async def stream_listing_geojson(
user: Annotated[User, Depends(get_current_user)],
query_parameters: Annotated[QueryParameters, Depends(get_query_parameters)],
batch_size: int = 50,
batch_size: int = DEFAULT_BATCH_SIZE,
limit: int | None = None,
) -> StreamingResponse:
"""Stream listings as NDJSON for progressive map loading.
@ -134,71 +204,14 @@ async def stream_listing_geojson(
- batch: Array of GeoJSON features
- complete: Final message with total count
"""
async def generate():
# Check cache first
cached_count = get_cached_count(query_parameters)
if cached_count is not None and cached_count > 0:
# Cache HIT
effective_total = min(limit, cached_count) if limit else cached_count
yield json.dumps({
"type": "metadata",
"batch_size": batch_size,
"total_expected": effective_total,
"cached": True,
}) + "\n"
count = 0
for feature_batch in get_cached_features(query_parameters, batch_size=batch_size):
if limit and count + len(feature_batch) > limit:
feature_batch = feature_batch[:limit - count]
count += len(feature_batch)
yield json.dumps({"type": "batch", "features": feature_batch}) + "\n"
if limit and count >= limit:
break
yield json.dumps({"type": "complete", "total": count}) + "\n"
else:
# Cache MISS - query DB and populate cache
repository = ListingRepository(engine)
# Phase 1: Fast count for progress estimation
total = repository.count_listings(query_parameters)
effective_total = min(limit, total) if limit else total
yield json.dumps({
"type": "metadata",
"batch_size": batch_size,
"total_expected": effective_total,
"cached": False,
}) + "\n"
# Phase 2: Stream with column projection and keyset pagination
count = 0
batch = []
for row in repository.stream_listings_optimized(
query_parameters, limit=limit, page_size=batch_size
):
feature = convert_row_to_geojson(row, query_parameters.listing_type.value)
batch.append(feature)
count += 1
if len(batch) >= batch_size:
cache_features_batch(query_parameters, batch)
yield json.dumps({"type": "batch", "features": batch}) + "\n"
batch = []
# Send remaining
if batch:
cache_features_batch(query_parameters, batch)
yield json.dumps({"type": "batch", "features": batch}) + "\n"
# Final message
yield json.dumps({"type": "complete", "total": count}) + "\n"
cached_count = get_cached_count(query_parameters)
if cached_count is not None and cached_count > 0:
generator = _stream_from_cache(query_parameters, batch_size, limit)
else:
generator = _stream_from_db(query_parameters, batch_size, limit)
return StreamingResponse(
generate(),
generator,
media_type="application/x-ndjson",
headers={
"Cache-Control": "no-cache",

View file

@ -59,7 +59,6 @@ async def _verify_authentik_token(token: str) -> User:
algorithms=["RS256"],
audience=OIDC_CLIENT_ID,
issuer=metadata["issuer"],
options={"verify_exp": False},
)
return User(**payload)
@ -84,7 +83,9 @@ async def get_current_user(
) -> User:
token = credentials.credentials
try:
# Peek at unverified issuer to route verification
# Decode WITHOUT verification just to read the "iss" claim for routing.
# This is safe: we only use the issuer to decide which verified decode
# path to take next; the actual security check happens in the branch below.
unverified = jwt.decode(
token, options={"verify_signature": False, "verify_exp": False}
)

View file

@ -1,10 +1,13 @@
from datetime import timedelta
import logging
import os
_logger = logging.getLogger(__name__)
# Authentik OIDC Configuration
AUTHENTIK_URL = "https://authentik.viktorbarzin.me"
OIDC_CLIENT_ID = "5AJKRgcdgVm1OyApBzFkadDFfStW9a555zwv2MOe"
AUTHENTIK_URL = os.getenv("AUTHENTIK_URL", "https://authentik.viktorbarzin.me")
OIDC_CLIENT_ID = os.getenv("OIDC_CLIENT_ID", "5AJKRgcdgVm1OyApBzFkadDFfStW9a555zwv2MOe")
OIDC_METADATA_URL = (
f"{AUTHENTIK_URL}/application/o/wrongmove/.well-known/openid-configuration"
)
@ -23,6 +26,8 @@ WEBAUTHN_ORIGIN = os.getenv("WEBAUTHN_ORIGIN", "https://localhost")
# JWT Configuration (for passkey-issued tokens)
JWT_SECRET = os.getenv("JWT_SECRET", "change-me-in-production")
if JWT_SECRET == "change-me-in-production":
_logger.warning("JWT_SECRET is using the default value. Set JWT_SECRET env var in production.")
JWT_ALGORITHM = os.getenv("JWT_ALGORITHM", "HS256")
JWT_EXPIRATION_HOURS = int(os.getenv("JWT_EXPIRATION_HOURS", "24"))
JWT_ISSUER = os.getenv("JWT_ISSUER", "wrongmove")

View file

@ -1,3 +1,11 @@
"""Legacy filesystem-based data access.
.. deprecated::
This module is only used by the ``populate_db`` CLI command for migrating
old filesystem data into the database. Do not import from this module in
new code. Use ``models.listing.RentListing`` or ``models.listing.BuyListing``
and ``repositories.listing_repository.ListingRepository`` instead.
"""
import asyncio
from collections import defaultdict
from dataclasses import dataclass
@ -381,8 +389,6 @@ class Listing:
for item in data
]
@property
def listing_site(self) -> ListingSite:
return ListingSite.RIGHTMOVE # this class supports only right move

View file

@ -1 +1 @@
{"root":["./src/App.tsx","./src/AppSidebar.tsx","./src/main.tsx","./src/vite-env.d.ts","./src/auth/authService.ts","./src/auth/config.ts","./src/auth/errors.ts","./src/components/ActiveQuery.tsx","./src/components/AlertError.tsx","./src/components/AuthCallback.tsx","./src/components/FilterPanel.tsx","./src/components/Header.tsx","./src/components/HealthIndicator.tsx","./src/components/ListView.tsx","./src/components/LoginModal.tsx","./src/components/Map.tsx","./src/components/Parameters.tsx","./src/components/PropertyCard.tsx","./src/components/Spinner.tsx","./src/components/StatsBar.tsx","./src/components/StreamingProgressBar.tsx","./src/components/TaskIndicator.tsx","./src/components/ui/DatePicker.tsx","./src/components/ui/accordion.tsx","./src/components/ui/alert-dialog.tsx","./src/components/ui/breadcrumb.tsx","./src/components/ui/button.tsx","./src/components/ui/calendar.tsx","./src/components/ui/checkbox.tsx","./src/components/ui/dialog.tsx","./src/components/ui/form.tsx","./src/components/ui/hover-card.tsx","./src/components/ui/input.tsx","./src/components/ui/label.tsx","./src/components/ui/popover.tsx","./src/components/ui/progress.tsx","./src/components/ui/scroll-area.tsx","./src/components/ui/select.tsx","./src/components/ui/separator.tsx","./src/components/ui/sheet.tsx","./src/components/ui/sidebar.tsx","./src/components/ui/skeleton.tsx","./src/components/ui/slider.tsx","./src/components/ui/tooltip.tsx","./src/constants/colorSchemes.ts","./src/constants/index.ts","./src/hooks/use-mobile.ts","./src/lib/utils.ts","./src/services/apiClient.ts","./src/services/healthService.ts","./src/services/index.ts","./src/services/listingService.ts","./src/services/streamingService.ts","./src/services/taskService.ts","./src/types/index.ts","./src/utils/mapUtils.ts"],"version":"5.8.3"}
{"root":["./src/app.tsx","./src/appsidebar.tsx","./src/main.tsx","./src/vite-env.d.ts","./src/auth/authservice.ts","./src/auth/config.ts","./src/auth/errors.ts","./src/auth/passkeyservice.ts","./src/auth/types.ts","./src/components/activequery.tsx","./src/components/alerterror.tsx","./src/components/authcallback.tsx","./src/components/filterpanel.tsx","./src/components/header.tsx","./src/components/healthindicator.tsx","./src/components/listview.tsx","./src/components/loginmodal.tsx","./src/components/map.tsx","./src/components/parameters.tsx","./src/components/propertycard.tsx","./src/components/spinner.tsx","./src/components/statsbar.tsx","./src/components/streamingprogressbar.tsx","./src/components/taskindicator.tsx","./src/components/taskprogressdrawer.tsx","./src/components/ui/datepicker.tsx","./src/components/ui/accordion.tsx","./src/components/ui/alert-dialog.tsx","./src/components/ui/breadcrumb.tsx","./src/components/ui/button.tsx","./src/components/ui/calendar.tsx","./src/components/ui/checkbox.tsx","./src/components/ui/dialog.tsx","./src/components/ui/form.tsx","./src/components/ui/hover-card.tsx","./src/components/ui/input.tsx","./src/components/ui/label.tsx","./src/components/ui/popover.tsx","./src/components/ui/progress.tsx","./src/components/ui/scroll-area.tsx","./src/components/ui/select.tsx","./src/components/ui/separator.tsx","./src/components/ui/sheet.tsx","./src/components/ui/sidebar.tsx","./src/components/ui/skeleton.tsx","./src/components/ui/slider.tsx","./src/components/ui/tabs.tsx","./src/components/ui/tooltip.tsx","./src/constants/colorschemes.ts","./src/constants/index.ts","./src/hooks/use-mobile.ts","./src/lib/utils.ts","./src/services/apiclient.ts","./src/services/healthservice.ts","./src/services/index.ts","./src/services/listingservice.ts","./src/services/streamingservice.ts","./src/services/taskservice.ts","./src/types/index.ts","./src/utils/maputils.ts"],"version":"5.8.3"}

View file

@ -6,6 +6,7 @@ from datetime import datetime
import logging
import multiprocessing
from pathlib import Path
from urllib.parse import urlparse
import aiohttp
from models.listing import FurnishType, Listing, ListingSite, RentListing
from rec import floorplan
@ -14,8 +15,33 @@ from repositories.listing_repository import ListingRepository
logger = logging.getLogger("uvicorn.error")
# Also use celery task logger for visibility in worker output
celery_logger = logging.getLogger("celery.task")
# Limit OCR threads to 25% of available cores to avoid starving other work.
MAX_OCR_WORKERS = max(1, multiprocessing.cpu_count() // 4)
def _parse_furnish_type(raw: str | None) -> FurnishType:
"""Normalise the raw furnish-type string from the API into a FurnishType enum."""
if raw is None:
return FurnishType.UNKNOWN
if "landlord" in raw.lower():
return FurnishType.ASK_LANDLORD
lowered = raw.lower()
try:
return FurnishType(lowered)
except ValueError:
return FurnishType.UNKNOWN
def _parse_available_from(raw: str | None) -> datetime | None:
"""Parse the available-from date string into a datetime, or None."""
if raw is None:
return None
if raw.lower() == "now":
return datetime.now()
try:
return datetime.strptime(raw, "%d/%m/%Y")
except ValueError:
return None
class ListingProcessor:
@ -62,7 +88,6 @@ class ListingProcessor:
on_step_complete(short_name)
except Exception as e:
logger.error(f"[{listing_id}] {step_class_name} failed: {e}")
celery_logger.error(f"[{listing_id}] {step_class_name} failed: {e}")
return None
return listing
@ -92,7 +117,7 @@ class FetchListingDetailsStep(Step):
async def process(self, listing_id: int) -> Listing:
logger.debug(f"[{listing_id}] Fetching property details from API")
celery_logger.info(f"[{listing_id}] Fetching details...")
logger.info(f"[{listing_id}] Fetching details...")
existing_listings = await self.listing_repository.get_listings(
only_ids=[listing_id]
@ -105,30 +130,15 @@ class FetchListingDetailsStep(Step):
listing_details = await detail_query(listing_id)
furnish_type_str = listing_details["property"].get("letFurnishType", "unknown")
if furnish_type_str is None:
furnish_type_str = "unknown"
elif "landlord" in furnish_type_str.lower():
furnish_type_str = "ask landlord"
else:
furnish_type_str = furnish_type_str.lower()
furnish_type = FurnishType(furnish_type_str)
furnish_type = _parse_furnish_type(
listing_details["property"].get("letFurnishType", "unknown")
)
available_from: datetime | None = None
available_from_str: str | None = listing_details["property"]["letDateAvailable"]
if available_from_str is None:
available_from = None
elif available_from_str.lower() == "now":
available_from = datetime.now()
else:
try:
available_from = datetime.strptime(available_from_str, "%d/%m/%Y")
except ValueError:
# If the date format is not as expected, return None
available_from = None
available_from = _parse_available_from(
listing_details["property"]["letDateAvailable"]
)
photos = listing_details["property"]["photos"]
# listing = Listing(
listing = RentListing( # TODO: should pick based on price?
id=listing_id,
price=listing_details["property"]["price"],
@ -150,7 +160,7 @@ class FetchListingDetailsStep(Step):
)
await self.listing_repository.upsert_listings([listing])
celery_logger.info(
logger.info(
f"[{listing_id}] Details fetched: £{listing.price}, "
f"{listing.number_of_bedrooms}BR, {listing.agency}"
)
@ -190,13 +200,13 @@ class FetchImagesStep(Step):
downloaded = 0
client_timeout = aiohttp.ClientTimeout(total=30)
for floorplan_obj in all_floorplans:
url = floorplan_obj["url"]
picname = url.split("/")[-1]
floorplan_path = Path(base_path, str(listing.id), "floorplans", picname)
if floorplan_path.exists():
continue
async with aiohttp.ClientSession() as session:
async with aiohttp.ClientSession() as session:
for floorplan_obj in all_floorplans:
url = floorplan_obj["url"]
picname = Path(urlparse(url).path).name
floorplan_path = Path(base_path, str(listing.id), "floorplans", picname)
if floorplan_path.exists():
continue
async with session.get(url, timeout=client_timeout) as response:
if response.status == 404:
return listing
@ -210,7 +220,7 @@ class FetchImagesStep(Step):
await self.listing_repository.upsert_listings([listing])
celery_logger.info(f"[{listing_id}] Downloaded {downloaded} floorplan images")
logger.info(f"[{listing_id}] Downloaded {downloaded} floorplan images")
logger.debug(f"[{listing_id}] Image fetch complete")
return listing
@ -220,7 +230,7 @@ class DetectFloorplanStep(Step):
def __init__(self, listing_repository: ListingRepository):
super().__init__(listing_repository)
self.ocr_semaphore = asyncio.Semaphore(multiprocessing.cpu_count() // 4)
self.ocr_semaphore = asyncio.Semaphore(MAX_OCR_WORKERS)
async def needs_processing(self, listing_id: int) -> bool:
listings = await self.listing_repository.get_listings(only_ids=[listing_id])
@ -256,7 +266,7 @@ class DetectFloorplanStep(Step):
await self.listing_repository.upsert_listings([listing])
if max_sqm > 0:
celery_logger.info(f"[{listing_id}] OCR detected {max_sqm} sqm")
logger.info(f"[{listing_id}] OCR detected {max_sqm} sqm")
else:
logger.debug(f"[{listing_id}] OCR: no square meters detected")

View file

@ -22,13 +22,50 @@ P = ParamSpec("P")
R = TypeVar("R")
def build_query_parameters(
type: str,
district: list[str] | tuple[str, ...] | None,
min_bedrooms: int,
max_bedrooms: int,
min_price: int,
max_price: int,
furnish_types: list[str] | tuple[str, ...],
available_from: datetime | None,
last_seen_days: int,
min_sqm: int | None = None,
radius: int = 0,
page_size: int = 500,
max_days_since_added: int = 14,
) -> QueryParameters:
"""Build QueryParameters from CLI options."""
return QueryParameters(
listing_type=ListingType[type],
district_names=set(district) if district else set(),
min_bedrooms=min_bedrooms,
max_bedrooms=max_bedrooms,
min_price=min_price,
max_price=max_price,
furnish_types=[FurnishType[ft] for ft in furnish_types] if furnish_types else None,
let_date_available_from=available_from,
last_seen_days=last_seen_days,
min_sqm=min_sqm,
radius=radius,
page_size=page_size,
max_days_since_added=max_days_since_added,
)
def listing_filter_options(func: Callable[P, R]) -> Callable[P, R]:
"""Decorator to add common options for filtering listings."""
"""Decorator that adds common listing filter options and builds QueryParameters.
The wrapped function receives a `query_parameters: QueryParameters` kwarg
instead of individual filter values.
"""
@click.option(
"--type",
"-t",
help="Type of listing to scrape",
help="Type of listing to scrape (BUY or RENT)",
type=click.Choice(
ListingType.__members__.keys(),
case_sensitive=False,
@ -50,26 +87,26 @@ def listing_filter_options(func: Callable[P, R]) -> Callable[P, R]:
@click.option(
"--min-price",
default=0,
help="Minimum price",
help="Minimum price in GBP",
type=click.IntRange(min=0),
)
@click.option(
"--max-price",
default=999_999,
help="Maximum price",
help="Maximum price in GBP",
type=click.IntRange(min=0),
)
@click.option(
"--district",
default=None,
help="Districts to scrape",
help="District to filter by (can be repeated for multiple districts)",
type=click.Choice(district_service.get_district_names(), case_sensitive=False),
multiple=True,
)
@click.option(
"--furnish-types",
"-f",
help="Furnish types for rented listings",
help="Furnish type filter for rented listings (can be repeated)",
type=click.Choice(
[furnish_type.name for furnish_type in FurnishType.__members__.values()],
case_sensitive=False,
@ -78,13 +115,13 @@ def listing_filter_options(func: Callable[P, R]) -> Callable[P, R]:
)
@click.option(
"--available-from",
help="Let date available from",
help="Only include listings available from this date (format: YYYY-MM-DD)",
default=None,
type=click.DateTime(),
)
@click.option(
"--last-seen-days",
help="Last seen (days). If set, only listings that were seen in the last N days will be included.",
help="Only include listings seen in the last N days",
default=14,
type=int,
)
@ -95,45 +132,37 @@ def listing_filter_options(func: Callable[P, R]) -> Callable[P, R]:
type=int,
)
@wraps(func)
def wrapper(*args: P.args, **kwargs: P.kwargs) -> R:
return func(*args, **kwargs)
def wrapper(
*args: P.args,
type: str,
district: tuple[str, ...],
min_bedrooms: int,
max_bedrooms: int,
min_price: int,
max_price: int,
furnish_types: tuple[str, ...],
available_from: datetime | None,
last_seen_days: int,
min_sqm: int | None,
**kwargs: P.kwargs,
) -> R:
query_parameters = build_query_parameters(
type=type,
district=district,
min_bedrooms=min_bedrooms,
max_bedrooms=max_bedrooms,
min_price=min_price,
max_price=max_price,
furnish_types=furnish_types,
available_from=available_from,
last_seen_days=last_seen_days,
min_sqm=min_sqm,
)
return func(*args, query_parameters=query_parameters, **kwargs)
return wrapper
def build_query_parameters(
type: str,
district: list[str],
min_bedrooms: int,
max_bedrooms: int,
min_price: int,
max_price: int,
furnish_types: list[str],
available_from: datetime | None,
last_seen_days: int,
min_sqm: int | None = None,
radius: int = 0,
page_size: int = 500,
max_days_since_added: int = 14,
) -> QueryParameters:
"""Build QueryParameters from CLI options."""
return QueryParameters(
listing_type=ListingType[type],
district_names=set(district) if district else None,
min_bedrooms=min_bedrooms,
max_bedrooms=max_bedrooms,
min_price=min_price,
max_price=max_price,
furnish_types=[FurnishType[ft] for ft in furnish_types] if furnish_types else None,
let_date_available_from=available_from,
last_seen_days=last_seen_days,
min_sqm=min_sqm,
radius=radius,
page_size=page_size,
max_days_since_added=max_days_since_added,
)
@click.group()
@click.option(
"--data-dir",
@ -155,46 +184,28 @@ def cli(ctx: click.Context, data_dir: str) -> None:
@cli.command()
@listing_filter_options
@click.option("--full", is_flag=True, help="Include images and floorplan detection")
@click.option(
"--include-processing",
"-p",
is_flag=True,
help="Also download images and run floorplan OCR detection",
)
@click.pass_context
def dump_listings(
ctx: click.Context,
full: bool,
district: list[str],
min_bedrooms: int,
max_bedrooms: int,
min_price: int,
max_price: int,
type: str,
furnish_types: list[str],
available_from: datetime | None,
last_seen_days: int,
min_sqm: int | None = None,
query_parameters: QueryParameters,
include_processing: bool,
) -> None:
"""Fetch listings from Rightmove API."""
data_dir: pathlib.Path = ctx.obj["data_dir"]
repository: ListingRepository = ctx.obj["repository"]
query_parameters = build_query_parameters(
type=type,
district=district,
min_bedrooms=min_bedrooms,
max_bedrooms=max_bedrooms,
min_price=min_price,
max_price=max_price,
furnish_types=furnish_types,
available_from=available_from,
last_seen_days=last_seen_days,
min_sqm=min_sqm,
)
click.echo(f"Fetching listings with parameters: {query_parameters}")
result = asyncio.run(
listing_service.refresh_listings(
repository,
query_parameters,
full=full,
full=include_processing,
async_mode=False,
)
)
@ -240,14 +251,14 @@ def detect_floorplan(ctx: click.Context) -> None:
@click.option(
"--travel-mode",
"-m",
help="Travel mode for routing",
help="Travel mode for routing (e.g. transit, driving, walking, bicycling)",
type=click.Choice(TravelMode.__members__.keys(), case_sensitive=False),
required=True,
)
@click.option(
"--limit",
"-l",
help="Limit the number of listings to process",
help="Maximum number of listings to calculate routes for",
type=click.IntRange(min=1),
default=1,
)
@ -293,33 +304,11 @@ def routing(
def export_csv(
ctx: click.Context,
output_file: str,
district: list[str],
min_bedrooms: int,
max_bedrooms: int,
min_price: int,
max_price: int,
type: str,
furnish_types: list[str],
available_from: datetime | None,
last_seen_days: int,
min_sqm: int | None = None,
query_parameters: QueryParameters,
) -> None:
"""Export listings to CSV file."""
repository: ListingRepository = ctx.obj["repository"]
query_parameters = build_query_parameters(
type=type,
district=district,
min_bedrooms=min_bedrooms,
max_bedrooms=max_bedrooms,
min_price=min_price,
max_price=max_price,
furnish_types=furnish_types,
available_from=available_from,
last_seen_days=last_seen_days,
min_sqm=min_sqm,
)
click.echo(f"Exporting to {output_file}")
result = asyncio.run(
@ -346,33 +335,11 @@ def export_csv(
def export_immoweb(
ctx: click.Context,
output_file: str,
district: list[str],
min_bedrooms: int,
max_bedrooms: int,
min_price: int,
max_price: int,
type: str,
furnish_types: list[str],
available_from: datetime | None,
last_seen_days: int,
min_sqm: int | None = None,
query_parameters: QueryParameters,
) -> None:
"""Export listings to GeoJSON file for map visualization."""
repository: ListingRepository = ctx.obj["repository"]
query_parameters = build_query_parameters(
type=type,
district=district,
min_bedrooms=min_bedrooms,
max_bedrooms=max_bedrooms,
min_price=min_price,
max_price=max_price,
furnish_types=furnish_types,
available_from=available_from,
last_seen_days=last_seen_days,
min_sqm=min_sqm,
)
click.echo(f"Exporting to {output_file}")
result = asyncio.run(

View file

@ -5,7 +5,7 @@ from datetime import datetime, timedelta
import enum
import json
from typing import Any, Dict, List
from pydantic import BaseModel, Field as PydanticField
from pydantic import BaseModel, Field as PydanticField, model_validator
from rec import routing
from sqlmodel import JSON, TEXT, SQLModel, Field
@ -52,6 +52,21 @@ class ListingSite(enum.StrEnum):
# ... add more
def _parse_price_history(price_history_json: str) -> list[PriceHistoryItem]:
"""Parse a JSON string into a list of PriceHistoryItem objects."""
if not price_history_json:
return []
parsed: list = json.loads(str(price_history_json))
return [
PriceHistoryItem(
first_seen=datetime.fromisoformat(item["first_seen"]),
last_seen=datetime.fromisoformat(item["last_seen"]),
price=item["price"],
)
for item in parsed
]
class Listing(SQLModel, table=False):
id: int = Field(primary_key=True)
price: float = Field(nullable=False, index=True)
@ -61,7 +76,6 @@ class Listing(SQLModel, table=False):
council_tax_band: str | None = Field(default=None, nullable=True)
longitude: float = Field(nullable=False)
latitude: float = Field(nullable=False)
# price_history: List[Dict[str, Any]] = Field(default_factory=list, sa_type=JSON)
price_history_json: str = Field(sa_type=TEXT)
listing_site: ListingSite = Field(nullable=False)
last_seen: datetime = Field(
@ -103,20 +117,7 @@ class Listing(SQLModel, table=False):
"""
Returns a list of PriceHistoryItem objects from the price_history_json.
"""
if not self.price_history_json:
return []
parsed: list = json.loads(str(self.price_history_json))
for item in parsed:
item["first_seen"] = datetime.fromisoformat(item["first_seen"])
item["last_seen"] = datetime.fromisoformat(item["last_seen"])
return [
PriceHistoryItem(
first_seen=item["first_seen"],
last_seen=item["last_seen"],
price=item["price"],
)
for item in parsed
]
return _parse_price_history(self.price_history_json)
@staticmethod
def serialize_price_history(price_history: List[PriceHistoryItem]) -> str:
@ -142,36 +143,8 @@ class Listing(SQLModel, table=False):
"""
if not self.routing_info_json:
return {}
# TODO: move to a separate serializer class
json_data = json.loads(self.routing_info_json)
destimation_routes = {}
for destination_mode_str, routes_json in json_data.items():
destination_mode = DestinationMode(
destination_address=json.loads(destination_mode_str)[
"destination_address"
],
travel_mode=routing.TravelMode(
json.loads(destination_mode_str)["travel_mode"]
),
)
parsed_route = json.loads(routes_json[0])
routes = [
Route(
legs=[
RouteLegStep(
distance_meters=step["distance_meters"],
duration_s=step["duration_s"],
travel_mode=routing.TravelMode(step["travel_mode"]),
)
for step in parsed_route["legs"]
],
distance_meters=parsed_route["distance_meters"],
duration_s=int(parsed_route["duration_s"]),
)
]
destimation_routes[destination_mode] = routes
return destimation_routes
from rec.route_serializer import RouteSerializer
return RouteSerializer.deserialize(self.routing_info_json)
def serialize_routing_info(
self, routing_info: dict[DestinationMode, list[Route]]
@ -179,17 +152,8 @@ class Listing(SQLModel, table=False):
"""
Serializes the routing_info to a JSON string.
"""
# TODO: move to a separate serializer class
# for destination_mode, routes in routing_info.items():
serialized = json.dumps(
{
json.dumps(dataclasses.asdict(destination_mode)): [
json.dumps(dataclasses.asdict(route)) for route in routes
]
for destination_mode, routes in routing_info.items()
}
)
return serialized
from rec.route_serializer import RouteSerializer
return RouteSerializer.serialize(routing_info)
class FurnishType(enum.StrEnum):
@ -224,9 +188,9 @@ class DestinationMode:
# This allows serializers to pick up a dict representation
return asdict(self)
def __iter__(self):
# Makes it behave like a dict when expected
return iter(asdict(self).items())
def to_dict(self) -> dict[str, Any]:
"""Return a dictionary representation of this DestinationMode."""
return asdict(self)
class ListingType(enum.StrEnum):
@ -254,3 +218,23 @@ class QueryParameters(BaseModel):
let_date_available_from: datetime | None = None
last_seen_days: int | None = None
min_sqm: int | None = None
@model_validator(mode="after")
def _validate_ranges(self) -> QueryParameters:
if self.min_price > self.max_price:
raise ValueError(
f"min_price ({self.min_price}) must be <= max_price ({self.max_price})"
)
if self.min_bedrooms < 0:
raise ValueError(
f"min_bedrooms ({self.min_bedrooms}) must be non-negative"
)
if self.max_bedrooms < 0:
raise ValueError(
f"max_bedrooms ({self.max_bedrooms}) must be non-negative"
)
if self.min_bedrooms > self.max_bedrooms:
raise ValueError(
f"min_bedrooms ({self.min_bedrooms}) must be <= max_bedrooms ({self.max_bedrooms})"
)
return self

View file

@ -36,3 +36,8 @@ def get_districts() -> dict[str, str]:
"Wandsworth": "REGION^93977",
"Westminster": "REGION^93980",
}
def get_district_by_name(name: str) -> str | None:
"""Return the region ID for a district name, or None if not found."""
return get_districts().get(name)

View file

@ -72,3 +72,14 @@ class CircuitBreakerOpenError(RightmoveAPIError):
"""
pass
class RoutingApiError(Exception):
"""Error from the Google Routes API."""
def __init__(self, status_code: int, response_body: dict):
self.status_code = status_code
self.response_body = response_body
super().__init__(
f"Routes API returned status {status_code}: {response_body}"
)

View file

@ -1,3 +1,4 @@
import logging
import re
from pathlib import Path
from typing import Any
@ -5,6 +6,11 @@ from PIL import Image
import cv2
import numpy as np
logger = logging.getLogger(__name__)
MIN_SQM = 30
MAX_SQM = 160
def inference(image_path: str | Path) -> tuple[str, Any]:
from transformers import Pix2StructProcessor, Pix2StructForConditionalGeneration
@ -22,26 +28,21 @@ def inference(image_path: str | Path) -> tuple[str, Any]:
def extract_total_sqm(input_str: str) -> float | None:
# Note: can be used on the output of inference() to extract sqm from model predictions.
sqmregex = r"(\d+\.?\d*) ?(sq ?m|sq. ?m)"
matches = re.findall(sqmregex, input_str.lower())
sqms = [float(m[0]) for m in matches]
filtered = [sqm for sqm in sqms if 30 < sqm < 160]
filtered = [sqm for sqm in sqms if MIN_SQM < sqm < MAX_SQM]
if len(filtered) == 0:
return None
return max(filtered)
def calculate_model(image_path: str | Path) -> tuple[float | None, str, Any]:
output, predictions_tensor = inference(image_path)
estimated_sqm = extract_total_sqm(output)
return estimated_sqm, output, predictions_tensor
def improve_img_for_ocr(img: Image.Image) -> Image.Image:
img2 = np.array(img.convert("L"))
cv2.resize(img2, None, fx=1.2, fy=1.2, interpolation=cv2.INTER_CUBIC)
grayscale_image = np.array(img.convert("L"))
grayscale_image = cv2.resize(grayscale_image, None, fx=1.2, fy=1.2, interpolation=cv2.INTER_CUBIC)
thresh = cv2.adaptiveThreshold(
img2, 255, cv2.ADAPTIVE_THRESH_GAUSSIAN_C, cv2.THRESH_BINARY, 11, 2
grayscale_image, 255, cv2.ADAPTIVE_THRESH_GAUSSIAN_C, cv2.THRESH_BINARY, 11, 2
)
return Image.fromarray(thresh)
@ -49,15 +50,18 @@ def improve_img_for_ocr(img: Image.Image) -> Image.Image:
def calculate_ocr(image_path: str | Path) -> tuple[float | None, str]:
import pytesseract
img = Image.open(image_path)
path = Path(image_path)
if not path.exists():
raise FileNotFoundError(f"Image not found: {image_path}")
img = Image.open(path)
text = pytesseract.image_to_string(img)
estimated_sqm = extract_total_sqm(text)
if estimated_sqm is None:
improved_img = improve_img_for_ocr(img)
text2 = pytesseract.image_to_string(improved_img)
estimated_sqm2 = extract_total_sqm(text2)
with open("recalculating.log", "a") as f:
f.write(f"before: {estimated_sqm} after: {estimated_sqm2} - {image_path}\n")
logger.debug(f"before: {estimated_sqm} after: {estimated_sqm2} - {image_path}")
return estimated_sqm2, text2
return estimated_sqm, text

View file

@ -28,6 +28,11 @@ logger = logging.getLogger("uvicorn.error")
# Global circuit breaker instance
_circuit_breaker: CircuitBreaker | None = None
# API constants
ANDROID_APP_VERSION = "3.70.0"
ANDROID_APP_VERSION_LISTING = "4.28.0"
RIGHTMOVE_API_BASE = "https://api.rightmove.co.uk/api"
PROPERTY_LISTING_ENDPOINT = f"{RIGHTMOVE_API_BASE}/property-listing"
DEFAULT_HEADERS = {
"Host": "api.rightmove.co.uk",
@ -35,6 +40,11 @@ DEFAULT_HEADERS = {
"Connection": "keep-alive",
}
LISTING_HEADERS = {
**DEFAULT_HEADERS,
"Accept-Encoding": "gzip, deflate, br",
}
class PropertyType(enum.StrEnum):
BUNGALOW = "bungalow"
@ -129,6 +139,177 @@ def check_circuit_breaker(config: ScraperConfig | None = None) -> None:
cb.call()
def _build_base_params(
*,
channel: ListingType,
page: int,
page_size: int,
radius: float,
min_price: int,
max_price: int,
min_bedrooms: int,
max_bedrooms: int,
district: str,
) -> dict[str, str]:
return {
"locationIdentifier": districts.get_districts()[district],
"channel": str(channel).upper(),
"page": str(page),
"numberOfPropertiesPerPage": str(page_size),
"radius": str(radius),
"sortBy": "distance",
"includeUnavailableProperties": "false",
"minPrice": str(min_price),
"maxPrice": str(max_price),
"minBedrooms": str(min_bedrooms),
"maxBedrooms": str(max_bedrooms),
"apiApplication": "ANDROID",
"appVersion": ANDROID_APP_VERSION_LISTING,
}
def _build_listing_params(
*,
page: int,
channel: ListingType,
min_bedrooms: int,
max_bedrooms: int,
radius: float,
min_price: int,
max_price: int,
district: str,
mustNewHome: bool,
max_days_since_added: int,
property_type: list[PropertyType],
page_size: int,
furnish_types: list[FurnishType],
) -> dict[str, str]:
params = _build_base_params(
channel=channel,
page=page,
page_size=page_size,
radius=radius,
min_price=min_price,
max_price=max_price,
min_bedrooms=min_bedrooms,
max_bedrooms=max_bedrooms,
district=district,
)
if channel is ListingType.BUY:
params["dontShow"] = "sharedOwnership,retirement"
if len(property_type) > 0:
params["propertyTypes"] = ",".join(property_type)
if max_days_since_added is not None and max_days_since_added not in [
1,
3,
7,
14,
]:
raise Exception(
f"Invalid max days - {max_days_since_added} Can only be got",
[1, 3, 7, 14],
)
params["maxDaysSinceAdded"] = str(max_days_since_added)
if mustNewHome:
params["mustHave"] = "newHome"
if channel is ListingType.RENT:
if furnish_types:
params["furnishTypes"] = ",".join(furnish_types)
return params
def _build_probe_params(
*,
channel: ListingType,
min_bedrooms: int,
max_bedrooms: int,
radius: float,
min_price: int,
max_price: int,
district: str,
max_days_since_added: int,
furnish_types: list[FurnishType],
) -> dict[str, str]:
params = _build_base_params(
channel=channel,
page=1,
page_size=1, # Minimal page size for probing
radius=radius,
min_price=min_price,
max_price=max_price,
min_bedrooms=min_bedrooms,
max_bedrooms=max_bedrooms,
district=district,
)
if channel is ListingType.BUY:
params["dontShow"] = "sharedOwnership,retirement"
if max_days_since_added is not None and max_days_since_added in [
1,
3,
7,
14,
]:
params["maxDaysSinceAdded"] = str(max_days_since_added)
if channel is ListingType.RENT:
if furnish_types:
params["furnishTypes"] = ",".join(furnish_types)
return params
async def _execute_api_request(
*,
url: str,
params: dict[str, str],
headers: dict[str, str],
session: aiohttp.ClientSession | None,
config: ScraperConfig,
expect_data: bool = True,
error_context: str = "",
) -> dict[str, Any]:
check_circuit_breaker(config)
cb = get_circuit_breaker(config)
async def do_request(s: aiohttp.ClientSession) -> dict[str, Any]:
start_time = time.time()
try:
async with s.get(url, params=params, headers=headers) as response:
response_time = time.time() - start_time
body = await response.json() if response.status == 200 else None
validate_response(
response,
response_time,
body,
config.slow_response_threshold,
expect_data=expect_data,
)
if response.status != 200:
raise Exception(
f"{error_context}Failed due to: {await response.text()}"
)
if cb is not None:
cb.record_success()
return body # type: ignore
except ThrottlingError:
if cb is not None:
cb.record_failure()
raise
except Exception as e:
if cb is not None:
cb.record_failure()
raise e
if session:
return await do_request(session)
else:
async with aiohttp.ClientSession(trust_env=True) as new_session:
return await do_request(new_session)
@retry(
retry=retry_if_exception_type(ThrottlingError),
wait=wait_exponential(multiplier=2, min=2, max=120),
@ -156,54 +337,21 @@ async def detail_query(
if config is None:
config = ScraperConfig.from_env()
check_circuit_breaker(config)
cb = get_circuit_breaker(config)
params = {
"apiApplication": "ANDROID",
"appVersion": "3.70.0",
"appVersion": ANDROID_APP_VERSION,
}
url = f"https://api.rightmove.co.uk/api/property/{detail_id}"
url = f"{RIGHTMOVE_API_BASE}/property/{detail_id}"
async def do_request(s: aiohttp.ClientSession) -> dict[str, Any]:
start_time = time.time()
try:
async with s.get(url, params=params, headers=DEFAULT_HEADERS) as response:
response_time = time.time() - start_time
body = await response.json() if response.status == 200 else None
# Validate response for throttling
validate_response(
response,
response_time,
body,
config.slow_response_threshold,
expect_data=True,
)
if response.status != 200:
raise Exception(
f"""id: {detail_id}. Status Code: {response.status}."""
f"""Failed due to: {await response.text()}"""
)
if cb is not None:
cb.record_success()
return body # type: ignore
except ThrottlingError:
if cb is not None:
cb.record_failure()
raise
except Exception as e:
if cb is not None:
cb.record_failure()
raise e
if session:
return await do_request(session)
else:
async with aiohttp.ClientSession(trust_env=True) as new_session:
return await do_request(new_session)
return await _execute_api_request(
url=url,
params=params,
headers=DEFAULT_HEADERS,
session=session,
config=config,
expect_data=True,
error_context=f"id: {detail_id}. Status Code: ",
)
@retry(
@ -223,9 +371,9 @@ async def listing_query(
district: str, # = "STATION^5168", # kings cross station
mustNewHome: bool = False,
max_days_since_added: int = 30,
property_type: list[PropertyType] = [],
property_type: list[PropertyType] | None = None,
page_size: int = 25,
furnish_types: list[FurnishType] = [],
furnish_types: list[FurnishType] | None = None,
session: aiohttp.ClientSession | None = None,
config: ScraperConfig | None = None,
) -> dict[str, Any]:
@ -257,94 +405,35 @@ async def listing_query(
"""
if config is None:
config = ScraperConfig.from_env()
if property_type is None:
property_type = []
if furnish_types is None:
furnish_types = []
check_circuit_breaker(config)
cb = get_circuit_breaker(config)
params = _build_listing_params(
page=page,
channel=channel,
min_bedrooms=min_bedrooms,
max_bedrooms=max_bedrooms,
radius=radius,
min_price=min_price,
max_price=max_price,
district=district,
mustNewHome=mustNewHome,
max_days_since_added=max_days_since_added,
property_type=property_type,
page_size=page_size,
furnish_types=furnish_types,
)
params: dict[str, str] = {
"locationIdentifier": districts.get_districts()[district],
"channel": str(channel).upper(),
"page": str(page),
"numberOfPropertiesPerPage": str(page_size),
"radius": str(radius),
"sortBy": "distance",
"includeUnavailableProperties": "false",
"minPrice": str(min_price),
"maxPrice": str(max_price),
"minBedrooms": str(min_bedrooms),
"maxBedrooms": str(max_bedrooms),
"apiApplication": "ANDROID",
"appVersion": "4.28.0",
}
if channel is ListingType.BUY:
params["dontShow"] = "sharedOwnership,retirement"
if len(property_type) > 0:
params["propertyTypes"] = ",".join(property_type)
if max_days_since_added is not None and max_days_since_added not in [
1,
3,
7,
14,
]:
raise Exception(
f"Invalid max days - {max_days_since_added} Can only be got",
[1, 3, 7, 14],
)
params["maxDaysSinceAdded"] = str(max_days_since_added)
if mustNewHome:
params["mustHave"] = "newHome"
if channel is ListingType.RENT:
if furnish_types:
params["furnishTypes"] = ",".join(furnish_types)
request_headers = {
"Host": "api.rightmove.co.uk",
"Accept-Encoding": "gzip, deflate, br",
"User-Agent": "okhttp/4.12.0",
"Connection": "keep-alive",
}
async def do_request(s: aiohttp.ClientSession) -> dict[str, Any]:
start_time = time.time()
try:
async with s.get(
"https://api.rightmove.co.uk/api/property-listing",
params=params,
headers=request_headers,
) as response:
response_time = time.time() - start_time
body = await response.json() if response.status == 200 else None
# Validate response for throttling
validate_response(
response,
response_time,
body,
config.slow_response_threshold,
expect_data=(page == 1), # Only expect data on first page
)
if response.status != 200:
raise Exception(f"Failed due to: {await response.text()}")
if cb is not None:
cb.record_success()
return body # type: ignore
except ThrottlingError:
if cb is not None:
cb.record_failure()
raise
except Exception as e:
if cb is not None:
cb.record_failure()
raise e
if session:
return await do_request(session)
else:
async with aiohttp.ClientSession(trust_env=True) as new_session:
return await do_request(new_session)
return await _execute_api_request(
url=PROPERTY_LISTING_ENDPOINT,
params=params,
headers=LISTING_HEADERS,
session=session,
config=config,
expect_data=(page == 1),
)
@retry(
@ -363,7 +452,7 @@ async def probe_query(
max_price: int,
district: str,
max_days_since_added: int = 30,
furnish_types: list[FurnishType] = [],
furnish_types: list[FurnishType] | None = None,
config: ScraperConfig | None = None,
) -> dict[str, Any]:
"""Probe the API to get result count without fetching full results.
@ -392,77 +481,27 @@ async def probe_query(
"""
if config is None:
config = ScraperConfig.from_env()
if furnish_types is None:
furnish_types = []
check_circuit_breaker(config)
cb = get_circuit_breaker(config)
params = _build_probe_params(
channel=channel,
min_bedrooms=min_bedrooms,
max_bedrooms=max_bedrooms,
radius=radius,
min_price=min_price,
max_price=max_price,
district=district,
max_days_since_added=max_days_since_added,
furnish_types=furnish_types,
)
params: dict[str, str] = {
"locationIdentifier": districts.get_districts()[district],
"channel": str(channel).upper(),
"page": "1",
"numberOfPropertiesPerPage": "1", # Minimal page size for probing
"radius": str(radius),
"sortBy": "distance",
"includeUnavailableProperties": "false",
"minPrice": str(min_price),
"maxPrice": str(max_price),
"minBedrooms": str(min_bedrooms),
"maxBedrooms": str(max_bedrooms),
"apiApplication": "ANDROID",
"appVersion": "4.28.0",
}
if channel is ListingType.BUY:
params["dontShow"] = "sharedOwnership,retirement"
if max_days_since_added is not None and max_days_since_added in [
1,
3,
7,
14,
]:
params["maxDaysSinceAdded"] = str(max_days_since_added)
if channel is ListingType.RENT:
if furnish_types:
params["furnishTypes"] = ",".join(furnish_types)
request_headers = {
"Host": "api.rightmove.co.uk",
"Accept-Encoding": "gzip, deflate, br",
"User-Agent": "okhttp/4.12.0",
"Connection": "keep-alive",
}
start_time = time.time()
try:
async with session.get(
"https://api.rightmove.co.uk/api/property-listing",
params=params,
headers=request_headers,
) as response:
response_time = time.time() - start_time
body = await response.json() if response.status == 200 else None
# Validate response for throttling
validate_response(
response,
response_time,
body,
config.slow_response_threshold,
expect_data=False, # Probe doesn't need data, just count
)
if response.status != 200:
raise Exception(f"Probe failed: {await response.text()}")
if cb is not None:
cb.record_success()
return body # type: ignore
except ThrottlingError:
if cb is not None:
cb.record_failure()
raise
except Exception as e:
if cb is not None:
cb.record_failure()
raise e
return await _execute_api_request(
url=PROPERTY_LISTING_ENDPOINT,
params=params,
headers=LISTING_HEADERS,
session=session,
config=config,
expect_data=False,
error_context="Probe failed: ",
)

View file

@ -1,3 +1,4 @@
import dataclasses
import json
from typing import List
@ -7,20 +8,25 @@ from rec import routing
class RouteSerializer:
@staticmethod
def serialize(route): ...
def serialize(routing_info: dict[DestinationMode, list[Route]]) -> str:
return json.dumps(
{
json.dumps(dataclasses.asdict(destination_mode)): [
json.dumps(dataclasses.asdict(route)) for route in routes
]
for destination_mode, routes in routing_info.items()
}
)
@staticmethod
def deserialize(route_data_json: str) -> dict[DestinationMode, List[Route]]:
json_data = json.loads(route_data_json)
destimation_routes = {}
destination_routes = {}
for destination_mode_str, routes_json in json_data.items():
parsed_destination = json.loads(destination_mode_str)
destination_mode = DestinationMode(
destination_address=json.loads(destination_mode_str)[
"destination_address"
],
travel_mode=routing.TravelMode(
json.loads(destination_mode_str)["travel_mode"]
),
destination_address=parsed_destination["destination_address"],
travel_mode=routing.TravelMode(parsed_destination["travel_mode"]),
)
parsed_route = json.loads(routes_json[0])
routes = [
@ -37,5 +43,5 @@ class RouteSerializer:
duration_s=int(parsed_route["duration_s"]),
)
]
destimation_routes[destination_mode] = routes
return destimation_routes
destination_routes[destination_mode] = routes
return destination_routes

View file

@ -3,9 +3,18 @@ import os
from typing import Any
import requests
from rec.utils import nextMonday
from rec.exceptions import RoutingApiError
url = "https://routes.googleapis.com/directions/v2:computeRoutes"
ROUTES_API_URL = "https://routes.googleapis.com/directions/v2:computeRoutes"
API_KEY_ENVIRONMENT_VARIABLE = "ROUTING_API_KEY"
ROUTES_FIELD_MASK = (
"routes.distanceMeters,"
"routes.duration,"
"routes.staticDuration,"
"routes.legs.steps.distanceMeters,"
"routes.legs.steps.staticDuration,"
"routes.legs.steps.travelMode"
)
class TravelMode(enum.StrEnum):
@ -20,7 +29,7 @@ def transit_route(
origin_lon: float,
dest_address: str,
travel_mode: TravelMode,
compute_alternative_routes=True,
compute_alternative_routes: bool = True,
) -> dict[str, Any]:
monday9am = nextMonday()
@ -30,38 +39,25 @@ def transit_route(
header = {
"X-Goog-Api-Key": api_key,
"Content-Type": "application/json",
"X-Goog-FieldMask": "routes.distanceMeters,routes.duration,routes.staticDuration,routes.legs.steps.distanceMeters,routes.legs.steps.staticDuration,routes.legs.steps.travelMode", # "routes.*",
"X-Goog-FieldMask": ROUTES_FIELD_MASK,
}
body = {
"origin": {
# "address": origin_address
"location": {"latLng": {"latitude": origin_lat, "longitude": origin_lon}}
},
"destination": {
"address": dest_address
# "location": {
# "latLng": {
# "latitude": dest_lat,
# "longitude": dest_lon
# }
# }
},
"travelMode": travel_mode.value,
# "2023-10-15T15:01:23.045123456Z"
"departureTime": monday9am.strftime("%Y-%m-%dT%H:%M:%S.%fZ"),
"computeAlternativeRoutes": compute_alternative_routes,
# "routeModifiers": {
# "avoidTolls": false,
# "avoidHighways": false,
# "avoidFerries": false
# },
"languageCode": "en-US",
"units": "METRIC",
}
r = requests.post(url, json=body, headers=header)
r = requests.post(ROUTES_API_URL, json=body, headers=header)
if r.status_code == 200:
return r.json()
raise Exception(r.json())
raise RoutingApiError(r.status_code, r.json())

View file

@ -1,6 +1,6 @@
from datetime import datetime, timedelta
import logging
from typing import Generator
from typing import Any, Generator
from data_access import Listing
from models.listing import (
BuyListing,
@ -12,7 +12,6 @@ from models.listing import (
)
from sqlalchemy import Engine, func, select as sa_select
from sqlmodel import Session, select
from sqlmodel.sql.expression import SelectOfScalar
from tqdm import tqdm
logger = logging.getLogger("uvicorn.error")
@ -27,8 +26,10 @@ STREAMING_COLUMNS = [
class ListingRepository:
engine: Engine
# anything more than 10k is considered buy type
buy_listing_price_threshold: int = 20_000
# Monthly rent prices in the UK are always below 20,000 GBP.
# Any listing priced at or above this threshold is treated as a purchase (buy) listing.
BUY_LISTING_PRICE_THRESHOLD: int = 20_000
def __init__(self, engine: Engine):
self.engine = engine
@ -44,24 +45,16 @@ class ListingRepository:
"""
only_ids = only_ids or []
model = RentListing # if no query params, default to renting listings
if query_parameters:
model = (
RentListing
if query_parameters.listing_type == ListingType.RENT
else BuyListing
# else RentListing
)
model = self._get_model_for_query(query_parameters)
query = select(model)
if only_ids:
query = query.where(model.id.in_(only_ids)) # type: ignore
query = self._add_where_from_query_parameters(query, model, query_parameters)
query = self._apply_query_filters(query, model, query_parameters)
if limit:
query = query.limit(limit)
with Session(self.engine) as session:
# query = select(modelListing)
rows = list(session.exec(query).all())
logging.debug(f"Found {len(rows)} listings")
return rows
@ -81,16 +74,10 @@ class ListingRepository:
limit: Maximum number of listings to yield
chunk_size: Number of rows to fetch at a time from the database
"""
model = RentListing # if no query params, default to renting listings
if query_parameters:
model = (
RentListing
if query_parameters.listing_type == ListingType.RENT
else BuyListing
)
model = self._get_model_for_query(query_parameters)
query = select(model)
query = self._add_where_from_query_parameters(query, model, query_parameters)
query = self._apply_query_filters(query, model, query_parameters)
if limit:
query = query.limit(limit)
@ -111,7 +98,7 @@ class ListingRepository:
model = self._get_model_for_query(query_parameters)
query = sa_select(func.count(model.id))
query = self._add_where_from_query_parameters_raw(query, model, query_parameters)
query = self._apply_query_filters(query, model, query_parameters)
with Session(self.engine) as session:
return session.execute(query).scalar() or 0
@ -147,7 +134,7 @@ class ListingRepository:
break
query = sa_select(*columns)
query = self._add_where_from_query_parameters_raw(
query = self._apply_query_filters(
query, model, query_parameters
)
@ -174,13 +161,25 @@ class ListingRepository:
if len(results) < page_size:
break
def _add_where_from_query_parameters_raw(
def _apply_query_filters(
self,
query,
query: Any,
model: type[RentListing] | type[BuyListing],
query_parameters: QueryParameters | None = None,
):
"""Add WHERE clauses from query parameters (for raw SQLAlchemy selects)."""
) -> Any:
"""Apply WHERE clauses from query parameters to a query.
Works with both SQLModel select() and raw SQLAlchemy sa_select() queries,
since both support the .where() interface.
Args:
query: A SQLModel or SQLAlchemy select query
model: The listing model class (RentListing or BuyListing)
query_parameters: Optional filtering parameters
Returns:
The query with WHERE clauses applied
"""
if query_parameters is None:
return query
query = query.where(
@ -207,38 +206,6 @@ class ListingRepository:
query = query.where(model.last_seen >= last_seen_threshold)
return query
def _add_where_from_query_parameters(
self,
query: SelectOfScalar[Listing],
model: type[Listing],
query_parameters: QueryParameters | None = None,
) -> SelectOfScalar[Listing]:
if query_parameters is None:
return query
query = query.where(
model.number_of_bedrooms.between(
query_parameters.min_bedrooms, query_parameters.max_bedrooms
),
model.price.between(query_parameters.min_price, query_parameters.max_price),
)
if query_parameters.min_sqm is not None:
query = query.where(model.square_meters >= query_parameters.min_sqm)
if query_parameters.furnish_types and model == RentListing:
query = query.where(model.furnish_type.in_(query_parameters.furnish_types))
if (
isinstance(model, RentListing)
and query_parameters.let_date_available_from is not None
):
query = query.where(
model.available_from >= query_parameters.let_date_available_from
)
if query_parameters.last_seen_days is not None:
last_seen_threshold = datetime.now() - timedelta(
days=query_parameters.last_seen_days
)
query = query.where(model.last_seen >= last_seen_threshold)
return query
async def upsert_listings(
self,
listings: list[modelListing],
@ -258,50 +225,74 @@ class ListingRepository:
self,
listings: list[Listing],
) -> list[modelListing]:
"""
Upsert listings into the database.
"""Upsert legacy Listing objects into the database.
.. deprecated::
This method converts legacy data_access.Listing objects to SQLModel
entities. Use upsert_listings() with RentListing/BuyListing directly.
Legacy Listing objects from filesystem-based data may contain malformed
or incomplete data, so conversion errors are logged and skipped rather
than aborting the entire batch.
"""
models = []
failed_to_upsert = []
with Session(self.engine) as session:
for listing in tqdm(listings, desc="Upserting listings"):
# Convert Listing to modelListing
# Convert legacy Listing to the appropriate SQLModel entity
try:
model_listing = await self._get_concrete_listing(listing)
except Exception as e: # WHY SO MANY ERORRS??
# If for whatever reason we cannot add listing, ignore and retry
print(f"Error converting listing {listing.identifier}: {e}")
except Exception as e:
# Legacy Listing -> model conversion may fail for malformed data
# (e.g. missing required fields, invalid types). Log and skip.
logger.error(f"Error converting listing {listing.identifier}: {e}")
failed_to_upsert.append(listing)
continue
session.merge(model_listing)
models.append(model_listing)
session.commit()
print(f"Failed to upsert {len(failed_to_upsert)} listings.")
if failed_to_upsert:
logger.warning(f"Failed to upsert {len(failed_to_upsert)} listings.")
return models
@staticmethod
def _parse_furnish_type(listing: Listing) -> FurnishType:
"""Extract and normalize the furnish type from a legacy Listing's detail object.
Handles missing/null detailobject, missing property key, missing or null
letFurnishType value, and normalizes "landlord" variants to ASK_LANDLORD.
Args:
listing: A legacy data_access.Listing object
Returns:
The parsed FurnishType enum value
"""
if (
listing.detailobject is None
or listing.detailobject.get("property") is None
or listing.detailobject["property"].get("letFurnishType") is None
):
return FurnishType.UNKNOWN
furnish_type_str = listing.detailobject["property"]["letFurnishType"]
if furnish_type_str is None:
return FurnishType.UNKNOWN
elif "landlord" in furnish_type_str.lower():
furnish_type_str = "ask landlord"
else:
furnish_type_str = furnish_type_str.lower()
return FurnishType(furnish_type_str)
async def _get_concrete_listing(
self,
listing: Listing,
) -> modelListing:
now = datetime.now()
furnish_type = self._parse_furnish_type(listing)
if (
listing.detailobject is None
or listing.detailobject.get("property") is None
or listing.detailobject["property"].get("letFurnishType") is None
):
furnish_type_str = "unknown"
else:
furnish_type_str = listing.detailobject["property"]["letFurnishType"]
if furnish_type_str is None:
furnish_type_str = "unknown"
elif "landlord" in furnish_type_str.lower():
furnish_type_str = "ask landlord"
else:
furnish_type_str = furnish_type_str.lower()
furnish_type = FurnishType(furnish_type_str)
if listing.price < self.buy_listing_price_threshold:
if listing.price < self.BUY_LISTING_PRICE_THRESHOLD:
model_listing = RentListing(
id=listing.identifier,
price=listing.price,

View file

@ -24,15 +24,14 @@ def get_district_names() -> list[str]:
return list(_get_districts().keys())
def validate_districts(district_names: list[str]) -> tuple[bool, list[str]]:
def validate_districts(district_names: list[str]) -> list[str]:
"""Validate that district names exist.
Args:
district_names: List of district names to validate
Returns:
Tuple of (all_valid, invalid_names)
List of invalid district names (empty if all valid)
"""
valid_districts = set(_get_districts().keys())
invalid = [d for d in district_names if d not in valid_districts]
return len(invalid) == 0, invalid
return [d for d in district_names if d not in valid_districts]

View file

@ -6,12 +6,14 @@ from repositories.listing_repository import ListingRepository
from tqdm.asyncio import tqdm
import multiprocessing
# Use a quarter of available CPUs to avoid starving other processes
MAX_OCR_WORKERS = max(1, multiprocessing.cpu_count() // 4)
async def detect_floorplan(repository: ListingRepository) -> None:
"""Detect square meters from floorplan images for all listings."""
listings = await repository.get_listings()
cpu_count = multiprocessing.cpu_count() // 4
semaphore = asyncio.Semaphore(cpu_count)
semaphore = asyncio.Semaphore(MAX_OCR_WORKERS)
updated_listings = [
listing
@ -29,6 +31,9 @@ async def _calculate_sqm_ocr(
"""Calculate square meters from floorplan images using OCR."""
if listing.square_meters is not None:
return None
if not listing.floorplan_image_paths:
listing.square_meters = 0
return listing
sqms: list[float] = []
for floorplan_path in listing.floorplan_image_paths:
async with semaphore:

View file

@ -1,6 +1,9 @@
"""Image fetcher service - downloads floorplan images for listings."""
import asyncio
import logging
from pathlib import Path
from urllib.parse import urlparse
import aiohttp
from repositories import ListingRepository
from tenacity import retry, stop_after_attempt, wait_random
@ -8,8 +11,12 @@ from tqdm.asyncio import tqdm
from models import Listing
# Setting this too high either crashes rightmove or gets us blocked
semaphore = asyncio.Semaphore(5)
logger = logging.getLogger(__name__)
# Maximum number of concurrent image downloads.
# Setting this too high either crashes Rightmove or gets us blocked.
MAX_CONCURRENT_DOWNLOADS = 5
semaphore = asyncio.Semaphore(MAX_CONCURRENT_DOWNLOADS)
async def dump_images(
@ -18,38 +25,64 @@ async def dump_images(
) -> None:
"""Download floorplan images for all listings."""
listings = await repository.get_listings()
updated_listings = await tqdm.gather(
*[dump_images_for_listing(listing, image_base_path) for listing in listings]
)
async with aiohttp.ClientSession() as session:
updated_listings = await tqdm.gather(
*[
dump_images_for_listing(listing, image_base_path, session=session)
for listing in listings
]
)
await repository.upsert_listings(
[listing for listing in updated_listings if listing is not None]
)
@retry(wait=wait_random(min=1, max=2), stop=stop_after_attempt(3))
async def dump_images_for_listing(listing: Listing, base_path: Path) -> Listing | None:
async def dump_images_for_listing(
listing: Listing,
base_path: Path,
session: aiohttp.ClientSession | None = None,
) -> Listing | None:
"""Download floorplan images for a single listing."""
all_floorplans = listing.additional_info.get("property", {}).get("floorplans", [])
for floorplan in all_floorplans:
url = floorplan["url"]
picname = url.split("/")[-1]
picname = Path(urlparse(url).path).name
floorplan_path = Path(base_path, str(listing.id), "floorplans", picname)
if floorplan_path.exists():
continue
try:
async with semaphore:
async with aiohttp.ClientSession() as session:
async with session.get(url) as response:
owns_session = session is None
active_session = session or aiohttp.ClientSession()
try:
async with semaphore:
async with active_session.get(url) as response:
if response.status == 404:
logger.warning(
"Listing %s: floorplan not found (404) at %s",
listing.id,
url,
)
return None
if response.status != 200:
raise Exception(f"Error for {url}: {response.status}")
raise Exception(
f"Error downloading floorplan for listing {listing.id} "
f"from {url}: HTTP {response.status}"
)
floorplan_path.parent.mkdir(parents=True, exist_ok=True)
with open(floorplan_path, "wb") as f:
f.write(await response.read())
listing.floorplan_image_paths.append(str(floorplan_path))
return listing
finally:
if owns_session:
await active_session.close()
except Exception as e:
tqdm.write(f"Error for {url}: {e}")
raise e # raise so that we retry it
logger.error(
"Listing %s: error downloading floorplan from %s: %s",
listing.id,
url,
e,
)
raise
return None

View file

@ -4,12 +4,13 @@ import json
import logging
import os
from typing import Generator
from urllib.parse import urlparse, urlunparse
import redis
from models.listing import QueryParameters
logger = logging.getLogger("uvicorn.error")
logger = logging.getLogger(__name__)
CACHE_PREFIX = "listings:geojson:"
CACHE_TTL_SECONDS = 30 * 60 # 30 minutes
@ -19,9 +20,9 @@ CACHE_DB = 2
def _get_redis_client() -> redis.Redis:
"""Get Redis client using Celery broker URL but overriding to db=2."""
broker_url = os.getenv("CELERY_BROKER_URL", "redis://localhost:6379/0")
# Replace the db number in the URL
base_url = broker_url.rsplit("/", 1)[0]
return redis.from_url(f"{base_url}/{CACHE_DB}", decode_responses=True)
parsed = urlparse(broker_url)
cache_url = urlunparse(parsed._replace(path=f"/{CACHE_DB}"))
return redis.from_url(cache_url, decode_responses=True)
def make_cache_key(query_params: QueryParameters) -> str:
@ -89,7 +90,10 @@ def invalidate_cache() -> None:
while True:
cursor, keys = client.scan(cursor, match=f"{CACHE_PREFIX}*", count=100)
if keys:
client.delete(*keys)
pipeline = client.pipeline()
for key in keys:
pipeline.delete(key)
pipeline.execute()
deleted += len(keys)
if cursor == 0:
break

View file

@ -13,6 +13,8 @@ from services.query_splitter import QuerySplitter, SubQuery
logger = logging.getLogger("uvicorn.error")
# Number of concurrent workers that process listing details (fetch details,
# download images, run OCR) from the streaming queue in parallel.
NUM_WORKERS = 20
@ -23,10 +25,104 @@ async def dump_listings_full(
"""Fetches all listings, images as well as detects floorplans."""
new_listings = await dump_listings(parameters, repository)
logger.debug(f"Upserted {len(new_listings)} new listings")
# refresh listings
listings = await repository.get_listings(parameters) # this can be better
new_listings = [x for x in listings if x.id in new_listings]
return new_listings
new_listing_ids = [listing.id for listing in new_listings]
return await repository.get_listings(only_ids=new_listing_ids)
async def _fetch_subquery(
sq: SubQuery,
parameters: QueryParameters,
session: object,
config: ScraperConfig,
semaphore: asyncio.Semaphore,
existing_ids: set[int],
queue: asyncio.Queue[int | None],
) -> int:
"""Fetch listing IDs for a single subquery and enqueue new ones.
Iterates through pages of results for the given subquery, adding any
newly discovered listing IDs to the processing queue.
Args:
sq: The subquery to fetch results for.
parameters: The original query parameters (for page_size, etc.).
session: The aiohttp session for making requests.
config: Scraper configuration.
semaphore: Concurrency limiter for HTTP requests.
existing_ids: Set of already-known listing IDs (mutated in place).
queue: Queue to push new listing IDs onto for processing.
Returns:
The number of new IDs discovered and enqueued.
"""
estimated = sq.estimated_results or 0
if estimated == 0:
return 0
ids_found = 0
page_size = parameters.page_size
max_pages = min(
config.max_pages_per_query,
(estimated // page_size) + 1,
)
for page_id in range(1, max_pages + 1):
async with semaphore:
await asyncio.sleep(config.request_delay_ms / 1000)
try:
result = await listing_query(
page=page_id,
channel=parameters.listing_type,
min_bedrooms=sq.min_bedrooms,
max_bedrooms=sq.max_bedrooms,
radius=parameters.radius,
min_price=sq.min_price,
max_price=sq.max_price,
district=sq.district,
page_size=page_size,
max_days_since_added=parameters.max_days_since_added,
furnish_types=parameters.furnish_types or [],
session=session,
config=config,
)
# Extract and enqueue new IDs inline
properties = result.get("properties", [])
for prop in properties:
identifier = prop.get("identifier")
if identifier and identifier not in existing_ids:
existing_ids.add(identifier)
ids_found += 1
await queue.put(identifier)
if len(properties) < page_size:
break
except CircuitBreakerOpenError as e:
logger.error(f"Circuit breaker open: {e}")
break
except ThrottlingError as e:
logger.warning(
f"Throttling error on page {page_id} for "
f"{sq.district}: {e}"
)
break
except Exception as e:
# Rightmove returns GENERIC_ERROR when requesting pages
# past the last page of results. This is expected behavior
# and signals we've exhausted this subquery's results.
if "GENERIC_ERROR" in str(e):
logger.debug(
f"Max page for {sq.district}: {page_id - 1}"
)
break
logger.warning(
f"Error fetching page {page_id} for "
f"{sq.district}: {e}"
)
break
return ids_found
async def dump_listings(
@ -63,82 +159,23 @@ async def dump_listings(
# Phase 2: Streaming fetch & process
queue: asyncio.Queue[int | None] = asyncio.Queue()
semaphore = asyncio.Semaphore(config.max_concurrent_requests)
ids_collected = 0
processed_listings: list[Listing] = []
async def fetch_subquery(sq: SubQuery) -> None:
nonlocal ids_collected
estimated = sq.estimated_results or 0
if estimated == 0:
return
page_size = parameters.page_size
max_pages = min(
config.max_pages_per_query,
(estimated // page_size) + 1,
)
for page_id in range(1, max_pages + 1):
async with semaphore:
await asyncio.sleep(config.request_delay_ms / 1000)
try:
result = await listing_query(
page=page_id,
channel=parameters.listing_type,
min_bedrooms=sq.min_bedrooms,
max_bedrooms=sq.max_bedrooms,
radius=parameters.radius,
min_price=sq.min_price,
max_price=sq.max_price,
district=sq.district,
page_size=page_size,
max_days_since_added=parameters.max_days_since_added,
furnish_types=parameters.furnish_types or [],
session=session,
config=config,
)
# Extract and enqueue new IDs inline
properties = result.get("properties", [])
for prop in properties:
identifier = prop.get("identifier")
if identifier and identifier not in existing_ids:
existing_ids.add(identifier)
ids_collected += 1
await queue.put(identifier)
if len(properties) < page_size:
break
except CircuitBreakerOpenError as e:
logger.error(f"Circuit breaker open: {e}")
break
except ThrottlingError as e:
logger.warning(
f"Throttling error on page {page_id} for "
f"{sq.district}: {e}"
)
break
except Exception as e:
if "GENERIC_ERROR" in str(e):
logger.debug(
f"Max page for {sq.district}: {page_id - 1}"
)
break
logger.warning(
f"Error fetching page {page_id} for "
f"{sq.district}: {e}"
)
break
async def producer() -> None:
await asyncio.gather(
*[fetch_subquery(sq) for sq in subqueries]
)
async def producer() -> int:
"""Fetch all subqueries and send sentinel values to workers."""
tasks = [
_fetch_subquery(
sq, parameters, session, config,
semaphore, existing_ids, queue,
)
for sq in subqueries
]
counts = await asyncio.gather(*tasks)
ids_collected = sum(counts)
logger.info(f"Fetch complete: {ids_collected} new IDs found")
for _ in range(NUM_WORKERS):
await queue.put(None)
return ids_collected
async def worker() -> None:
while True:
@ -150,10 +187,11 @@ async def dump_listings(
if listing is not None:
processed_listings.append(listing)
await asyncio.gather(
results = await asyncio.gather(
producer(),
*[worker() for _ in range(NUM_WORKERS)],
)
ids_collected = results[0]
except CircuitBreakerOpenError as e:
logger.error(f"Circuit breaker prevented listing fetch: {e}")

View file

@ -6,6 +6,11 @@ from rec import routing
from models import Listing
def _parse_duration(duration_str: str) -> int:
"""Parse a duration string like '123s' to integer seconds."""
return int(duration_str.rstrip("s"))
async def calculate_route(
repository: ListingRepository,
destination_address: str,
@ -18,9 +23,9 @@ async def calculate_route(
if limit is not None:
listings = listings[:limit]
destimation_mode = DestinationMode(destination_address, travel_mode)
destination_mode = DestinationMode(destination_address, travel_mode)
updated_listings = await tqdm.gather(
*[update_routing_info(listing, destimation_mode) for listing in listings],
*[update_routing_info(listing, destination_mode) for listing in listings],
total=len(listings),
desc="Updating routing info",
)
@ -46,12 +51,12 @@ async def update_routing_info(
routes: list[Route] = []
for route_data in routes_data["routes"]:
duration_s = int(route_data["duration"].split("s")[0])
duration_s = _parse_duration(route_data["duration"])
route = Route(
legs=[
RouteLegStep(
distance_meters=step_data["distanceMeters"],
duration_s=int(step_data["staticDuration"].split("s")[0]),
duration_s=_parse_duration(step_data["staticDuration"]),
travel_mode=routing.TravelMode(step_data["travelMode"]),
)
for step_data in route_data["legs"][0]["steps"]
@ -63,4 +68,4 @@ async def update_routing_info(
listing.routing_info_json = listing.serialize_routing_info(
{**listing.routing_info, **{destination_mode: routes}}
)
return listing
return listing

View file

@ -5,6 +5,16 @@ Manages background task operations using Celery.
from dataclasses import dataclass
from typing import Any
import json
import logging
logger = logging.getLogger(__name__)
# Standard Celery states; anything else is treated as a custom state
# whose name is used as the human-readable status message.
_CELERY_STANDARD_STATES = frozenset(
{"PENDING", "STARTED", "SUCCESS", "FAILURE", "REVOKED", "RETRY"}
)
@dataclass
@ -21,6 +31,68 @@ class TaskStatus:
traceback: str | None # Full traceback if failed
def _make_system_user(email: str) -> Any:
"""Create a minimal User object used only for Redis key generation.
These are *not* real authenticated users -- they exist solely so that
RedisRepository can derive the per-user storage key from the email.
"""
# Lazy import: api.auth imports from api.app which eventually imports
# services, so importing at module level would create a circular dependency.
from api.auth import User
return User(sub="", email=email, name="")
def _extract_result(task_result: Any) -> tuple[Any, str | None]:
"""Extract a serialisable result and an error string from a Celery AsyncResult.
Returns:
(result, error) -- exactly one of the two will be non-None (or both None
for tasks that haven't produced output yet).
"""
if task_result.failed():
error = str(task_result.result) if task_result.result else None
return None, error
try:
result = json.loads(json.dumps(task_result.result))
except (TypeError, json.JSONDecodeError):
result = str(task_result.result) if task_result.result else None
return result, None
def _extract_progress_info(task_result: Any) -> dict[str, Any]:
"""Extract progress metadata from a Celery AsyncResult's ``info`` dict.
Returns a dict with keys ``progress``, ``processed``, ``total``, and
``message`` (any of which may be None).
"""
progress: float | None = None
processed: int | None = None
total: int | None = None
message: str | None = None
if task_result.info and isinstance(task_result.info, dict):
progress = task_result.info.get("progress")
processed = task_result.info.get("processed")
total = task_result.info.get("total")
# Use 'message' if available, fall back to 'reason' for SKIPPED tasks
message = task_result.info.get("message") or task_result.info.get("reason")
# For custom states (like "Fetching listings"), use the state as message
# if no message was provided in info
if not message and task_result.status not in _CELERY_STANDARD_STATES:
message = task_result.status
return {
"progress": progress,
"processed": processed,
"total": total,
"message": message,
}
def get_task_status(task_id: str) -> TaskStatus:
"""Get the status of a background task.
@ -33,55 +105,24 @@ def get_task_status(task_id: str) -> TaskStatus:
Returns:
TaskStatus with current state
"""
# Lazy import: listing_tasks imports the Celery app which in turn
# pulls in broker configuration; importing at module level would
# create a circular dependency chain.
from tasks.listing_tasks import dump_listings_task
task_result = dump_listings_task.AsyncResult(task_id)
# Try to serialize result
result = None
error = None
if task_result.failed():
# Extract error message from failed task
error = str(task_result.result) if task_result.result else None
else:
try:
result = json.loads(json.dumps(task_result.result))
except (TypeError, json.JSONDecodeError):
result = str(task_result.result) if task_result.result else None
# Extract traceback if available
result, error = _extract_result(task_result)
task_traceback = task_result.traceback if task_result.failed() else None
# Extract progress, processed, total, and message from task meta
progress = None
processed = None
total = None
message = None
if task_result.info and isinstance(task_result.info, dict):
progress = task_result.info.get("progress")
processed = task_result.info.get("processed")
total = task_result.info.get("total")
# Use 'message' if available, fall back to 'reason' for SKIPPED tasks
message = task_result.info.get("message") or task_result.info.get("reason")
# For custom states (like "Fetching listings"), use the state as message
# if no message was provided in info
if not message and task_result.status not in (
"PENDING", "STARTED", "SUCCESS", "FAILURE", "REVOKED", "RETRY"
):
message = task_result.status
progress_info = _extract_progress_info(task_result)
return TaskStatus(
task_id=task_id,
status=task_result.status,
result=result,
progress=progress,
processed=processed,
total=total,
message=message,
error=error,
traceback=task_traceback,
**progress_info,
)
@ -97,12 +138,12 @@ def get_user_tasks(user_email: str) -> list[str]:
Returns:
List of task IDs
"""
# Lazy import: RedisRepository depends on redis which may not be
# available at import time in all contexts (CLI, tests).
from redis_repository import RedisRepository
from api.auth import User
redis_repo = RedisRepository.instance()
# Create a minimal User object for the lookup
user = User(sub="", email=user_email, name="")
user = _make_system_user(user_email)
return redis_repo.get_tasks_for_user(user)
@ -116,11 +157,11 @@ def add_task_for_user(user_email: str, task_id: str) -> None:
user_email: The user's email address
task_id: The Celery task ID
"""
# Lazy import: see get_user_tasks for rationale.
from redis_repository import RedisRepository
from api.auth import User
redis_repo = RedisRepository.instance()
user = User(sub="", email=user_email, name="")
user = _make_system_user(user_email)
redis_repo.add_task_for_user(user, task_id)
@ -134,8 +175,10 @@ def cancel_task(task_id: str, user_email: str | None = None) -> bool:
Returns:
True if task was cancelled successfully
"""
# Lazy import: celery_app bootstraps the broker connection.
from celery_app import app as celery_app
logger.info("Cancelling task %s (user=%s)", task_id, user_email)
# Revoke the task in Celery
celery_app.control.revoke(task_id, terminate=True)
@ -158,11 +201,11 @@ def remove_task_from_user(user_email: str, task_id: str) -> bool:
Returns:
True if task was removed, False if not found
"""
# Lazy import: see get_user_tasks for rationale.
from redis_repository import RedisRepository
from api.auth import User
redis_repo = RedisRepository.instance()
user = User(sub="", email=user_email, name="")
user = _make_system_user(user_email)
return redis_repo.remove_task_for_user(user, task_id)
@ -176,12 +219,14 @@ def clear_all_tasks(user_email: str, revoke: bool = True) -> int:
Returns:
Number of tasks cleared
"""
# Lazy imports: see get_user_tasks and cancel_task for rationale.
from redis_repository import RedisRepository
from celery_app import app as celery_app
from api.auth import User
redis_repo = RedisRepository.instance()
user = User(sub="", email=user_email, name="")
user = _make_system_user(user_email)
logger.info("Clearing all tasks for user %s (revoke=%s)", user_email, revoke)
# Get tasks before clearing to revoke them
if revoke:
@ -189,7 +234,9 @@ def clear_all_tasks(user_email: str, revoke: bool = True) -> int:
for task_id in tasks:
try:
celery_app.control.revoke(task_id, terminate=True)
except Exception:
pass # Best effort, continue clearing
except Exception as e:
logger.warning(
"Failed to revoke task %s: %s", task_id, e
)
return redis_repo.clear_tasks_for_user(user)

View file

@ -2,6 +2,7 @@ import asyncio
import logging
import time
from collections import deque
from dataclasses import dataclass, field
from typing import Any
from celery import Task
from celery.schedules import crontab
@ -34,11 +35,38 @@ if not celery_logger.handlers:
SCRAPE_LOCK_NAME = "scrape_listings"
LOG_BUFFER_MAX_LINES = 200
# Number of concurrent consumer workers that process listings from the queue.
NUM_WORKERS = 20
# Phase constants for task state reporting
PHASE_SPLITTING = "splitting"
PHASE_FETCHING = "fetching"
PHASE_PROCESSING = "processing"
PHASE_COMPLETED = "completed"
# Module-level log buffer — active only during task execution.
# The TaskLogHandler appends here; _update_task_state reads from here.
# This is safe as module-level mutable state because Celery workers use a
# prefork pool: each worker process handles one task at a time, so there is
# no concurrent access within a single process. The TaskLogHandler appends
# here; _update_task_state reads from here.
_active_log_buffer: deque[str] | None = None
@dataclass
class _PipelineState:
"""Shared mutable state for the streaming fetch-and-process pipeline."""
ids_collected: int = 0
completed_subqueries: int = 0
total_pages_fetched: int = 0
fetching_done: bool = False
processed_count: int = 0
failed_count: int = 0
details_fetched: int = 0
images_downloaded: int = 0
ocr_completed: int = 0
processed_listings: list[Listing] = field(default_factory=list)
class TaskLogHandler(logging.Handler):
"""Captures log records into a deque for inclusion in task state updates."""
@ -60,34 +88,204 @@ def _update_task_state(task: Task, state: str, meta: dict[str, Any]) -> None:
task.update_state(state=state, meta=meta)
async def _fetch_subquery(
sq: SubQuery,
parameters: QueryParameters,
session: object,
config: ScraperConfig,
semaphore: asyncio.Semaphore,
existing_ids: set[int],
queue: asyncio.Queue[int | None],
state: _PipelineState,
) -> None:
"""Fetch pages for a single subquery and enqueue new listing IDs."""
estimated = sq.estimated_results or 0
if estimated == 0:
state.completed_subqueries += 1
return
page_size = parameters.page_size
max_pages = min(
config.max_pages_per_query,
(estimated // page_size) + 1,
)
for page_id in range(1, max_pages + 1):
async with semaphore:
await asyncio.sleep(config.request_delay_ms / 1000)
try:
result = await listing_query(
page=page_id,
channel=parameters.listing_type,
min_bedrooms=sq.min_bedrooms,
max_bedrooms=sq.max_bedrooms,
radius=parameters.radius,
min_price=sq.min_price,
max_price=sq.max_price,
district=sq.district,
page_size=page_size,
max_days_since_added=parameters.max_days_since_added,
furnish_types=parameters.furnish_types or [],
session=session,
config=config,
)
state.total_pages_fetched += 1
properties = result.get("properties", [])
for prop in properties:
identifier = prop.get("identifier")
if identifier and identifier not in existing_ids:
existing_ids.add(identifier)
state.ids_collected += 1
await queue.put(identifier)
if len(properties) < page_size:
break
except CircuitBreakerOpenError as e:
celery_logger.error(f"Circuit breaker open: {e}")
break
except ThrottlingError as e:
celery_logger.warning(
f"Throttling on {sq.district} page {page_id}: {e}"
)
break
except Exception as e:
if "GENERIC_ERROR" in str(e):
celery_logger.debug(
f"Max page for {sq.district}: {page_id - 1}"
)
break
celery_logger.warning(
f"Error fetching page {page_id} for "
f"{sq.district}: {e}"
)
break
state.completed_subqueries += 1
async def _process_worker(
queue: asyncio.Queue[int | None],
processor: ListingProcessor,
state: _PipelineState,
) -> None:
"""Consumer worker: pull listing IDs from the queue and process them."""
while True:
listing_id = await queue.get()
if listing_id is None:
break
def step_callback(step_name: str) -> None:
if step_name == "details":
state.details_fetched += 1
elif step_name == "images":
state.images_downloaded += 1
elif step_name == "ocr":
state.ocr_completed += 1
listing = await processor.process_listing(
listing_id, on_step_complete=step_callback
)
if listing is not None:
state.processed_count += 1
state.processed_listings.append(listing)
else:
state.failed_count += 1
async def _monitor_progress(
task: Task,
state: _PipelineState,
subqueries_total: int,
start_time: float,
) -> None:
"""Periodically report pipeline progress via task state updates."""
last_progress = 0.0
while True:
total = state.ids_collected
done = state.processed_count + state.failed_count
if state.fetching_done and done >= total and total > 0:
break
if state.fetching_done and total == 0:
break
phase = PHASE_PROCESSING if state.fetching_done else PHASE_FETCHING
if total > 0:
progress_ratio = round(done / total, 2)
else:
progress_ratio = 0.0
elapsed = time.time() - start_time
rate = done / elapsed if elapsed > 0 else 0
remaining = (total - done) if total > 0 else 0
eta = remaining / rate if rate > 0 else 0
if progress_ratio >= last_progress + 0.1 or done == 1:
celery_logger.info(
f"Progress: {progress_ratio * 100:.0f}% "
f"({done}/{total}) "
f"| Elapsed: {elapsed:.0f}s "
f"| Rate: {rate:.1f}/s "
f"| ETA: {eta:.0f}s"
)
last_progress = progress_ratio
_update_task_state(
task,
f"{'Processing' if state.fetching_done else 'Fetching & processing'}: "
f"{done}/{total}",
{
"phase": phase,
"progress": progress_ratio,
"processed": done,
"total": total,
"subqueries_completed": state.completed_subqueries,
"subqueries_total": subqueries_total,
"ids_collected": state.ids_collected,
"pages_fetched": state.total_pages_fetched,
"fetching_done": state.fetching_done,
"details_fetched": state.details_fetched,
"images_downloaded": state.images_downloaded,
"ocr_completed": state.ocr_completed,
"failed": state.failed_count,
"elapsed_seconds": round(elapsed, 1),
"rate_per_second": round(rate, 2),
"eta_seconds": round(eta, 1),
},
)
await asyncio.sleep(1)
@app.task(bind=True, pydantic=True)
def dump_listings_task(self: Task, parameters_json: str) -> dict[str, Any]:
with redis_lock(SCRAPE_LOCK_NAME) as acquired:
if not acquired:
msg = "Another scrape job is already running, skipping this execution"
logger.warning(msg)
celery_logger.warning(msg)
self.update_state(state="SKIPPED", meta={"reason": "Another scrape job is running"})
return {"status": "skipped", "reason": "another_job_running"}
celery_logger.info(f"Acquired lock: {SCRAPE_LOCK_NAME}")
logger.info(f"Acquired lock: {SCRAPE_LOCK_NAME}")
parsed_parameters = QueryParameters.model_validate_json(parameters_json)
celery_logger.info(f"Starting scrape with parameters: {parsed_parameters}")
self.update_state(state="Starting...", meta={"phase": "splitting", "progress": 0})
self.update_state(state="Starting...", meta={"phase": PHASE_SPLITTING, "progress": 0})
asyncio.run(dump_listings_full(task=self, parameters=parsed_parameters))
return {"phase": "completed", "progress": 1}
return {"phase": PHASE_COMPLETED, "progress": 1}
async def async_dump_listings_task(parameters_json: str) -> dict[str, Any]:
with redis_lock(SCRAPE_LOCK_NAME) as acquired:
if not acquired:
logger.warning("Another scrape job is already running, skipping this execution")
celery_logger.warning("Another scrape job is already running, skipping this execution")
return {"status": "skipped", "reason": "another_job_running"}
logger.info(f"Acquired lock: {SCRAPE_LOCK_NAME}")
celery_logger.info(f"Acquired lock: {SCRAPE_LOCK_NAME}")
parsed_parameters = QueryParameters.model_validate_json(parameters_json)
await dump_listings_full(task=Task(), parameters=parsed_parameters)
return {"progress": 0}
@ -141,17 +339,16 @@ async def _dump_listings_full_inner(
soon as IDs become available from each subquery.
"""
start_time = time.time()
NUM_WORKERS = 20
state = _PipelineState()
celery_logger.info("=" * 60)
celery_logger.info("PHASE 1: Splitting queries")
celery_logger.info(f"PHASE 1: Splitting queries")
celery_logger.info("=" * 60)
repository = ListingRepository(engine)
config = ScraperConfig.from_env()
splitter = QuerySplitter(config)
# Reset throttle metrics
reset_throttle_metrics()
def on_progress(phase: str, message: str, **kwargs: Any) -> None:
@ -161,7 +358,7 @@ async def _dump_listings_full_inner(
celery_logger.info(f"[{phase}] {message}")
_update_task_state(task, "Analyzing query and splitting by price bands...", {
"phase": "splitting", "progress": 0,
"phase": PHASE_SPLITTING, "progress": 0,
})
celery_logger.info("Starting query splitting and probing...")
@ -175,34 +372,22 @@ async def _dump_listings_full_inner(
f"~{total_estimated} estimated total results"
)
# Load existing IDs (fast, ID-only projection)
celery_logger.info("Loading existing listing IDs from database...")
existing_ids = repository.get_listing_ids(parameters.listing_type)
celery_logger.info(f"Found {len(existing_ids)} existing listings in DB")
celery_logger.info("=" * 60)
celery_logger.info("PHASE 2: Streaming fetch & process")
celery_logger.info(f"PHASE 2: Streaming fetch & process")
celery_logger.info("=" * 60)
# Shared state for the streaming pipeline
queue: asyncio.Queue[int | None] = asyncio.Queue()
ids_collected = 0
completed_subqueries = 0
total_pages_fetched = 0
fetching_done = False
processed_count = 0
failed_count = 0
details_fetched = 0
images_downloaded = 0
ocr_completed = 0
processed_listings: list[Listing] = []
semaphore = asyncio.Semaphore(config.max_concurrent_requests)
_update_task_state(
task,
f"Fetching listings from {len(subqueries)} subqueries...",
{
"phase": "fetching",
"phase": PHASE_FETCHING,
"subqueries_completed": 0,
"subqueries_total": len(subqueries),
"ids_collected": 0,
@ -214,190 +399,32 @@ async def _dump_listings_full_inner(
listing_processor = ListingProcessor(repository)
# --- Producer: fetch subquery pages and enqueue new IDs ---
# Producer: fetch all subqueries concurrently, then signal workers to stop
async def producer() -> None:
nonlocal ids_collected, completed_subqueries, total_pages_fetched
nonlocal fetching_done
async def fetch_subquery(sq: SubQuery) -> None:
nonlocal ids_collected, completed_subqueries, total_pages_fetched
estimated = sq.estimated_results or 0
if estimated == 0:
completed_subqueries += 1
return
page_size = parameters.page_size
max_pages = min(
config.max_pages_per_query,
(estimated // page_size) + 1,
)
for page_id in range(1, max_pages + 1):
async with semaphore:
await asyncio.sleep(config.request_delay_ms / 1000)
try:
result = await listing_query(
page=page_id,
channel=parameters.listing_type,
min_bedrooms=sq.min_bedrooms,
max_bedrooms=sq.max_bedrooms,
radius=parameters.radius,
min_price=sq.min_price,
max_price=sq.max_price,
district=sq.district,
page_size=page_size,
max_days_since_added=parameters.max_days_since_added,
furnish_types=parameters.furnish_types or [],
session=session,
config=config,
)
total_pages_fetched += 1
# Extract and enqueue new IDs inline
properties = result.get("properties", [])
for prop in properties:
identifier = prop.get("identifier")
if identifier and identifier not in existing_ids:
existing_ids.add(identifier)
ids_collected += 1
await queue.put(identifier)
if len(properties) < page_size:
break
except CircuitBreakerOpenError as e:
celery_logger.error(f"Circuit breaker open: {e}")
break
except ThrottlingError as e:
celery_logger.warning(
f"Throttling on {sq.district} page {page_id}: {e}"
)
break
except Exception as e:
if "GENERIC_ERROR" in str(e):
logger.debug(
f"Max page for {sq.district}: {page_id - 1}"
)
break
logger.warning(
f"Error fetching page {page_id} for "
f"{sq.district}: {e}"
)
break
completed_subqueries += 1
# Fetch all subqueries concurrently
await asyncio.gather(
*[fetch_subquery(sq) for sq in subqueries]
*[
_fetch_subquery(
sq, parameters, session, config,
semaphore, existing_ids, queue, state,
)
for sq in subqueries
]
)
celery_logger.info(
f"Fetch complete: {total_pages_fetched} pages from "
f"{completed_subqueries} subqueries, "
f"{ids_collected} new IDs"
f"Fetch complete: {state.total_pages_fetched} pages from "
f"{state.completed_subqueries} subqueries, "
f"{state.ids_collected} new IDs"
)
fetching_done = True
state.fetching_done = True
# Send sentinel values to stop workers
for _ in range(NUM_WORKERS):
await queue.put(None)
# --- Consumer workers: process listings from queue ---
async def worker() -> None:
nonlocal processed_count, failed_count
nonlocal details_fetched, images_downloaded, ocr_completed
while True:
listing_id = await queue.get()
if listing_id is None:
break
def step_callback(step_name: str) -> None:
nonlocal details_fetched, images_downloaded, ocr_completed
if step_name == "details":
details_fetched += 1
elif step_name == "images":
images_downloaded += 1
elif step_name == "ocr":
ocr_completed += 1
listing = await listing_processor.process_listing(
listing_id, on_step_complete=step_callback
)
if listing is not None:
processed_count += 1
processed_listings.append(listing)
else:
failed_count += 1
# --- Monitor: reports combined progress ---
async def monitor() -> None:
last_progress = 0.0
while True:
total = ids_collected
done = processed_count + failed_count
if fetching_done and done >= total and total > 0:
break
if fetching_done and total == 0:
break
# Determine phase label
phase = "processing" if fetching_done else "fetching"
if total > 0:
progress_ratio = round(done / total, 2)
else:
progress_ratio = 0.0
elapsed = time.time() - start_time
rate = done / elapsed if elapsed > 0 else 0
remaining = (total - done) if total > 0 else 0
eta = remaining / rate if rate > 0 else 0
if progress_ratio >= last_progress + 0.1 or done == 1:
celery_logger.info(
f"Progress: {progress_ratio * 100:.0f}% "
f"({done}/{total}) "
f"| Elapsed: {elapsed:.0f}s "
f"| Rate: {rate:.1f}/s "
f"| ETA: {eta:.0f}s"
)
last_progress = progress_ratio
_update_task_state(
task,
f"{'Processing' if fetching_done else 'Fetching & processing'}: "
f"{done}/{total}",
{
"phase": phase,
"progress": progress_ratio,
"processed": done,
"total": total,
"subqueries_completed": completed_subqueries,
"subqueries_total": len(subqueries),
"ids_collected": ids_collected,
"pages_fetched": total_pages_fetched,
"fetching_done": fetching_done,
"details_fetched": details_fetched,
"images_downloaded": images_downloaded,
"ocr_completed": ocr_completed,
"failed": failed_count,
"elapsed_seconds": round(elapsed, 1),
"rate_per_second": round(rate, 2),
"eta_seconds": round(eta, 1),
},
)
await asyncio.sleep(1)
# Run producer, workers, and monitor concurrently
await asyncio.gather(
producer(),
*[worker() for _ in range(NUM_WORKERS)],
monitor(),
*[_process_worker(queue, listing_processor, state) for _ in range(NUM_WORKERS)],
_monitor_progress(task, state, len(subqueries), start_time),
)
except CircuitBreakerOpenError as e:
@ -418,19 +445,19 @@ async def _dump_listings_full_inner(
elapsed = time.time() - start_time
celery_logger.info("=" * 60)
celery_logger.info(
f"COMPLETED: Processed {len(processed_listings)} listings in {elapsed:.1f}s"
f"COMPLETED: Processed {len(state.processed_listings)} listings in {elapsed:.1f}s"
)
celery_logger.info("=" * 60)
invalidate_cache()
_update_task_state(task, "Completed", {
"phase": "completed", "progress": 1,
"processed": len(processed_listings), "total": ids_collected,
"message": f"Processed {len(processed_listings)} listings in {elapsed:.0f}s",
"phase": PHASE_COMPLETED, "progress": 1,
"processed": len(state.processed_listings), "total": state.ids_collected,
"message": f"Processed {len(state.processed_listings)} listings in {elapsed:.0f}s",
})
return processed_listings
return state.processed_listings
@app.on_after_finalize.connect
@ -439,11 +466,11 @@ def setup_periodic_tasks(sender, **kwargs):
try:
config = SchedulesConfig.from_env()
except ValueError as e:
logger.error(f"Failed to load schedule configuration: {e}")
celery_logger.error(f"Failed to load schedule configuration: {e}")
return
for schedule in config.get_enabled_schedules():
logger.info(
celery_logger.info(
f"Registering periodic task: {schedule.name} at {schedule.hour}:{schedule.minute}"
)

View file

@ -1,5 +1,6 @@
"""Integration tests for API endpoints."""
from unittest.mock import AsyncMock, patch
import json
from unittest.mock import AsyncMock, MagicMock, patch
import pytest
from httpx import AsyncClient
@ -75,10 +76,12 @@ class TestListingGeoJsonEndpoint:
self, async_client: AsyncClient
) -> None:
"""Test that listing_geojson accepts filter parameters."""
mock_result = MagicMock()
mock_result.data = {"type": "FeatureCollection", "features": []}
with patch(
"api.app.export_immoweb",
"api.app.export_service.export_to_geojson",
new_callable=AsyncMock,
return_value={"type": "FeatureCollection", "features": []},
return_value=mock_result,
):
response = await async_client.get(
"/api/listing_geojson",
@ -178,3 +181,135 @@ class TestTaskStatusEndpoint:
)
# Should return 401 or 403 without valid auth
assert response.status_code in (401, 403)
class TestStreamListingGeoJsonEndpoint:
"""Tests for the /api/listing_geojson/stream endpoint."""
async def test_stream_returns_ndjson_with_metadata(
self, async_client: AsyncClient
) -> None:
"""Test that the stream endpoint returns valid NDJSON starting with a metadata message."""
fake_features = [
{"type": "Feature", "properties": {"id": 1}, "geometry": {"type": "Point", "coordinates": [0, 0]}},
{"type": "Feature", "properties": {"id": 2}, "geometry": {"type": "Point", "coordinates": [1, 1]}},
]
with patch("api.app.get_cached_count", return_value=2), \
patch("api.app.get_cached_features", return_value=iter([fake_features])):
response = await async_client.get(
"/api/listing_geojson/stream",
params={"listing_type": "RENT", "batch_size": 50},
)
assert response.status_code == 200
assert response.headers["content-type"] == "application/x-ndjson"
lines = [line for line in response.text.strip().split("\n") if line]
assert len(lines) >= 2 # at least metadata + complete
metadata = json.loads(lines[0])
assert metadata["type"] == "metadata"
assert "batch_size" in metadata
assert "total_expected" in metadata
complete = json.loads(lines[-1])
assert complete["type"] == "complete"
assert "total" in complete
async def test_stream_cache_hit_path(
self, async_client: AsyncClient
) -> None:
"""Test that cache-hit path returns cached: True in metadata."""
fake_features = [
{"type": "Feature", "properties": {"id": 1}, "geometry": {"type": "Point", "coordinates": [0, 0]}},
]
with patch("api.app.get_cached_count", return_value=1), \
patch("api.app.get_cached_features", return_value=iter([fake_features])):
response = await async_client.get(
"/api/listing_geojson/stream",
params={"listing_type": "RENT"},
)
assert response.status_code == 200
lines = [line for line in response.text.strip().split("\n") if line]
metadata = json.loads(lines[0])
assert metadata["cached"] is True
assert metadata["total_expected"] == 1
batch_msg = json.loads(lines[1])
assert batch_msg["type"] == "batch"
assert len(batch_msg["features"]) == 1
async def test_stream_cache_miss_path(
self, async_client: AsyncClient
) -> None:
"""Test that cache-miss path queries DB and returns cached: False."""
from datetime import datetime
fake_rows = [
{
"id": 100,
"price": 2000.0,
"number_of_bedrooms": 2,
"square_meters": 50.0,
"longitude": -0.1,
"latitude": 51.5,
"photo_thumbnail": None,
"last_seen": datetime(2024, 1, 1),
"agency": "Test Agency",
"price_history_json": "[]",
"available_from": None,
},
]
mock_repo = MagicMock()
mock_repo.count_listings.return_value = 1
mock_repo.stream_listings_optimized.return_value = iter(fake_rows)
with patch("api.app.get_cached_count", return_value=None), \
patch("api.app.ListingRepository", return_value=mock_repo), \
patch("api.app.cache_features_batch"):
response = await async_client.get(
"/api/listing_geojson/stream",
params={"listing_type": "RENT"},
)
assert response.status_code == 200
lines = [line for line in response.text.strip().split("\n") if line]
metadata = json.loads(lines[0])
assert metadata["cached"] is False
assert metadata["total_expected"] == 1
batch_msg = json.loads(lines[1])
assert batch_msg["type"] == "batch"
assert len(batch_msg["features"]) == 1
assert batch_msg["features"][0]["type"] == "Feature"
assert batch_msg["features"][0]["properties"]["total_price"] == 2000.0
complete = json.loads(lines[-1])
assert complete["type"] == "complete"
assert complete["total"] == 1
async def test_stream_with_limit(
self, async_client: AsyncClient
) -> None:
"""Test that the limit parameter caps the number of streamed features."""
fake_features = [
{"type": "Feature", "properties": {"id": i}, "geometry": {"type": "Point", "coordinates": [0, 0]}}
for i in range(5)
]
with patch("api.app.get_cached_count", return_value=5), \
patch("api.app.get_cached_features", return_value=iter([fake_features])):
response = await async_client.get(
"/api/listing_geojson/stream",
params={"listing_type": "RENT", "limit": 3},
)
assert response.status_code == 200
lines = [line for line in response.text.strip().split("\n") if line]
metadata = json.loads(lines[0])
assert metadata["total_expected"] == 3
complete = json.loads(lines[-1])
assert complete["type"] == "complete"
assert complete["total"] == 3

View file

@ -77,7 +77,7 @@ class TestThrottlingRetryBehavior:
"""Test that 429 responses trigger retry with backoff."""
call_count = 0
async def mock_get(*args: object, **kwargs: object) -> MockResponse:
def mock_get(*args: object, **kwargs: object) -> MockResponse:
nonlocal call_count
call_count += 1
if call_count < 3:
@ -117,7 +117,7 @@ class TestThrottlingRetryBehavior:
"""Test that 503 responses trigger retry."""
call_count = 0
async def mock_get(*args: object, **kwargs: object) -> MockResponse:
def mock_get(*args: object, **kwargs: object) -> MockResponse:
nonlocal call_count
call_count += 1
if call_count < 2:
@ -157,7 +157,7 @@ class TestCircuitBreakerIntegration:
"""Test that circuit breaker opens after consecutive failures."""
call_count = 0
async def mock_get(*args: object, **kwargs: object) -> MockResponse:
def mock_get(*args: object, **kwargs: object) -> MockResponse:
nonlocal call_count
call_count += 1
return MockResponse(status=429)
@ -223,14 +223,14 @@ class TestMetricsTracking:
@pytest.mark.asyncio
async def test_metrics_tracked_on_rate_limit(self, config: ScraperConfig) -> None:
"""Test that rate limit errors are tracked in metrics."""
async def mock_get(*args: object, **kwargs: object) -> MockResponse:
def mock_get(*args: object, **kwargs: object) -> MockResponse:
return MockResponse(status=429)
mock_session = MagicMock()
mock_session.get = mock_get
with patch("rec.query.districts.get_districts", return_value={"Test": "LOC1"}):
with pytest.raises(RateLimitError):
with pytest.raises((RateLimitError, CircuitBreakerOpenError)):
with patch("tenacity.wait_exponential.__call__", return_value=0):
await probe_query(
session=mock_session,
@ -250,7 +250,7 @@ class TestMetricsTracking:
@pytest.mark.asyncio
async def test_metrics_tracked_on_success(self, config: ScraperConfig) -> None:
"""Test that successful requests are tracked in metrics."""
async def mock_get(*args: object, **kwargs: object) -> MockResponse:
def mock_get(*args: object, **kwargs: object) -> MockResponse:
return MockResponse(
status=200,
json_data={"totalAvailableResults": 10, "properties": []},

View file

@ -97,7 +97,7 @@ class TestListingGeoJsonEndpoint:
# Override auth dependency
async def mock_auth():
return User(email="test@example.com", name="Test User")
return User(sub="test-id", email="test@example.com", name="Test User")
app.dependency_overrides[get_current_user] = mock_auth
yield TestClient(app)

View file

@ -0,0 +1,151 @@
"""Unit tests for api/auth.py."""
from datetime import datetime, timedelta, timezone
from unittest.mock import AsyncMock, patch, MagicMock
import jwt as pyjwt
import pytest
from fastapi import HTTPException
from fastapi.security import HTTPAuthorizationCredentials
from api.auth import (
User,
_verify_passkey_token,
_verify_authentik_token,
get_current_user,
)
from api.config import JWT_SECRET, JWT_ALGORITHM, JWT_ISSUER
def _make_passkey_token(
sub: str = "user-123",
email: str = "test@example.com",
name: str = "Test User",
issuer: str = JWT_ISSUER,
secret: str = JWT_SECRET,
algorithm: str = JWT_ALGORITHM,
expires_delta: timedelta | None = timedelta(hours=1),
) -> str:
"""Helper to mint a passkey-style HS256 JWT."""
payload: dict = {"sub": sub, "email": email, "name": name, "iss": issuer}
if expires_delta is not None:
payload["exp"] = datetime.now(timezone.utc) + expires_delta
return pyjwt.encode(payload, secret, algorithm=algorithm)
class TestVerifyPasskeyToken:
"""Tests for _verify_passkey_token()."""
def test_valid_token_returns_user(self) -> None:
token = _make_passkey_token()
user = _verify_passkey_token(token)
assert isinstance(user, User)
assert user.sub == "user-123"
assert user.email == "test@example.com"
assert user.name == "Test User"
def test_valid_token_without_name_uses_email(self) -> None:
payload = {
"sub": "user-456",
"email": "noname@example.com",
"iss": JWT_ISSUER,
"exp": datetime.now(timezone.utc) + timedelta(hours=1),
}
token = pyjwt.encode(payload, JWT_SECRET, algorithm=JWT_ALGORITHM)
user = _verify_passkey_token(token)
assert user.name == "noname@example.com"
def test_rejects_expired_token(self) -> None:
token = _make_passkey_token(expires_delta=timedelta(hours=-1))
with pytest.raises(pyjwt.ExpiredSignatureError):
_verify_passkey_token(token)
def test_rejects_wrong_secret(self) -> None:
token = _make_passkey_token(secret="wrong-secret")
with pytest.raises(pyjwt.InvalidSignatureError):
_verify_passkey_token(token)
def test_rejects_wrong_issuer(self) -> None:
token = _make_passkey_token(issuer="some-other-issuer")
with pytest.raises(pyjwt.InvalidIssuerError):
_verify_passkey_token(token)
class TestVerifyAuthentikToken:
"""Tests for _verify_authentik_token() — specifically that expiration is verified."""
async def test_verifies_expiration_after_fix(self) -> None:
"""After removing verify_exp: False, expired Authentik tokens should be rejected."""
from cryptography.hazmat.primitives.asymmetric import rsa
from cryptography.hazmat.primitives import serialization
private_key = rsa.generate_private_key(public_exponent=65537, key_size=2048)
public_key = private_key.public_key()
public_pem = public_key.public_bytes(
encoding=serialization.Encoding.PEM,
format=serialization.PublicFormat.SubjectPublicKeyInfo,
)
issuer = "https://authentik.viktorbarzin.me/application/o/wrongmove/"
payload = {
"sub": "authentik-user",
"email": "auth@example.com",
"name": "Auth User",
"iss": issuer,
"aud": "5AJKRgcdgVm1OyApBzFkadDFfStW9a555zwv2MOe",
"exp": datetime.now(timezone.utc) - timedelta(hours=1), # expired
}
token = pyjwt.encode(payload, private_key, algorithm="RS256")
# Build a real PyJWK-compatible signing key mock so jwt.decode
# takes the PyJWK code path (uses key.key directly, skips prepare_key)
mock_signing_key = MagicMock(spec=pyjwt.PyJWK)
mock_signing_key.key = public_key
mock_signing_key.algorithm_name = "RS256"
mock_signing_key.Algorithm = pyjwt.get_algorithm_by_name("RS256")
mock_jwks_client = MagicMock()
mock_jwks_client.get_signing_key_from_jwt.return_value = mock_signing_key
mock_metadata = {
"issuer": issuer,
"jwks_uri": f"{issuer}jwks/",
}
with patch("api.auth.get_oidc_metadata", new_callable=AsyncMock, return_value=mock_metadata), \
patch("api.auth.get_cached_jwks_client", new_callable=AsyncMock, return_value=mock_jwks_client):
with pytest.raises(pyjwt.ExpiredSignatureError):
await _verify_authentik_token(token)
class TestGetCurrentUser:
"""Tests for get_current_user()."""
async def test_routes_to_passkey_verifier_for_matching_issuer(self) -> None:
token = _make_passkey_token()
credentials = HTTPAuthorizationCredentials(scheme="Bearer", credentials=token)
user = await get_current_user(credentials)
assert user.sub == "user-123"
assert user.email == "test@example.com"
async def test_routes_to_authentik_for_other_issuer(self) -> None:
"""When issuer != JWT_ISSUER, should route to Authentik verifier."""
token = _make_passkey_token(issuer="https://authentik.viktorbarzin.me/application/o/wrongmove/")
credentials = HTTPAuthorizationCredentials(scheme="Bearer", credentials=token)
mock_user = User(sub="authentik-user", email="auth@example.com", name="Auth User")
with patch("api.auth._verify_authentik_token", new_callable=AsyncMock, return_value=mock_user):
user = await get_current_user(credentials)
assert user.email == "auth@example.com"
async def test_raises_http_exception_for_invalid_token(self) -> None:
credentials = HTTPAuthorizationCredentials(scheme="Bearer", credentials="not.a.valid.token")
with pytest.raises(HTTPException) as exc_info:
await get_current_user(credentials)
assert exc_info.value.status_code == 401
assert "Invalid token" in exc_info.value.detail
async def test_raises_http_exception_for_garbage_token(self) -> None:
credentials = HTTPAuthorizationCredentials(scheme="Bearer", credentials="totalgarbage")
with pytest.raises(HTTPException) as exc_info:
await get_current_user(credentials)
assert exc_info.value.status_code == 401

View file

@ -0,0 +1,388 @@
"""Characterization and unit tests for the CLI (main.py)."""
from datetime import datetime
from unittest.mock import AsyncMock, MagicMock, patch
import click
import pytest
from click.testing import CliRunner
from models.listing import FurnishType, ListingType, QueryParameters
from main import build_query_parameters, cli, listing_filter_options
class TestBuildQueryParameters:
"""Tests for build_query_parameters()."""
def test_typical_rent_inputs(self) -> None:
qp = build_query_parameters(
type="RENT",
district=["London", "Camden"],
min_bedrooms=2,
max_bedrooms=4,
min_price=1000,
max_price=3000,
furnish_types=["FURNISHED"],
available_from=datetime(2025, 6, 1),
last_seen_days=7,
min_sqm=50,
)
assert qp.listing_type == ListingType.RENT
assert qp.district_names == {"London", "Camden"}
assert qp.min_bedrooms == 2
assert qp.max_bedrooms == 4
assert qp.min_price == 1000
assert qp.max_price == 3000
assert qp.furnish_types == [FurnishType.FURNISHED]
assert qp.let_date_available_from == datetime(2025, 6, 1)
assert qp.last_seen_days == 7
assert qp.min_sqm == 50
def test_typical_buy_inputs(self) -> None:
qp = build_query_parameters(
type="BUY",
district=["Barnet"],
min_bedrooms=3,
max_bedrooms=5,
min_price=200000,
max_price=500000,
furnish_types=[],
available_from=None,
last_seen_days=14,
)
assert qp.listing_type == ListingType.BUY
assert qp.district_names == {"Barnet"}
assert qp.furnish_types is None
assert qp.let_date_available_from is None
assert qp.min_sqm is None
def test_empty_districts_yields_empty_set(self) -> None:
qp = build_query_parameters(
type="RENT",
district=[],
min_bedrooms=1,
max_bedrooms=10,
min_price=0,
max_price=999999,
furnish_types=[],
available_from=None,
last_seen_days=14,
)
assert qp.district_names == set()
def test_none_districts_yields_empty_set(self) -> None:
qp = build_query_parameters(
type="RENT",
district=None,
min_bedrooms=1,
max_bedrooms=10,
min_price=0,
max_price=999999,
furnish_types=[],
available_from=None,
last_seen_days=14,
)
assert qp.district_names == set()
def test_furnish_types_conversion(self) -> None:
qp = build_query_parameters(
type="RENT",
district=["London"],
min_bedrooms=1,
max_bedrooms=10,
min_price=0,
max_price=999999,
furnish_types=["FURNISHED", "UNFURNISHED"],
available_from=None,
last_seen_days=14,
)
assert qp.furnish_types == [FurnishType.FURNISHED, FurnishType.UNFURNISHED]
def test_empty_furnish_types_yields_none(self) -> None:
qp = build_query_parameters(
type="RENT",
district=["London"],
min_bedrooms=1,
max_bedrooms=10,
min_price=0,
max_price=999999,
furnish_types=[],
available_from=None,
last_seen_days=14,
)
assert qp.furnish_types is None
def test_default_optional_parameters(self) -> None:
qp = build_query_parameters(
type="RENT",
district=["London"],
min_bedrooms=1,
max_bedrooms=10,
min_price=0,
max_price=999999,
furnish_types=[],
available_from=None,
last_seen_days=14,
)
assert qp.radius == 0
assert qp.page_size == 500
assert qp.max_days_since_added == 14
class TestListingFilterOptionsDecorator:
"""Tests for the listing_filter_options decorator."""
def test_applies_all_expected_options(self) -> None:
@click.command()
@listing_filter_options
def dummy_cmd(**kwargs: object) -> None:
pass
expected_option_names = {
"type",
"min_bedrooms",
"max_bedrooms",
"min_price",
"max_price",
"district",
"furnish_types",
"available_from",
"last_seen_days",
"min_sqm",
}
param_names = {p.name for p in dummy_cmd.params}
assert expected_option_names.issubset(param_names), (
f"Missing options: {expected_option_names - param_names}"
)
def test_type_option_is_required(self) -> None:
@click.command()
@listing_filter_options
def dummy_cmd(**kwargs: object) -> None:
pass
type_param = next(p for p in dummy_cmd.params if p.name == "type")
assert type_param.required is True
def test_produces_query_parameters_kwarg(self) -> None:
"""After refactoring, the decorator should produce a query_parameters kwarg."""
captured: dict = {}
@click.command()
@listing_filter_options
def dummy_cmd(query_parameters: QueryParameters) -> None:
captured["qp"] = query_parameters
runner = CliRunner()
result = runner.invoke(dummy_cmd, ["--type", "RENT"])
assert result.exit_code == 0, f"Command failed: {result.output}"
assert isinstance(captured["qp"], QueryParameters)
assert captured["qp"].listing_type == ListingType.RENT
class TestDumpListingsCommand:
"""Tests for the dump-listings CLI command."""
@patch("main.listing_service.refresh_listings", new_callable=AsyncMock)
@patch("main.engine", new_callable=MagicMock)
def test_calls_refresh_listings_with_correct_params(
self,
mock_engine: MagicMock,
mock_refresh: AsyncMock,
) -> None:
from services.listing_service import RefreshResult
mock_refresh.return_value = RefreshResult(
task_id=None,
new_listings_count=5,
message="Fetched 5 new listings",
)
runner = CliRunner()
result = runner.invoke(
cli,
[
"dump-listings",
"--type", "RENT",
"--min-bedrooms", "2",
"--max-bedrooms", "4",
"--min-price", "1000",
"--max-price", "3000",
],
)
assert result.exit_code == 0, f"CLI failed: {result.output}"
mock_refresh.assert_called_once()
call_args = mock_refresh.call_args
qp: QueryParameters = call_args.args[1]
assert qp.listing_type == ListingType.RENT
assert qp.min_bedrooms == 2
assert qp.max_bedrooms == 4
assert qp.min_price == 1000
assert qp.max_price == 3000
assert call_args.kwargs.get("full") is not True
@patch("main.listing_service.refresh_listings", new_callable=AsyncMock)
@patch("main.engine", new_callable=MagicMock)
def test_include_processing_flag_passes_full_true(
self,
mock_engine: MagicMock,
mock_refresh: AsyncMock,
) -> None:
from services.listing_service import RefreshResult
mock_refresh.return_value = RefreshResult(
task_id=None,
new_listings_count=0,
message="Fetched 0 new listings",
)
runner = CliRunner()
result = runner.invoke(
cli,
[
"dump-listings",
"--type", "RENT",
"--include-processing",
],
)
assert result.exit_code == 0, f"CLI failed: {result.output}"
mock_refresh.assert_called_once()
call_kwargs = mock_refresh.call_args.kwargs
assert call_kwargs.get("full") is True
@patch("main.listing_service.refresh_listings", new_callable=AsyncMock)
@patch("main.engine", new_callable=MagicMock)
def test_include_processing_short_flag(
self,
mock_engine: MagicMock,
mock_refresh: AsyncMock,
) -> None:
from services.listing_service import RefreshResult
mock_refresh.return_value = RefreshResult(
task_id=None,
new_listings_count=0,
message="Fetched 0 new listings",
)
runner = CliRunner()
result = runner.invoke(
cli,
[
"dump-listings",
"--type", "RENT",
"-p",
],
)
assert result.exit_code == 0, f"CLI failed: {result.output}"
mock_refresh.assert_called_once()
call_kwargs = mock_refresh.call_args.kwargs
assert call_kwargs.get("full") is True
class TestExportCsvCommand:
"""Tests for the export-csv CLI command."""
@patch("main.export_service.export_to_csv", new_callable=AsyncMock)
@patch("main.engine", new_callable=MagicMock)
def test_calls_export_to_csv(
self,
mock_engine: MagicMock,
mock_export: AsyncMock,
) -> None:
from services.export_service import ExportResult
mock_export.return_value = ExportResult(
success=True,
output_path="/tmp/test.csv",
data=None,
record_count=10,
message="Exported 10 listings to /tmp/test.csv",
)
runner = CliRunner()
result = runner.invoke(
cli,
[
"export-csv",
"--output-file", "/tmp/test.csv",
"--type", "RENT",
],
)
assert result.exit_code == 0, f"CLI failed: {result.output}"
mock_export.assert_called_once()
call_args = mock_export.call_args
qp = call_args[0][2]
assert qp.listing_type == ListingType.RENT
class TestExportImmowebCommand:
"""Tests for the export-immoweb CLI command."""
@patch("main.export_service.export_to_geojson", new_callable=AsyncMock)
@patch("main.engine", new_callable=MagicMock)
def test_calls_export_to_geojson(
self,
mock_engine: MagicMock,
mock_export: AsyncMock,
) -> None:
from services.export_service import ExportResult
mock_export.return_value = ExportResult(
success=True,
output_path="/tmp/test.geojson",
data=None,
record_count=5,
message="Exported 5 listings to /tmp/test.geojson",
)
runner = CliRunner()
result = runner.invoke(
cli,
[
"export-immoweb",
"--output-file", "/tmp/test.geojson",
"--type", "RENT",
],
)
assert result.exit_code == 0, f"CLI failed: {result.output}"
mock_export.assert_called_once()
class TestListDistrictsCommand:
"""Tests for the list-districts CLI command."""
@patch("main.engine", new_callable=MagicMock)
def test_outputs_district_names(self, mock_engine: MagicMock) -> None:
runner = CliRunner()
result = runner.invoke(cli, ["list-districts"])
assert result.exit_code == 0
assert "London" in result.output
assert "Camden" in result.output
assert "Available districts" in result.output
class TestRoutingCommand:
"""Tests for the routing CLI command."""
@patch("main.engine", new_callable=MagicMock)
def test_requires_api_key_env_var(self, mock_engine: MagicMock) -> None:
runner = CliRunner(env={"ROUTING_API_KEY": None})
result = runner.invoke(
cli,
[
"routing",
"--destination-address", "London Bridge",
"--travel-mode", "transit",
"--limit", "1",
],
catch_exceptions=False,
)
assert result.exit_code != 0
assert "ROUTING_API_KEY" in result.output

View file

@ -0,0 +1,62 @@
"""Unit tests for rec/districts.py and services/district_service.py."""
from rec.districts import get_districts, get_district_by_name
from services.district_service import get_all_districts, get_district_names, validate_districts
class TestGetDistricts:
def test_returns_non_empty_dict(self) -> None:
districts = get_districts()
assert isinstance(districts, dict)
assert len(districts) > 0
def test_values_start_with_region_prefix(self) -> None:
for name, region_id in get_districts().items():
assert region_id.startswith("REGION^"), (
f"District '{name}' has value '{region_id}' that doesn't start with REGION^"
)
def test_contains_expected_london_boroughs(self) -> None:
districts = get_districts()
for borough in ("Camden", "Westminster", "Hackney"):
assert borough in districts, f"Expected borough '{borough}' not found"
class TestGetDistrictByName:
def test_valid_name_returns_region_id(self) -> None:
result = get_district_by_name("Camden")
assert result == "REGION^93941"
def test_invalid_name_returns_none(self) -> None:
result = get_district_by_name("Nonexistent District")
assert result is None
class TestGetDistrictNames:
def test_returns_list_matching_dict_keys(self) -> None:
names = get_district_names()
assert isinstance(names, list)
assert names == list(get_districts().keys())
class TestGetAllDistricts:
def test_returns_same_as_get_districts(self) -> None:
assert get_all_districts() == get_districts()
class TestValidateDistricts:
def test_all_valid_returns_empty_list(self) -> None:
result = validate_districts(["Camden", "Westminster", "Hackney"])
assert result == []
def test_some_invalid_returns_invalid_ones(self) -> None:
result = validate_districts(["Camden", "Faketown", "Westminster", "Nowhere"])
assert result == ["Faketown", "Nowhere"]
def test_all_invalid_returns_all(self) -> None:
invalid = ["Faketown", "Nowhere", "Neverland"]
result = validate_districts(invalid)
assert result == invalid
def test_empty_list_returns_empty_list(self) -> None:
result = validate_districts([])
assert result == []

View file

@ -0,0 +1,104 @@
"""Unit tests for rec/floorplan.py."""
from unittest.mock import patch
import numpy as np
from PIL import Image
import pytest
from rec.floorplan import extract_total_sqm, improve_img_for_ocr, calculate_ocr
class TestExtractTotalSqm:
def test_normal_value(self) -> None:
assert extract_total_sqm("Total area: 75.5 sq m") == 75.5
def test_multiple_values_returns_max_in_range(self) -> None:
assert extract_total_sqm("Room 1: 20 sqm, Total: 65 sq m") == 65.0
def test_no_match_returns_none(self) -> None:
assert extract_total_sqm("No area info") is None
def test_below_minimum_returns_none(self) -> None:
assert extract_total_sqm("Area: 15 sq m") is None
def test_above_maximum_returns_none(self) -> None:
assert extract_total_sqm("Area: 200 sq m") is None
def test_edge_just_above_min(self) -> None:
assert extract_total_sqm("Area: 30.1 sq m") == 30.1
def test_edge_just_below_max(self) -> None:
assert extract_total_sqm("Area: 159.9 sq m") == 159.9
def test_exactly_at_min_boundary_returns_none(self) -> None:
# MIN_SQM < sqm, so 30 is not strictly greater than 30
assert extract_total_sqm("Area: 30 sq m") is None
def test_exactly_at_max_boundary_returns_none(self) -> None:
# sqm < MAX_SQM, so 160 is not strictly less than 160
assert extract_total_sqm("Area: 160 sq m") is None
def test_format_sq_dot_m(self) -> None:
assert extract_total_sqm("Area: 80 sq. m") == 80.0
def test_format_sqm_no_space(self) -> None:
assert extract_total_sqm("Area: 80sqm") == 80.0
def test_format_sq_m_with_space(self) -> None:
assert extract_total_sqm("Area: 80 sq m") == 80.0
def test_empty_string(self) -> None:
assert extract_total_sqm("") is None
def test_multiple_valid_values_returns_max(self) -> None:
assert extract_total_sqm("Living: 40 sq m, Total: 100 sq m") == 100.0
class TestImproveImgForOcr:
def test_produces_valid_pil_image(self) -> None:
# Create a small test image (50x50 white image)
img = Image.fromarray(np.ones((50, 50, 3), dtype=np.uint8) * 200)
result = improve_img_for_ocr(img)
assert isinstance(result, Image.Image)
# Result should be a grayscale (thresholded) image
assert result.mode == "L"
def test_output_dimensions_scaled(self) -> None:
img = Image.fromarray(np.ones((100, 100, 3), dtype=np.uint8) * 128)
result = improve_img_for_ocr(img)
# After 1.2x resize, 100 -> 120
assert result.size[0] == 120
assert result.size[1] == 120
class TestCalculateOcr:
def test_invalid_path_raises_file_not_found(self) -> None:
with pytest.raises(FileNotFoundError):
calculate_ocr("/nonexistent/path/to/image.png")
def test_returns_sqm_from_first_pass(self, tmp_path) -> None: # type: ignore[no-untyped-def]
# Create a real image file so the path check passes
image_file = tmp_path / "test.png"
Image.fromarray(np.ones((10, 10, 3), dtype=np.uint8)).save(str(image_file))
with patch("pytesseract.image_to_string", return_value="Total: 85 sq m"):
result_sqm, result_text = calculate_ocr(str(image_file))
assert result_sqm == 85.0
assert result_text == "Total: 85 sq m"
def test_falls_back_to_improved_image(self, tmp_path) -> None: # type: ignore[no-untyped-def]
image_file = tmp_path / "test.png"
Image.fromarray(np.ones((10, 10, 3), dtype=np.uint8)).save(str(image_file))
# First call returns no sqm data, second (on improved image) returns valid data
with patch("pytesseract.image_to_string", side_effect=[
"No area info here",
"Total: 72 sq m",
]):
result_sqm, result_text = calculate_ocr(str(image_file))
assert result_sqm == 72.0
assert result_text == "Total: 72 sq m"

View file

@ -0,0 +1,110 @@
"""Unit tests for services/floorplan_detector.py."""
import asyncio
from datetime import datetime
from unittest.mock import AsyncMock, patch, MagicMock
from models.listing import RentListing, ListingSite, FurnishType
from services.floorplan_detector import _calculate_sqm_ocr, detect_floorplan
def _make_listing(**kwargs) -> RentListing: # type: ignore[no-untyped-def]
defaults = dict(
id=1,
price=2000.0,
number_of_bedrooms=2,
square_meters=None,
agency="Test",
council_tax_band="C",
longitude=0.0,
latitude=0.0,
price_history_json="[]",
listing_site=ListingSite.RIGHTMOVE,
last_seen=datetime.now(),
photo_thumbnail=None,
floorplan_image_paths=[],
additional_info={"property": {"visible": True}},
routing_info_json=None,
furnish_type=FurnishType.FURNISHED,
available_from=None,
)
defaults.update(kwargs)
return RentListing(**defaults)
class TestCalculateSqmOcr:
async def test_skips_listing_with_existing_square_meters(self) -> None:
listing = _make_listing(square_meters=50.0)
semaphore = asyncio.Semaphore(1)
result = await _calculate_sqm_ocr(listing, semaphore)
assert result is None
async def test_empty_floorplan_paths_returns_listing_with_zero(self) -> None:
listing = _make_listing(floorplan_image_paths=[])
semaphore = asyncio.Semaphore(1)
result = await _calculate_sqm_ocr(listing, semaphore)
assert result is not None
assert result.square_meters == 0
@patch("services.floorplan_detector.floorplan")
async def test_with_mocked_ocr_returning_value(self, mock_floorplan: MagicMock) -> None:
mock_floorplan.calculate_ocr.return_value = (85.0, "Total: 85 sq m")
listing = _make_listing(floorplan_image_paths=["/fake/path.png"])
semaphore = asyncio.Semaphore(1)
result = await _calculate_sqm_ocr(listing, semaphore)
assert result is not None
assert result.square_meters == 85.0
@patch("services.floorplan_detector.floorplan")
async def test_with_mocked_ocr_returning_none(self, mock_floorplan: MagicMock) -> None:
mock_floorplan.calculate_ocr.return_value = (None, "no data")
listing = _make_listing(floorplan_image_paths=["/fake/path.png"])
semaphore = asyncio.Semaphore(1)
result = await _calculate_sqm_ocr(listing, semaphore)
assert result is not None
assert result.square_meters == 0
@patch("services.floorplan_detector.floorplan")
async def test_picks_max_from_multiple_floorplans(self, mock_floorplan: MagicMock) -> None:
mock_floorplan.calculate_ocr.side_effect = [
(50.0, "50 sq m"),
(90.0, "90 sq m"),
]
listing = _make_listing(floorplan_image_paths=["/fake/a.png", "/fake/b.png"])
semaphore = asyncio.Semaphore(2)
result = await _calculate_sqm_ocr(listing, semaphore)
assert result is not None
assert result.square_meters == 90.0
class TestDetectFloorplan:
@patch("services.floorplan_detector.floorplan")
async def test_detect_floorplan_with_mocked_repository(self, mock_floorplan: MagicMock) -> None:
mock_floorplan.calculate_ocr.return_value = (75.0, "75 sq m")
listing = _make_listing(
floorplan_image_paths=["/fake/path.png"],
)
repository = MagicMock()
repository.get_listings = AsyncMock(return_value=[listing])
repository.upsert_listings = AsyncMock(return_value=[])
await detect_floorplan(repository)
repository.upsert_listings.assert_called_once()
upserted = repository.upsert_listings.call_args[0][0]
assert len(upserted) == 1
assert upserted[0].square_meters == 75.0
async def test_detect_floorplan_skips_already_processed(self) -> None:
listing = _make_listing(square_meters=50.0)
repository = MagicMock()
repository.get_listings = AsyncMock(return_value=[listing])
repository.upsert_listings = AsyncMock(return_value=[])
await detect_floorplan(repository)
repository.upsert_listings.assert_called_once()
upserted = repository.upsert_listings.call_args[0][0]
assert len(upserted) == 0

View file

@ -0,0 +1,215 @@
"""Unit tests for the image fetcher service."""
from pathlib import Path
from unittest.mock import AsyncMock, MagicMock, patch
from datetime import datetime
import aiohttp
import pytest
from tenacity import stop_after_attempt
from models.listing import RentListing, ListingSite, FurnishType
from services.image_fetcher import dump_images_for_listing, MAX_CONCURRENT_DOWNLOADS
def _make_listing(**kwargs) -> RentListing: # type: ignore[no-untyped-def]
"""Create a RentListing with sensible defaults for testing."""
defaults = dict(
id=12345,
price=2000.0,
number_of_bedrooms=2,
square_meters=None,
agency="Test Agency",
council_tax_band="C",
longitude=0.0,
latitude=0.0,
price_history_json="[]",
listing_site=ListingSite.RIGHTMOVE,
last_seen=datetime.now(),
photo_thumbnail=None,
floorplan_image_paths=[],
additional_info={
"property": {
"visible": True,
"floorplans": [
{"url": "https://media.rightmove.co.uk/imgs/floorplan_1.jpg"}
],
}
},
routing_info_json=None,
furnish_type=FurnishType.FURNISHED,
available_from=None,
)
defaults.update(kwargs)
return RentListing(**defaults)
class TestDumpImagesForListing:
"""Tests for dump_images_for_listing function."""
async def test_downloads_floorplan_image(self, tmp_path: Path) -> None:
"""Test successful floorplan image download."""
listing = _make_listing()
image_bytes = b"\x89PNG fake image data"
mock_response = AsyncMock()
mock_response.status = 200
mock_response.read = AsyncMock(return_value=image_bytes)
mock_session = MagicMock(spec=aiohttp.ClientSession)
mock_cm = AsyncMock()
mock_cm.__aenter__ = AsyncMock(return_value=mock_response)
mock_cm.__aexit__ = AsyncMock(return_value=False)
mock_session.get = MagicMock(return_value=mock_cm)
result = await dump_images_for_listing(
listing, tmp_path, session=mock_session
)
assert result is not None
assert result.id == 12345
assert len(result.floorplan_image_paths) == 1
# Verify the image was written
written_path = Path(result.floorplan_image_paths[0])
assert written_path.exists()
assert written_path.read_bytes() == image_bytes
async def test_skips_existing_images(self, tmp_path: Path) -> None:
"""Test that existing images are not re-downloaded."""
listing = _make_listing()
# Pre-create the floorplan file
floorplan_dir = tmp_path / str(listing.id) / "floorplans"
floorplan_dir.mkdir(parents=True)
existing_file = floorplan_dir / "floorplan_1.jpg"
existing_file.write_bytes(b"existing image")
mock_session = MagicMock(spec=aiohttp.ClientSession)
result = await dump_images_for_listing(
listing, tmp_path, session=mock_session
)
# Should return None because the only floorplan was skipped (continue)
assert result is None
# Session.get should NOT have been called
mock_session.get.assert_not_called()
async def test_returns_none_on_404(self, tmp_path: Path) -> None:
"""Test that 404 responses return None (image not found)."""
listing = _make_listing()
mock_response = AsyncMock()
mock_response.status = 404
mock_session = MagicMock(spec=aiohttp.ClientSession)
mock_cm = AsyncMock()
mock_cm.__aenter__ = AsyncMock(return_value=mock_response)
mock_cm.__aexit__ = AsyncMock(return_value=False)
mock_session.get = MagicMock(return_value=mock_cm)
result = await dump_images_for_listing(
listing, tmp_path, session=mock_session
)
assert result is None
async def test_raises_on_non_200_status(self, tmp_path: Path) -> None:
"""Test that non-200/404 status raises exception."""
listing = _make_listing()
mock_response = AsyncMock()
mock_response.status = 500
mock_session = MagicMock(spec=aiohttp.ClientSession)
mock_cm = AsyncMock()
mock_cm.__aenter__ = AsyncMock(return_value=mock_response)
mock_cm.__aexit__ = AsyncMock(return_value=False)
mock_session.get = MagicMock(return_value=mock_cm)
with pytest.raises(Exception, match="HTTP 500"):
# Disable tenacity retry for testing: stop after 1 attempt and reraise
await dump_images_for_listing.retry_with(
stop=stop_after_attempt(1),
reraise=True,
)(listing, tmp_path, session=mock_session)
async def test_returns_none_when_no_floorplans(self, tmp_path: Path) -> None:
"""Test listing with no floorplans returns None."""
listing = _make_listing(
additional_info={"property": {"visible": True, "floorplans": []}}
)
mock_session = MagicMock(spec=aiohttp.ClientSession)
result = await dump_images_for_listing(
listing, tmp_path, session=mock_session
)
assert result is None
async def test_url_filename_extraction(self, tmp_path: Path) -> None:
"""Test that filenames are correctly extracted from URLs."""
listing = _make_listing(
additional_info={
"property": {
"visible": True,
"floorplans": [
{
"url": "https://media.rightmove.co.uk/dir/sub/my_floorplan.png"
}
],
}
}
)
image_bytes = b"fake png"
mock_response = AsyncMock()
mock_response.status = 200
mock_response.read = AsyncMock(return_value=image_bytes)
mock_session = MagicMock(spec=aiohttp.ClientSession)
mock_cm = AsyncMock()
mock_cm.__aenter__ = AsyncMock(return_value=mock_response)
mock_cm.__aexit__ = AsyncMock(return_value=False)
mock_session.get = MagicMock(return_value=mock_cm)
result = await dump_images_for_listing(
listing, tmp_path, session=mock_session
)
assert result is not None
written_path = Path(result.floorplan_image_paths[0])
assert written_path.name == "my_floorplan.png"
async def test_creates_session_when_none_provided(self, tmp_path: Path) -> None:
"""Test that a session is created and closed when none is provided."""
listing = _make_listing()
image_bytes = b"fake image"
mock_response = AsyncMock()
mock_response.status = 200
mock_response.read = AsyncMock(return_value=image_bytes)
mock_session_instance = MagicMock(spec=aiohttp.ClientSession)
mock_cm = AsyncMock()
mock_cm.__aenter__ = AsyncMock(return_value=mock_response)
mock_cm.__aexit__ = AsyncMock(return_value=False)
mock_session_instance.get = MagicMock(return_value=mock_cm)
mock_session_instance.close = AsyncMock()
with patch(
"services.image_fetcher.aiohttp.ClientSession",
return_value=mock_session_instance,
):
result = await dump_images_for_listing(listing, tmp_path, session=None)
assert result is not None
mock_session_instance.close.assert_awaited_once()
class TestImageFetcherConfig:
"""Tests for image fetcher configuration."""
def test_max_concurrent_downloads_constant(self) -> None:
"""Test that MAX_CONCURRENT_DOWNLOADS is defined and reasonable."""
assert MAX_CONCURRENT_DOWNLOADS > 0
assert MAX_CONCURRENT_DOWNLOADS <= 20

View file

@ -0,0 +1,225 @@
"""Unit tests for services/listing_cache.py."""
import json
from unittest import mock
import pytest
import redis
from models.listing import ListingType, QueryParameters
from services.listing_cache import (
CACHE_PREFIX,
_get_redis_client,
cache_features_batch,
get_cached_count,
get_cached_features,
invalidate_cache,
make_cache_key,
)
def _make_query(**overrides) -> QueryParameters:
"""Create a QueryParameters with defaults for testing."""
defaults = {"listing_type": ListingType.RENT, "min_price": 1000, "max_price": 3000}
defaults.update(overrides)
return QueryParameters(**defaults)
class TestMakeCacheKey:
"""Tests for make_cache_key()."""
def test_deterministic_for_same_params(self):
"""Same parameters produce the same cache key."""
qp = _make_query()
assert make_cache_key(qp) == make_cache_key(qp)
def test_different_for_different_params(self):
"""Different parameters produce different cache keys."""
qp1 = _make_query(min_price=1000)
qp2 = _make_query(min_price=2000)
assert make_cache_key(qp1) != make_cache_key(qp2)
def test_key_starts_with_prefix(self):
"""Cache key starts with CACHE_PREFIX."""
qp = _make_query()
assert make_cache_key(qp).startswith(CACHE_PREFIX)
class TestGetRedisClient:
"""Tests for _get_redis_client() URL parsing."""
@mock.patch("services.listing_cache.redis")
def test_default_broker_url(self, mock_redis):
"""Uses default localhost URL when env var is not set."""
with mock.patch.dict("os.environ", {}, clear=True):
_get_redis_client()
mock_redis.from_url.assert_called_once_with(
"redis://localhost:6379/2", decode_responses=True
)
@mock.patch("services.listing_cache.redis")
def test_custom_broker_url(self, mock_redis):
"""Replaces db number from custom broker URL."""
with mock.patch.dict(
"os.environ", {"CELERY_BROKER_URL": "redis://myhost:1234/5"}
):
_get_redis_client()
mock_redis.from_url.assert_called_once_with(
"redis://myhost:1234/2", decode_responses=True
)
@mock.patch("services.listing_cache.redis")
def test_broker_url_with_password(self, mock_redis):
"""Preserves auth info in broker URL."""
with mock.patch.dict(
"os.environ",
{"CELERY_BROKER_URL": "redis://:secret@myhost:6379/0"},
):
_get_redis_client()
mock_redis.from_url.assert_called_once_with(
"redis://:secret@myhost:6379/2", decode_responses=True
)
@mock.patch("services.listing_cache.redis")
def test_broker_url_with_query_params(self, mock_redis):
"""Preserves query parameters in broker URL."""
with mock.patch.dict(
"os.environ",
{"CELERY_BROKER_URL": "redis://myhost:6379/0?timeout=5"},
):
_get_redis_client()
mock_redis.from_url.assert_called_once_with(
"redis://myhost:6379/2?timeout=5", decode_responses=True
)
class TestGetCachedCount:
"""Tests for get_cached_count()."""
@mock.patch("services.listing_cache._get_redis_client")
def test_returns_none_on_redis_error(self, mock_get_client):
"""Returns None when Redis raises an error."""
mock_get_client.side_effect = redis.RedisError("connection refused")
result = get_cached_count(_make_query())
assert result is None
@mock.patch("services.listing_cache._get_redis_client")
def test_returns_none_when_key_not_exists(self, mock_get_client):
"""Returns None when the cache key does not exist."""
mock_client = mock.MagicMock()
mock_client.exists.return_value = False
mock_get_client.return_value = mock_client
result = get_cached_count(_make_query())
assert result is None
@mock.patch("services.listing_cache._get_redis_client")
def test_returns_count_when_key_exists(self, mock_get_client):
"""Returns list length when key exists."""
mock_client = mock.MagicMock()
mock_client.exists.return_value = True
mock_client.llen.return_value = 42
mock_get_client.return_value = mock_client
result = get_cached_count(_make_query())
assert result == 42
class TestGetCachedFeatures:
"""Tests for get_cached_features()."""
@mock.patch("services.listing_cache._get_redis_client")
def test_yields_empty_on_redis_error(self, mock_get_client):
"""Yields nothing when Redis raises an error."""
mock_get_client.side_effect = redis.RedisError("connection refused")
batches = list(get_cached_features(_make_query()))
assert batches == []
@mock.patch("services.listing_cache._get_redis_client")
def test_yields_batches(self, mock_get_client):
"""Yields features in batches."""
features = [{"type": "Feature", "id": i} for i in range(3)]
mock_client = mock.MagicMock()
mock_client.llen.return_value = 3
mock_client.lrange.return_value = [json.dumps(f) for f in features]
mock_get_client.return_value = mock_client
batches = list(get_cached_features(_make_query(), batch_size=50))
assert len(batches) == 1
assert batches[0] == features
class TestCacheFeaturesBatch:
"""Tests for cache_features_batch()."""
@mock.patch("services.listing_cache._get_redis_client")
def test_empty_features_returns_early(self, mock_get_client):
"""Does not call Redis when features list is empty."""
cache_features_batch(_make_query(), [])
mock_get_client.assert_not_called()
@mock.patch("services.listing_cache._get_redis_client")
def test_writes_features_via_pipeline(self, mock_get_client):
"""Writes features and sets TTL through pipeline."""
mock_client = mock.MagicMock()
mock_pipeline = mock.MagicMock()
mock_client.pipeline.return_value = mock_pipeline
mock_get_client.return_value = mock_client
features = [{"type": "Feature", "id": 1}]
cache_features_batch(_make_query(), features)
mock_pipeline.rpush.assert_called_once()
mock_pipeline.expire.assert_called_once()
mock_pipeline.execute.assert_called_once()
@mock.patch("services.listing_cache._get_redis_client")
def test_handles_redis_error(self, mock_get_client):
"""Handles Redis error gracefully during write."""
mock_get_client.side_effect = redis.RedisError("write error")
# Should not raise
cache_features_batch(_make_query(), [{"id": 1}])
class TestInvalidateCache:
"""Tests for invalidate_cache()."""
@mock.patch("services.listing_cache._get_redis_client")
def test_handles_redis_error(self, mock_get_client):
"""Handles Redis error gracefully during invalidation."""
mock_get_client.side_effect = redis.RedisError("connection refused")
# Should not raise
invalidate_cache()
@mock.patch("services.listing_cache._get_redis_client")
def test_deletes_matching_keys_via_pipeline(self, mock_get_client):
"""Deletes keys matching the cache prefix using pipeline."""
mock_client = mock.MagicMock()
mock_pipeline = mock.MagicMock()
mock_client.pipeline.return_value = mock_pipeline
# Simulate one scan iteration that returns keys, then done
mock_client.scan.return_value = (0, ["listings:geojson:abc", "listings:geojson:def"])
mock_get_client.return_value = mock_client
invalidate_cache()
assert mock_pipeline.delete.call_count == 2
mock_pipeline.execute.assert_called_once()
@mock.patch("services.listing_cache._get_redis_client")
def test_no_keys_to_delete(self, mock_get_client):
"""Does nothing when no cache keys exist."""
mock_client = mock.MagicMock()
mock_client.scan.return_value = (0, [])
mock_get_client.return_value = mock_client
invalidate_cache()
mock_client.pipeline.assert_not_called()

View file

@ -0,0 +1,372 @@
"""Unit tests for the listing fetcher service."""
import asyncio
from unittest.mock import AsyncMock, MagicMock, patch
import pytest
from models.listing import ListingType, QueryParameters
from rec.exceptions import CircuitBreakerOpenError, ThrottlingError
from services.listing_fetcher import (
NUM_WORKERS,
_fetch_subquery,
dump_listings,
dump_listings_full,
)
from services.query_splitter import SubQuery
def _make_subquery(**kwargs) -> SubQuery:
"""Create a SubQuery with sensible defaults for testing."""
defaults = dict(
district="REGION^123",
min_bedrooms=1,
max_bedrooms=3,
min_price=1000,
max_price=3000,
estimated_results=50,
)
defaults.update(kwargs)
return SubQuery(**defaults)
class TestDumpListingsFull:
"""Tests for dump_listings_full."""
async def test_returns_empty_list_when_no_new_listings(self) -> None:
"""Test that empty results from dump_listings returns empty list."""
with patch(
"services.listing_fetcher.dump_listings",
new_callable=AsyncMock,
return_value=[],
):
mock_repo = AsyncMock()
mock_repo.get_listings = AsyncMock(return_value=[])
params = QueryParameters(listing_type=ListingType.RENT)
result = await dump_listings_full(params, mock_repo)
assert result == []
async def test_returns_only_new_listings_from_db(self) -> None:
"""Test that dump_listings_full fetches new listings by ID from the repository."""
mock_listing_1 = MagicMock()
mock_listing_1.id = 100
mock_listing_2 = MagicMock()
mock_listing_2.id = 200
with patch(
"services.listing_fetcher.dump_listings",
new_callable=AsyncMock,
return_value=[mock_listing_1, mock_listing_2],
):
mock_repo = AsyncMock()
mock_repo.get_listings = AsyncMock(
return_value=[mock_listing_1, mock_listing_2]
)
params = QueryParameters(listing_type=ListingType.RENT)
result = await dump_listings_full(params, mock_repo)
# Verify get_listings was called with the correct IDs
mock_repo.get_listings.assert_awaited_once_with(
only_ids=[100, 200]
)
assert len(result) == 2
class TestFetchSubquery:
"""Tests for _fetch_subquery helper."""
async def test_skips_subquery_with_zero_estimated_results(self) -> None:
"""Test that subqueries with 0 estimated results are skipped."""
sq = _make_subquery(estimated_results=0)
params = QueryParameters(listing_type=ListingType.RENT)
queue: asyncio.Queue[int | None] = asyncio.Queue()
ids_found = await _fetch_subquery(
sq=sq,
parameters=params,
session=MagicMock(),
config=MagicMock(),
semaphore=asyncio.Semaphore(5),
existing_ids=set(),
queue=queue,
)
assert ids_found == 0
assert queue.empty()
async def test_skips_subquery_with_none_estimated_results(self) -> None:
"""Test that subqueries with None estimated results are skipped."""
sq = _make_subquery(estimated_results=None)
params = QueryParameters(listing_type=ListingType.RENT)
queue: asyncio.Queue[int | None] = asyncio.Queue()
ids_found = await _fetch_subquery(
sq=sq,
parameters=params,
session=MagicMock(),
config=MagicMock(),
semaphore=asyncio.Semaphore(5),
existing_ids=set(),
queue=queue,
)
assert ids_found == 0
assert queue.empty()
async def test_enqueues_new_ids_only(self) -> None:
"""Test that only new (not existing) IDs are enqueued."""
sq = _make_subquery(estimated_results=10)
params = QueryParameters(listing_type=ListingType.RENT, page_size=25)
queue: asyncio.Queue[int | None] = asyncio.Queue()
existing_ids: set[int] = {101, 103}
mock_config = MagicMock()
mock_config.max_pages_per_query = 60
mock_config.request_delay_ms = 0
mock_config.max_concurrent_requests = 5
api_result = {
"properties": [
{"identifier": 101}, # existing
{"identifier": 102}, # new
{"identifier": 103}, # existing
{"identifier": 104}, # new
]
}
with patch(
"services.listing_fetcher.listing_query",
new_callable=AsyncMock,
return_value=api_result,
):
ids_found = await _fetch_subquery(
sq=sq,
parameters=params,
session=MagicMock(),
config=mock_config,
semaphore=asyncio.Semaphore(5),
existing_ids=existing_ids,
queue=queue,
)
assert ids_found == 2
# Verify that queued IDs are the new ones
queued = []
while not queue.empty():
queued.append(queue.get_nowait())
assert 102 in queued
assert 104 in queued
assert 101 not in queued
assert 103 not in queued
async def test_stops_on_circuit_breaker_error(self) -> None:
"""Test that CircuitBreakerOpenError breaks the page loop."""
sq = _make_subquery(estimated_results=100)
params = QueryParameters(listing_type=ListingType.RENT, page_size=25)
queue: asyncio.Queue[int | None] = asyncio.Queue()
mock_config = MagicMock()
mock_config.max_pages_per_query = 60
mock_config.request_delay_ms = 0
with patch(
"services.listing_fetcher.listing_query",
new_callable=AsyncMock,
side_effect=CircuitBreakerOpenError("open"),
):
ids_found = await _fetch_subquery(
sq=sq,
parameters=params,
session=MagicMock(),
config=mock_config,
semaphore=asyncio.Semaphore(5),
existing_ids=set(),
queue=queue,
)
assert ids_found == 0
assert queue.empty()
async def test_stops_on_throttling_error(self) -> None:
"""Test that ThrottlingError breaks the page loop."""
sq = _make_subquery(estimated_results=100)
params = QueryParameters(listing_type=ListingType.RENT, page_size=25)
queue: asyncio.Queue[int | None] = asyncio.Queue()
mock_config = MagicMock()
mock_config.max_pages_per_query = 60
mock_config.request_delay_ms = 0
with patch(
"services.listing_fetcher.listing_query",
new_callable=AsyncMock,
side_effect=ThrottlingError("throttled"),
):
ids_found = await _fetch_subquery(
sq=sq,
parameters=params,
session=MagicMock(),
config=mock_config,
semaphore=asyncio.Semaphore(5),
existing_ids=set(),
queue=queue,
)
assert ids_found == 0
assert queue.empty()
async def test_stops_on_generic_error(self) -> None:
"""Test that GENERIC_ERROR (past last page) stops pagination."""
sq = _make_subquery(estimated_results=100)
params = QueryParameters(listing_type=ListingType.RENT, page_size=25)
queue: asyncio.Queue[int | None] = asyncio.Queue()
mock_config = MagicMock()
mock_config.max_pages_per_query = 60
mock_config.request_delay_ms = 0
with patch(
"services.listing_fetcher.listing_query",
new_callable=AsyncMock,
side_effect=Exception("GENERIC_ERROR: no more results"),
):
ids_found = await _fetch_subquery(
sq=sq,
parameters=params,
session=MagicMock(),
config=mock_config,
semaphore=asyncio.Semaphore(5),
existing_ids=set(),
queue=queue,
)
assert ids_found == 0
assert queue.empty()
async def test_stops_on_unexpected_error(self) -> None:
"""Test that unexpected errors also stop pagination."""
sq = _make_subquery(estimated_results=100)
params = QueryParameters(listing_type=ListingType.RENT, page_size=25)
queue: asyncio.Queue[int | None] = asyncio.Queue()
mock_config = MagicMock()
mock_config.max_pages_per_query = 60
mock_config.request_delay_ms = 0
with patch(
"services.listing_fetcher.listing_query",
new_callable=AsyncMock,
side_effect=Exception("some network error"),
):
ids_found = await _fetch_subquery(
sq=sq,
parameters=params,
session=MagicMock(),
config=mock_config,
semaphore=asyncio.Semaphore(5),
existing_ids=set(),
queue=queue,
)
assert ids_found == 0
assert queue.empty()
async def test_stops_when_fewer_results_than_page_size(self) -> None:
"""Test that pagination stops when a page has fewer results than page_size."""
sq = _make_subquery(estimated_results=100)
params = QueryParameters(listing_type=ListingType.RENT, page_size=25)
queue: asyncio.Queue[int | None] = asyncio.Queue()
mock_config = MagicMock()
mock_config.max_pages_per_query = 60
mock_config.request_delay_ms = 0
# Return fewer properties than page_size
api_result = {
"properties": [
{"identifier": 1},
{"identifier": 2},
]
}
with patch(
"services.listing_fetcher.listing_query",
new_callable=AsyncMock,
return_value=api_result,
) as mock_query:
ids_found = await _fetch_subquery(
sq=sq,
parameters=params,
session=MagicMock(),
config=mock_config,
semaphore=asyncio.Semaphore(5),
existing_ids=set(),
queue=queue,
)
# Should have called listing_query exactly once (then stopped)
assert mock_query.await_count == 1
assert ids_found == 2
class TestDumpListings:
"""Tests for dump_listings."""
async def test_circuit_breaker_returns_empty_list(self) -> None:
"""Test that CircuitBreakerOpenError returns empty list."""
mock_repo = AsyncMock()
params = QueryParameters(listing_type=ListingType.RENT)
with patch("services.listing_fetcher.create_session") as mock_cs:
mock_cs.side_effect = CircuitBreakerOpenError("open")
result = await dump_listings(params, mock_repo)
assert result == []
async def test_returns_processed_listings(self) -> None:
"""Test that dump_listings returns processed listings from the pipeline."""
mock_repo = AsyncMock()
mock_repo.get_listing_ids = MagicMock(return_value=set())
params = QueryParameters(listing_type=ListingType.RENT)
mock_listing = MagicMock()
mock_listing.id = 42
mock_session_cm = AsyncMock()
mock_session = MagicMock()
mock_session_cm.__aenter__ = AsyncMock(return_value=mock_session)
mock_session_cm.__aexit__ = AsyncMock(return_value=False)
with (
patch(
"services.listing_fetcher.create_session",
return_value=mock_session_cm,
),
patch(
"services.listing_fetcher.QuerySplitter"
) as mock_splitter_cls,
patch(
"services.listing_fetcher._fetch_subquery",
new_callable=AsyncMock,
return_value=0,
),
):
mock_splitter = mock_splitter_cls.return_value
mock_splitter.split = AsyncMock(return_value=[])
mock_splitter.calculate_total_estimated_results = MagicMock(
return_value=0
)
result = await dump_listings(params, mock_repo)
# With no subqueries, no listings are processed
assert result == []
class TestNumWorkers:
"""Tests for NUM_WORKERS constant."""
def test_num_workers_is_positive(self) -> None:
"""Test that NUM_WORKERS is a positive integer."""
assert NUM_WORKERS > 0
def test_num_workers_value(self) -> None:
"""Test that NUM_WORKERS has the expected value."""
assert NUM_WORKERS == 20

View file

@ -0,0 +1,87 @@
"""Unit tests for the listing processor."""
from datetime import datetime
from unittest.mock import AsyncMock, MagicMock, patch
import pytest
from models.listing import FurnishType
from listing_processor import (
_parse_furnish_type,
_parse_available_from,
ListingProcessor,
FetchListingDetailsStep,
MAX_OCR_WORKERS,
)
class TestParseFurnishType:
"""Tests for _parse_furnish_type helper."""
def test_none_returns_unknown(self):
assert _parse_furnish_type(None) == FurnishType.UNKNOWN
def test_ask_landlord_variant(self):
assert _parse_furnish_type("Ask landlord") == FurnishType.ASK_LANDLORD
def test_furnished_lowercased(self):
assert _parse_furnish_type("Furnished") == FurnishType.FURNISHED
def test_unfurnished(self):
assert _parse_furnish_type("Unfurnished") == FurnishType.UNFURNISHED
def test_part_furnished(self):
assert _parse_furnish_type("Part Furnished") == FurnishType.PART_FURNISHED
def test_unknown_string_returns_unknown(self):
assert _parse_furnish_type("unknown") == FurnishType.UNKNOWN
def test_garbage_string_returns_unknown(self):
assert _parse_furnish_type("xyzzy") == FurnishType.UNKNOWN
class TestParseAvailableFrom:
"""Tests for _parse_available_from helper."""
def test_none_returns_none(self):
assert _parse_available_from(None) is None
def test_now_returns_datetime(self):
result = _parse_available_from("Now")
assert isinstance(result, datetime)
def test_valid_date_string(self):
result = _parse_available_from("15/03/2024")
assert result is not None
assert result.day == 15
assert result.month == 3
def test_invalid_date_returns_none(self):
assert _parse_available_from("invalid") is None
class TestListingProcessor:
"""Tests for ListingProcessor."""
async def test_process_listing_marks_seen(self):
"""Test that process_listing calls mark_seen."""
mock_repo = AsyncMock()
mock_repo.get_listings = AsyncMock(return_value=[MagicMock()])
processor = ListingProcessor(mock_repo)
# Mock all steps to not need processing
for step in processor.process_steps:
step.needs_processing = AsyncMock(return_value=False)
await processor.process_listing(123)
mock_repo.mark_seen.assert_awaited_once_with(123)
async def test_process_listing_returns_none_on_step_failure(self):
"""Test that a step failure returns None."""
mock_repo = AsyncMock()
processor = ListingProcessor(mock_repo)
for step in processor.process_steps:
step.needs_processing = AsyncMock(return_value=True)
step.process = AsyncMock(side_effect=Exception("fail"))
result = await processor.process_listing(123)
assert result is None
class TestOcrWorkersConfig:
def test_max_ocr_workers_positive(self):
assert MAX_OCR_WORKERS >= 1

View file

@ -0,0 +1,295 @@
"""Unit tests for tasks/listing_tasks.py."""
import json
import os
from collections import deque
from unittest.mock import MagicMock, patch, AsyncMock, call
import pytest
import tasks.listing_tasks as module
from tasks.listing_tasks import (
_update_task_state,
_PipelineState,
TaskLogHandler,
SCRAPE_LOCK_NAME,
LOG_BUFFER_MAX_LINES,
NUM_WORKERS,
PHASE_SPLITTING,
PHASE_FETCHING,
PHASE_PROCESSING,
PHASE_COMPLETED,
)
class TestUpdateTaskState:
"""Tests for _update_task_state."""
def test_injects_logs_from_active_buffer(self):
task = MagicMock()
original = module._active_log_buffer
try:
module._active_log_buffer = deque(["log line 1", "log line 2"])
_update_task_state(task, "test_state", {"key": "value"})
task.update_state.assert_called_once()
call_meta = task.update_state.call_args[1]["meta"]
assert call_meta["logs"] == ["log line 1", "log line 2"]
assert call_meta["key"] == "value"
finally:
module._active_log_buffer = original
def test_works_when_buffer_is_none(self):
task = MagicMock()
original = module._active_log_buffer
try:
module._active_log_buffer = None
_update_task_state(task, "some_state", {"phase": "testing"})
task.update_state.assert_called_once_with(
state="some_state", meta={"phase": "testing"}
)
# No "logs" key should be injected
call_meta = task.update_state.call_args[1]["meta"]
assert "logs" not in call_meta
finally:
module._active_log_buffer = original
def test_state_string_is_passed_through(self):
task = MagicMock()
original = module._active_log_buffer
try:
module._active_log_buffer = None
_update_task_state(task, "PROGRESS", {})
task.update_state.assert_called_once_with(state="PROGRESS", meta={})
finally:
module._active_log_buffer = original
def test_empty_buffer_injects_empty_list(self):
task = MagicMock()
original = module._active_log_buffer
try:
module._active_log_buffer = deque()
_update_task_state(task, "state", {"a": 1})
call_meta = task.update_state.call_args[1]["meta"]
assert call_meta["logs"] == []
finally:
module._active_log_buffer = original
class TestTaskLogHandler:
"""Tests for the TaskLogHandler."""
def test_emit_appends_to_buffer(self):
buf = deque(maxlen=10)
handler = TaskLogHandler(buf)
handler.setFormatter(
__import__("logging").Formatter("%(message)s")
)
record = __import__("logging").LogRecord(
name="test", level=20, pathname="", lineno=0,
msg="hello", args=(), exc_info=None,
)
handler.emit(record)
assert "hello" in buf
def test_buffer_respects_maxlen(self):
buf = deque(maxlen=2)
handler = TaskLogHandler(buf)
handler.setFormatter(
__import__("logging").Formatter("%(message)s")
)
for i in range(5):
record = __import__("logging").LogRecord(
name="test", level=20, pathname="", lineno=0,
msg=f"msg{i}", args=(), exc_info=None,
)
handler.emit(record)
assert len(buf) == 2
assert list(buf) == ["msg3", "msg4"]
class TestDumpListingsTask:
"""Tests for dump_listings_task Celery task."""
@patch("tasks.listing_tasks.redis_lock")
def test_skips_when_lock_not_acquired(self, mock_redis_lock):
"""Task should skip when another scrape is running."""
mock_cm = MagicMock()
mock_cm.__enter__ = MagicMock(return_value=False)
mock_cm.__exit__ = MagicMock(return_value=False)
mock_redis_lock.return_value = mock_cm
from tasks.listing_tasks import dump_listings_task
# Use run() which handles bind=True properly
task_instance = dump_listings_task
task_instance.update_state = MagicMock()
result = dump_listings_task.run('{"listing_type": "RENT"}')
assert result["status"] == "skipped"
assert result["reason"] == "another_job_running"
mock_redis_lock.assert_called_once_with(SCRAPE_LOCK_NAME)
@patch("tasks.listing_tasks.asyncio.run")
@patch("tasks.listing_tasks.redis_lock")
def test_calls_dump_listings_full_when_lock_acquired(
self, mock_redis_lock, mock_asyncio_run
):
"""Task should call dump_listings_full when lock is acquired."""
mock_cm = MagicMock()
mock_cm.__enter__ = MagicMock(return_value=True)
mock_cm.__exit__ = MagicMock(return_value=False)
mock_redis_lock.return_value = mock_cm
mock_asyncio_run.return_value = []
from tasks.listing_tasks import dump_listings_task
task_instance = dump_listings_task
task_instance.update_state = MagicMock()
params_json = '{"listing_type": "RENT", "min_price": 1000, "max_price": 5000}'
result = dump_listings_task.run(params_json)
assert result["phase"] == "completed"
assert result["progress"] == 1
mock_asyncio_run.assert_called_once()
mock_redis_lock.assert_called_once_with(SCRAPE_LOCK_NAME)
class TestSetupPeriodicTasks:
"""Tests for setup_periodic_tasks."""
@patch("tasks.listing_tasks.SchedulesConfig.from_env")
def test_registers_enabled_schedules(self, mock_from_env):
from config.schedule_config import ScheduleConfig
from models.listing import ListingType
schedule = ScheduleConfig(
name="Test Schedule",
listing_type=ListingType.RENT,
hour="3",
minute="30",
)
mock_config = MagicMock()
mock_config.get_enabled_schedules.return_value = [schedule]
mock_from_env.return_value = mock_config
sender = MagicMock()
module.setup_periodic_tasks(sender)
sender.add_periodic_task.assert_called_once()
call_args = sender.add_periodic_task.call_args
assert call_args[1]["name"] == "Test Schedule"
@patch("tasks.listing_tasks.SchedulesConfig.from_env")
def test_handles_config_error_gracefully(self, mock_from_env):
mock_from_env.side_effect = ValueError("bad config")
sender = MagicMock()
module.setup_periodic_tasks(sender)
sender.add_periodic_task.assert_not_called()
@patch("tasks.listing_tasks.SchedulesConfig.from_env")
def test_registers_nothing_when_no_schedules(self, mock_from_env):
mock_config = MagicMock()
mock_config.get_enabled_schedules.return_value = []
mock_from_env.return_value = mock_config
sender = MagicMock()
module.setup_periodic_tasks(sender)
sender.add_periodic_task.assert_not_called()
@patch("tasks.listing_tasks.SchedulesConfig.from_env")
def test_registers_multiple_schedules(self, mock_from_env):
from config.schedule_config import ScheduleConfig
from models.listing import ListingType
schedules = [
ScheduleConfig(name="Rent", listing_type=ListingType.RENT, hour="2"),
ScheduleConfig(name="Buy", listing_type=ListingType.BUY, hour="4"),
]
mock_config = MagicMock()
mock_config.get_enabled_schedules.return_value = schedules
mock_from_env.return_value = mock_config
sender = MagicMock()
module.setup_periodic_tasks(sender)
assert sender.add_periodic_task.call_count == 2
class TestPipelineState:
"""Tests for _PipelineState dataclass."""
def test_default_initialization(self):
state = _PipelineState()
assert state.ids_collected == 0
assert state.completed_subqueries == 0
assert state.total_pages_fetched == 0
assert state.fetching_done is False
assert state.processed_count == 0
assert state.failed_count == 0
assert state.details_fetched == 0
assert state.images_downloaded == 0
assert state.ocr_completed == 0
assert state.processed_listings == []
def test_incrementing_counters(self):
state = _PipelineState()
state.ids_collected += 5
state.completed_subqueries += 3
state.total_pages_fetched += 10
state.processed_count += 4
state.failed_count += 1
state.details_fetched += 4
state.images_downloaded += 3
state.ocr_completed += 2
assert state.ids_collected == 5
assert state.completed_subqueries == 3
assert state.total_pages_fetched == 10
assert state.processed_count == 4
assert state.failed_count == 1
assert state.details_fetched == 4
assert state.images_downloaded == 3
assert state.ocr_completed == 2
def test_appending_to_processed_listings(self):
state = _PipelineState()
state.processed_listings.append("listing_a")
state.processed_listings.append("listing_b")
assert len(state.processed_listings) == 2
assert state.processed_listings == ["listing_a", "listing_b"]
def test_separate_instances_have_independent_lists(self):
state_a = _PipelineState()
state_b = _PipelineState()
state_a.processed_listings.append("only_a")
assert state_b.processed_listings == []
def test_fetching_done_toggle(self):
state = _PipelineState()
assert state.fetching_done is False
state.fetching_done = True
assert state.fetching_done is True
class TestPhaseConstants:
"""Tests for phase constant values."""
def test_phase_splitting(self):
assert PHASE_SPLITTING == "splitting"
def test_phase_fetching(self):
assert PHASE_FETCHING == "fetching"
def test_phase_processing(self):
assert PHASE_PROCESSING == "processing"
def test_phase_completed(self):
assert PHASE_COMPLETED == "completed"
def test_num_workers(self):
assert NUM_WORKERS == 20

View file

@ -1,16 +1,24 @@
"""Unit tests for Listing models."""
import dataclasses
from datetime import datetime
import json
import pytest
from pydantic import ValidationError
from models.listing import (
BuyListing,
DestinationMode,
FurnishType,
ListingSite,
ListingType,
PriceHistoryItem,
QueryParameters,
RentListing,
Listing,
Route,
RouteLegStep,
)
from rec.routing import TravelMode
class TestListing:
@ -341,3 +349,190 @@ class TestBuyListing:
lease_left=120,
)
assert listing.lease_left == 120
def _make_listing_with_routing(routing_info_json: str | None) -> RentListing:
"""Helper to create a RentListing with given routing_info_json."""
return RentListing(
id=1,
price=2000.0,
number_of_bedrooms=2,
square_meters=50.0,
agency="Test",
council_tax_band="C",
longitude=0.0,
latitude=0.0,
price_history_json="[]",
listing_site=ListingSite.RIGHTMOVE,
last_seen=datetime.now(),
photo_thumbnail=None,
floorplan_image_paths=[],
additional_info={"property": {"visible": True}},
routing_info_json=routing_info_json,
furnish_type=FurnishType.FURNISHED,
available_from=None,
)
def _make_sample_routing_info() -> dict[DestinationMode, list[Route]]:
"""Helper to create sample routing info for tests."""
destination_mode = DestinationMode(
destination_address="London Bridge",
travel_mode=TravelMode.TRANSIT,
)
routes = [
Route(
legs=[
RouteLegStep(
distance_meters=500,
duration_s=120,
travel_mode=TravelMode.WALK,
),
RouteLegStep(
distance_meters=4000,
duration_s=480,
travel_mode=TravelMode.TRANSIT,
),
],
distance_meters=4500,
duration_s=600,
)
]
return {destination_mode: routes}
class TestQueryParametersValidation:
"""Tests for QueryParameters validation."""
def test_valid_parameters(self) -> None:
"""Basic valid QueryParameters creation."""
params = QueryParameters(
listing_type=ListingType.RENT,
min_price=1000,
max_price=3000,
min_bedrooms=1,
max_bedrooms=3,
)
assert params.min_price == 1000
assert params.max_price == 3000
assert params.min_bedrooms == 1
assert params.max_bedrooms == 3
def test_invalid_price_range_raises(self) -> None:
"""min_price > max_price should raise ValidationError."""
with pytest.raises(ValidationError, match="min_price.*must be <= max_price"):
QueryParameters(
listing_type=ListingType.RENT,
min_price=5000,
max_price=1000,
)
def test_invalid_bedroom_range_raises(self) -> None:
"""min_bedrooms > max_bedrooms should raise ValidationError."""
with pytest.raises(ValidationError, match="min_bedrooms.*must be <= max_bedrooms"):
QueryParameters(
listing_type=ListingType.RENT,
min_bedrooms=5,
max_bedrooms=2,
)
def test_negative_bedrooms_raises(self) -> None:
"""Negative bedroom counts should raise ValidationError."""
with pytest.raises(ValidationError, match="min_bedrooms.*must be non-negative"):
QueryParameters(
listing_type=ListingType.RENT,
min_bedrooms=-1,
max_bedrooms=3,
)
class TestDestinationMode:
"""Tests for DestinationMode."""
def test_to_dict(self) -> None:
"""Test to_dict returns correct dict."""
dm = DestinationMode(
destination_address="London Bridge",
travel_mode=TravelMode.TRANSIT,
)
result = dm.to_dict()
assert result == {
"destination_address": "London Bridge",
"travel_mode": TravelMode.TRANSIT,
}
def test_hash(self) -> None:
"""Test hashing works correctly."""
dm1 = DestinationMode(
destination_address="London Bridge",
travel_mode=TravelMode.TRANSIT,
)
dm2 = DestinationMode(
destination_address="London Bridge",
travel_mode=TravelMode.TRANSIT,
)
dm3 = DestinationMode(
destination_address="King's Cross",
travel_mode=TravelMode.TRANSIT,
)
assert hash(dm1) == hash(dm2)
assert dm1 == dm2
assert hash(dm1) != hash(dm3)
# Can be used as dict key
d = {dm1: "route1"}
assert d[dm2] == "route1"
class TestRoutingInfoSerialization:
"""Tests for routing info via RouteSerializer."""
def test_routing_info_property_returns_parsed_routes(self) -> None:
"""Test routing_info property deserializes correctly."""
routing_info = _make_sample_routing_info()
listing = _make_listing_with_routing(None)
serialized = listing.serialize_routing_info(routing_info)
listing.routing_info_json = serialized
result = listing.routing_info
assert len(result) == 1
dest_mode = list(result.keys())[0]
assert dest_mode.destination_address == "London Bridge"
assert dest_mode.travel_mode == TravelMode.TRANSIT
routes = result[dest_mode]
assert len(routes) == 1
assert routes[0].distance_meters == 4500
assert routes[0].duration_s == 600
assert len(routes[0].legs) == 2
assert routes[0].legs[0].distance_meters == 500
assert routes[0].legs[0].travel_mode == TravelMode.WALK
def test_routing_info_empty_json(self) -> None:
"""Test routing_info with no routing data."""
listing = _make_listing_with_routing(None)
assert listing.routing_info == {}
def test_serialize_routing_info_roundtrip(self) -> None:
"""Test serialize then deserialize via routing_info property."""
routing_info = _make_sample_routing_info()
listing = _make_listing_with_routing(None)
# Serialize
serialized = listing.serialize_routing_info(routing_info)
assert isinstance(serialized, str)
# Assign and deserialize via property
listing.routing_info_json = serialized
deserialized = listing.routing_info
# Compare
orig_dm = list(routing_info.keys())[0]
result_dm = list(deserialized.keys())[0]
assert orig_dm.destination_address == result_dm.destination_address
assert orig_dm.travel_mode == result_dm.travel_mode
orig_route = routing_info[orig_dm][0]
result_route = deserialized[result_dm][0]
assert orig_route.distance_meters == result_route.distance_meters
assert orig_route.duration_s == result_route.duration_s
assert len(orig_route.legs) == len(result_route.legs)

View file

@ -0,0 +1,385 @@
"""Unit tests for rec/query.py."""
import pytest
from unittest.mock import AsyncMock, MagicMock, patch
import aiohttp
from rec.query import (
detail_query,
listing_query,
probe_query,
PropertyType,
create_session,
_build_base_params,
_build_listing_params,
_build_probe_params,
ANDROID_APP_VERSION,
ANDROID_APP_VERSION_LISTING,
RIGHTMOVE_API_BASE,
PROPERTY_LISTING_ENDPOINT,
DEFAULT_HEADERS,
LISTING_HEADERS,
check_circuit_breaker,
reset_circuit_breaker,
get_circuit_breaker,
)
from models.listing import ListingType, FurnishType
from config.scraper_config import ScraperConfig
from rec.exceptions import CircuitBreakerOpenError
from rec.throttle_detector import reset_throttle_metrics
@pytest.fixture
def config() -> ScraperConfig:
return ScraperConfig(
max_concurrent_requests=5,
request_delay_ms=10,
slow_response_threshold=10.0,
enable_circuit_breaker=True,
circuit_breaker_failure_threshold=3,
circuit_breaker_recovery_timeout=0.5,
)
@pytest.fixture
def config_no_cb() -> ScraperConfig:
return ScraperConfig(enable_circuit_breaker=False)
@pytest.fixture(autouse=True)
def reset_globals() -> None:
reset_throttle_metrics()
reset_circuit_breaker()
class MockResponse:
def __init__(
self,
status: int = 200,
json_data: dict | None = None,
text: str = "",
):
self.status = status
self._json_data = json_data or {}
self._text = text
async def json(self) -> dict:
return self._json_data
async def text(self) -> str:
return self._text
async def __aenter__(self) -> "MockResponse":
return self
async def __aexit__(self, *args: object) -> None:
pass
def make_mock_session(response: MockResponse) -> MagicMock:
"""Create a mock session whose .get() returns an async context manager."""
mock_session = MagicMock()
mock_session.get = MagicMock(return_value=response)
return mock_session
def make_mock_session_fn(get_fn: object) -> MagicMock:
"""Create a mock session whose .get() calls a function to produce responses."""
mock_session = MagicMock()
mock_session.get = MagicMock(side_effect=get_fn)
return mock_session
class TestBuildBaseParams:
def test_constructs_correct_params(self) -> None:
with patch("rec.query.districts.get_districts", return_value={"TestDistrict": "REGION^123"}):
params = _build_base_params(
channel=ListingType.RENT,
page=2,
page_size=25,
radius=1.5,
min_price=1000,
max_price=3000,
min_bedrooms=1,
max_bedrooms=3,
district="TestDistrict",
)
assert params["locationIdentifier"] == "REGION^123"
assert params["channel"] == "RENT"
assert params["page"] == "2"
assert params["numberOfPropertiesPerPage"] == "25"
assert params["radius"] == "1.5"
assert params["sortBy"] == "distance"
assert params["includeUnavailableProperties"] == "false"
assert params["minPrice"] == "1000"
assert params["maxPrice"] == "3000"
assert params["minBedrooms"] == "1"
assert params["maxBedrooms"] == "3"
assert params["apiApplication"] == "ANDROID"
assert params["appVersion"] == ANDROID_APP_VERSION_LISTING
def test_buy_channel_includes_dont_show_and_max_days(self) -> None:
with patch("rec.query.districts.get_districts", return_value={"D": "LOC1"}):
params = _build_listing_params(
page=1,
channel=ListingType.BUY,
min_bedrooms=1,
max_bedrooms=2,
radius=1.0,
min_price=100000,
max_price=500000,
district="D",
mustNewHome=False,
max_days_since_added=7,
property_type=[],
page_size=25,
furnish_types=[],
)
assert params["dontShow"] == "sharedOwnership,retirement"
assert params["maxDaysSinceAdded"] == "7"
def test_rent_channel_includes_furnish_types(self) -> None:
with patch("rec.query.districts.get_districts", return_value={"D": "LOC1"}):
params = _build_listing_params(
page=1,
channel=ListingType.RENT,
min_bedrooms=1,
max_bedrooms=2,
radius=1.0,
min_price=1000,
max_price=3000,
district="D",
mustNewHome=False,
max_days_since_added=30,
property_type=[],
page_size=25,
furnish_types=[FurnishType.FURNISHED, FurnishType.UNFURNISHED],
)
assert params["furnishTypes"] == "furnished,unfurnished"
assert "dontShow" not in params
assert "maxDaysSinceAdded" not in params
def test_buy_channel_probe_includes_dont_show(self) -> None:
with patch("rec.query.districts.get_districts", return_value={"D": "LOC1"}):
params = _build_probe_params(
channel=ListingType.BUY,
min_bedrooms=1,
max_bedrooms=2,
radius=1.0,
min_price=100000,
max_price=500000,
district="D",
max_days_since_added=7,
furnish_types=[],
)
assert params["dontShow"] == "sharedOwnership,retirement"
assert params["maxDaysSinceAdded"] == "7"
assert params["numberOfPropertiesPerPage"] == "1"
def test_probe_buy_skips_max_days_if_not_valid(self) -> None:
with patch("rec.query.districts.get_districts", return_value={"D": "LOC1"}):
params = _build_probe_params(
channel=ListingType.BUY,
min_bedrooms=1,
max_bedrooms=2,
radius=1.0,
min_price=100000,
max_price=500000,
district="D",
max_days_since_added=30,
furnish_types=[],
)
# 30 is not in [1, 3, 7, 14], so maxDaysSinceAdded is not added for probe
assert "maxDaysSinceAdded" not in params
class TestMutableDefaultArgFix:
@pytest.mark.asyncio
async def test_property_type_default_not_shared(self, config: ScraperConfig) -> None:
"""Calling listing_query with no property_type should not share state between calls."""
response = MockResponse(
status=200,
json_data={"totalAvailableResults": 0, "properties": []},
)
mock_session = make_mock_session(response)
with patch("rec.query.districts.get_districts", return_value={"D": "LOC1"}):
# Call twice without explicit property_type
await listing_query(
page=1,
channel=ListingType.RENT,
min_bedrooms=1,
max_bedrooms=2,
radius=1.0,
min_price=1000,
max_price=2000,
district="D",
session=mock_session,
config=config,
)
await listing_query(
page=1,
channel=ListingType.RENT,
min_bedrooms=1,
max_bedrooms=2,
radius=1.0,
min_price=1000,
max_price=2000,
district="D",
session=mock_session,
config=config,
)
# If mutable default was shared, this test would detect mutations.
# The fact that it completes without error proves defaults are independent.
@pytest.mark.asyncio
async def test_furnish_types_default_not_shared(self, config: ScraperConfig) -> None:
"""Calling probe_query with no furnish_types should not share state between calls."""
response = MockResponse(
status=200,
json_data={"totalAvailableResults": 0, "properties": []},
)
mock_session = make_mock_session(response)
with patch("rec.query.districts.get_districts", return_value={"D": "LOC1"}):
await probe_query(
session=mock_session,
channel=ListingType.RENT,
min_bedrooms=1,
max_bedrooms=2,
radius=1.0,
min_price=1000,
max_price=2000,
district="D",
config=config,
)
await probe_query(
session=mock_session,
channel=ListingType.RENT,
min_bedrooms=1,
max_bedrooms=2,
radius=1.0,
min_price=1000,
max_price=2000,
district="D",
config=config,
)
class TestPropertyTypeEnum:
def test_enum_values(self) -> None:
assert PropertyType.BUNGALOW == "bungalow"
assert PropertyType.DETACHED == "detached"
assert PropertyType.FLAT == "flat"
assert PropertyType.LAND == "land"
assert PropertyType.PARK_HOME == "park-home"
assert PropertyType.SEMI_DETACHED == "semi-detached"
assert PropertyType.TERRACED == "terraced"
def test_enum_is_str(self) -> None:
assert isinstance(PropertyType.FLAT, str)
assert ",".join([PropertyType.FLAT, PropertyType.DETACHED]) == "flat,detached"
class TestDetailQuery:
@pytest.mark.asyncio
async def test_success_200(self, config: ScraperConfig) -> None:
expected_body = {"id": 12345, "address": "123 Test St"}
response = MockResponse(status=200, json_data=expected_body)
mock_session = make_mock_session(response)
result = await detail_query(12345, session=mock_session, config=config)
assert result == expected_body
@pytest.mark.asyncio
async def test_raises_on_non_200(self, config: ScraperConfig) -> None:
response = MockResponse(status=404, text="Not Found")
mock_session = make_mock_session(response)
with pytest.raises(Exception, match="Failed due to"):
await detail_query(99999, session=mock_session, config=config)
class TestCircuitBreakerBlocksRequests:
@pytest.mark.asyncio
async def test_circuit_breaker_blocks_when_open(self, config: ScraperConfig) -> None:
cb = get_circuit_breaker(config)
assert cb is not None
for _ in range(config.circuit_breaker_failure_threshold):
cb.record_failure()
assert cb.is_open
mock_session = MagicMock()
with pytest.raises(CircuitBreakerOpenError):
await detail_query(1, session=mock_session, config=config)
@pytest.mark.asyncio
async def test_circuit_breaker_blocks_listing_query(self, config: ScraperConfig) -> None:
cb = get_circuit_breaker(config)
assert cb is not None
for _ in range(config.circuit_breaker_failure_threshold):
cb.record_failure()
mock_session = MagicMock()
with patch("rec.query.districts.get_districts", return_value={"D": "LOC1"}):
with pytest.raises(CircuitBreakerOpenError):
await listing_query(
page=1,
channel=ListingType.RENT,
min_bedrooms=1,
max_bedrooms=2,
radius=1.0,
min_price=1000,
max_price=2000,
district="D",
session=mock_session,
config=config,
)
@pytest.mark.asyncio
async def test_circuit_breaker_blocks_probe_query(self, config: ScraperConfig) -> None:
cb = get_circuit_breaker(config)
assert cb is not None
for _ in range(config.circuit_breaker_failure_threshold):
cb.record_failure()
mock_session = MagicMock()
with patch("rec.query.districts.get_districts", return_value={"D": "LOC1"}):
with pytest.raises(CircuitBreakerOpenError):
await probe_query(
session=mock_session,
channel=ListingType.RENT,
min_bedrooms=1,
max_bedrooms=2,
radius=1.0,
min_price=1000,
max_price=2000,
district="D",
config=config,
)
class TestConstants:
def test_android_app_version(self) -> None:
assert ANDROID_APP_VERSION == "3.70.0"
def test_android_app_version_listing(self) -> None:
assert ANDROID_APP_VERSION_LISTING == "4.28.0"
def test_rightmove_api_base(self) -> None:
assert RIGHTMOVE_API_BASE == "https://api.rightmove.co.uk/api"
def test_property_listing_endpoint(self) -> None:
assert PROPERTY_LISTING_ENDPOINT == "https://api.rightmove.co.uk/api/property-listing"
def test_listing_headers_extends_default(self) -> None:
for key, value in DEFAULT_HEADERS.items():
assert LISTING_HEADERS[key] == value
assert LISTING_HEADERS["Accept-Encoding"] == "gzip, deflate, br"

View file

@ -161,7 +161,7 @@ class TestQuerySplitter:
mock_session = AsyncMock()
# Mock the probe_query function
with patch("services.query_splitter.probe_query") as mock_probe:
with patch("rec.query.probe_query") as mock_probe:
mock_probe.return_value = {"totalAvailableResults": 800}
count = await splitter.probe_result_count(sq, mock_session, parameters)
@ -184,7 +184,7 @@ class TestQuerySplitter:
mock_session = AsyncMock()
with patch("services.query_splitter.probe_query") as mock_probe:
with patch("rec.query.probe_query") as mock_probe:
mock_probe.side_effect = Exception("API error")
count = await splitter.probe_result_count(sq, mock_session, parameters)
@ -208,7 +208,7 @@ class TestQuerySplitter:
mock_session = AsyncMock()
mock_semaphore = AsyncMock()
with patch("services.query_splitter.probe_query") as mock_probe:
with patch("rec.query.probe_query") as mock_probe:
# First half has 600 results, second half has 500
mock_probe.side_effect = [
{"totalAvailableResults": 600},
@ -240,7 +240,7 @@ class TestQuerySplitter:
mock_session = AsyncMock()
mock_semaphore = AsyncMock()
with patch("services.query_splitter.probe_query") as mock_probe:
with patch("rec.query.probe_query") as mock_probe:
# First split: 1000-3000 has 1300 (over threshold), 3000-5000 has 800
# Second split of 1000-3000: 1000-2000 has 700, 2000-3000 has 600
mock_probe.side_effect = [
@ -326,7 +326,7 @@ class TestQuerySplitter:
mock_districts = {"Kings Cross": "STATION^5168", "Angel": "STATION^1234"}
with patch("services.query_splitter.get_districts", return_value=mock_districts):
with patch("services.query_splitter.probe_query") as mock_probe:
with patch("rec.query.probe_query") as mock_probe:
# Mock probe results for each initial subquery
# 2 districts × 2 bedroom counts = 4 initial subqueries
mock_probe.side_effect = [
@ -358,11 +358,11 @@ class TestQuerySplitter:
mock_districts = {"Kings Cross": "STATION^5168", "Angel": "STATION^1234"}
progress_calls = []
def on_progress(phase: str, message: str) -> None:
def on_progress(phase: str, message: str, **kwargs: object) -> None:
progress_calls.append((phase, message))
with patch("services.query_splitter.get_districts", return_value=mock_districts):
with patch("services.query_splitter.probe_query") as mock_probe:
with patch("rec.query.probe_query") as mock_probe:
mock_probe.return_value = {"totalAvailableResults": 500}
await splitter.split(parameters, mock_session, on_progress)

View file

@ -1,5 +1,6 @@
"""Unit tests for ListingRepository."""
from datetime import datetime, timedelta
from unittest.mock import AsyncMock, MagicMock
import pytest
from sqlalchemy import Engine
@ -225,3 +226,156 @@ class TestListingRepositoryFilters:
listings = await listing_repository.get_listings(query_parameters=query_params)
# Should match listings with 1-2 bedrooms in price range
assert len(listings) == 2
class TestListingRepositoryStreaming:
"""Tests for streaming and optimized query methods."""
async def test_count_listings_empty_db(
self, listing_repository: ListingRepository
) -> None:
"""Test count returns 0 for empty database."""
count = listing_repository.count_listings()
assert count == 0
async def test_count_listings_with_data(
self,
listing_repository: ListingRepository,
sample_rent_listings: list[RentListing],
) -> None:
"""Test count returns correct number."""
await listing_repository.upsert_listings(sample_rent_listings)
count = listing_repository.count_listings()
assert count == 3
async def test_count_listings_with_filters(
self,
listing_repository: ListingRepository,
sample_rent_listings: list[RentListing],
) -> None:
"""Test count respects query parameters."""
await listing_repository.upsert_listings(sample_rent_listings)
query_params = QueryParameters(
listing_type=ListingType.RENT,
min_bedrooms=2,
max_bedrooms=3,
)
count = listing_repository.count_listings(query_parameters=query_params)
assert count == 2
async def test_stream_listings_optimized_returns_dicts(
self,
listing_repository: ListingRepository,
sample_rent_listings: list[RentListing],
) -> None:
"""Test optimized streaming returns dict rows."""
await listing_repository.upsert_listings(sample_rent_listings)
results = list(listing_repository.stream_listings_optimized())
assert len(results) == 3
# Each result should be a dict
for row in results:
assert isinstance(row, dict)
assert "id" in row
assert "price" in row
assert "number_of_bedrooms" in row
async def test_stream_listings_optimized_respects_limit(
self,
listing_repository: ListingRepository,
sample_rent_listings: list[RentListing],
) -> None:
"""Test streaming limit parameter."""
await listing_repository.upsert_listings(sample_rent_listings)
results = list(listing_repository.stream_listings_optimized(limit=2))
assert len(results) == 2
async def test_get_listing_ids(
self,
listing_repository: ListingRepository,
sample_rent_listings: list[RentListing],
) -> None:
"""Test get_listing_ids returns set of IDs."""
await listing_repository.upsert_listings(sample_rent_listings)
ids = listing_repository.get_listing_ids()
assert isinstance(ids, set)
assert ids == {1, 2, 3}
async def test_get_listing_ids_empty_db(
self,
listing_repository: ListingRepository,
) -> None:
"""Test get_listing_ids returns empty set for empty database."""
ids = listing_repository.get_listing_ids()
assert isinstance(ids, set)
assert len(ids) == 0
class TestFurnishTypeParsing:
"""Tests for _parse_furnish_type helper."""
def test_parse_furnish_type_none_detailobject(self) -> None:
"""Test that None detailobject returns UNKNOWN."""
listing = MagicMock()
listing.detailobject = None
result = ListingRepository._parse_furnish_type(listing)
assert result == FurnishType.UNKNOWN
def test_parse_furnish_type_missing_property_key(self) -> None:
"""Test that missing 'property' key returns UNKNOWN."""
listing = MagicMock()
listing.detailobject = {}
result = ListingRepository._parse_furnish_type(listing)
assert result == FurnishType.UNKNOWN
def test_parse_furnish_type_missing_let_furnish_type(self) -> None:
"""Test that missing 'letFurnishType' key returns UNKNOWN."""
listing = MagicMock()
listing.detailobject = {"property": {}}
result = ListingRepository._parse_furnish_type(listing)
assert result == FurnishType.UNKNOWN
def test_parse_furnish_type_null_value(self) -> None:
"""Test that null letFurnishType value returns UNKNOWN."""
listing = MagicMock()
listing.detailobject = {"property": {"letFurnishType": None}}
result = ListingRepository._parse_furnish_type(listing)
assert result == FurnishType.UNKNOWN
def test_parse_furnish_type_furnished(self) -> None:
"""Test that 'Furnished' is parsed correctly."""
listing = MagicMock()
listing.detailobject = {"property": {"letFurnishType": "Furnished"}}
result = ListingRepository._parse_furnish_type(listing)
assert result == FurnishType.FURNISHED
def test_parse_furnish_type_unfurnished(self) -> None:
"""Test that 'Unfurnished' is parsed correctly."""
listing = MagicMock()
listing.detailobject = {"property": {"letFurnishType": "Unfurnished"}}
result = ListingRepository._parse_furnish_type(listing)
assert result == FurnishType.UNFURNISHED
def test_parse_furnish_type_part_furnished(self) -> None:
"""Test that 'Part Furnished' is parsed correctly."""
listing = MagicMock()
listing.detailobject = {"property": {"letFurnishType": "Part Furnished"}}
result = ListingRepository._parse_furnish_type(listing)
assert result == FurnishType.PART_FURNISHED
def test_parse_furnish_type_landlord_variant(self) -> None:
"""Test that landlord variants map to ASK_LANDLORD."""
listing = MagicMock()
listing.detailobject = {"property": {"letFurnishType": "Ask Landlord Please"}}
result = ListingRepository._parse_furnish_type(listing)
assert result == FurnishType.ASK_LANDLORD
def test_parse_furnish_type_landlord_case_insensitive(self) -> None:
"""Test that landlord check is case-insensitive."""
listing = MagicMock()
listing.detailobject = {"property": {"letFurnishType": "LANDLORD decides"}}
result = ListingRepository._parse_furnish_type(listing)
assert result == FurnishType.ASK_LANDLORD

View file

@ -0,0 +1,10 @@
"""Unit tests for services/route_calculator.py."""
from services.route_calculator import _parse_duration
class TestParseDuration:
def test_parse_normal_duration(self) -> None:
assert _parse_duration("123s") == 123
def test_parse_zero_duration(self) -> None:
assert _parse_duration("0s") == 0

View file

@ -0,0 +1,72 @@
"""Unit tests for rec/route_serializer.py."""
from models.listing import DestinationMode, Route, RouteLegStep
from rec.route_serializer import RouteSerializer
from rec.routing import TravelMode
def _make_sample_routing_info() -> dict[DestinationMode, list[Route]]:
destination_mode = DestinationMode(
destination_address="London Bridge",
travel_mode=TravelMode.TRANSIT,
)
routes = [
Route(
legs=[
RouteLegStep(
distance_meters=500,
duration_s=120,
travel_mode=TravelMode.WALK,
),
RouteLegStep(
distance_meters=4000,
duration_s=480,
travel_mode=TravelMode.TRANSIT,
),
],
distance_meters=4500,
duration_s=600,
)
]
return {destination_mode: routes}
class TestRouteSerializer:
def test_serialize_then_deserialize_roundtrip(self) -> None:
routing_info = _make_sample_routing_info()
serialized = RouteSerializer.serialize(routing_info)
deserialized = RouteSerializer.deserialize(serialized)
assert len(deserialized) == 1
dest_mode = list(deserialized.keys())[0]
assert dest_mode.destination_address == "London Bridge"
assert dest_mode.travel_mode == TravelMode.TRANSIT
routes = deserialized[dest_mode]
assert len(routes) == 1
assert routes[0].distance_meters == 4500
assert routes[0].duration_s == 600
assert len(routes[0].legs) == 2
assert routes[0].legs[0].distance_meters == 500
assert routes[0].legs[0].travel_mode == TravelMode.WALK
assert routes[0].legs[1].travel_mode == TravelMode.TRANSIT
def test_deserialize_sample_json(self) -> None:
import json
import dataclasses
routing_info = _make_sample_routing_info()
# Build the JSON manually to test deserialize independently
json_str = json.dumps(
{
json.dumps(dataclasses.asdict(dm)): [
json.dumps(dataclasses.asdict(r)) for r in routes
]
for dm, routes in routing_info.items()
}
)
result = RouteSerializer.deserialize(json_str)
assert len(result) == 1
dest_mode = list(result.keys())[0]
assert dest_mode.destination_address == "London Bridge"
assert result[dest_mode][0].duration_s == 600

View file

@ -0,0 +1,67 @@
"""Unit tests for rec/routing.py."""
import os
from unittest.mock import patch, MagicMock
import pytest
from rec.routing import TravelMode, transit_route, ROUTES_API_URL, ROUTES_FIELD_MASK
from rec.exceptions import RoutingApiError
class TestTravelMode:
def test_enum_values(self) -> None:
assert TravelMode.TRANSIT == "TRANSIT"
assert TravelMode.BICYCLE == "BICYCLE"
assert TravelMode.WALK == "WALK"
assert TravelMode.DRIVE == "DRIVE"
def test_enum_has_four_members(self) -> None:
assert len(TravelMode) == 4
class TestTransitRoute:
@patch("rec.routing.requests.post")
@patch("rec.routing.nextMonday")
def test_success_response(self, mock_monday: MagicMock, mock_post: MagicMock) -> None:
mock_monday.return_value = MagicMock(
strftime=MagicMock(return_value="2024-01-08T09:00:00.000000Z")
)
expected = {"routes": [{"duration": "600s", "distanceMeters": 5000}]}
mock_response = MagicMock()
mock_response.status_code = 200
mock_response.json.return_value = expected
mock_post.return_value = mock_response
with patch.dict(os.environ, {"ROUTING_API_KEY": "test-key"}):
result = transit_route(51.5, -0.1, "London Bridge", TravelMode.TRANSIT)
assert result == expected
mock_post.assert_called_once()
call_kwargs = mock_post.call_args
assert call_kwargs.kwargs["headers"]["X-Goog-Api-Key"] == "test-key"
@patch("rec.routing.requests.post")
@patch("rec.routing.nextMonday")
def test_raises_routing_api_error_on_non_200(self, mock_monday: MagicMock, mock_post: MagicMock) -> None:
mock_monday.return_value = MagicMock(
strftime=MagicMock(return_value="2024-01-08T09:00:00.000000Z")
)
error_body = {"error": {"message": "Invalid API key", "status": "PERMISSION_DENIED"}}
mock_response = MagicMock()
mock_response.status_code = 403
mock_response.json.return_value = error_body
mock_post.return_value = mock_response
with patch.dict(os.environ, {"ROUTING_API_KEY": "bad-key"}):
with pytest.raises(RoutingApiError) as exc_info:
transit_route(51.5, -0.1, "London Bridge", TravelMode.TRANSIT)
assert exc_info.value.status_code == 403
assert exc_info.value.response_body == error_body
def test_raises_key_error_when_api_key_not_set(self) -> None:
env = os.environ.copy()
env.pop("ROUTING_API_KEY", None)
with patch.dict(os.environ, env, clear=True):
with pytest.raises(KeyError):
transit_route(51.5, -0.1, "London Bridge", TravelMode.TRANSIT)

View file

@ -0,0 +1,306 @@
"""Unit tests for services/task_service.py."""
from unittest.mock import MagicMock, patch
import pytest
from services.task_service import (
TaskStatus,
_extract_progress_info,
_extract_result,
_make_system_user,
get_task_status,
)
class TestMakeSystemUser:
"""Tests for _make_system_user helper."""
def test_creates_user_with_email(self) -> None:
user = _make_system_user("test@example.com")
assert user.email == "test@example.com"
assert user.sub == ""
assert user.name == ""
def test_different_emails_create_different_users(self) -> None:
u1 = _make_system_user("a@b.com")
u2 = _make_system_user("c@d.com")
assert u1.email != u2.email
class TestExtractResult:
"""Tests for _extract_result helper."""
def test_failed_task_returns_error(self) -> None:
mock_result = MagicMock()
mock_result.failed.return_value = True
mock_result.result = Exception("something broke")
result, error = _extract_result(mock_result)
assert result is None
assert error is not None
assert "something broke" in error
def test_failed_task_with_no_result(self) -> None:
mock_result = MagicMock()
mock_result.failed.return_value = True
mock_result.result = None
result, error = _extract_result(mock_result)
assert result is None
assert error is None
def test_successful_json_serializable_result(self) -> None:
mock_result = MagicMock()
mock_result.failed.return_value = False
mock_result.result = {"count": 42, "status": "done"}
result, error = _extract_result(mock_result)
assert result == {"count": 42, "status": "done"}
assert error is None
def test_non_serializable_result_falls_back_to_str(self) -> None:
mock_result = MagicMock()
mock_result.failed.return_value = False
mock_result.result = object() # not JSON-serializable
result, error = _extract_result(mock_result)
assert isinstance(result, str)
assert error is None
def test_none_result_stays_none(self) -> None:
mock_result = MagicMock()
mock_result.failed.return_value = False
mock_result.result = None
result, error = _extract_result(mock_result)
assert result is None
assert error is None
class TestExtractProgressInfo:
"""Tests for _extract_progress_info helper."""
def test_extracts_progress_fields(self) -> None:
mock_result = MagicMock()
mock_result.info = {"progress": 0.5, "processed": 50, "total": 100}
mock_result.status = "STARTED"
info = _extract_progress_info(mock_result)
assert info["progress"] == 0.5
assert info["processed"] == 50
assert info["total"] == 100
assert info["message"] is None
def test_extracts_message_from_info(self) -> None:
mock_result = MagicMock()
mock_result.info = {"message": "Processing page 3"}
mock_result.status = "STARTED"
info = _extract_progress_info(mock_result)
assert info["message"] == "Processing page 3"
def test_falls_back_to_reason_for_skipped(self) -> None:
mock_result = MagicMock()
mock_result.info = {"reason": "Already running"}
mock_result.status = "SKIPPED"
info = _extract_progress_info(mock_result)
assert info["message"] == "Already running"
def test_custom_state_used_as_message(self) -> None:
mock_result = MagicMock()
mock_result.info = {}
mock_result.status = "Fetching listings"
info = _extract_progress_info(mock_result)
assert info["message"] == "Fetching listings"
def test_standard_state_not_used_as_message(self) -> None:
mock_result = MagicMock()
mock_result.info = {}
mock_result.status = "PENDING"
info = _extract_progress_info(mock_result)
assert info["message"] is None
def test_none_info_returns_all_none(self) -> None:
mock_result = MagicMock()
mock_result.info = None
mock_result.status = "PENDING"
info = _extract_progress_info(mock_result)
assert info == {"progress": None, "processed": None, "total": None, "message": None}
class TestGetTaskStatus:
"""Tests for get_task_status."""
def test_pending_task(self) -> None:
"""Test status for a pending task."""
mock_result = MagicMock()
mock_result.status = "PENDING"
mock_result.failed.return_value = False
mock_result.result = None
mock_result.info = None
mock_result.traceback = None
with patch("services.task_service.dump_listings_task", create=True) as mock_task:
mock_task.AsyncResult.return_value = mock_result
# Patch the lazy import
with patch.dict("sys.modules", {"tasks.listing_tasks": MagicMock(dump_listings_task=mock_task)}):
status = get_task_status("test-id")
assert status.task_id == "test-id"
assert status.status == "PENDING"
assert status.error is None
def test_failed_task(self) -> None:
"""Test status for a failed task."""
mock_result = MagicMock()
mock_result.status = "FAILURE"
mock_result.failed.return_value = True
mock_result.result = Exception("something broke")
mock_result.info = None
mock_result.traceback = "Traceback..."
with patch("services.task_service.dump_listings_task", create=True) as mock_task:
mock_task.AsyncResult.return_value = mock_result
with patch.dict("sys.modules", {"tasks.listing_tasks": MagicMock(dump_listings_task=mock_task)}):
status = get_task_status("test-id")
assert status.status == "FAILURE"
assert status.error is not None
assert status.traceback == "Traceback..."
def test_custom_state_with_progress(self) -> None:
"""Test that custom states with progress info are extracted correctly."""
mock_result = MagicMock()
mock_result.status = "Fetching listings"
mock_result.failed.return_value = False
mock_result.result = None
mock_result.info = {"progress": 0.5, "processed": 50, "total": 100}
mock_result.traceback = None
with patch("services.task_service.dump_listings_task", create=True) as mock_task:
mock_task.AsyncResult.return_value = mock_result
with patch.dict("sys.modules", {"tasks.listing_tasks": MagicMock(dump_listings_task=mock_task)}):
status = get_task_status("test-id")
assert status.progress == 0.5
assert status.processed == 50
assert status.total == 100
def test_successful_task(self) -> None:
"""Test status for a successful task."""
mock_result = MagicMock()
mock_result.status = "SUCCESS"
mock_result.failed.return_value = False
mock_result.result = {"listings_count": 42}
mock_result.info = None
mock_result.traceback = None
with patch("services.task_service.dump_listings_task", create=True) as mock_task:
mock_task.AsyncResult.return_value = mock_result
with patch.dict("sys.modules", {"tasks.listing_tasks": MagicMock(dump_listings_task=mock_task)}):
status = get_task_status("test-id")
assert status.status == "SUCCESS"
assert status.result == {"listings_count": 42}
assert status.error is None
class TestGetUserTasks:
"""Tests for get_user_tasks."""
def test_returns_task_list(self) -> None:
mock_redis = MagicMock()
mock_redis.get_tasks_for_user.return_value = ["task-1", "task-2"]
with patch("services.task_service.RedisRepository", create=True) as MockRedisRepo:
MockRedisRepo.instance.return_value = mock_redis
with patch.dict("sys.modules", {"redis_repository": MagicMock(RedisRepository=MockRedisRepo)}):
from services.task_service import get_user_tasks
result = get_user_tasks("test@example.com")
assert result == ["task-1", "task-2"]
def test_returns_empty_list_for_unknown_user(self) -> None:
mock_redis = MagicMock()
mock_redis.get_tasks_for_user.return_value = []
with patch("services.task_service.RedisRepository", create=True) as MockRedisRepo:
MockRedisRepo.instance.return_value = mock_redis
with patch.dict("sys.modules", {"redis_repository": MagicMock(RedisRepository=MockRedisRepo)}):
from services.task_service import get_user_tasks
result = get_user_tasks("nobody@example.com")
assert result == []
class TestCancelTask:
"""Tests for cancel_task."""
def test_cancel_revokes_and_removes(self) -> None:
mock_celery = MagicMock()
mock_redis = MagicMock()
mock_redis.remove_task_for_user.return_value = True
with patch.dict("sys.modules", {
"celery_app": MagicMock(app=mock_celery),
"redis_repository": MagicMock(RedisRepository=MagicMock(instance=MagicMock(return_value=mock_redis))),
}):
from services.task_service import cancel_task
result = cancel_task("task-123", user_email="test@example.com")
assert result is True
mock_celery.control.revoke.assert_called_once_with("task-123", terminate=True)
def test_cancel_without_user_email(self) -> None:
mock_celery = MagicMock()
with patch.dict("sys.modules", {"celery_app": MagicMock(app=mock_celery)}):
from services.task_service import cancel_task
result = cancel_task("task-456")
assert result is True
mock_celery.control.revoke.assert_called_once_with("task-456", terminate=True)
class TestClearAllTasks:
"""Tests for clear_all_tasks."""
def test_clear_with_revoke(self) -> None:
mock_celery = MagicMock()
mock_redis = MagicMock()
mock_redis.get_tasks_for_user.return_value = ["t1", "t2"]
mock_redis.clear_tasks_for_user.return_value = 2
with patch.dict("sys.modules", {
"celery_app": MagicMock(app=mock_celery),
"redis_repository": MagicMock(RedisRepository=MagicMock(instance=MagicMock(return_value=mock_redis))),
}):
from services.task_service import clear_all_tasks
count = clear_all_tasks("test@example.com", revoke=True)
assert count == 2
assert mock_celery.control.revoke.call_count == 2
def test_clear_without_revoke(self) -> None:
mock_celery = MagicMock()
mock_redis = MagicMock()
mock_redis.clear_tasks_for_user.return_value = 3
with patch.dict("sys.modules", {
"celery_app": MagicMock(app=mock_celery),
"redis_repository": MagicMock(RedisRepository=MagicMock(instance=MagicMock(return_value=mock_redis))),
}):
from services.task_service import clear_all_tasks
count = clear_all_tasks("test@example.com", revoke=False)
assert count == 3
mock_celery.control.revoke.assert_not_called()
def test_revoke_failure_logs_warning_and_continues(self) -> None:
mock_celery = MagicMock()
mock_celery.control.revoke.side_effect = Exception("connection lost")
mock_redis = MagicMock()
mock_redis.get_tasks_for_user.return_value = ["t1"]
mock_redis.clear_tasks_for_user.return_value = 1
with patch.dict("sys.modules", {
"celery_app": MagicMock(app=mock_celery),
"redis_repository": MagicMock(RedisRepository=MagicMock(instance=MagicMock(return_value=mock_redis))),
}):
from services.task_service import clear_all_tasks
# Should not raise despite revoke failure
count = clear_all_tasks("test@example.com", revoke=True)
assert count == 1