feat: real data pipeline — market data, DB persistence, portfolio sync, signal-trade linkage

Wire the trading bot to real Alpaca market data and persist pipeline
state to the database so the dashboard displays live information.

- Add market-data service fetching OHLCV bars from Alpaca, publishing
  to market:bars Redis Stream; signal generator consumes bars and
  injects current_price into signals for position sizing
- Sentiment analyzer now persists Article + ArticleSentiment rows to
  DB after scoring, with duplicate and error handling
- API gateway runs a background portfolio sync task that snapshots
  Alpaca account state into PortfolioSnapshot/Position DB tables
  during market hours
- TradeSignal carries a signal_id UUID; signal generator and trade
  executor both persist their records to DB with cross-references
- 303 unit tests pass (57 new tests added)
This commit is contained in:
Viktor Barzin 2026-02-22 19:52:45 +00:00
parent 5a6b20c8f1
commit e2a3bd456d
No known key found for this signature in database
GPG key ID: 0EB088298288D958
19 changed files with 2238 additions and 72 deletions

View file

@ -0,0 +1,462 @@
"""Tests for the Market Data service.
Covers configuration defaults, bar parsing from Alpaca response format,
historical bar fetching, poll logic, and Redis publish behaviour.
All Alpaca SDK imports are mocked to avoid requiring pytz and other
transitive dependencies in the test environment.
"""
from __future__ import annotations
import sys
from datetime import datetime, timezone
from types import ModuleType
from unittest.mock import AsyncMock, MagicMock, patch
import pytest
from services.market_data.config import MarketDataConfig
# ---------------------------------------------------------------------------
# Helpers
# ---------------------------------------------------------------------------
class _FakeBar:
"""Mimics an Alpaca Bar object with the required attributes."""
def __init__(
self,
timestamp: datetime,
open: float,
high: float,
low: float,
close: float,
volume: float,
) -> None:
self.timestamp = timestamp
self.open = open
self.high = high
self.low = low
self.close = close
self.volume = volume
def _make_fake_bar(close: float = 150.0) -> _FakeBar:
return _FakeBar(
timestamp=datetime(2026, 1, 15, 10, 30, tzinfo=timezone.utc),
open=149.0,
high=151.0,
low=148.0,
close=close,
volume=10000.0,
)
class _FakeBarSet(dict):
"""Mimics an Alpaca BarSet (dict-like) mapping ticker to list of bars."""
pass
class _FakeTimeFrame:
"""Mimics alpaca.data.timeframe.TimeFrame."""
def __init__(self, amount, unit):
self.amount = amount
self.unit = unit
class _FakeTimeFrameUnit:
"""Mimics alpaca.data.timeframe.TimeFrameUnit."""
Minute = "Minute"
Hour = "Hour"
Day = "Day"
class _FakeStockBarsRequest:
"""Mimics alpaca.data.requests.StockBarsRequest."""
def __init__(self, **kwargs):
for k, v in kwargs.items():
setattr(self, k, v)
def _install_alpaca_mocks():
"""Install mock modules for the alpaca SDK so we can import market_data.main."""
timeframe_mod = ModuleType("alpaca.data.timeframe")
timeframe_mod.TimeFrame = _FakeTimeFrame
timeframe_mod.TimeFrameUnit = _FakeTimeFrameUnit
requests_mod = ModuleType("alpaca.data.requests")
requests_mod.StockBarsRequest = _FakeStockBarsRequest
historical_mod = ModuleType("alpaca.data.historical")
historical_mod.StockHistoricalDataClient = MagicMock
# Build the package hierarchy
alpaca_mod = sys.modules.get("alpaca") or ModuleType("alpaca")
data_mod = sys.modules.get("alpaca.data") or ModuleType("alpaca.data")
sys.modules.setdefault("alpaca", alpaca_mod)
sys.modules["alpaca.data"] = data_mod
sys.modules["alpaca.data.timeframe"] = timeframe_mod
sys.modules["alpaca.data.requests"] = requests_mod
sys.modules["alpaca.data.historical"] = historical_mod
# Install mocks before importing from market_data.main
_install_alpaca_mocks()
from services.market_data.main import ( # noqa: E402
MARKET_BARS_STREAM,
_bar_to_dict,
_fetch_historical_bars,
_parse_timeframe,
_poll_latest_bars,
)
async def _to_thread_passthrough(func, *args, **kwargs):
"""Replacement for asyncio.to_thread that calls synchronously."""
return func(*args, **kwargs)
# ---------------------------------------------------------------------------
# Config defaults
# ---------------------------------------------------------------------------
class TestMarketDataConfig:
"""Tests for MarketDataConfig defaults."""
def test_default_watchlist(self):
config = MarketDataConfig()
assert config.watchlist == ["AAPL", "TSLA", "NVDA", "MSFT", "GOOGL"]
def test_default_bar_timeframe(self):
config = MarketDataConfig()
assert config.bar_timeframe == "5Min"
def test_default_poll_interval(self):
config = MarketDataConfig()
assert config.poll_interval_seconds == 60
def test_default_historical_bars(self):
config = MarketDataConfig()
assert config.historical_bars == 100
def test_default_alpaca_keys_empty(self):
config = MarketDataConfig()
assert config.alpaca_api_key == ""
assert config.alpaca_secret_key == ""
def test_inherits_base_config_defaults(self):
config = MarketDataConfig()
assert config.log_level == "INFO"
assert config.otel_metrics_port == 9090
# ---------------------------------------------------------------------------
# Timeframe parsing
# ---------------------------------------------------------------------------
class TestParseTimeframe:
"""Tests for the _parse_timeframe helper."""
def test_parse_5min(self):
tf = _parse_timeframe("5Min")
assert tf is not None
assert tf.amount == 5
def test_parse_1min(self):
tf = _parse_timeframe("1Min")
assert tf is not None
assert tf.amount == 1
def test_parse_15min(self):
tf = _parse_timeframe("15Min")
assert tf is not None
assert tf.amount == 15
def test_parse_1hour(self):
tf = _parse_timeframe("1Hour")
assert tf is not None
assert tf.amount == 1
def test_parse_1day(self):
tf = _parse_timeframe("1Day")
assert tf is not None
assert tf.amount == 1
def test_parse_invalid_raises(self):
with pytest.raises(ValueError, match="Unsupported timeframe"):
_parse_timeframe("3Min")
# ---------------------------------------------------------------------------
# Bar to dict conversion
# ---------------------------------------------------------------------------
class TestBarToDict:
"""Tests for _bar_to_dict conversion."""
def test_basic_conversion(self):
bar = _make_fake_bar(close=155.5)
result = _bar_to_dict("AAPL", bar)
assert result["ticker"] == "AAPL"
assert result["close"] == 155.5
assert result["open"] == 149.0
assert result["high"] == 151.0
assert result["low"] == 148.0
assert result["volume"] == 10000.0
assert "timestamp" in result
def test_timestamp_is_iso_format(self):
bar = _make_fake_bar()
result = _bar_to_dict("TSLA", bar)
# Should be parseable back
parsed = datetime.fromisoformat(result["timestamp"])
assert parsed.tzinfo is not None
def test_all_fields_are_floats(self):
bar = _make_fake_bar()
result = _bar_to_dict("AAPL", bar)
for field in ("open", "high", "low", "close", "volume"):
assert isinstance(result[field], float)
def test_ticker_is_passed_through(self):
bar = _make_fake_bar()
result = _bar_to_dict("NVDA", bar)
assert result["ticker"] == "NVDA"
# ---------------------------------------------------------------------------
# Fetch historical bars
# ---------------------------------------------------------------------------
class TestFetchHistoricalBars:
"""Tests for _fetch_historical_bars with mocked Alpaca client."""
@pytest.mark.asyncio
async def test_publishes_bars_for_each_ticker(self):
"""Should publish all historical bars for each ticker in the watchlist."""
bars_aapl = [_make_fake_bar(150.0), _make_fake_bar(151.0)]
bars_tsla = [_make_fake_bar(200.0)]
bar_set = _FakeBarSet({"AAPL": bars_aapl, "TSLA": bars_tsla})
mock_client = MagicMock()
mock_client.get_stock_bars = MagicMock(return_value=bar_set)
publisher = AsyncMock()
publisher.publish = AsyncMock()
counter = MagicMock()
with patch("services.market_data.main.asyncio.to_thread", side_effect=_to_thread_passthrough):
total = await _fetch_historical_bars(
mock_client, ["AAPL", "TSLA"], MagicMock(), 100, publisher, counter
)
assert total == 3
assert publisher.publish.call_count == 3
counter.add.assert_called_once_with(3)
@pytest.mark.asyncio
async def test_handles_empty_response(self):
"""Should handle tickers with no bars gracefully."""
bar_set = _FakeBarSet({})
mock_client = MagicMock()
mock_client.get_stock_bars = MagicMock(return_value=bar_set)
publisher = AsyncMock()
counter = MagicMock()
with patch("services.market_data.main.asyncio.to_thread", side_effect=_to_thread_passthrough):
total = await _fetch_historical_bars(
mock_client, ["AAPL"], MagicMock(), 100, publisher, counter
)
assert total == 0
publisher.publish.assert_not_called()
@pytest.mark.asyncio
async def test_handles_client_exception(self):
"""Should log and continue if one ticker fails."""
mock_client = MagicMock()
mock_client.get_stock_bars = MagicMock(side_effect=Exception("API error"))
publisher = AsyncMock()
counter = MagicMock()
with patch("services.market_data.main.asyncio.to_thread", side_effect=_to_thread_passthrough):
total = await _fetch_historical_bars(
mock_client, ["AAPL", "TSLA"], MagicMock(), 100, publisher, counter
)
assert total == 0
publisher.publish.assert_not_called()
@pytest.mark.asyncio
async def test_published_message_format(self):
"""Published messages should have the expected fields."""
bar = _make_fake_bar(175.0)
bar_set = _FakeBarSet({"MSFT": [bar]})
mock_client = MagicMock()
mock_client.get_stock_bars = MagicMock(return_value=bar_set)
publisher = AsyncMock()
publisher.publish = AsyncMock()
counter = MagicMock()
with patch("services.market_data.main.asyncio.to_thread", side_effect=_to_thread_passthrough):
await _fetch_historical_bars(
mock_client, ["MSFT"], MagicMock(), 100, publisher, counter
)
msg = publisher.publish.call_args[0][0]
assert msg["ticker"] == "MSFT"
assert msg["close"] == 175.0
assert "timestamp" in msg
assert "open" in msg
assert "high" in msg
assert "low" in msg
assert "volume" in msg
# ---------------------------------------------------------------------------
# Poll latest bars
# ---------------------------------------------------------------------------
class TestPollLatestBars:
"""Tests for _poll_latest_bars with mocked Alpaca client."""
@pytest.mark.asyncio
async def test_publishes_latest_bar_per_ticker(self):
"""Should publish one bar per ticker."""
bar_set = _FakeBarSet({"AAPL": [_make_fake_bar(155.0)]})
mock_client = MagicMock()
mock_client.get_stock_bars = MagicMock(return_value=bar_set)
publisher = AsyncMock()
publisher.publish = AsyncMock()
counter = MagicMock()
with patch("services.market_data.main.asyncio.to_thread", side_effect=_to_thread_passthrough):
count = await _poll_latest_bars(
mock_client, ["AAPL"], MagicMock(), publisher, counter
)
assert count == 1
assert publisher.publish.call_count == 1
published_msg = publisher.publish.call_args[0][0]
assert published_msg["ticker"] == "AAPL"
assert published_msg["close"] == 155.0
@pytest.mark.asyncio
async def test_publishes_only_most_recent_bar(self):
"""When multiple bars returned, should only publish the last one."""
bars = [_make_fake_bar(150.0), _make_fake_bar(155.0)]
bar_set = _FakeBarSet({"AAPL": bars})
mock_client = MagicMock()
mock_client.get_stock_bars = MagicMock(return_value=bar_set)
publisher = AsyncMock()
publisher.publish = AsyncMock()
counter = MagicMock()
with patch("services.market_data.main.asyncio.to_thread", side_effect=_to_thread_passthrough):
count = await _poll_latest_bars(
mock_client, ["AAPL"], MagicMock(), publisher, counter
)
assert count == 1
published_msg = publisher.publish.call_args[0][0]
assert published_msg["close"] == 155.0
@pytest.mark.asyncio
async def test_handles_no_bars_for_ticker(self):
"""Should return 0 when a ticker has no bars."""
bar_set = _FakeBarSet({"AAPL": []})
mock_client = MagicMock()
mock_client.get_stock_bars = MagicMock(return_value=bar_set)
publisher = AsyncMock()
counter = MagicMock()
with patch("services.market_data.main.asyncio.to_thread", side_effect=_to_thread_passthrough):
count = await _poll_latest_bars(
mock_client, ["AAPL"], MagicMock(), publisher, counter
)
assert count == 0
publisher.publish.assert_not_called()
@pytest.mark.asyncio
async def test_handles_ticker_not_in_response(self):
"""Should handle gracefully when the response doesn't contain the ticker."""
bar_set = _FakeBarSet({})
mock_client = MagicMock()
mock_client.get_stock_bars = MagicMock(return_value=bar_set)
publisher = AsyncMock()
counter = MagicMock()
with patch("services.market_data.main.asyncio.to_thread", side_effect=_to_thread_passthrough):
count = await _poll_latest_bars(
mock_client, ["AAPL"], MagicMock(), publisher, counter
)
assert count == 0
publisher.publish.assert_not_called()
@pytest.mark.asyncio
async def test_multiple_tickers(self):
"""Should poll and publish bars for all tickers in watchlist."""
bar_set_aapl = _FakeBarSet({"AAPL": [_make_fake_bar(150.0)]})
bar_set_tsla = _FakeBarSet({"TSLA": [_make_fake_bar(250.0)]})
mock_client = MagicMock()
# Return different bar sets for each call
mock_client.get_stock_bars = MagicMock(
side_effect=[bar_set_aapl, bar_set_tsla]
)
publisher = AsyncMock()
publisher.publish = AsyncMock()
counter = MagicMock()
with patch("services.market_data.main.asyncio.to_thread", side_effect=_to_thread_passthrough):
count = await _poll_latest_bars(
mock_client, ["AAPL", "TSLA"], MagicMock(), publisher, counter
)
assert count == 2
assert publisher.publish.call_count == 2
# ---------------------------------------------------------------------------
# Stream name constant
# ---------------------------------------------------------------------------
class TestStreamConstants:
"""Verify the stream name constant."""
def test_market_bars_stream_name(self):
assert MARKET_BARS_STREAM == "market:bars"

