from datetime import datetime import time from api.metrics import record_db_query from models.decision import ListingDecision from sqlalchemy import Engine from sqlmodel import Session, select class DecisionRepository: engine: Engine def __init__(self, engine: Engine) -> None: self.engine = engine def upsert_decision( self, user_id: int, listing_id: int, listing_type: str, decision: str, price_at_decision: float | None = None, ) -> ListingDecision: """Create or update a decision. Uses dialect-specific upsert.""" t0 = time.monotonic() with Session(self.engine) as session: now = datetime.utcnow() values = { "user_id": user_id, "listing_id": listing_id, "listing_type": listing_type, "decision": decision, "price_at_decision": price_at_decision, "created_at": now, "updated_at": now, } dialect = self.engine.dialect.name if dialect == "mysql": from sqlalchemy.dialects.mysql import insert as mysql_insert stmt = mysql_insert(ListingDecision).values(**values) stmt = stmt.on_duplicate_key_update( decision=stmt.inserted.decision, price_at_decision=stmt.inserted.price_at_decision, updated_at=stmt.inserted.updated_at, ) else: from sqlalchemy.dialects.sqlite import insert as sqlite_insert stmt = sqlite_insert(ListingDecision).values(**values) stmt = stmt.on_conflict_do_update( index_elements=["user_id", "listing_id", "listing_type"], set_={ "decision": stmt.excluded.decision, "price_at_decision": stmt.excluded.price_at_decision, "updated_at": stmt.excluded.updated_at, }, ) session.execute(stmt) session.commit() # Fetch the result result = session.exec( select(ListingDecision).where( ListingDecision.user_id == user_id, ListingDecision.listing_id == listing_id, ListingDecision.listing_type == listing_type, ) ).first() assert result is not None record_db_query("upsert_decision", "decision", time.monotonic() - t0) return result def get_decisions_for_user(self, user_id: int) -> list[ListingDecision]: t0 = time.monotonic() with Session(self.engine) as session: statement = select(ListingDecision).where( ListingDecision.user_id == user_id ) results = list(session.exec(statement).all()) record_db_query("get_decisions_for_user", "decision", time.monotonic() - t0, len(results)) return results def delete_decision( self, user_id: int, listing_id: int, listing_type: str, ) -> bool: with Session(self.engine) as session: result = session.exec( select(ListingDecision).where( ListingDecision.user_id == user_id, ListingDecision.listing_id == listing_id, ListingDecision.listing_type == listing_type, ) ).first() if result is None: return False session.delete(result) session.commit() return True def get_disliked_listing_ids( self, user_id: int, listing_type: str, ) -> set[int]: t0 = time.monotonic() with Session(self.engine) as session: statement = select(ListingDecision.listing_id).where( ListingDecision.user_id == user_id, ListingDecision.listing_type == listing_type, ListingDecision.decision == "disliked", ) ids = {row for row in session.exec(statement).all()} record_db_query("get_disliked_listing_ids", "decision", time.monotonic() - t0, len(ids)) return ids def get_liked_listing_ids( self, user_id: int, listing_type: str, ) -> set[int]: with Session(self.engine) as session: statement = select(ListingDecision.listing_id).where( ListingDecision.user_id == user_id, ListingDecision.listing_type == listing_type, ListingDecision.decision == "liked", ) return {row for row in session.exec(statement).all()}