Fix 7 bugs: security, memory leak, stale state, error handling

- WebSocket: verify task ownership before allowing subscribe (security)
- POI routes: replace assert with HTTPException for production safety
- cancel_task: return HTTP 404 instead of 200 for missing tasks
- routing_config: add descriptive ValueError for invalid env vars
- POIManager: show error feedback instead of silently swallowing failures
- VisualizationCard: reset POI/travel mode state on metric switch
- Map: clean up heatmap layers/sources on unmount to prevent memory leak
- Update test to expect 404 from cancel_task ownership check
This commit is contained in:
Viktor Barzin 2026-02-13 19:36:43 +00:00
parent 25c87da1cf
commit 41b7d221e4
No known key found for this signature in database
GPG key ID: 0EB088298288D958
8 changed files with 45 additions and 9 deletions

View file

@ -415,7 +415,7 @@ async def cancel_task(
# Verify user owns this task
user_tasks = task_service.get_user_tasks(user.email)
if task_id not in user_tasks:
return {"success": False, "message": "Task not found or not owned by user"}
raise HTTPException(status_code=404, detail="Task not found or not owned by user")
try:
task_service.cancel_task(task_id, user_email=user.email)

View file

@ -60,7 +60,8 @@ def _get_user_id(user: User) -> int:
if db_user is None:
# Auto-create user on first POI interaction
db_user = user_repo.create_user(user.email)
assert db_user.id is not None
if db_user.id is None:
raise HTTPException(status_code=500, detail="Failed to create user")
return db_user.id

View file

@ -137,6 +137,10 @@ async def ws_task_progress(websocket: WebSocket) -> None:
if msg_type == "subscribe":
new_task_id = msg.get("task_id")
if new_task_id:
# Verify task belongs to the authenticated user
user_tasks = task_service.get_user_tasks(user.email)
if new_task_id not in user_tasks:
continue
channel = f"task_progress:{new_task_id}"
if channel not in subscribed_channels:
await pubsub.subscribe(channel)

View file

@ -6,6 +6,15 @@ from dataclasses import dataclass
from typing import Self
def _int_env(name: str, default: str) -> int:
"""Parse an integer environment variable with a descriptive error."""
raw = os.environ.get(name, default)
try:
return int(raw)
except ValueError:
raise ValueError(f"Environment variable {name}={raw!r} must be an integer")
@dataclass(frozen=True)
class RoutingConfig:
"""Configuration for self-hosted routing engines (OSRM + OTP).
@ -39,8 +48,8 @@ class RoutingConfig:
osrm_foot_url=os.environ.get("OSRM_FOOT_URL", "http://osrm-foot:5000"),
osrm_bicycle_url=os.environ.get("OSRM_BICYCLE_URL", "http://osrm-bicycle:5000"),
otp_url=os.environ.get("OTP_URL", "http://otp:8080"),
osrm_batch_size=int(os.environ.get("OSRM_BATCH_SIZE", "50")),
otp_max_concurrent=int(os.environ.get("OTP_MAX_CONCURRENT", "10")),
osrm_batch_size=_int_env("OSRM_BATCH_SIZE", "50"),
otp_max_concurrent=_int_env("OTP_MAX_CONCURRENT", "10"),
)
def get_osrm_url(self, profile: str) -> str:

View file

@ -194,6 +194,19 @@ export function Map(props: MapProps) {
if (updateTimeoutRef.current) {
clearTimeout(updateTimeoutRef.current);
}
// Remove heatmap layers and sources before destroying the map
if (heatmapRef.current && mapRef.current) {
for (const layerId of ['hexgrid-heatmap', 'hexgrid-heatmap-back']) {
if (mapRef.current.getLayer(layerId)) {
mapRef.current.removeLayer(layerId);
}
}
for (const sourceId of ['hexgrid-heatmap', 'hexgrid-heatmap-back']) {
if (mapRef.current.getSource(sourceId)) {
mapRef.current.removeSource(sourceId);
}
}
}
heatmapRef.current = null;
isMapLoadedRef.current = false;
mapRef.current?.remove();

View file

@ -23,6 +23,7 @@ export function POIManager({ user, listingType, onTaskCreated, pickedLocation, o
const [lng, setLng] = useState('');
const [calculating, setCalculating] = useState<number | null>(null);
const [selectedModes, setSelectedModes] = useState<string[]>(['WALK', 'BICYCLE', 'TRANSIT']);
const [error, setError] = useState<string | null>(null);
useEffect(() => {
fetchUserPOIs(user).then(setPois).catch(() => {});
@ -38,6 +39,7 @@ export function POIManager({ user, listingType, onTaskCreated, pickedLocation, o
const handleCreate = async () => {
if (!name || !lat || !lng) return;
setError(null);
try {
const poi = await createPOI(user, {
name,
@ -62,16 +64,17 @@ export function POIManager({ user, listingType, onTaskCreated, pickedLocation, o
// Non-fatal: POI created successfully, calculation can be retried manually.
}
} catch {
// silently fail
setError('Failed to create POI. Please try again.');
}
};
const handleDelete = async (poiId: number) => {
setError(null);
try {
await deletePOI(user, poiId);
setPois(prev => prev.filter(p => p.id !== poiId));
} catch {
// silently fail
setError('Failed to delete POI. Please try again.');
}
};
@ -95,6 +98,12 @@ export function POIManager({ user, listingType, onTaskCreated, pickedLocation, o
return (
<div className="space-y-3">
{error && (
<div className="text-xs text-destructive bg-destructive/10 border border-destructive/20 rounded-md px-3 py-2 flex items-center justify-between">
<span>{error}</span>
<button type="button" className="ml-2 text-destructive hover:text-destructive/80" onClick={() => setError(null)}>&times;</button>
</div>
)}
{/* POI List */}
{pois.map(poi => (
<div key={poi.id} className="flex items-center gap-2 p-2 rounded-md border text-sm">

View file

@ -24,6 +24,8 @@ export function VisualizationCard({ metric, onMetricChange, userPOIs, onPoiMetri
onValueChange={(value) => {
onMetricChange(value as Metric);
if (value !== Metric.poi_travel) {
setSelectedPoiId('');
setSelectedTravelMode('');
onPoiMetricChange?.(null);
}
}}

View file

@ -296,9 +296,7 @@ async def test_cancel_task_not_owned(
monkeypatch.setattr("services.task_service.get_user_tasks", lambda email: [])
resp = await async_client.post("/api/cancel_task?task_id=not-mine")
assert resp.status_code == 200
data = resp.json()
assert data["success"] is False
assert resp.status_code == 404
@pytest.mark.asyncio