View file

@ -0,0 +1,456 @@
"""Tests for portfolio sync background task.
Verifies that the sync loop correctly:
- Creates PortfolioSnapshot rows from broker account data
- Upserts Position rows from broker positions
- Removes Position rows for closed positions
- Handles broker errors gracefully
- Respects US market hours
"""
from __future__ import annotations
import asyncio
from datetime import datetime, time, timezone
from unittest.mock import AsyncMock, MagicMock, patch
import pytest
from services.api_gateway.config import ApiGatewayConfig
from services.api_gateway.tasks.portfolio_sync import (
_sync_once,
is_market_open,
portfolio_sync_loop,
)
from shared.schemas.trading import AccountInfo, PositionInfo
# ---------------------------------------------------------------------------
# Fixtures
# ---------------------------------------------------------------------------
@pytest.fixture()
def config() -> ApiGatewayConfig:
return ApiGatewayConfig(
jwt_secret_key="test-secret-for-sync",
database_url="sqlite+aiosqlite:///:memory:",
redis_url="redis://localhost:6379/0",
alpaca_api_key="test-key",
alpaca_secret_key="test-secret",
paper_trading=True,
snapshot_interval_seconds=1,
)
@pytest.fixture()
def config_no_creds() -> ApiGatewayConfig:
return ApiGatewayConfig(
jwt_secret_key="test-secret-for-sync",
database_url="sqlite+aiosqlite:///:memory:",
redis_url="redis://localhost:6379/0",
alpaca_api_key="",
alpaca_secret_key="",
)
@pytest.fixture()
def mock_account() -> AccountInfo:
return AccountInfo(
equity=105000.0,
cash=50000.0,
buying_power=100000.0,
portfolio_value=105000.0,
)
@pytest.fixture()
def mock_positions() -> list[PositionInfo]:
return [
PositionInfo(
ticker="AAPL",
qty=10.0,
avg_entry=150.0,
current_price=155.0,
unrealized_pnl=50.0,
market_value=1550.0,
),
PositionInfo(
ticker="MSFT",
qty=5.0,
avg_entry=400.0,
current_price=410.0,
unrealized_pnl=50.0,
market_value=2050.0,
),
]
@pytest.fixture()
def mock_broker(mock_account, mock_positions):
broker = AsyncMock()
broker.get_account = AsyncMock(return_value=mock_account)
broker.get_positions = AsyncMock(return_value=mock_positions)
return broker
@pytest.fixture()
def mock_session():
"""Create a mock async session with context manager support."""
session = AsyncMock()
session.__aenter__ = AsyncMock(return_value=session)
session.__aexit__ = AsyncMock(return_value=False)
# Mock the begin() context manager
begin_ctx = AsyncMock()
begin_ctx.__aenter__ = AsyncMock(return_value=None)
begin_ctx.__aexit__ = AsyncMock(return_value=False)
session.begin = MagicMock(return_value=begin_ctx)
# session.add is synchronous in SQLAlchemy — use MagicMock to avoid warnings
session.add = MagicMock()
return session
@pytest.fixture()
def mock_session_factory(mock_session):
factory = MagicMock()
factory.return_value = mock_session
return factory
# ---------------------------------------------------------------------------
# Market hours tests
# ---------------------------------------------------------------------------
class TestMarketHours:
"""Tests for the is_market_open() function."""
def test_weekday_during_market_hours(self) -> None:
# Wednesday 2024-01-10 at 10:00 AM ET = 15:00 UTC
dt = datetime(2024, 1, 10, 15, 0, 0, tzinfo=timezone.utc)
assert is_market_open(dt) is True
def test_weekday_before_market_open(self) -> None:
# Wednesday 2024-01-10 at 9:00 AM ET = 14:00 UTC
dt = datetime(2024, 1, 10, 14, 0, 0, tzinfo=timezone.utc)
assert is_market_open(dt) is False
def test_weekday_after_market_close(self) -> None:
# Wednesday 2024-01-10 at 4:30 PM ET = 21:30 UTC
dt = datetime(2024, 1, 10, 21, 30, 0, tzinfo=timezone.utc)
assert is_market_open(dt) is False
def test_weekend_saturday(self) -> None:
# Saturday 2024-01-13 at 12:00 PM ET = 17:00 UTC
dt = datetime(2024, 1, 13, 17, 0, 0, tzinfo=timezone.utc)
assert is_market_open(dt) is False
def test_weekend_sunday(self) -> None:
# Sunday 2024-01-14 at 12:00 PM ET = 17:00 UTC
dt = datetime(2024, 1, 14, 17, 0, 0, tzinfo=timezone.utc)
assert is_market_open(dt) is False
def test_market_open_boundary(self) -> None:
# Wednesday 2024-01-10 at exactly 9:30 AM ET = 14:30 UTC
dt = datetime(2024, 1, 10, 14, 30, 0, tzinfo=timezone.utc)
assert is_market_open(dt) is True
def test_market_close_boundary(self) -> None:
# Wednesday 2024-01-10 at exactly 4:00 PM ET = 21:00 UTC
dt = datetime(2024, 1, 10, 21, 0, 0, tzinfo=timezone.utc)
assert is_market_open(dt) is False
# ---------------------------------------------------------------------------
# Snapshot creation tests
# ---------------------------------------------------------------------------
class TestSyncOnce:
"""Tests for the _sync_once() function."""
async def test_creates_portfolio_snapshot(
self, mock_broker, mock_session_factory, mock_session
) -> None:
# Mock the select query to return None (no existing positions)
execute_result = MagicMock()
execute_result.scalar_one_or_none.return_value = None
mock_session.execute = AsyncMock(return_value=execute_result)
await _sync_once(mock_broker, mock_session_factory)
# Verify the broker was called
mock_broker.get_account.assert_awaited_once()
mock_broker.get_positions.assert_awaited_once()
# Verify session.add was called (snapshot + 2 new positions)
assert mock_session.add.call_count == 3 # 1 snapshot + 2 positions
# Check the snapshot
snapshot_call = mock_session.add.call_args_list[0]
snapshot = snapshot_call[0][0]
assert snapshot.total_value == 105000.0
assert snapshot.cash == 50000.0
assert snapshot.positions_value == 55000.0 # 105000 - 50000
assert snapshot.daily_pnl == 0.0
async def test_creates_position_rows_for_new_positions(
self, mock_broker, mock_session_factory, mock_session
) -> None:
# No existing positions in DB
execute_result = MagicMock()
execute_result.scalar_one_or_none.return_value = None
mock_session.execute = AsyncMock(return_value=execute_result)
await _sync_once(mock_broker, mock_session_factory)
# Positions are added via session.add (after the snapshot)
position_calls = mock_session.add.call_args_list[1:]
assert len(position_calls) == 2
pos1 = position_calls[0][0][0]
assert pos1.ticker == "AAPL"
assert pos1.qty == 10.0
assert pos1.avg_entry == 150.0
assert pos1.unrealized_pnl == 50.0
pos2 = position_calls[1][0][0]
assert pos2.ticker == "MSFT"
assert pos2.qty == 5.0
assert pos2.avg_entry == 400.0
async def test_updates_existing_position(
self, mock_broker, mock_session_factory, mock_session
) -> None:
# Mock an existing position for AAPL, None for MSFT
existing_aapl = MagicMock()
existing_aapl.ticker = "AAPL"
existing_aapl.qty = 5.0 # old qty
existing_aapl.avg_entry = 140.0 # old entry
result_aapl = MagicMock()
result_aapl.scalar_one_or_none.return_value = existing_aapl
result_msft = MagicMock()
result_msft.scalar_one_or_none.return_value = None
# First execute call is for the delete of stale positions;
# but within the loop, select calls come first
mock_session.execute = AsyncMock(
side_effect=[result_aapl, result_msft, MagicMock()]
)
await _sync_once(mock_broker, mock_session_factory)
# AAPL should be updated in place
assert existing_aapl.qty == 10.0
assert existing_aapl.avg_entry == 150.0
assert existing_aapl.unrealized_pnl == 50.0
# MSFT should be added as new (snapshot + MSFT = 2 adds)
assert mock_session.add.call_count == 2 # snapshot + new MSFT
async def test_removes_closed_positions(
self, mock_session_factory, mock_session
) -> None:
# Broker returns only AAPL (MSFT was sold)
broker = AsyncMock()
broker.get_account = AsyncMock(
return_value=AccountInfo(
equity=100000, cash=90000, buying_power=90000, portfolio_value=100000
)
)
broker.get_positions = AsyncMock(
return_value=[
PositionInfo(
ticker="AAPL",
qty=10.0,
avg_entry=150.0,
current_price=155.0,
unrealized_pnl=50.0,
market_value=1550.0,
)
]
)
execute_result = MagicMock()
execute_result.scalar_one_or_none.return_value = None
mock_session.execute = AsyncMock(return_value=execute_result)
await _sync_once(broker, mock_session_factory)
# The delete statement should have been executed
# Find the delete call among execute calls
delete_called = False
for call in mock_session.execute.call_args_list:
stmt = call[0][0]
# Check if it's a delete statement (SQLAlchemy Delete object)
stmt_str = str(stmt)
if "DELETE" in stmt_str.upper():
delete_called = True
break
assert delete_called, "Expected a DELETE statement for closed positions"
async def test_removes_all_positions_when_broker_has_none(
self, mock_session_factory, mock_session
) -> None:
broker = AsyncMock()
broker.get_account = AsyncMock(
return_value=AccountInfo(
equity=100000, cash=100000, buying_power=100000, portfolio_value=100000
)
)
broker.get_positions = AsyncMock(return_value=[])
mock_session.execute = AsyncMock(return_value=MagicMock())
await _sync_once(broker, mock_session_factory)
# Should delete all positions since broker has none
delete_called = False
for call in mock_session.execute.call_args_list:
stmt = call[0][0]
stmt_str = str(stmt)
if "DELETE" in stmt_str.upper():
delete_called = True
break
assert delete_called, "Expected a DELETE statement to clear all positions"
# ---------------------------------------------------------------------------
# Error handling tests
# ---------------------------------------------------------------------------
class TestSyncErrorHandling:
"""Tests that the sync loop handles errors gracefully."""
async def test_broker_error_does_not_crash_loop(
self, config, mock_session_factory
) -> None:
"""Broker raises an exception — loop should catch it and continue."""
call_count = 0
async def mock_sync_once(broker, sf):
nonlocal call_count
call_count += 1
if call_count == 1:
raise ConnectionError("Broker API down")
# Second call succeeds
with (
patch(
"services.api_gateway.tasks.portfolio_sync.AlpacaBroker"
) as MockBroker,
patch(
"services.api_gateway.tasks.portfolio_sync._sync_once",
side_effect=mock_sync_once,
),
patch(
"services.api_gateway.tasks.portfolio_sync.is_market_open",
return_value=True,
),
):
MockBroker.return_value = AsyncMock()
task = asyncio.create_task(portfolio_sync_loop(config, mock_session_factory))
# Give it time for 2 iterations (interval=1s)
await asyncio.sleep(2.5)
task.cancel()
try:
await task
except asyncio.CancelledError:
pass
assert call_count >= 2, "Loop should have retried after the error"
async def test_no_credentials_returns_immediately(
self, config_no_creds, mock_session_factory
) -> None:
"""When Alpaca credentials are empty, the loop should exit immediately."""
task = asyncio.create_task(
portfolio_sync_loop(config_no_creds, mock_session_factory)
)
# Should complete almost immediately since no creds
await asyncio.wait_for(task, timeout=2.0)
# If we get here without timeout, the function returned correctly
# ---------------------------------------------------------------------------
# Market hours integration with loop
# ---------------------------------------------------------------------------
class TestSyncLoopMarketHours:
"""Tests that the loop respects market hours."""
async def test_skips_sync_outside_market_hours(
self, config, mock_session_factory
) -> None:
sync_called = False
async def mock_sync(broker, sf):
nonlocal sync_called
sync_called = True
with (
patch(
"services.api_gateway.tasks.portfolio_sync.AlpacaBroker"
) as MockBroker,
patch(
"services.api_gateway.tasks.portfolio_sync._sync_once",
side_effect=mock_sync,
),
patch(
"services.api_gateway.tasks.portfolio_sync.is_market_open",
return_value=False,
),
):
MockBroker.return_value = AsyncMock()
task = asyncio.create_task(portfolio_sync_loop(config, mock_session_factory))
await asyncio.sleep(1.5)
task.cancel()
try:
await task
except asyncio.CancelledError:
pass
assert not sync_called, "Sync should not run outside market hours"
async def test_runs_sync_during_market_hours(
self, config, mock_session_factory
) -> None:
sync_called = False
async def mock_sync(broker, sf):
nonlocal sync_called
sync_called = True
with (
patch(
"services.api_gateway.tasks.portfolio_sync.AlpacaBroker"
) as MockBroker,
patch(
"services.api_gateway.tasks.portfolio_sync._sync_once",
side_effect=mock_sync,
),
patch(
"services.api_gateway.tasks.portfolio_sync.is_market_open",
return_value=True,
),
):
MockBroker.return_value = AsyncMock()
task = asyncio.create_task(portfolio_sync_loop(config, mock_session_factory))
await asyncio.sleep(1.5)
task.cancel()
try:
await task
except asyncio.CancelledError:
pass
assert sync_called, "Sync should run during market hours"

