examples: orchestrator + click CLI (ingest sub-command)
Some checks failed
ci/woodpecker/push/woodpecker Pipeline was canceled

This commit is contained in:
Viktor Barzin 2026-05-28 22:33:41 +00:00
parent a10d7fe2a6
commit 2271d7d5e5
4 changed files with 256 additions and 1 deletions

View file

@ -24,6 +24,7 @@ import numpy as np
from sqlalchemy.ext.asyncio import async_sessionmaker
from fire_planner.db import create_engine_from_env, make_session_factory
from fire_planner.examples.cli import examples_cli
from fire_planner.glide_path import get as get_glide
from fire_planner.ingest.wealthfolio import upsert_snapshots
from fire_planner.ingest.wealthfolio_pg import (
@ -363,5 +364,8 @@ def serve() -> None:
uvicorn.run("fire_planner.app:app", host="0.0.0.0", port=8080)
cli.add_command(examples_cli)
if __name__ == "__main__":
cli()

View file

@ -0,0 +1,151 @@
"""Orchestrator + click CLI for the examples ingest pipeline.
`ingest_subreddit(...)` is the testable async unit: fetch filter
extract (Tier 1+2) upsert return (inserted, skipped) counts.
`ingest_all(...)` fans out across the 12 target subreddits with
`asyncio.gather(..., return_exceptions=True)` so a single sub's failure
doesn't sink the others. Job exits 0 when >=half succeed, else exits 2.
The click commands at the bottom of the file are the entrypoints the
K8s Job + CronJob use.
"""
from __future__ import annotations
import asyncio
import logging
import os
from datetime import date
from decimal import Decimal
from typing import Any, cast
import asyncpraw
import click
import httpx
from fire_planner.db import create_engine_from_env, make_session_factory
from fire_planner.examples.filters import is_candidate
from fire_planner.examples.llm_extract import extract_with_fallback
from fire_planner.examples.praw_source import TopWhen, fetch_top
from fire_planner.examples.service import upsert_example
from fire_planner.fx import fetch_rates
log = logging.getLogger(__name__)
DEFAULT_SUBS: list[str] = [
"financialindependence", "leanfire", "fatFIRE", "coastFIRE",
"baristaFIRE", "ExpatFIRE", "EuropeFIRE", "FIRE_Ind",
"AusFinance", "CanadianFIRE", "UKPersonalFinance",
"financialindependence_UK",
]
async def ingest_subreddit(
session: Any,
reddit: Any,
*,
sub: str,
when: TopWhen,
limit: int,
llama_url: str,
claude_url: str,
claude_bearer: str,
client: httpx.AsyncClient,
fx_rates: dict[str, Decimal],
) -> tuple[int, int]:
inserted = 0
skipped = 0
async for post in fetch_top(reddit, sub, when, limit=limit):
if not is_candidate(post):
skipped += 1
continue
extracted = await extract_with_fallback(
post,
llama_url=llama_url,
claude_url=claude_url,
claude_bearer=claude_bearer,
client=client,
)
if extracted is None:
log.info("dropping %s — both LLM tiers failed", post.reddit_id)
skipped += 1
continue
did_insert = await upsert_example(session, post, extracted, fx_rates)
if did_insert:
inserted += 1
else:
skipped += 1
return inserted, skipped
async def _ingest_all(
when_list: list[TopWhen],
limit: int,
subs: list[str],
) -> tuple[int, int, int]:
engine = create_engine_from_env()
factory = make_session_factory(engine)
rates = await fetch_rates(date.today())
reddit = asyncpraw.Reddit(
client_id=os.environ["REDDIT_CLIENT_ID"],
client_secret=os.environ["REDDIT_CLIENT_SECRET"],
user_agent=os.environ.get("REDDIT_USER_AGENT", "fire-planner/0.1"),
)
llama_url = os.environ["LLAMA_CPP_BASE_URL"]
claude_url = os.environ["CLAUDE_AGENT_SERVICE_URL"]
claude_bearer = os.environ["CLAUDE_AGENT_BEARER"]
async def _one(sub: str, when: TopWhen) -> tuple[int, int]:
async with factory() as session, httpx.AsyncClient() as client:
return await ingest_subreddit(
session, reddit,
sub=sub, when=when, limit=limit,
llama_url=llama_url,
claude_url=claude_url,
claude_bearer=claude_bearer,
client=client,
fx_rates=rates,
)
tasks = [_one(s, w) for s in subs for w in when_list]
results = await asyncio.gather(*tasks, return_exceptions=True)
await reddit.close()
await engine.dispose()
n_succ = sum(1 for r in results if not isinstance(r, Exception))
total_inserted = sum(r[0] for r in results if isinstance(r, tuple))
total_skipped = sum(r[1] for r in results if isinstance(r, tuple))
return total_inserted, total_skipped, n_succ
@click.group(name="examples")
def examples_cli() -> None:
"""Reddit FIRE examples ingest commands."""
@examples_cli.command("ingest")
@click.option("--top", "top_csv", default="all,year",
help="Comma-list of top-of-X windows (all,year,week).")
@click.option("--limit", default=1000, show_default=True)
@click.option("--sub", "subs_csv", default=None,
help="Comma-list of subs (default: all 12).")
def ingest_cmd(top_csv: str, limit: int, subs_csv: str | None) -> None:
"""Bulk one-shot ingest. Used by the K8s Job."""
logging.basicConfig(level=os.environ.get("LOG_LEVEL", "INFO"))
# click hands us free-form strings; narrow to TopWhen at the boundary.
# Invalid values surface as an asyncpraw API error rather than a type error.
when_list = cast(
list[TopWhen],
[w.strip() for w in top_csv.split(",") if w.strip()],
)
subs = [s.strip() for s in subs_csv.split(",")] if subs_csv else DEFAULT_SUBS
inserted, skipped, succ = asyncio.run(_ingest_all(when_list, limit, subs))
total = len(subs) * len(when_list)
log.info("ingest done: inserted=%d skipped=%d sub_runs_succ=%d/%d",
inserted, skipped, succ, total)
# Exit 2 if fewer than half the (sub, when) pairs succeeded.
if succ < (total // 2 + 1):
raise click.exceptions.Exit(code=2)

View file

@ -45,7 +45,7 @@ strict = true
files = ["fire_planner", "tests"]
[[tool.mypy.overrides]]
module = ["respx.*", "pandas.*"]
module = ["respx.*", "pandas.*", "asyncpraw.*"]
ignore_missing_imports = true
[tool.ruff]

100
tests/test_examples_cli.py Normal file
View file

@ -0,0 +1,100 @@
"""End-to-end pipeline test — mocked PRAW + respx-mocked LLM + in-memory DB."""
from __future__ import annotations
import json
from collections.abc import AsyncIterator
from dataclasses import dataclass
from datetime import datetime
from decimal import Decimal
from unittest.mock import AsyncMock, MagicMock
import httpx
import pytest
import respx
from sqlalchemy import select
from sqlalchemy.ext.asyncio import AsyncSession
from fire_planner.db import FireExample
from fire_planner.examples.cli import ingest_subreddit
LLAMA_URL = "http://llama-cpp.llama-cpp.svc.cluster.local:8000/v1/chat/completions"
CLAUDE_URL = "http://claude-agent-service.claude-agent.svc.cluster.local:8080/v1/chat/completions"
@dataclass
class _FakeSub:
id: str
title: str
selftext: str
permalink: str
created_utc: float
def _async_iter(items: list[_FakeSub]) -> AsyncIterator[_FakeSub]:
async def _gen() -> AsyncIterator[_FakeSub]:
for it in items:
yield it
return _gen()
@respx.mock
@pytest.mark.asyncio
async def test_ingest_subreddit_end_to_end(session: AsyncSession) -> None:
fakes = [
_FakeSub(
id="ok1",
title="FIRE at 38 in Manila",
selftext="Net worth £1m, family of 3, retired last year",
permalink="/r/ExpatFIRE/comments/ok1/",
created_utc=datetime(2026, 1, 1).timestamp(),
),
_FakeSub( # filter should drop this — no money signal
id="drop1",
title="Thinking about moving to Lisbon",
selftext="No specifics yet",
permalink="/r/ExpatFIRE/comments/drop1/",
created_utc=datetime(2026, 1, 2).timestamp(),
),
]
mock_subreddit = MagicMock()
mock_subreddit.top = MagicMock(return_value=_async_iter(fakes))
mock_reddit = MagicMock()
mock_reddit.subreddit = AsyncMock(return_value=mock_subreddit)
payload = {
"country": "Philippines",
"city": "Manila",
"portfolio_native": 1000000,
"raw_currency": "GBP",
"age": 38,
"family_size": 3,
"fi_status": "FIRE",
"is_retired": True,
"confidence": 0.8,
}
respx.post(LLAMA_URL).respond(
200,
json={"choices": [{"message": {"content": json.dumps(payload)}}]},
)
fx_rates = {"GBP": Decimal("1"), "USD": Decimal("0.80")}
async with httpx.AsyncClient() as client:
n_inserted, n_skipped = await ingest_subreddit(
session,
mock_reddit,
sub="ExpatFIRE",
when="all",
limit=10,
llama_url=LLAMA_URL,
claude_url=CLAUDE_URL,
claude_bearer="t",
client=client,
fx_rates=fx_rates,
)
assert n_inserted == 1
assert n_skipped == 1
rows = (await session.execute(select(FireExample))).scalars().all()
assert len(rows) == 1
assert rows[0].country == "Philippines"
assert rows[0].portfolio_gbp == Decimal("1000000.00")