152 lines
5.1 KiB
Python
152 lines
5.1 KiB
Python
|
|
"""Orchestrator + click CLI for the examples ingest pipeline.
|
||
|
|
|
||
|
|
`ingest_subreddit(...)` is the testable async unit: fetch → filter →
|
||
|
|
extract (Tier 1+2) → upsert → return (inserted, skipped) counts.
|
||
|
|
|
||
|
|
`ingest_all(...)` fans out across the 12 target subreddits with
|
||
|
|
`asyncio.gather(..., return_exceptions=True)` so a single sub's failure
|
||
|
|
doesn't sink the others. Job exits 0 when >=half succeed, else exits 2.
|
||
|
|
|
||
|
|
The click commands at the bottom of the file are the entrypoints the
|
||
|
|
K8s Job + CronJob use.
|
||
|
|
"""
|
||
|
|
from __future__ import annotations
|
||
|
|
|
||
|
|
import asyncio
|
||
|
|
import logging
|
||
|
|
import os
|
||
|
|
from datetime import date
|
||
|
|
from decimal import Decimal
|
||
|
|
from typing import Any, cast
|
||
|
|
|
||
|
|
import asyncpraw
|
||
|
|
import click
|
||
|
|
import httpx
|
||
|
|
|
||
|
|
from fire_planner.db import create_engine_from_env, make_session_factory
|
||
|
|
from fire_planner.examples.filters import is_candidate
|
||
|
|
from fire_planner.examples.llm_extract import extract_with_fallback
|
||
|
|
from fire_planner.examples.praw_source import TopWhen, fetch_top
|
||
|
|
from fire_planner.examples.service import upsert_example
|
||
|
|
from fire_planner.fx import fetch_rates
|
||
|
|
|
||
|
|
log = logging.getLogger(__name__)
|
||
|
|
|
||
|
|
DEFAULT_SUBS: list[str] = [
|
||
|
|
"financialindependence", "leanfire", "fatFIRE", "coastFIRE",
|
||
|
|
"baristaFIRE", "ExpatFIRE", "EuropeFIRE", "FIRE_Ind",
|
||
|
|
"AusFinance", "CanadianFIRE", "UKPersonalFinance",
|
||
|
|
"financialindependence_UK",
|
||
|
|
]
|
||
|
|
|
||
|
|
|
||
|
|
async def ingest_subreddit(
|
||
|
|
session: Any,
|
||
|
|
reddit: Any,
|
||
|
|
*,
|
||
|
|
sub: str,
|
||
|
|
when: TopWhen,
|
||
|
|
limit: int,
|
||
|
|
llama_url: str,
|
||
|
|
claude_url: str,
|
||
|
|
claude_bearer: str,
|
||
|
|
client: httpx.AsyncClient,
|
||
|
|
fx_rates: dict[str, Decimal],
|
||
|
|
) -> tuple[int, int]:
|
||
|
|
inserted = 0
|
||
|
|
skipped = 0
|
||
|
|
async for post in fetch_top(reddit, sub, when, limit=limit):
|
||
|
|
if not is_candidate(post):
|
||
|
|
skipped += 1
|
||
|
|
continue
|
||
|
|
extracted = await extract_with_fallback(
|
||
|
|
post,
|
||
|
|
llama_url=llama_url,
|
||
|
|
claude_url=claude_url,
|
||
|
|
claude_bearer=claude_bearer,
|
||
|
|
client=client,
|
||
|
|
)
|
||
|
|
if extracted is None:
|
||
|
|
log.info("dropping %s — both LLM tiers failed", post.reddit_id)
|
||
|
|
skipped += 1
|
||
|
|
continue
|
||
|
|
did_insert = await upsert_example(session, post, extracted, fx_rates)
|
||
|
|
if did_insert:
|
||
|
|
inserted += 1
|
||
|
|
else:
|
||
|
|
skipped += 1
|
||
|
|
return inserted, skipped
|
||
|
|
|
||
|
|
|
||
|
|
async def _ingest_all(
|
||
|
|
when_list: list[TopWhen],
|
||
|
|
limit: int,
|
||
|
|
subs: list[str],
|
||
|
|
) -> tuple[int, int, int]:
|
||
|
|
engine = create_engine_from_env()
|
||
|
|
factory = make_session_factory(engine)
|
||
|
|
rates = await fetch_rates(date.today())
|
||
|
|
|
||
|
|
reddit = asyncpraw.Reddit(
|
||
|
|
client_id=os.environ["REDDIT_CLIENT_ID"],
|
||
|
|
client_secret=os.environ["REDDIT_CLIENT_SECRET"],
|
||
|
|
user_agent=os.environ.get("REDDIT_USER_AGENT", "fire-planner/0.1"),
|
||
|
|
)
|
||
|
|
llama_url = os.environ["LLAMA_CPP_BASE_URL"]
|
||
|
|
claude_url = os.environ["CLAUDE_AGENT_SERVICE_URL"]
|
||
|
|
claude_bearer = os.environ["CLAUDE_AGENT_BEARER"]
|
||
|
|
|
||
|
|
async def _one(sub: str, when: TopWhen) -> tuple[int, int]:
|
||
|
|
async with factory() as session, httpx.AsyncClient() as client:
|
||
|
|
return await ingest_subreddit(
|
||
|
|
session, reddit,
|
||
|
|
sub=sub, when=when, limit=limit,
|
||
|
|
llama_url=llama_url,
|
||
|
|
claude_url=claude_url,
|
||
|
|
claude_bearer=claude_bearer,
|
||
|
|
client=client,
|
||
|
|
fx_rates=rates,
|
||
|
|
)
|
||
|
|
|
||
|
|
tasks = [_one(s, w) for s in subs for w in when_list]
|
||
|
|
results = await asyncio.gather(*tasks, return_exceptions=True)
|
||
|
|
await reddit.close()
|
||
|
|
await engine.dispose()
|
||
|
|
|
||
|
|
n_succ = sum(1 for r in results if not isinstance(r, Exception))
|
||
|
|
total_inserted = sum(r[0] for r in results if isinstance(r, tuple))
|
||
|
|
total_skipped = sum(r[1] for r in results if isinstance(r, tuple))
|
||
|
|
return total_inserted, total_skipped, n_succ
|
||
|
|
|
||
|
|
|
||
|
|
@click.group(name="examples")
|
||
|
|
def examples_cli() -> None:
|
||
|
|
"""Reddit FIRE examples ingest commands."""
|
||
|
|
|
||
|
|
|
||
|
|
@examples_cli.command("ingest")
|
||
|
|
@click.option("--top", "top_csv", default="all,year",
|
||
|
|
help="Comma-list of top-of-X windows (all,year,week).")
|
||
|
|
@click.option("--limit", default=1000, show_default=True)
|
||
|
|
@click.option("--sub", "subs_csv", default=None,
|
||
|
|
help="Comma-list of subs (default: all 12).")
|
||
|
|
def ingest_cmd(top_csv: str, limit: int, subs_csv: str | None) -> None:
|
||
|
|
"""Bulk one-shot ingest. Used by the K8s Job."""
|
||
|
|
logging.basicConfig(level=os.environ.get("LOG_LEVEL", "INFO"))
|
||
|
|
# click hands us free-form strings; narrow to TopWhen at the boundary.
|
||
|
|
# Invalid values surface as an asyncpraw API error rather than a type error.
|
||
|
|
when_list = cast(
|
||
|
|
list[TopWhen],
|
||
|
|
[w.strip() for w in top_csv.split(",") if w.strip()],
|
||
|
|
)
|
||
|
|
subs = [s.strip() for s in subs_csv.split(",")] if subs_csv else DEFAULT_SUBS
|
||
|
|
|
||
|
|
inserted, skipped, succ = asyncio.run(_ingest_all(when_list, limit, subs))
|
||
|
|
total = len(subs) * len(when_list)
|
||
|
|
log.info("ingest done: inserted=%d skipped=%d sub_runs_succ=%d/%d",
|
||
|
|
inserted, skipped, succ, total)
|
||
|
|
|
||
|
|
# Exit 2 if fewer than half the (sub, when) pairs succeeded.
|
||
|
|
if succ < (total // 2 + 1):
|
||
|
|
raise click.exceptions.Exit(code=2)
|