View file

@ -1,7 +1,7 @@
"""Tests for the sentiment analyzer service.
Covers FinBERT analyzer, Ollama analyzer, ticker extraction, and the main
service flow.
service flow including DB persistence.
"""
from __future__ import annotations
@ -10,6 +10,7 @@ from datetime import datetime, timezone
from unittest.mock import AsyncMock, MagicMock, patch
import pytest
from sqlalchemy.exc import IntegrityError
from services.sentiment_analyzer.analyzers.finbert import FinBERTAnalyzer
from services.sentiment_analyzer.analyzers.ollama_analyzer import OllamaAnalyzer
@ -409,3 +410,229 @@ class TestMainFlow:
publisher.publish.assert_not_called()
# Still counted as scored
counters["articles_scored"].add.assert_called_once_with(1)
# ---------------------------------------------------------------------------
# Helpers for DB persistence tests
# ---------------------------------------------------------------------------
def _make_counters() -> dict:
"""Create a dict of mock OpenTelemetry counter/histogram instruments."""
return {
"articles_scored": MagicMock(),
"finbert_count": MagicMock(),
"ollama_count": MagicMock(),
"inference_latency": MagicMock(),
}
def _make_mock_db_session_factory(session: AsyncMock | None = None) -> AsyncMock:
"""Create a mock async_sessionmaker that yields a mock session.
The returned factory, when called, returns an async context manager
that yields ``session`` (a mock AsyncSession).
"""
if session is None:
session = AsyncMock()
session.add = MagicMock()
session.commit = AsyncMock()
factory = MagicMock()
# factory() should return an async context manager (the session)
ctx = AsyncMock()
ctx.__aenter__ = AsyncMock(return_value=session)
ctx.__aexit__ = AsyncMock(return_value=False)
factory.return_value = ctx
return factory
# ---------------------------------------------------------------------------
# DB Persistence Tests
# ---------------------------------------------------------------------------
class TestDBPersistence:
"""Tests for the DB write step in process_article."""
@pytest.mark.asyncio
async def test_db_write_creates_article_and_sentiments(self):
"""When db_session_factory is provided, Article and ArticleSentiment rows are created."""
finbert = AsyncMock(spec=FinBERTAnalyzer)
finbert.analyze = AsyncMock(return_value=(0.75, 0.88))
ollama = AsyncMock(spec=OllamaAnalyzer)
publisher = AsyncMock()
publisher.publish = AsyncMock(return_value=b"1-0")
config = SentimentAnalyzerConfig(
finbert_confidence_threshold=0.6,
otel_metrics_port=0,
)
counters = _make_counters()
mock_session = AsyncMock()
mock_session.add = MagicMock()
mock_session.commit = AsyncMock()
db_factory = _make_mock_db_session_factory(mock_session)
article = _make_raw_article(
title="$AAPL and $MSFT report strong earnings",
content="Both Apple and Microsoft beat estimates.",
)
await process_article(
article, finbert, ollama, publisher, config, counters, db_factory
)
# session.add should be called: 1 Article + 2 ArticleSentiments = 3 calls
assert mock_session.add.call_count == 3
# Verify the types of objects added
added_objects = [call.args[0] for call in mock_session.add.call_args_list]
from shared.models.news import Article, ArticleSentiment
articles = [o for o in added_objects if isinstance(o, Article)]
sentiments = [o for o in added_objects if isinstance(o, ArticleSentiment)]
assert len(articles) == 1
assert len(sentiments) == 2
# Verify article fields
db_article = articles[0]
assert db_article.source == "test"
assert db_article.url == "https://example.com/article"
assert db_article.content_hash == "abc123"
# Verify sentiment fields
tickers_in_sentiments = {s.ticker for s in sentiments}
assert "AAPL" in tickers_in_sentiments
assert "MSFT" in tickers_in_sentiments
for s in sentiments:
assert s.score == 0.75
assert s.confidence == 0.88
assert s.model_used == "finbert"
# session.commit should be called once
mock_session.commit.assert_awaited_once()
# Redis publishing should still happen
assert publisher.publish.call_count == 2
@pytest.mark.asyncio
async def test_db_duplicate_article_handled_gracefully(self):
"""Duplicate content_hash (IntegrityError) should be caught silently."""
finbert = AsyncMock(spec=FinBERTAnalyzer)
finbert.analyze = AsyncMock(return_value=(0.5, 0.9))
ollama = AsyncMock(spec=OllamaAnalyzer)
publisher = AsyncMock()
publisher.publish = AsyncMock(return_value=b"1-0")
config = SentimentAnalyzerConfig(
finbert_confidence_threshold=0.6,
otel_metrics_port=0,
)
counters = _make_counters()
mock_session = AsyncMock()
mock_session.add = MagicMock()
# Simulate IntegrityError on commit (duplicate content_hash)
mock_session.commit = AsyncMock(
side_effect=IntegrityError("duplicate", {}, Exception())
)
db_factory = _make_mock_db_session_factory(mock_session)
article = _make_raw_article(
title="$AAPL news",
content="Apple earnings report.",
)
# Should NOT raise — IntegrityError is caught
await process_article(
article, finbert, ollama, publisher, config, counters, db_factory
)
# Redis publishing should still have happened before the DB write
assert publisher.publish.call_count >= 1
counters["articles_scored"].add.assert_called_once_with(1)
@pytest.mark.asyncio
async def test_db_none_backward_compatible(self):
"""When db_session_factory is None, process_article works without DB writes."""
finbert = AsyncMock(spec=FinBERTAnalyzer)
finbert.analyze = AsyncMock(return_value=(0.6, 0.85))
ollama = AsyncMock(spec=OllamaAnalyzer)
publisher = AsyncMock()
publisher.publish = AsyncMock(return_value=b"1-0")
config = SentimentAnalyzerConfig(
finbert_confidence_threshold=0.6,
otel_metrics_port=0,
)
counters = _make_counters()
article = _make_raw_article(
title="$GOOG quarterly results",
content="Google reports revenue growth.",
)
# Pass db_session_factory=None explicitly (the default)
await process_article(
article, finbert, ollama, publisher, config, counters, db_session_factory=None
)
# Redis publishing should work as before
assert publisher.publish.call_count >= 1
counters["articles_scored"].add.assert_called_once_with(1)
@pytest.mark.asyncio
async def test_db_error_does_not_break_processing(self):
"""A DB error should be logged but not prevent Redis publishing."""
finbert = AsyncMock(spec=FinBERTAnalyzer)
finbert.analyze = AsyncMock(return_value=(0.3, 0.7))
ollama = AsyncMock(spec=OllamaAnalyzer)
publisher = AsyncMock()
publisher.publish = AsyncMock(return_value=b"1-0")
config = SentimentAnalyzerConfig(
finbert_confidence_threshold=0.6,
otel_metrics_port=0,
)
counters = _make_counters()
mock_session = AsyncMock()
mock_session.add = MagicMock()
# Simulate a generic DB error
mock_session.commit = AsyncMock(
side_effect=RuntimeError("connection lost")
)
db_factory = _make_mock_db_session_factory(mock_session)
article = _make_raw_article(
title="$TSLA stock update",
content="Tesla announces new factory.",
)
# Should NOT raise — generic exceptions in DB write are caught
await process_article(
article, finbert, ollama, publisher, config, counters, db_factory
)
# Redis publishing should still have happened
assert publisher.publish.call_count >= 1
counters["articles_scored"].add.assert_called_once_with(1)

