"""Tests for API Gateway trading endpoints (Task 14). Uses FastAPI TestClient with mocked DB sessions and Redis. """ from __future__ import annotations import json import uuid from datetime import datetime, timedelta, timezone from unittest.mock import AsyncMock, MagicMock, patch import pytest from fastapi.testclient import TestClient from services.api_gateway.auth.jwt import create_access_token from services.api_gateway.auth.middleware import get_config, get_current_user from services.api_gateway.config import ApiGatewayConfig from services.api_gateway.main import create_app # --------------------------------------------------------------------------- # Fixtures # --------------------------------------------------------------------------- @pytest.fixture() def config() -> ApiGatewayConfig: return ApiGatewayConfig( jwt_secret_key="test-secret-for-routes", database_url="sqlite+aiosqlite:///:memory:", redis_url="redis://localhost:6379/0", ) @pytest.fixture() def mock_user() -> dict: return {"sub": "user-test", "username": "tester", "type": "access"} @pytest.fixture() def auth_headers(config: ApiGatewayConfig) -> dict[str, str]: token = create_access_token("user-test", "tester", config) return {"Authorization": f"Bearer {token}"} @pytest.fixture() def mock_redis() -> AsyncMock: """A fully mocked Redis client.""" redis = AsyncMock() redis.get = AsyncMock(return_value=None) redis.set = AsyncMock() redis.setex = AsyncMock() redis.delete = AsyncMock() redis.publish = AsyncMock() return redis @pytest.fixture() def mock_session_factory(): """Creates a mock async session factory that returns a mock session.""" session = AsyncMock() session.__aenter__ = AsyncMock(return_value=session) session.__aexit__ = AsyncMock(return_value=False) factory = MagicMock() factory.return_value = session return factory, session @pytest.fixture() def client( config: ApiGatewayConfig, mock_user: dict, mock_redis: AsyncMock, mock_session_factory, ) -> TestClient: """Create a test client with all dependencies mocked.""" factory, session = mock_session_factory app = create_app(config) # Override auth dependency to bypass JWT validation app.dependency_overrides[get_current_user] = lambda: mock_user app.dependency_overrides[get_config] = lambda: config # Inject mock state app.state.redis = mock_redis app.state.db_session_factory = factory app.state.db_engine = MagicMock() app.state.config = config return TestClient(app, raise_server_exceptions=False) # --------------------------------------------------------------------------- # Helper: build mock execute results # --------------------------------------------------------------------------- def _make_execute_result(rows, scalar=None): """Build a mock result for session.execute().""" result = MagicMock() result.scalars.return_value.all.return_value = rows result.scalars.return_value.__iter__ = lambda self: iter(rows) result.scalar_one_or_none.return_value = scalar result.scalar.return_value = len(rows) if scalar is None else scalar result.all.return_value = rows return result # --------------------------------------------------------------------------- # Portfolio Tests # --------------------------------------------------------------------------- class TestPortfolioEndpoint: """test_portfolio_endpoint.""" def test_portfolio_returns_defaults_when_empty( self, client: TestClient, mock_session_factory ) -> None: _, session = mock_session_factory session.execute = AsyncMock( return_value=_make_execute_result([], scalar=None) ) resp = client.get("/api/portfolio") assert resp.status_code == 200 data = resp.json() assert data["total_value"] == 0.0 assert data["cash"] == 0.0 assert data["daily_pnl"] == 0.0 class TestPositionsEndpoint: """test_positions_endpoint.""" def test_positions_returns_list( self, client: TestClient, mock_session_factory ) -> None: _, session = mock_session_factory # Create mock positions pos = MagicMock() pos.id = uuid.uuid4() pos.ticker = "AAPL" pos.qty = 10.0 pos.avg_entry = 150.0 pos.unrealized_pnl = 50.0 pos.stop_loss = 145.0 pos.take_profit = 160.0 session.execute = AsyncMock( return_value=_make_execute_result([pos]) ) resp = client.get("/api/portfolio/positions") assert resp.status_code == 200 data = resp.json() assert len(data) == 1 assert data[0]["ticker"] == "AAPL" assert data[0]["qty"] == 10.0 # --------------------------------------------------------------------------- # Trades Tests # --------------------------------------------------------------------------- class TestTradesListEndpoint: """test_trades_list_endpoint.""" def test_trades_returns_paginated_list( self, client: TestClient, mock_session_factory ) -> None: _, session = mock_session_factory trade = MagicMock() trade.id = uuid.uuid4() trade.ticker = "TSLA" trade.side.value = "BUY" trade.qty = 5.0 trade.price = 200.0 trade.status.value = "FILLED" trade.pnl = 25.0 trade.strategy_id = None trade.signal_id = None trade.created_at = datetime(2024, 1, 1, tzinfo=timezone.utc) # session.execute is called twice: count + data (now returns tuples) count_result = _make_execute_result([], scalar=1) data_result = MagicMock() data_result.all.return_value = [(trade, None)] # (Trade, strategy_name) session.execute = AsyncMock(side_effect=[count_result, data_result]) resp = client.get("/api/trades") assert resp.status_code == 200 data = resp.json() assert "trades" in data assert "total" in data assert "page" in data class TestTradesPagination: """test_trades_pagination.""" def test_trades_page_and_per_page( self, client: TestClient, mock_session_factory ) -> None: _, session = mock_session_factory count_result = _make_execute_result([], scalar=50) data_result = _make_execute_result([]) session.execute = AsyncMock(side_effect=[count_result, data_result]) resp = client.get("/api/trades?page=3&per_page=10") assert resp.status_code == 200 data = resp.json() assert data["page"] == 3 assert data["per_page"] == 10 assert data["pages"] == 5 # 50 / 10 # --------------------------------------------------------------------------- # Strategies Tests # --------------------------------------------------------------------------- class TestStrategiesEndpoint: """test_strategies_endpoint.""" def test_strategies_returns_list( self, client: TestClient, mock_session_factory ) -> None: _, session = mock_session_factory strategy = MagicMock() strategy.id = uuid.uuid4() strategy.name = "momentum" strategy.description = "Momentum strategy" strategy.current_weight = 0.333 strategy.active = True strategy.created_at = datetime(2024, 1, 1, tzinfo=timezone.utc) # First call: list strategies; subsequent calls: trades per strategy strategies_result = _make_execute_result([strategy]) trades_result = _make_execute_result([]) # no trades session.execute = AsyncMock( side_effect=[strategies_result, trades_result] ) resp = client.get("/api/strategies") assert resp.status_code == 200 data = resp.json() assert len(data) == 1 assert data[0]["name"] == "momentum" assert data[0]["current_weight"] == 0.333 # --------------------------------------------------------------------------- # News Tests # --------------------------------------------------------------------------- class TestNewsEndpoint: """test_news_endpoint.""" def test_news_returns_paginated_articles( self, client: TestClient, mock_session_factory ) -> None: _, session = mock_session_factory article = MagicMock() article.id = uuid.uuid4() article.source = "reuters" article.url = "https://reuters.com/article/1" article.title = "Stock rises" article.published_at = datetime(2024, 1, 1, tzinfo=timezone.utc) article.fetched_at = datetime(2024, 1, 1, tzinfo=timezone.utc) sentiment = MagicMock() sentiment.ticker = "AAPL" sentiment.score = 0.8 sentiment.confidence = 0.9 sentiment.model_used = "finbert" count_result = _make_execute_result([], scalar=1) data_result = MagicMock() data_result.all.return_value = [(article, sentiment)] session.execute = AsyncMock(side_effect=[count_result, data_result]) resp = client.get("/api/news") assert resp.status_code == 200 data = resp.json() assert "articles" in data assert data["total"] == 1 # --------------------------------------------------------------------------- # Controls Tests # --------------------------------------------------------------------------- class TestControlsPauseResume: """test_controls_pause_resume.""" def test_pause_sets_redis_key( self, client: TestClient, mock_redis: AsyncMock ) -> None: resp = client.post("/api/controls/pause") assert resp.status_code == 200 assert resp.json()["status"] == "paused" mock_redis.set.assert_called_once_with("trading:paused", "1") def test_resume_clears_redis_key( self, client: TestClient, mock_redis: AsyncMock ) -> None: resp = client.post("/api/controls/resume") assert resp.status_code == 200 assert resp.json()["status"] == "active" mock_redis.delete.assert_called_once_with("trading:paused") class TestControlsStatus: """test_controls_status.""" def test_status_active_when_not_paused( self, client: TestClient, mock_redis: AsyncMock ) -> None: mock_redis.get = AsyncMock(return_value=None) resp = client.get("/api/controls/status") assert resp.status_code == 200 assert resp.json()["status"] == "active" def test_status_paused_when_flag_set( self, client: TestClient, mock_redis: AsyncMock ) -> None: mock_redis.get = AsyncMock(return_value="1") resp = client.get("/api/controls/status") assert resp.status_code == 200 assert resp.json()["status"] == "paused" # --------------------------------------------------------------------------- # Backtest Tests # --------------------------------------------------------------------------- class TestBacktestRunEndpoint: """test_backtest_run_endpoint.""" def test_backtest_run_returns_run_id( self, client: TestClient, mock_redis: AsyncMock ) -> None: resp = client.post( "/api/backtest/run", json={ "start_date": "2024-01-01T00:00:00Z", "end_date": "2024-06-01T00:00:00Z", "initial_capital": 100000, }, ) assert resp.status_code == 200 data = resp.json() assert "run_id" in data assert data["status"] == "running" mock_redis.setex.assert_called() def test_backtest_get_not_found( self, client: TestClient, mock_redis: AsyncMock ) -> None: mock_redis.get = AsyncMock(return_value=None) resp = client.get("/api/backtest/nonexistent-id") assert resp.status_code == 404 def test_backtest_get_returns_result( self, client: TestClient, mock_redis: AsyncMock ) -> None: stored = json.dumps({ "status": "completed", "result": {"total_return": 0.15, "sharpe_ratio": 1.2}, }) mock_redis.get = AsyncMock(return_value=stored) resp = client.get("/api/backtest/some-run-id") assert resp.status_code == 200 data = resp.json() assert data["status"] == "completed" assert data["result"]["total_return"] == 0.15