99 lines
4.1 KiB
Python
99 lines
4.1 KiB
Python
"""DB-backed cache for fundamental data."""
|
|
|
|
from __future__ import annotations
|
|
|
|
import logging
|
|
from datetime import datetime, timezone, timedelta
|
|
|
|
from sqlalchemy import select
|
|
from sqlalchemy.ext.asyncio import async_sessionmaker
|
|
|
|
from shared.fundamentals.base import FundamentalsProvider
|
|
from shared.schemas.trading import FundamentalsSnapshot
|
|
from shared.models.fundamentals import Fundamentals
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
class CachedFundamentalsProvider:
|
|
"""Wraps a FundamentalsProvider with DB-backed caching."""
|
|
|
|
def __init__(
|
|
self,
|
|
provider: FundamentalsProvider,
|
|
session_factory: async_sessionmaker,
|
|
cache_ttl_hours: int = 24,
|
|
) -> None:
|
|
self._provider = provider
|
|
self._session_factory = session_factory
|
|
self._cache_ttl = timedelta(hours=cache_ttl_hours)
|
|
|
|
async def fetch(self, ticker: str) -> FundamentalsSnapshot | None:
|
|
cached = await self._load_from_db(ticker)
|
|
if cached is not None:
|
|
age = datetime.now(timezone.utc) - cached.fetched_at.replace(tzinfo=timezone.utc)
|
|
if age < self._cache_ttl:
|
|
logger.debug("Cache hit for %s (age=%s)", ticker, age)
|
|
return cached
|
|
logger.debug("Cache stale for %s (age=%s), refreshing", ticker, age)
|
|
|
|
result = await self._provider.fetch(ticker)
|
|
if result is not None:
|
|
await self._save_to_db(result)
|
|
return result
|
|
|
|
async def _load_from_db(self, ticker: str) -> FundamentalsSnapshot | None:
|
|
try:
|
|
async with self._session_factory() as session:
|
|
stmt = select(Fundamentals).where(Fundamentals.ticker == ticker)
|
|
row = (await session.execute(stmt)).scalar_one_or_none()
|
|
if row is None:
|
|
return None
|
|
return FundamentalsSnapshot(
|
|
ticker=row.ticker,
|
|
eps_ttm=row.eps_ttm,
|
|
pe_ratio=row.pe_ratio,
|
|
peg_ratio=row.peg_ratio,
|
|
revenue_growth_yoy=row.revenue_growth_yoy,
|
|
profit_margin=row.profit_margin,
|
|
debt_to_equity=row.debt_to_equity,
|
|
market_cap=row.market_cap,
|
|
fetched_at=row.fetched_at,
|
|
)
|
|
except Exception:
|
|
logger.exception("Failed to load fundamentals from DB for %s", ticker)
|
|
return None
|
|
|
|
async def _save_to_db(self, snapshot: FundamentalsSnapshot) -> None:
|
|
try:
|
|
async with self._session_factory() as session:
|
|
stmt = select(Fundamentals).where(Fundamentals.ticker == snapshot.ticker)
|
|
existing = (await session.execute(stmt)).scalar_one_or_none()
|
|
|
|
if existing:
|
|
existing.eps_ttm = snapshot.eps_ttm
|
|
existing.pe_ratio = snapshot.pe_ratio
|
|
existing.peg_ratio = snapshot.peg_ratio
|
|
existing.revenue_growth_yoy = snapshot.revenue_growth_yoy
|
|
existing.profit_margin = snapshot.profit_margin
|
|
existing.debt_to_equity = snapshot.debt_to_equity
|
|
existing.market_cap = snapshot.market_cap
|
|
existing.fetched_at = snapshot.fetched_at
|
|
else:
|
|
row = Fundamentals(
|
|
ticker=snapshot.ticker,
|
|
eps_ttm=snapshot.eps_ttm,
|
|
pe_ratio=snapshot.pe_ratio,
|
|
peg_ratio=snapshot.peg_ratio,
|
|
revenue_growth_yoy=snapshot.revenue_growth_yoy,
|
|
profit_margin=snapshot.profit_margin,
|
|
debt_to_equity=snapshot.debt_to_equity,
|
|
market_cap=snapshot.market_cap,
|
|
fetched_at=snapshot.fetched_at,
|
|
)
|
|
session.add(row)
|
|
|
|
await session.commit()
|
|
logger.debug("Saved fundamentals for %s to DB", snapshot.ticker)
|
|
except Exception:
|
|
logger.exception("Failed to save fundamentals for %s to DB", snapshot.ticker)
|