View file

@ -357,3 +357,201 @@ class TestEnsembleTagsStrategySources:
assert parts[0] == "alpha"
assert parts[1] == "LONG"
assert float(parts[2]) == pytest.approx(0.75, abs=0.01)
# ---------------------------------------------------------------------------
# Signal Generator — current_price injection into sentiment_context
# ---------------------------------------------------------------------------
class TestCurrentPriceInjection:
"""Verify that current_price flows into sentiment_context on published signals."""
@pytest.mark.asyncio
async def test_current_price_set_when_snapshot_has_price(self):
"""When snapshot has a positive current_price, it should appear in sentiment_context."""
s1 = _StubStrategy("alpha", _make_signal(SignalDirection.LONG, 0.9))
ensemble = WeightedEnsemble([s1], threshold=0.0)
market = MarketSnapshot(
ticker="AAPL",
current_price=155.50,
open=154.0,
high=156.0,
low=153.0,
close=155.50,
volume=5000,
)
sentiment = SentimentContext(
ticker="AAPL",
avg_score=0.7,
article_count=5,
recent_scores=[0.6, 0.7, 0.8],
avg_confidence=0.85,
)
weights = {"alpha": 1.0}
signal_result = await ensemble.evaluate("AAPL", market, sentiment, weights)
assert signal_result is not None
# Simulate the current_price injection from the signal generator main loop
if signal_result.sentiment_context is None:
signal_result.sentiment_context = {}
if market.current_price > 0:
signal_result.sentiment_context["current_price"] = market.current_price
assert "current_price" in signal_result.sentiment_context
assert signal_result.sentiment_context["current_price"] == 155.50
@pytest.mark.asyncio
async def test_current_price_not_set_when_snapshot_is_none(self):
"""When snapshot is None (minimal fallback), current_price should not be injected."""
s1 = _StubStrategy("alpha", _make_signal(SignalDirection.LONG, 0.9))
ensemble = WeightedEnsemble([s1], threshold=0.0)
# Minimal fallback snapshot with current_price=0.0
market = MarketSnapshot(
ticker="AAPL",
current_price=0.0,
open=0.0,
high=0.0,
low=0.0,
close=0.0,
volume=0.0,
)
weights = {"alpha": 1.0}
signal_result = await ensemble.evaluate("AAPL", market, None, weights)
assert signal_result is not None
# Simulate the injection logic — should NOT set current_price when price is 0
if market.current_price > 0:
if signal_result.sentiment_context is None:
signal_result.sentiment_context = {}
signal_result.sentiment_context["current_price"] = market.current_price
# sentiment_context should be None (no sentiment was passed, and no price injection)
assert signal_result.sentiment_context is None
@pytest.mark.asyncio
async def test_current_price_not_set_when_price_is_zero(self):
"""current_price should NOT be injected when snapshot.current_price is exactly 0."""
s1 = _StubStrategy("alpha", _make_signal(SignalDirection.LONG, 0.8))
ensemble = WeightedEnsemble([s1], threshold=0.0)
sentiment = SentimentContext(
ticker="AAPL",
avg_score=0.5,
article_count=3,
recent_scores=[0.5],
avg_confidence=0.9,
)
market = MarketSnapshot(
ticker="AAPL",
current_price=0.0,
open=0.0,
high=0.0,
low=0.0,
close=0.0,
volume=0.0,
)
weights = {"alpha": 1.0}
signal_result = await ensemble.evaluate("AAPL", market, sentiment, weights)
assert signal_result is not None
# Simulate injection logic
if market.current_price > 0:
if signal_result.sentiment_context is None:
signal_result.sentiment_context = {}
signal_result.sentiment_context["current_price"] = market.current_price
# sentiment_context should exist (from ensemble) but WITHOUT current_price
assert signal_result.sentiment_context is not None
assert "current_price" not in signal_result.sentiment_context
@pytest.mark.asyncio
async def test_current_price_preserves_existing_sentiment_context(self):
"""Injecting current_price should not overwrite other sentiment_context fields."""
s1 = _StubStrategy("alpha", _make_signal(SignalDirection.LONG, 0.9))
ensemble = WeightedEnsemble([s1], threshold=0.0)
sentiment = SentimentContext(
ticker="AAPL",
avg_score=0.75,
article_count=10,
recent_scores=[0.7, 0.8],
avg_confidence=0.9,
)
market = MarketSnapshot(
ticker="AAPL",
current_price=200.0,
open=198.0,
high=202.0,
low=197.0,
close=200.0,
volume=8000,
)
weights = {"alpha": 1.0}
signal_result = await ensemble.evaluate("AAPL", market, sentiment, weights)
assert signal_result is not None
assert signal_result.sentiment_context is not None
# Preserve existing keys
original_keys = set(signal_result.sentiment_context.keys())
# Inject current_price
if market.current_price > 0:
signal_result.sentiment_context["current_price"] = market.current_price
# All original keys should still be present
for key in original_keys:
assert key in signal_result.sentiment_context
# Plus the new one
assert signal_result.sentiment_context["current_price"] == 200.0
# ---------------------------------------------------------------------------
# Signal Generator — signal_id generation
# ---------------------------------------------------------------------------
class TestSignalIdGeneration:
"""Verify that TradeSignal includes a signal_id UUID."""
@pytest.mark.asyncio
async def test_signal_has_uuid(self):
"""Each generated signal should have a signal_id UUID."""
s1 = _StubStrategy("alpha", _make_signal(SignalDirection.LONG, 0.9))
ensemble = WeightedEnsemble([s1], threshold=0.0)
market = MarketSnapshot(
ticker="AAPL", current_price=150.0,
open=149.0, high=151.0, low=148.0, close=150.0, volume=1000,
)
weights = {"alpha": 1.0}
signal_result = await ensemble.evaluate("AAPL", market, None, weights)
assert signal_result is not None
assert signal_result.signal_id is not None
from uuid import UUID
assert isinstance(signal_result.signal_id, UUID)
@pytest.mark.asyncio
async def test_signal_ids_are_unique(self):
"""Multiple signals should get different signal_id values."""
s1 = _StubStrategy("alpha", _make_signal(SignalDirection.LONG, 0.9))
ensemble = WeightedEnsemble([s1], threshold=0.0)
market = MarketSnapshot(
ticker="AAPL", current_price=150.0,
open=149.0, high=151.0, low=148.0, close=150.0, volume=1000,
)
weights = {"alpha": 1.0}
sig1 = await ensemble.evaluate("AAPL", market, None, weights)
sig2 = await ensemble.evaluate("AAPL", market, None, weights)
assert sig1 is not None and sig2 is not None
assert sig1.signal_id != sig2.signal_id
def test_signal_id_serializes_in_json(self):
"""signal_id should serialize properly in JSON mode."""
signal = _make_signal(SignalDirection.LONG, 0.8)
data = signal.model_dump(mode="json")
assert "signal_id" in data
assert isinstance(data["signal_id"], str)

