- Point Ollama to local instance via host.docker.internal, use gemma3 model - Remove Docker Ollama service (using host's Ollama instead) - Add company-name-to-ticker mapping (Apple→AAPL, Tesla→TSLA, etc.) for RSS articles - Lower signal thresholds for faster feedback with paper trading: - FinBERT confidence: 0.6→0.4, signal strength: 0.3→0.15 - News strategy: article_count 2→1, confidence 0.5→0.3, score ±0.3→±0.15 - Fix market data BarSet access bug (BarSet.__contains__ returns False incorrectly) - Fix market data SIP feed error by switching to IEX feed for free Alpaca accounts - Fix nginx proxy routing for /api/auth/* to api-gateway /auth/* - Add seed_sample_data script - Update tests for new thresholds and alpaca mock modules
390 lines
12 KiB
Python
390 lines
12 KiB
Python
"""Tests for API Gateway trading endpoints (Task 14).
|
|
|
|
Uses FastAPI TestClient with mocked DB sessions and Redis.
|
|
"""
|
|
|
|
from __future__ import annotations
|
|
|
|
import json
|
|
import uuid
|
|
from datetime import datetime, timedelta, timezone
|
|
from unittest.mock import AsyncMock, MagicMock, patch
|
|
|
|
import pytest
|
|
from fastapi.testclient import TestClient
|
|
|
|
from services.api_gateway.auth.jwt import create_access_token
|
|
from services.api_gateway.auth.middleware import get_config, get_current_user
|
|
from services.api_gateway.config import ApiGatewayConfig
|
|
from services.api_gateway.main import create_app
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# Fixtures
|
|
# ---------------------------------------------------------------------------
|
|
|
|
|
|
@pytest.fixture()
|
|
def config() -> ApiGatewayConfig:
|
|
return ApiGatewayConfig(
|
|
jwt_secret_key="test-secret-for-routes",
|
|
database_url="sqlite+aiosqlite:///:memory:",
|
|
redis_url="redis://localhost:6379/0",
|
|
)
|
|
|
|
|
|
@pytest.fixture()
|
|
def mock_user() -> dict:
|
|
return {"sub": "user-test", "username": "tester", "type": "access"}
|
|
|
|
|
|
@pytest.fixture()
|
|
def auth_headers(config: ApiGatewayConfig) -> dict[str, str]:
|
|
token = create_access_token("user-test", "tester", config)
|
|
return {"Authorization": f"Bearer {token}"}
|
|
|
|
|
|
@pytest.fixture()
|
|
def mock_redis() -> AsyncMock:
|
|
"""A fully mocked Redis client."""
|
|
redis = AsyncMock()
|
|
redis.get = AsyncMock(return_value=None)
|
|
redis.set = AsyncMock()
|
|
redis.setex = AsyncMock()
|
|
redis.delete = AsyncMock()
|
|
redis.publish = AsyncMock()
|
|
return redis
|
|
|
|
|
|
@pytest.fixture()
|
|
def mock_session_factory():
|
|
"""Creates a mock async session factory that returns a mock session."""
|
|
session = AsyncMock()
|
|
session.__aenter__ = AsyncMock(return_value=session)
|
|
session.__aexit__ = AsyncMock(return_value=False)
|
|
|
|
factory = MagicMock()
|
|
factory.return_value = session
|
|
return factory, session
|
|
|
|
|
|
@pytest.fixture()
|
|
def client(
|
|
config: ApiGatewayConfig,
|
|
mock_user: dict,
|
|
mock_redis: AsyncMock,
|
|
mock_session_factory,
|
|
) -> TestClient:
|
|
"""Create a test client with all dependencies mocked."""
|
|
factory, session = mock_session_factory
|
|
|
|
app = create_app(config)
|
|
|
|
# Override auth dependency to bypass JWT validation
|
|
app.dependency_overrides[get_current_user] = lambda: mock_user
|
|
app.dependency_overrides[get_config] = lambda: config
|
|
|
|
# Inject mock state
|
|
app.state.redis = mock_redis
|
|
app.state.db_session_factory = factory
|
|
app.state.db_engine = MagicMock()
|
|
app.state.config = config
|
|
|
|
return TestClient(app, raise_server_exceptions=False)
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# Helper: build mock execute results
|
|
# ---------------------------------------------------------------------------
|
|
|
|
|
|
def _make_execute_result(rows, scalar=None):
|
|
"""Build a mock result for session.execute()."""
|
|
result = MagicMock()
|
|
result.scalars.return_value.all.return_value = rows
|
|
result.scalars.return_value.__iter__ = lambda self: iter(rows)
|
|
result.scalar_one_or_none.return_value = scalar
|
|
result.scalar.return_value = len(rows) if scalar is None else scalar
|
|
result.all.return_value = rows
|
|
return result
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# Portfolio Tests
|
|
# ---------------------------------------------------------------------------
|
|
|
|
|
|
class TestPortfolioEndpoint:
|
|
"""test_portfolio_endpoint."""
|
|
|
|
def test_portfolio_returns_defaults_when_empty(
|
|
self, client: TestClient, mock_session_factory
|
|
) -> None:
|
|
_, session = mock_session_factory
|
|
session.execute = AsyncMock(
|
|
return_value=_make_execute_result([], scalar=None)
|
|
)
|
|
|
|
resp = client.get("/api/portfolio")
|
|
assert resp.status_code == 200
|
|
data = resp.json()
|
|
assert data["total_value"] == 0.0
|
|
assert data["cash"] == 0.0
|
|
assert data["daily_pnl"] == 0.0
|
|
|
|
|
|
class TestPositionsEndpoint:
|
|
"""test_positions_endpoint."""
|
|
|
|
def test_positions_returns_list(
|
|
self, client: TestClient, mock_session_factory
|
|
) -> None:
|
|
_, session = mock_session_factory
|
|
|
|
# Create mock positions
|
|
pos = MagicMock()
|
|
pos.id = uuid.uuid4()
|
|
pos.ticker = "AAPL"
|
|
pos.qty = 10.0
|
|
pos.avg_entry = 150.0
|
|
pos.unrealized_pnl = 50.0
|
|
pos.stop_loss = 145.0
|
|
pos.take_profit = 160.0
|
|
|
|
session.execute = AsyncMock(
|
|
return_value=_make_execute_result([pos])
|
|
)
|
|
|
|
resp = client.get("/api/portfolio/positions")
|
|
assert resp.status_code == 200
|
|
data = resp.json()
|
|
assert len(data) == 1
|
|
assert data[0]["ticker"] == "AAPL"
|
|
assert data[0]["qty"] == 10.0
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# Trades Tests
|
|
# ---------------------------------------------------------------------------
|
|
|
|
|
|
class TestTradesListEndpoint:
|
|
"""test_trades_list_endpoint."""
|
|
|
|
def test_trades_returns_paginated_list(
|
|
self, client: TestClient, mock_session_factory
|
|
) -> None:
|
|
_, session = mock_session_factory
|
|
|
|
trade = MagicMock()
|
|
trade.id = uuid.uuid4()
|
|
trade.ticker = "TSLA"
|
|
trade.side.value = "BUY"
|
|
trade.qty = 5.0
|
|
trade.price = 200.0
|
|
trade.status.value = "FILLED"
|
|
trade.pnl = 25.0
|
|
trade.strategy_id = None
|
|
trade.signal_id = None
|
|
trade.created_at = datetime(2024, 1, 1, tzinfo=timezone.utc)
|
|
|
|
# session.execute is called twice: count + data (now returns tuples)
|
|
count_result = _make_execute_result([], scalar=1)
|
|
data_result = MagicMock()
|
|
data_result.all.return_value = [(trade, None)] # (Trade, strategy_name)
|
|
session.execute = AsyncMock(side_effect=[count_result, data_result])
|
|
|
|
resp = client.get("/api/trades")
|
|
assert resp.status_code == 200
|
|
data = resp.json()
|
|
assert "trades" in data
|
|
assert "total" in data
|
|
assert "page" in data
|
|
|
|
|
|
class TestTradesPagination:
|
|
"""test_trades_pagination."""
|
|
|
|
def test_trades_page_and_per_page(
|
|
self, client: TestClient, mock_session_factory
|
|
) -> None:
|
|
_, session = mock_session_factory
|
|
|
|
count_result = _make_execute_result([], scalar=50)
|
|
data_result = _make_execute_result([])
|
|
session.execute = AsyncMock(side_effect=[count_result, data_result])
|
|
|
|
resp = client.get("/api/trades?page=3&per_page=10")
|
|
assert resp.status_code == 200
|
|
data = resp.json()
|
|
assert data["page"] == 3
|
|
assert data["per_page"] == 10
|
|
assert data["pages"] == 5 # 50 / 10
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# Strategies Tests
|
|
# ---------------------------------------------------------------------------
|
|
|
|
|
|
class TestStrategiesEndpoint:
|
|
"""test_strategies_endpoint."""
|
|
|
|
def test_strategies_returns_list(
|
|
self, client: TestClient, mock_session_factory
|
|
) -> None:
|
|
_, session = mock_session_factory
|
|
|
|
strategy = MagicMock()
|
|
strategy.id = uuid.uuid4()
|
|
strategy.name = "momentum"
|
|
strategy.description = "Momentum strategy"
|
|
strategy.current_weight = 0.333
|
|
strategy.active = True
|
|
strategy.created_at = datetime(2024, 1, 1, tzinfo=timezone.utc)
|
|
|
|
# First call: list strategies; subsequent calls: trades per strategy
|
|
strategies_result = _make_execute_result([strategy])
|
|
trades_result = _make_execute_result([]) # no trades
|
|
session.execute = AsyncMock(
|
|
side_effect=[strategies_result, trades_result]
|
|
)
|
|
|
|
resp = client.get("/api/strategies")
|
|
assert resp.status_code == 200
|
|
data = resp.json()
|
|
assert len(data) == 1
|
|
assert data[0]["name"] == "momentum"
|
|
assert data[0]["current_weight"] == 0.333
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# News Tests
|
|
# ---------------------------------------------------------------------------
|
|
|
|
|
|
class TestNewsEndpoint:
|
|
"""test_news_endpoint."""
|
|
|
|
def test_news_returns_paginated_articles(
|
|
self, client: TestClient, mock_session_factory
|
|
) -> None:
|
|
_, session = mock_session_factory
|
|
|
|
article = MagicMock()
|
|
article.id = uuid.uuid4()
|
|
article.source = "reuters"
|
|
article.url = "https://reuters.com/article/1"
|
|
article.title = "Stock rises"
|
|
article.published_at = datetime(2024, 1, 1, tzinfo=timezone.utc)
|
|
article.fetched_at = datetime(2024, 1, 1, tzinfo=timezone.utc)
|
|
|
|
sentiment = MagicMock()
|
|
sentiment.ticker = "AAPL"
|
|
sentiment.score = 0.8
|
|
sentiment.confidence = 0.9
|
|
sentiment.model_used = "finbert"
|
|
|
|
count_result = _make_execute_result([], scalar=1)
|
|
data_result = MagicMock()
|
|
data_result.all.return_value = [(article, sentiment)]
|
|
session.execute = AsyncMock(side_effect=[count_result, data_result])
|
|
|
|
resp = client.get("/api/news")
|
|
assert resp.status_code == 200
|
|
data = resp.json()
|
|
assert "articles" in data
|
|
assert data["total"] == 1
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# Controls Tests
|
|
# ---------------------------------------------------------------------------
|
|
|
|
|
|
class TestControlsPauseResume:
|
|
"""test_controls_pause_resume."""
|
|
|
|
def test_pause_sets_redis_key(
|
|
self, client: TestClient, mock_redis: AsyncMock
|
|
) -> None:
|
|
resp = client.post("/api/controls/pause")
|
|
assert resp.status_code == 200
|
|
assert resp.json()["status"] == "paused"
|
|
mock_redis.set.assert_called_once_with("trading:paused", "1")
|
|
|
|
def test_resume_clears_redis_key(
|
|
self, client: TestClient, mock_redis: AsyncMock
|
|
) -> None:
|
|
resp = client.post("/api/controls/resume")
|
|
assert resp.status_code == 200
|
|
assert resp.json()["status"] == "active"
|
|
mock_redis.delete.assert_called_once_with("trading:paused")
|
|
|
|
|
|
class TestControlsStatus:
|
|
"""test_controls_status."""
|
|
|
|
def test_status_active_when_not_paused(
|
|
self, client: TestClient, mock_redis: AsyncMock
|
|
) -> None:
|
|
mock_redis.get = AsyncMock(return_value=None)
|
|
resp = client.get("/api/controls/status")
|
|
assert resp.status_code == 200
|
|
assert resp.json()["status"] == "active"
|
|
|
|
def test_status_paused_when_flag_set(
|
|
self, client: TestClient, mock_redis: AsyncMock
|
|
) -> None:
|
|
mock_redis.get = AsyncMock(return_value="1")
|
|
resp = client.get("/api/controls/status")
|
|
assert resp.status_code == 200
|
|
assert resp.json()["status"] == "paused"
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# Backtest Tests
|
|
# ---------------------------------------------------------------------------
|
|
|
|
|
|
class TestBacktestRunEndpoint:
|
|
"""test_backtest_run_endpoint."""
|
|
|
|
def test_backtest_run_returns_run_id(
|
|
self, client: TestClient, mock_redis: AsyncMock
|
|
) -> None:
|
|
resp = client.post(
|
|
"/api/backtest/run",
|
|
json={
|
|
"start_date": "2024-01-01T00:00:00Z",
|
|
"end_date": "2024-06-01T00:00:00Z",
|
|
"initial_capital": 100000,
|
|
},
|
|
)
|
|
assert resp.status_code == 200
|
|
data = resp.json()
|
|
assert "run_id" in data
|
|
assert data["status"] == "running"
|
|
mock_redis.setex.assert_called()
|
|
|
|
def test_backtest_get_not_found(
|
|
self, client: TestClient, mock_redis: AsyncMock
|
|
) -> None:
|
|
mock_redis.get = AsyncMock(return_value=None)
|
|
resp = client.get("/api/backtest/nonexistent-id")
|
|
assert resp.status_code == 404
|
|
|
|
def test_backtest_get_returns_result(
|
|
self, client: TestClient, mock_redis: AsyncMock
|
|
) -> None:
|
|
stored = json.dumps({
|
|
"status": "completed",
|
|
"result": {"total_return": 0.15, "sharpe_ratio": 1.2},
|
|
})
|
|
mock_redis.get = AsyncMock(return_value=stored)
|
|
|
|
resp = client.get("/api/backtest/some-run-id")
|
|
assert resp.status_code == 200
|
|
data = resp.json()
|
|
assert data["status"] == "completed"
|
|
assert data["result"]["total_return"] == 0.15
|