diff --git a/tests/conftest.py b/tests/conftest.py index 141488c..3cc078b 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -165,25 +165,28 @@ async def async_client( in_memory_engine: Engine, mock_user: User ) -> AsyncGenerator[AsyncClient, None]: """Create an AsyncClient for API testing with mock auth.""" - from api.app import app + import database + import api.app as api_app from api.auth import get_current_user + app = api_app.app + # Override dependencies app.dependency_overrides[get_current_user] = lambda: mock_user - # Patch the engine used by the repository - original_engine = None - try: - from database import engine as db_engine - original_engine = db_engine - except Exception: - pass + # Patch the engine so the API uses the in-memory database + original_db_engine = database.engine + original_app_engine = api_app.engine + database.engine = in_memory_engine + api_app.engine = in_memory_engine transport = ASGITransport(app=app) async with AsyncClient(transport=transport, base_url="http://test") as client: yield client - # Clean up dependency overrides + # Restore original engines and clean up + database.engine = original_db_engine + api_app.engine = original_app_engine app.dependency_overrides.clear()