View file

@ -401,3 +401,119 @@ class TestExecutorFlowRejected:
# Rejection counter should have been incremented
counters["rejections"].add.assert_called_once()
# ---------------------------------------------------------------------------
# Executor flow — DB persistence
# ---------------------------------------------------------------------------
def _make_mock_db_session_factory(session=None):
"""Create a mock async_sessionmaker that yields a mock session."""
if session is None:
session = AsyncMock()
session.add = MagicMock()
session.commit = AsyncMock()
factory = MagicMock()
ctx = AsyncMock()
ctx.__aenter__ = AsyncMock(return_value=session)
ctx.__aexit__ = AsyncMock(return_value=False)
factory.return_value = ctx
return factory
class TestExecutorDBPersistence:
"""Verify that trades are persisted to the DB when db_session_factory is provided."""
@pytest.mark.asyncio
async def test_trade_persisted_with_signal_id(self):
"""When db_session_factory is provided, a Trade row should be created."""
config = _make_config()
broker = _mock_broker(positions=[], account=_make_account(100_000))
publisher = AsyncMock()
publisher.publish = AsyncMock(return_value=b"1-0")
counters = {
"trades_executed": MagicMock(),
"rejections": MagicMock(),
"fill_latency": MagicMock(),
}
signal = _make_signal(ticker="AAPL", strength=0.8, current_price=150.0)
mock_session = AsyncMock()
mock_session.add = MagicMock()
mock_session.commit = AsyncMock()
db_factory = _make_mock_db_session_factory(mock_session)
with patch.object(RiskManager, "check_risk", return_value=(True, "approved")):
await process_signal(
signal, RiskManager(config, broker), broker, publisher, counters, db_factory
)
# Trade should be persisted
mock_session.add.assert_called_once()
mock_session.commit.assert_awaited_once()
# Verify the trade object
trade_obj = mock_session.add.call_args[0][0]
assert trade_obj.ticker == "AAPL"
assert trade_obj.signal_id == signal.signal_id
@pytest.mark.asyncio
async def test_trade_not_persisted_without_db(self):
"""When db_session_factory is None, no DB write should happen."""
config = _make_config()
broker = _mock_broker(positions=[], account=_make_account(100_000))
publisher = AsyncMock()
publisher.publish = AsyncMock(return_value=b"1-0")
counters = {
"trades_executed": MagicMock(),
"rejections": MagicMock(),
"fill_latency": MagicMock(),
}
signal = _make_signal(ticker="AAPL", strength=0.8, current_price=150.0)
with patch.object(RiskManager, "check_risk", return_value=(True, "approved")):
await process_signal(
signal, RiskManager(config, broker), broker, publisher, counters, None
)
# Should still publish
publisher.publish.assert_called_once()
@pytest.mark.asyncio
async def test_db_error_does_not_block_publishing(self):
"""A DB error should not prevent the trade from being published."""
config = _make_config()
broker = _mock_broker(positions=[], account=_make_account(100_000))
publisher = AsyncMock()
publisher.publish = AsyncMock(return_value=b"1-0")
counters = {
"trades_executed": MagicMock(),
"rejections": MagicMock(),
"fill_latency": MagicMock(),
}
signal = _make_signal(ticker="AAPL", strength=0.8, current_price=150.0)
mock_session = AsyncMock()
mock_session.add = MagicMock()
mock_session.commit = AsyncMock(side_effect=RuntimeError("DB connection lost"))
db_factory = _make_mock_db_session_factory(mock_session)
with patch.object(RiskManager, "check_risk", return_value=(True, "approved")):
await process_signal(
signal, RiskManager(config, broker), broker, publisher, counters, db_factory
)
# Trade should still be published despite DB error
publisher.publish.assert_called_once()
counters["trades_executed"].add.assert_called_once_with(1)
def test_signal_id_flows_through_execution(self):
"""signal_id from TradeSignal should appear in the published TradeExecution."""
signal = _make_signal(ticker="AAPL", strength=0.8, current_price=150.0)
assert signal.signal_id is not None
# Verify signal_id is a UUID
from uuid import UUID
assert isinstance(signal.signal_id, UUID)