Add API rate limiting, metrics guard, and audit middleware
Per-user rate limits via Redis sliding window, IP-restricted /metrics endpoint, audit logging of all requests, CORS tightening, and export caps on listing/geojson endpoints.
This commit is contained in:
parent
08ac72bbfc
commit
87b5bd8676
8 changed files with 756 additions and 2 deletions
28
api/app.py
28
api/app.py
|
|
@ -7,6 +7,10 @@ from typing import Annotated, AsyncGenerator, Optional
|
|||
from api.auth import get_current_user
|
||||
from api.config import DEV_TIER_ORIGINS, PROD_TIER_ORIGINS
|
||||
from api.passkey_routes import passkey_router
|
||||
from api.rate_limit_config import RateLimitConfig
|
||||
from api.rate_limiter import RateLimitMiddleware
|
||||
from api.audit_middleware import AuditLogMiddleware
|
||||
from api.metrics_guard import MetricsGuardMiddleware
|
||||
from dotenv import load_dotenv
|
||||
from fastapi import Depends, FastAPI, Query
|
||||
from fastapi.responses import StreamingResponse
|
||||
|
|
@ -33,6 +37,7 @@ load_dotenv()
|
|||
logger = logging.getLogger("uvicorn")
|
||||
|
||||
DEFAULT_BATCH_SIZE = 50
|
||||
_rate_limit_config = RateLimitConfig.from_env()
|
||||
|
||||
|
||||
def get_query_parameters(
|
||||
|
|
@ -82,10 +87,18 @@ hist = meter.create_histogram(
|
|||
app.add_middleware(
|
||||
CORSMiddleware,
|
||||
allow_origins=[*DEV_TIER_ORIGINS, *PROD_TIER_ORIGINS],
|
||||
allow_methods=["*"],
|
||||
allow_headers=["*"],
|
||||
allow_methods=["GET", "POST"],
|
||||
allow_headers=["Authorization", "Content-Type"],
|
||||
)
|
||||
|
||||
# Security middleware (added bottom-to-top; last added = outermost)
|
||||
# 3. Rate limiting — enforces per-user limits
|
||||
app.add_middleware(RateLimitMiddleware, config=_rate_limit_config)
|
||||
# 2. Metrics guard — blocks unauthorized /metrics access
|
||||
app.add_middleware(MetricsGuardMiddleware, config=_rate_limit_config)
|
||||
# 1. Audit logging — logs everything including 429s and 403s
|
||||
app.add_middleware(AuditLogMiddleware)
|
||||
|
||||
|
||||
@app.get("/api/status")
|
||||
async def get_status() -> dict[str, str]:
|
||||
|
|
@ -100,6 +113,7 @@ async def get_listing(
|
|||
limit: int = 5,
|
||||
) -> dict[str, list]:
|
||||
"""Get listings from the database."""
|
||||
limit = min(limit, _rate_limit_config.listing_limit_cap)
|
||||
repository = ListingRepository(engine)
|
||||
result = await listing_service.get_listings(repository, limit=limit)
|
||||
logger.info(f"Fetched {result.total_count} listings for {user.email}")
|
||||
|
|
@ -113,6 +127,10 @@ async def get_listing_geojson(
|
|||
limit: int | None = None,
|
||||
) -> dict:
|
||||
"""Get listings as GeoJSON for map display."""
|
||||
if limit is not None:
|
||||
limit = min(limit, _rate_limit_config.geojson_limit_cap)
|
||||
else:
|
||||
limit = _rate_limit_config.geojson_limit_cap
|
||||
repository = ListingRepository(engine)
|
||||
result = await export_service.export_to_geojson(
|
||||
repository,
|
||||
|
|
@ -204,6 +222,12 @@ async def stream_listing_geojson(
|
|||
- batch: Array of GeoJSON features
|
||||
- complete: Final message with total count
|
||||
"""
|
||||
batch_size = min(batch_size, _rate_limit_config.geojson_stream_batch_size_cap)
|
||||
if limit is not None:
|
||||
limit = min(limit, _rate_limit_config.geojson_stream_limit_cap)
|
||||
else:
|
||||
limit = _rate_limit_config.geojson_stream_limit_cap
|
||||
|
||||
cached_count = get_cached_count(query_parameters)
|
||||
if cached_count is not None and cached_count > 0:
|
||||
generator = _stream_from_cache(query_parameters, batch_size, limit)
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue