examples: orchestrator + click CLI (ingest sub-command)
Some checks failed
ci/woodpecker/push/woodpecker Pipeline was canceled
Some checks failed
ci/woodpecker/push/woodpecker Pipeline was canceled
This commit is contained in:
parent
a10d7fe2a6
commit
2271d7d5e5
4 changed files with 256 additions and 1 deletions
|
|
@ -24,6 +24,7 @@ import numpy as np
|
|||
from sqlalchemy.ext.asyncio import async_sessionmaker
|
||||
|
||||
from fire_planner.db import create_engine_from_env, make_session_factory
|
||||
from fire_planner.examples.cli import examples_cli
|
||||
from fire_planner.glide_path import get as get_glide
|
||||
from fire_planner.ingest.wealthfolio import upsert_snapshots
|
||||
from fire_planner.ingest.wealthfolio_pg import (
|
||||
|
|
@ -363,5 +364,8 @@ def serve() -> None:
|
|||
uvicorn.run("fire_planner.app:app", host="0.0.0.0", port=8080)
|
||||
|
||||
|
||||
cli.add_command(examples_cli)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
cli()
|
||||
|
|
|
|||
151
fire_planner/examples/cli.py
Normal file
151
fire_planner/examples/cli.py
Normal file
|
|
@ -0,0 +1,151 @@
|
|||
"""Orchestrator + click CLI for the examples ingest pipeline.
|
||||
|
||||
`ingest_subreddit(...)` is the testable async unit: fetch → filter →
|
||||
extract (Tier 1+2) → upsert → return (inserted, skipped) counts.
|
||||
|
||||
`ingest_all(...)` fans out across the 12 target subreddits with
|
||||
`asyncio.gather(..., return_exceptions=True)` so a single sub's failure
|
||||
doesn't sink the others. Job exits 0 when >=half succeed, else exits 2.
|
||||
|
||||
The click commands at the bottom of the file are the entrypoints the
|
||||
K8s Job + CronJob use.
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
import os
|
||||
from datetime import date
|
||||
from decimal import Decimal
|
||||
from typing import Any, cast
|
||||
|
||||
import asyncpraw
|
||||
import click
|
||||
import httpx
|
||||
|
||||
from fire_planner.db import create_engine_from_env, make_session_factory
|
||||
from fire_planner.examples.filters import is_candidate
|
||||
from fire_planner.examples.llm_extract import extract_with_fallback
|
||||
from fire_planner.examples.praw_source import TopWhen, fetch_top
|
||||
from fire_planner.examples.service import upsert_example
|
||||
from fire_planner.fx import fetch_rates
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
|
||||
DEFAULT_SUBS: list[str] = [
|
||||
"financialindependence", "leanfire", "fatFIRE", "coastFIRE",
|
||||
"baristaFIRE", "ExpatFIRE", "EuropeFIRE", "FIRE_Ind",
|
||||
"AusFinance", "CanadianFIRE", "UKPersonalFinance",
|
||||
"financialindependence_UK",
|
||||
]
|
||||
|
||||
|
||||
async def ingest_subreddit(
|
||||
session: Any,
|
||||
reddit: Any,
|
||||
*,
|
||||
sub: str,
|
||||
when: TopWhen,
|
||||
limit: int,
|
||||
llama_url: str,
|
||||
claude_url: str,
|
||||
claude_bearer: str,
|
||||
client: httpx.AsyncClient,
|
||||
fx_rates: dict[str, Decimal],
|
||||
) -> tuple[int, int]:
|
||||
inserted = 0
|
||||
skipped = 0
|
||||
async for post in fetch_top(reddit, sub, when, limit=limit):
|
||||
if not is_candidate(post):
|
||||
skipped += 1
|
||||
continue
|
||||
extracted = await extract_with_fallback(
|
||||
post,
|
||||
llama_url=llama_url,
|
||||
claude_url=claude_url,
|
||||
claude_bearer=claude_bearer,
|
||||
client=client,
|
||||
)
|
||||
if extracted is None:
|
||||
log.info("dropping %s — both LLM tiers failed", post.reddit_id)
|
||||
skipped += 1
|
||||
continue
|
||||
did_insert = await upsert_example(session, post, extracted, fx_rates)
|
||||
if did_insert:
|
||||
inserted += 1
|
||||
else:
|
||||
skipped += 1
|
||||
return inserted, skipped
|
||||
|
||||
|
||||
async def _ingest_all(
|
||||
when_list: list[TopWhen],
|
||||
limit: int,
|
||||
subs: list[str],
|
||||
) -> tuple[int, int, int]:
|
||||
engine = create_engine_from_env()
|
||||
factory = make_session_factory(engine)
|
||||
rates = await fetch_rates(date.today())
|
||||
|
||||
reddit = asyncpraw.Reddit(
|
||||
client_id=os.environ["REDDIT_CLIENT_ID"],
|
||||
client_secret=os.environ["REDDIT_CLIENT_SECRET"],
|
||||
user_agent=os.environ.get("REDDIT_USER_AGENT", "fire-planner/0.1"),
|
||||
)
|
||||
llama_url = os.environ["LLAMA_CPP_BASE_URL"]
|
||||
claude_url = os.environ["CLAUDE_AGENT_SERVICE_URL"]
|
||||
claude_bearer = os.environ["CLAUDE_AGENT_BEARER"]
|
||||
|
||||
async def _one(sub: str, when: TopWhen) -> tuple[int, int]:
|
||||
async with factory() as session, httpx.AsyncClient() as client:
|
||||
return await ingest_subreddit(
|
||||
session, reddit,
|
||||
sub=sub, when=when, limit=limit,
|
||||
llama_url=llama_url,
|
||||
claude_url=claude_url,
|
||||
claude_bearer=claude_bearer,
|
||||
client=client,
|
||||
fx_rates=rates,
|
||||
)
|
||||
|
||||
tasks = [_one(s, w) for s in subs for w in when_list]
|
||||
results = await asyncio.gather(*tasks, return_exceptions=True)
|
||||
await reddit.close()
|
||||
await engine.dispose()
|
||||
|
||||
n_succ = sum(1 for r in results if not isinstance(r, Exception))
|
||||
total_inserted = sum(r[0] for r in results if isinstance(r, tuple))
|
||||
total_skipped = sum(r[1] for r in results if isinstance(r, tuple))
|
||||
return total_inserted, total_skipped, n_succ
|
||||
|
||||
|
||||
@click.group(name="examples")
|
||||
def examples_cli() -> None:
|
||||
"""Reddit FIRE examples ingest commands."""
|
||||
|
||||
|
||||
@examples_cli.command("ingest")
|
||||
@click.option("--top", "top_csv", default="all,year",
|
||||
help="Comma-list of top-of-X windows (all,year,week).")
|
||||
@click.option("--limit", default=1000, show_default=True)
|
||||
@click.option("--sub", "subs_csv", default=None,
|
||||
help="Comma-list of subs (default: all 12).")
|
||||
def ingest_cmd(top_csv: str, limit: int, subs_csv: str | None) -> None:
|
||||
"""Bulk one-shot ingest. Used by the K8s Job."""
|
||||
logging.basicConfig(level=os.environ.get("LOG_LEVEL", "INFO"))
|
||||
# click hands us free-form strings; narrow to TopWhen at the boundary.
|
||||
# Invalid values surface as an asyncpraw API error rather than a type error.
|
||||
when_list = cast(
|
||||
list[TopWhen],
|
||||
[w.strip() for w in top_csv.split(",") if w.strip()],
|
||||
)
|
||||
subs = [s.strip() for s in subs_csv.split(",")] if subs_csv else DEFAULT_SUBS
|
||||
|
||||
inserted, skipped, succ = asyncio.run(_ingest_all(when_list, limit, subs))
|
||||
total = len(subs) * len(when_list)
|
||||
log.info("ingest done: inserted=%d skipped=%d sub_runs_succ=%d/%d",
|
||||
inserted, skipped, succ, total)
|
||||
|
||||
# Exit 2 if fewer than half the (sub, when) pairs succeeded.
|
||||
if succ < (total // 2 + 1):
|
||||
raise click.exceptions.Exit(code=2)
|
||||
|
|
@ -45,7 +45,7 @@ strict = true
|
|||
files = ["fire_planner", "tests"]
|
||||
|
||||
[[tool.mypy.overrides]]
|
||||
module = ["respx.*", "pandas.*"]
|
||||
module = ["respx.*", "pandas.*", "asyncpraw.*"]
|
||||
ignore_missing_imports = true
|
||||
|
||||
[tool.ruff]
|
||||
|
|
|
|||
100
tests/test_examples_cli.py
Normal file
100
tests/test_examples_cli.py
Normal file
|
|
@ -0,0 +1,100 @@
|
|||
"""End-to-end pipeline test — mocked PRAW + respx-mocked LLM + in-memory DB."""
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
from collections.abc import AsyncIterator
|
||||
from dataclasses import dataclass
|
||||
from datetime import datetime
|
||||
from decimal import Decimal
|
||||
from unittest.mock import AsyncMock, MagicMock
|
||||
|
||||
import httpx
|
||||
import pytest
|
||||
import respx
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from fire_planner.db import FireExample
|
||||
from fire_planner.examples.cli import ingest_subreddit
|
||||
|
||||
LLAMA_URL = "http://llama-cpp.llama-cpp.svc.cluster.local:8000/v1/chat/completions"
|
||||
CLAUDE_URL = "http://claude-agent-service.claude-agent.svc.cluster.local:8080/v1/chat/completions"
|
||||
|
||||
|
||||
@dataclass
|
||||
class _FakeSub:
|
||||
id: str
|
||||
title: str
|
||||
selftext: str
|
||||
permalink: str
|
||||
created_utc: float
|
||||
|
||||
|
||||
def _async_iter(items: list[_FakeSub]) -> AsyncIterator[_FakeSub]:
|
||||
async def _gen() -> AsyncIterator[_FakeSub]:
|
||||
for it in items:
|
||||
yield it
|
||||
return _gen()
|
||||
|
||||
|
||||
@respx.mock
|
||||
@pytest.mark.asyncio
|
||||
async def test_ingest_subreddit_end_to_end(session: AsyncSession) -> None:
|
||||
fakes = [
|
||||
_FakeSub(
|
||||
id="ok1",
|
||||
title="FIRE at 38 in Manila",
|
||||
selftext="Net worth £1m, family of 3, retired last year",
|
||||
permalink="/r/ExpatFIRE/comments/ok1/",
|
||||
created_utc=datetime(2026, 1, 1).timestamp(),
|
||||
),
|
||||
_FakeSub( # filter should drop this — no money signal
|
||||
id="drop1",
|
||||
title="Thinking about moving to Lisbon",
|
||||
selftext="No specifics yet",
|
||||
permalink="/r/ExpatFIRE/comments/drop1/",
|
||||
created_utc=datetime(2026, 1, 2).timestamp(),
|
||||
),
|
||||
]
|
||||
mock_subreddit = MagicMock()
|
||||
mock_subreddit.top = MagicMock(return_value=_async_iter(fakes))
|
||||
mock_reddit = MagicMock()
|
||||
mock_reddit.subreddit = AsyncMock(return_value=mock_subreddit)
|
||||
|
||||
payload = {
|
||||
"country": "Philippines",
|
||||
"city": "Manila",
|
||||
"portfolio_native": 1000000,
|
||||
"raw_currency": "GBP",
|
||||
"age": 38,
|
||||
"family_size": 3,
|
||||
"fi_status": "FIRE",
|
||||
"is_retired": True,
|
||||
"confidence": 0.8,
|
||||
}
|
||||
respx.post(LLAMA_URL).respond(
|
||||
200,
|
||||
json={"choices": [{"message": {"content": json.dumps(payload)}}]},
|
||||
)
|
||||
|
||||
fx_rates = {"GBP": Decimal("1"), "USD": Decimal("0.80")}
|
||||
async with httpx.AsyncClient() as client:
|
||||
n_inserted, n_skipped = await ingest_subreddit(
|
||||
session,
|
||||
mock_reddit,
|
||||
sub="ExpatFIRE",
|
||||
when="all",
|
||||
limit=10,
|
||||
llama_url=LLAMA_URL,
|
||||
claude_url=CLAUDE_URL,
|
||||
claude_bearer="t",
|
||||
client=client,
|
||||
fx_rates=fx_rates,
|
||||
)
|
||||
|
||||
assert n_inserted == 1
|
||||
assert n_skipped == 1
|
||||
rows = (await session.execute(select(FireExample))).scalars().all()
|
||||
assert len(rows) == 1
|
||||
assert rows[0].country == "Philippines"
|
||||
assert rows[0].portfolio_gbp == Decimal("1000000.00")
|
||||
Loading…
Add table
Add a link
Reference in a new issue