From 2271d7d5e5aca9df9c83a06c72212244b1f6a1de Mon Sep 17 00:00:00 2001 From: Viktor Barzin Date: Thu, 28 May 2026 22:33:41 +0000 Subject: [PATCH] examples: orchestrator + click CLI (ingest sub-command) --- fire_planner/__main__.py | 4 + fire_planner/examples/cli.py | 151 +++++++++++++++++++++++++++++++++++ pyproject.toml | 2 +- tests/test_examples_cli.py | 100 +++++++++++++++++++++++ 4 files changed, 256 insertions(+), 1 deletion(-) create mode 100644 fire_planner/examples/cli.py create mode 100644 tests/test_examples_cli.py diff --git a/fire_planner/__main__.py b/fire_planner/__main__.py index 4aa7992..77eb3e3 100644 --- a/fire_planner/__main__.py +++ b/fire_planner/__main__.py @@ -24,6 +24,7 @@ import numpy as np from sqlalchemy.ext.asyncio import async_sessionmaker from fire_planner.db import create_engine_from_env, make_session_factory +from fire_planner.examples.cli import examples_cli from fire_planner.glide_path import get as get_glide from fire_planner.ingest.wealthfolio import upsert_snapshots from fire_planner.ingest.wealthfolio_pg import ( @@ -363,5 +364,8 @@ def serve() -> None: uvicorn.run("fire_planner.app:app", host="0.0.0.0", port=8080) +cli.add_command(examples_cli) + + if __name__ == "__main__": cli() diff --git a/fire_planner/examples/cli.py b/fire_planner/examples/cli.py new file mode 100644 index 0000000..19ded31 --- /dev/null +++ b/fire_planner/examples/cli.py @@ -0,0 +1,151 @@ +"""Orchestrator + click CLI for the examples ingest pipeline. + +`ingest_subreddit(...)` is the testable async unit: fetch → filter → +extract (Tier 1+2) → upsert → return (inserted, skipped) counts. + +`ingest_all(...)` fans out across the 12 target subreddits with +`asyncio.gather(..., return_exceptions=True)` so a single sub's failure +doesn't sink the others. Job exits 0 when >=half succeed, else exits 2. + +The click commands at the bottom of the file are the entrypoints the +K8s Job + CronJob use. +""" +from __future__ import annotations + +import asyncio +import logging +import os +from datetime import date +from decimal import Decimal +from typing import Any, cast + +import asyncpraw +import click +import httpx + +from fire_planner.db import create_engine_from_env, make_session_factory +from fire_planner.examples.filters import is_candidate +from fire_planner.examples.llm_extract import extract_with_fallback +from fire_planner.examples.praw_source import TopWhen, fetch_top +from fire_planner.examples.service import upsert_example +from fire_planner.fx import fetch_rates + +log = logging.getLogger(__name__) + +DEFAULT_SUBS: list[str] = [ + "financialindependence", "leanfire", "fatFIRE", "coastFIRE", + "baristaFIRE", "ExpatFIRE", "EuropeFIRE", "FIRE_Ind", + "AusFinance", "CanadianFIRE", "UKPersonalFinance", + "financialindependence_UK", +] + + +async def ingest_subreddit( + session: Any, + reddit: Any, + *, + sub: str, + when: TopWhen, + limit: int, + llama_url: str, + claude_url: str, + claude_bearer: str, + client: httpx.AsyncClient, + fx_rates: dict[str, Decimal], +) -> tuple[int, int]: + inserted = 0 + skipped = 0 + async for post in fetch_top(reddit, sub, when, limit=limit): + if not is_candidate(post): + skipped += 1 + continue + extracted = await extract_with_fallback( + post, + llama_url=llama_url, + claude_url=claude_url, + claude_bearer=claude_bearer, + client=client, + ) + if extracted is None: + log.info("dropping %s — both LLM tiers failed", post.reddit_id) + skipped += 1 + continue + did_insert = await upsert_example(session, post, extracted, fx_rates) + if did_insert: + inserted += 1 + else: + skipped += 1 + return inserted, skipped + + +async def _ingest_all( + when_list: list[TopWhen], + limit: int, + subs: list[str], +) -> tuple[int, int, int]: + engine = create_engine_from_env() + factory = make_session_factory(engine) + rates = await fetch_rates(date.today()) + + reddit = asyncpraw.Reddit( + client_id=os.environ["REDDIT_CLIENT_ID"], + client_secret=os.environ["REDDIT_CLIENT_SECRET"], + user_agent=os.environ.get("REDDIT_USER_AGENT", "fire-planner/0.1"), + ) + llama_url = os.environ["LLAMA_CPP_BASE_URL"] + claude_url = os.environ["CLAUDE_AGENT_SERVICE_URL"] + claude_bearer = os.environ["CLAUDE_AGENT_BEARER"] + + async def _one(sub: str, when: TopWhen) -> tuple[int, int]: + async with factory() as session, httpx.AsyncClient() as client: + return await ingest_subreddit( + session, reddit, + sub=sub, when=when, limit=limit, + llama_url=llama_url, + claude_url=claude_url, + claude_bearer=claude_bearer, + client=client, + fx_rates=rates, + ) + + tasks = [_one(s, w) for s in subs for w in when_list] + results = await asyncio.gather(*tasks, return_exceptions=True) + await reddit.close() + await engine.dispose() + + n_succ = sum(1 for r in results if not isinstance(r, Exception)) + total_inserted = sum(r[0] for r in results if isinstance(r, tuple)) + total_skipped = sum(r[1] for r in results if isinstance(r, tuple)) + return total_inserted, total_skipped, n_succ + + +@click.group(name="examples") +def examples_cli() -> None: + """Reddit FIRE examples ingest commands.""" + + +@examples_cli.command("ingest") +@click.option("--top", "top_csv", default="all,year", + help="Comma-list of top-of-X windows (all,year,week).") +@click.option("--limit", default=1000, show_default=True) +@click.option("--sub", "subs_csv", default=None, + help="Comma-list of subs (default: all 12).") +def ingest_cmd(top_csv: str, limit: int, subs_csv: str | None) -> None: + """Bulk one-shot ingest. Used by the K8s Job.""" + logging.basicConfig(level=os.environ.get("LOG_LEVEL", "INFO")) + # click hands us free-form strings; narrow to TopWhen at the boundary. + # Invalid values surface as an asyncpraw API error rather than a type error. + when_list = cast( + list[TopWhen], + [w.strip() for w in top_csv.split(",") if w.strip()], + ) + subs = [s.strip() for s in subs_csv.split(",")] if subs_csv else DEFAULT_SUBS + + inserted, skipped, succ = asyncio.run(_ingest_all(when_list, limit, subs)) + total = len(subs) * len(when_list) + log.info("ingest done: inserted=%d skipped=%d sub_runs_succ=%d/%d", + inserted, skipped, succ, total) + + # Exit 2 if fewer than half the (sub, when) pairs succeeded. + if succ < (total // 2 + 1): + raise click.exceptions.Exit(code=2) diff --git a/pyproject.toml b/pyproject.toml index bf27f01..981f1d8 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -45,7 +45,7 @@ strict = true files = ["fire_planner", "tests"] [[tool.mypy.overrides]] -module = ["respx.*", "pandas.*"] +module = ["respx.*", "pandas.*", "asyncpraw.*"] ignore_missing_imports = true [tool.ruff] diff --git a/tests/test_examples_cli.py b/tests/test_examples_cli.py new file mode 100644 index 0000000..89086eb --- /dev/null +++ b/tests/test_examples_cli.py @@ -0,0 +1,100 @@ +"""End-to-end pipeline test — mocked PRAW + respx-mocked LLM + in-memory DB.""" +from __future__ import annotations + +import json +from collections.abc import AsyncIterator +from dataclasses import dataclass +from datetime import datetime +from decimal import Decimal +from unittest.mock import AsyncMock, MagicMock + +import httpx +import pytest +import respx +from sqlalchemy import select +from sqlalchemy.ext.asyncio import AsyncSession + +from fire_planner.db import FireExample +from fire_planner.examples.cli import ingest_subreddit + +LLAMA_URL = "http://llama-cpp.llama-cpp.svc.cluster.local:8000/v1/chat/completions" +CLAUDE_URL = "http://claude-agent-service.claude-agent.svc.cluster.local:8080/v1/chat/completions" + + +@dataclass +class _FakeSub: + id: str + title: str + selftext: str + permalink: str + created_utc: float + + +def _async_iter(items: list[_FakeSub]) -> AsyncIterator[_FakeSub]: + async def _gen() -> AsyncIterator[_FakeSub]: + for it in items: + yield it + return _gen() + + +@respx.mock +@pytest.mark.asyncio +async def test_ingest_subreddit_end_to_end(session: AsyncSession) -> None: + fakes = [ + _FakeSub( + id="ok1", + title="FIRE at 38 in Manila", + selftext="Net worth £1m, family of 3, retired last year", + permalink="/r/ExpatFIRE/comments/ok1/", + created_utc=datetime(2026, 1, 1).timestamp(), + ), + _FakeSub( # filter should drop this — no money signal + id="drop1", + title="Thinking about moving to Lisbon", + selftext="No specifics yet", + permalink="/r/ExpatFIRE/comments/drop1/", + created_utc=datetime(2026, 1, 2).timestamp(), + ), + ] + mock_subreddit = MagicMock() + mock_subreddit.top = MagicMock(return_value=_async_iter(fakes)) + mock_reddit = MagicMock() + mock_reddit.subreddit = AsyncMock(return_value=mock_subreddit) + + payload = { + "country": "Philippines", + "city": "Manila", + "portfolio_native": 1000000, + "raw_currency": "GBP", + "age": 38, + "family_size": 3, + "fi_status": "FIRE", + "is_retired": True, + "confidence": 0.8, + } + respx.post(LLAMA_URL).respond( + 200, + json={"choices": [{"message": {"content": json.dumps(payload)}}]}, + ) + + fx_rates = {"GBP": Decimal("1"), "USD": Decimal("0.80")} + async with httpx.AsyncClient() as client: + n_inserted, n_skipped = await ingest_subreddit( + session, + mock_reddit, + sub="ExpatFIRE", + when="all", + limit=10, + llama_url=LLAMA_URL, + claude_url=CLAUDE_URL, + claude_bearer="t", + client=client, + fx_rates=fx_rates, + ) + + assert n_inserted == 1 + assert n_skipped == 1 + rows = (await session.execute(select(FireExample))).scalars().all() + assert len(rows) == 1 + assert rows[0].country == "Philippines" + assert rows[0].portfolio_gbp == Decimal("1000000.00")