"""DB-backed cache for fundamental data.""" from __future__ import annotations import logging from datetime import datetime, timezone, timedelta from sqlalchemy import select from sqlalchemy.ext.asyncio import async_sessionmaker from shared.fundamentals.base import FundamentalsProvider from shared.schemas.trading import FundamentalsSnapshot from shared.models.fundamentals import Fundamentals logger = logging.getLogger(__name__) class CachedFundamentalsProvider: """Wraps a FundamentalsProvider with DB-backed caching.""" def __init__( self, provider: FundamentalsProvider, session_factory: async_sessionmaker, cache_ttl_hours: int = 24, ) -> None: self._provider = provider self._session_factory = session_factory self._cache_ttl = timedelta(hours=cache_ttl_hours) async def fetch(self, ticker: str) -> FundamentalsSnapshot | None: cached = await self._load_from_db(ticker) if cached is not None: age = datetime.now(timezone.utc) - cached.fetched_at.replace(tzinfo=timezone.utc) if age < self._cache_ttl: logger.debug("Cache hit for %s (age=%s)", ticker, age) return cached logger.debug("Cache stale for %s (age=%s), refreshing", ticker, age) result = await self._provider.fetch(ticker) if result is not None: await self._save_to_db(result) return result async def _load_from_db(self, ticker: str) -> FundamentalsSnapshot | None: try: async with self._session_factory() as session: stmt = select(Fundamentals).where(Fundamentals.ticker == ticker) row = (await session.execute(stmt)).scalar_one_or_none() if row is None: return None return FundamentalsSnapshot( ticker=row.ticker, eps_ttm=row.eps_ttm, pe_ratio=row.pe_ratio, peg_ratio=row.peg_ratio, revenue_growth_yoy=row.revenue_growth_yoy, profit_margin=row.profit_margin, debt_to_equity=row.debt_to_equity, market_cap=row.market_cap, fetched_at=row.fetched_at, ) except Exception: logger.exception("Failed to load fundamentals from DB for %s", ticker) return None async def _save_to_db(self, snapshot: FundamentalsSnapshot) -> None: try: async with self._session_factory() as session: stmt = select(Fundamentals).where(Fundamentals.ticker == snapshot.ticker) existing = (await session.execute(stmt)).scalar_one_or_none() if existing: existing.eps_ttm = snapshot.eps_ttm existing.pe_ratio = snapshot.pe_ratio existing.peg_ratio = snapshot.peg_ratio existing.revenue_growth_yoy = snapshot.revenue_growth_yoy existing.profit_margin = snapshot.profit_margin existing.debt_to_equity = snapshot.debt_to_equity existing.market_cap = snapshot.market_cap existing.fetched_at = snapshot.fetched_at else: row = Fundamentals( ticker=snapshot.ticker, eps_ttm=snapshot.eps_ttm, pe_ratio=snapshot.pe_ratio, peg_ratio=snapshot.peg_ratio, revenue_growth_yoy=snapshot.revenue_growth_yoy, profit_margin=snapshot.profit_margin, debt_to_equity=snapshot.debt_to_equity, market_cap=snapshot.market_cap, fetched_at=snapshot.fetched_at, ) session.add(row) await session.commit() logger.debug("Saved fundamentals for %s to DB", snapshot.ticker) except Exception: logger.exception("Failed to save fundamentals for %s to DB", snapshot.ticker)