diff --git a/repositories/decision_repository.py b/repositories/decision_repository.py new file mode 100644 index 0000000..fbed9df --- /dev/null +++ b/repositories/decision_repository.py @@ -0,0 +1,75 @@ +"""Repository for listing decisions (like/dislike).""" +from datetime import datetime + +from models.decision import ListingDecision +from sqlalchemy import Engine +from sqlmodel import Session, select + + +class DecisionRepository: + engine: Engine + + def __init__(self, engine: Engine) -> None: + self.engine = engine + + def upsert_decision( + self, user_id: int, listing_id: int, listing_type: str, decision: str + ) -> ListingDecision: + with Session(self.engine) as session: + statement = select(ListingDecision).where( + ListingDecision.user_id == user_id, + ListingDecision.listing_id == listing_id, + ListingDecision.listing_type == listing_type, + ) + existing = session.exec(statement).first() + if existing: + existing.decision = decision + existing.updated_at = datetime.utcnow() + session.add(existing) + session.commit() + session.refresh(existing) + return existing + new_decision = ListingDecision( + user_id=user_id, + listing_id=listing_id, + listing_type=listing_type, + decision=decision, + ) + session.add(new_decision) + session.commit() + session.refresh(new_decision) + return new_decision + + def get_decisions_for_user(self, user_id: int) -> list[ListingDecision]: + with Session(self.engine) as session: + statement = select(ListingDecision).where( + ListingDecision.user_id == user_id + ) + return list(session.exec(statement).all()) + + def delete_decision( + self, user_id: int, listing_id: int, listing_type: str + ) -> bool: + with Session(self.engine) as session: + statement = select(ListingDecision).where( + ListingDecision.user_id == user_id, + ListingDecision.listing_id == listing_id, + ListingDecision.listing_type == listing_type, + ) + existing = session.exec(statement).first() + if existing is None: + return False + session.delete(existing) + session.commit() + return True + + def get_disliked_listing_ids( + self, user_id: int, listing_type: str + ) -> set[int]: + with Session(self.engine) as session: + statement = select(ListingDecision.listing_id).where( + ListingDecision.user_id == user_id, + ListingDecision.listing_type == listing_type, + ListingDecision.decision == "disliked", + ) + return {row for row in session.exec(statement).all()} diff --git a/tests/unit/test_decision_repository.py b/tests/unit/test_decision_repository.py new file mode 100644 index 0000000..58afd91 --- /dev/null +++ b/tests/unit/test_decision_repository.py @@ -0,0 +1,88 @@ +"""Unit tests for DecisionRepository.""" +import pytest +from sqlalchemy import Engine +from sqlmodel import SQLModel, Session, create_engine + +from models.decision import ListingDecision +from models.user import User +from repositories.decision_repository import DecisionRepository + + +@pytest.fixture +def decision_engine() -> Engine: + engine = create_engine( + "sqlite:///:memory:", + echo=False, + connect_args={"check_same_thread": False}, + ) + SQLModel.metadata.create_all(engine) + with Session(engine) as session: + session.add(User(id=1, email="test@example.com")) + session.commit() + yield engine # type: ignore[misc] + SQLModel.metadata.drop_all(engine) + + +@pytest.fixture +def repo(decision_engine: Engine) -> DecisionRepository: + return DecisionRepository(decision_engine) + + +class TestDecisionRepository: + def test_upsert_creates_new_decision(self, repo: DecisionRepository) -> None: + result = repo.upsert_decision( + user_id=1, listing_id=100, listing_type="RENT", decision="liked" + ) + assert result.decision == "liked" + assert result.listing_id == 100 + + def test_upsert_updates_existing_decision(self, repo: DecisionRepository) -> None: + repo.upsert_decision( + user_id=1, listing_id=100, listing_type="RENT", decision="liked" + ) + result = repo.upsert_decision( + user_id=1, listing_id=100, listing_type="RENT", decision="disliked" + ) + assert result.decision == "disliked" + # Should still be one record + all_decisions = repo.get_decisions_for_user(1) + assert len(all_decisions) == 1 + + def test_get_decisions_for_user(self, repo: DecisionRepository) -> None: + repo.upsert_decision( + user_id=1, listing_id=100, listing_type="RENT", decision="liked" + ) + repo.upsert_decision( + user_id=1, listing_id=200, listing_type="RENT", decision="disliked" + ) + decisions = repo.get_decisions_for_user(1) + assert len(decisions) == 2 + + def test_delete_decision(self, repo: DecisionRepository) -> None: + repo.upsert_decision( + user_id=1, listing_id=100, listing_type="RENT", decision="liked" + ) + deleted = repo.delete_decision( + user_id=1, listing_id=100, listing_type="RENT" + ) + assert deleted is True + assert len(repo.get_decisions_for_user(1)) == 0 + + def test_delete_nonexistent_decision(self, repo: DecisionRepository) -> None: + deleted = repo.delete_decision( + user_id=1, listing_id=999, listing_type="RENT" + ) + assert deleted is False + + def test_get_disliked_listing_ids(self, repo: DecisionRepository) -> None: + repo.upsert_decision( + user_id=1, listing_id=100, listing_type="RENT", decision="liked" + ) + repo.upsert_decision( + user_id=1, listing_id=200, listing_type="RENT", decision="disliked" + ) + repo.upsert_decision( + user_id=1, listing_id=300, listing_type="RENT", decision="disliked" + ) + disliked = repo.get_disliked_listing_ids(user_id=1, listing_type="RENT") + assert disliked == {200, 300}