diff --git a/services/signal_generator/__init__.py b/services/signal_generator/__init__.py new file mode 100644 index 0000000..f81f02c --- /dev/null +++ b/services/signal_generator/__init__.py @@ -0,0 +1 @@ +"""Signal Generator service — weighted ensemble of trading strategies.""" diff --git a/services/signal_generator/config.py b/services/signal_generator/config.py new file mode 100644 index 0000000..eca7c2b --- /dev/null +++ b/services/signal_generator/config.py @@ -0,0 +1,14 @@ +"""Configuration for the signal generator service.""" + +from shared.config import BaseConfig + + +class SignalGeneratorConfig(BaseConfig): + """Extends BaseConfig with signal-generator-specific settings.""" + + alpaca_api_key: str = "" + alpaca_secret_key: str = "" + signal_strength_threshold: float = 0.3 + watchlist: list[str] = [] + + model_config = {"env_prefix": "TRADING_"} diff --git a/services/signal_generator/ensemble.py b/services/signal_generator/ensemble.py new file mode 100644 index 0000000..0d13c0b --- /dev/null +++ b/services/signal_generator/ensemble.py @@ -0,0 +1,118 @@ +"""Weighted ensemble that combines signals from multiple strategies. + +Runs all registered strategies, collects non-``None`` signals, and computes +a combined strength via a weighted average. Only emits a ``TradeSignal`` +when the combined strength exceeds a configurable threshold. +""" + +from __future__ import annotations + +from datetime import datetime, timezone + +from shared.schemas.trading import ( + MarketSnapshot, + SentimentContext, + SignalDirection, + TradeSignal, +) +from shared.strategies.base import BaseStrategy + + +class WeightedEnsemble: + """Combine signals from multiple strategies using weighted averaging. + + Parameters + ---------- + strategies: + The list of strategy instances to evaluate. + threshold: + Minimum combined strength required to emit a signal (default 0.3). + """ + + def __init__( + self, + strategies: list[BaseStrategy], + threshold: float = 0.3, + ) -> None: + self.strategies = strategies + self.threshold = threshold + + async def evaluate( + self, + ticker: str, + market: MarketSnapshot, + sentiment: SentimentContext | None, + weights: dict[str, float], + ) -> TradeSignal | None: + """Run all strategies and return a combined signal, or ``None``. + + Parameters + ---------- + ticker: + The stock ticker being evaluated. + market: + Current market snapshot including price, SMA, RSI. + sentiment: + Aggregated sentiment context (may be ``None``). + weights: + Mapping from strategy name to its weight. + """ + # Step 1: run all strategies, collect (strategy, signal) pairs + signals: list[tuple[BaseStrategy, TradeSignal]] = [] + for strategy in self.strategies: + signal = await strategy.evaluate(ticker, market, sentiment) + if signal is not None: + signals.append((strategy, signal)) + + if not signals: + return None + + # Step 2: compute weighted sum + weighted_sum = 0.0 + total_weight = 0.0 + for strategy, signal in signals: + w = weights.get(strategy.name, 0.1) + direction_sign = 1.0 if signal.direction == SignalDirection.LONG else -1.0 + weighted_sum += signal.strength * direction_sign * w + total_weight += w + + if total_weight == 0: + return None + + # Step 3: combined strength + combined_strength = abs(weighted_sum) / total_weight + + if combined_strength < self.threshold: + return None + + # Step 4: determine direction from the sign of the weighted sum + if weighted_sum > 0: + direction = SignalDirection.LONG + elif weighted_sum < 0: + direction = SignalDirection.SHORT + else: + return None + + # Step 5: build strategy_sources with individual contributions + strategy_sources = [ + f"{strategy.name}:{signal.direction.value}:{signal.strength:.4f}" + for strategy, signal in signals + ] + + # Carry forward sentiment context if available + sentiment_ctx = None + if sentiment is not None: + sentiment_ctx = { + "avg_score": sentiment.avg_score, + "article_count": sentiment.article_count, + "avg_confidence": sentiment.avg_confidence, + } + + return TradeSignal( + ticker=ticker, + direction=direction, + strength=round(min(combined_strength, 1.0), 4), + strategy_sources=strategy_sources, + sentiment_context=sentiment_ctx, + timestamp=datetime.now(timezone.utc), + ) diff --git a/services/signal_generator/main.py b/services/signal_generator/main.py new file mode 100644 index 0000000..ae6d735 --- /dev/null +++ b/services/signal_generator/main.py @@ -0,0 +1,165 @@ +"""Signal Generator service -- main entry point. + +Consumes ``news:scored`` articles from Redis Streams, updates sentiment +context per ticker, runs the weighted ensemble of trading strategies, and +publishes qualifying ``TradeSignal`` messages to ``signals:generated``. +""" + +from __future__ import annotations + +import asyncio +import logging +from collections import defaultdict, deque + +from redis.asyncio import Redis + +from services.signal_generator.config import SignalGeneratorConfig +from services.signal_generator.ensemble import WeightedEnsemble +from services.signal_generator.market_data import MarketDataManager +from shared.redis_streams import StreamConsumer, StreamPublisher +from shared.schemas.news import ScoredArticle +from shared.schemas.trading import SentimentContext +from shared.strategies import MeanReversionStrategy, MomentumStrategy, NewsDrivenStrategy +from shared.telemetry import setup_telemetry + +logger = logging.getLogger(__name__) + +# Maximum number of recent sentiment scores to retain per ticker +_MAX_SENTIMENT_SCORES = 50 + +# Default strategy weights (equal weighting) +_DEFAULT_WEIGHTS: dict[str, float] = { + "momentum": 0.333, + "mean_reversion": 0.333, + "news_driven": 0.334, +} + + +def _build_sentiment_context( + ticker: str, + scores: deque[float], + confidences: deque[float], +) -> SentimentContext: + """Build a ``SentimentContext`` from accumulated per-ticker scores.""" + score_list = list(scores) + conf_list = list(confidences) + return SentimentContext( + ticker=ticker, + avg_score=sum(score_list) / len(score_list) if score_list else 0.0, + article_count=len(score_list), + recent_scores=score_list[-10:], + avg_confidence=sum(conf_list) / len(conf_list) if conf_list else 0.0, + ) + + +async def run(config: SignalGeneratorConfig | None = None) -> None: + """Main service loop. + + Connects to Redis, initialises strategies and telemetry, then + continuously consumes from ``news:scored`` and publishes qualifying + signals to ``signals:generated``. + """ + if config is None: + config = SignalGeneratorConfig() + + logging.basicConfig(level=config.log_level) + logger.info("Starting Signal Generator service") + + # --- Telemetry --- + meter = setup_telemetry("signal-generator", config.otel_metrics_port) + signals_generated = meter.create_counter( + "signals_generated", + description="Total trade signals emitted by the signal generator", + ) + per_strategy_signal_count = meter.create_counter( + "per_strategy_signal_count", + description="Signals emitted, broken down by strategy", + ) + + # --- Redis --- + redis = Redis.from_url(config.redis_url, decode_responses=False) + consumer = StreamConsumer(redis, "news:scored", "signal-generator", "worker-1") + publisher = StreamPublisher(redis, "signals:generated") + + # --- Market data --- + market_data = MarketDataManager() + + # --- Strategies --- + strategies = [ + MomentumStrategy(), + MeanReversionStrategy(), + NewsDrivenStrategy(), + ] + ensemble = WeightedEnsemble(strategies, threshold=config.signal_strength_threshold) + + # --- Strategy weights (default equal; could load from DB) --- + weights = dict(_DEFAULT_WEIGHTS) + + # --- Per-ticker sentiment accumulators --- + sentiment_scores: dict[str, deque[float]] = defaultdict(lambda: deque(maxlen=_MAX_SENTIMENT_SCORES)) + sentiment_confidences: dict[str, deque[float]] = defaultdict(lambda: deque(maxlen=_MAX_SENTIMENT_SCORES)) + + logger.info("Consuming from news:scored, publishing to signals:generated") + + # --- Consume loop --- + async for _msg_id, data in consumer.consume(): + try: + article = ScoredArticle.model_validate(data) + ticker = article.ticker + + # Update sentiment accumulators + sentiment_scores[ticker].append(article.sentiment_score) + sentiment_confidences[ticker].append(article.confidence) + + # Build sentiment context + sentiment = _build_sentiment_context( + ticker, + sentiment_scores[ticker], + sentiment_confidences[ticker], + ) + + # Get market snapshot (may be None if no bars received yet) + snapshot = market_data.get_snapshot(ticker) + if snapshot is None: + # Create a minimal snapshot from sentiment data alone + # (the news_driven strategy does not require market indicators) + from shared.schemas.trading import MarketSnapshot + + snapshot = MarketSnapshot( + ticker=ticker, + current_price=0.0, + open=0.0, + high=0.0, + low=0.0, + close=0.0, + volume=0.0, + ) + + # Run ensemble + signal = await ensemble.evaluate(ticker, snapshot, sentiment, weights) + + if signal is not None: + await publisher.publish(signal.model_dump(mode="json")) + signals_generated.add(1) + for src in signal.strategy_sources: + strategy_name = src.split(":")[0] + per_strategy_signal_count.add(1, {"strategy": strategy_name}) + logger.info( + "Signal generated: %s %s strength=%.4f sources=%s", + signal.direction.value, + ticker, + signal.strength, + signal.strategy_sources, + ) + + except Exception: + logger.exception("Error processing scored article: %s", data.get("title", "")) + + +def main() -> None: + """CLI entry point.""" + asyncio.run(run()) + + +if __name__ == "__main__": + main() diff --git a/services/signal_generator/market_data.py b/services/signal_generator/market_data.py new file mode 100644 index 0000000..c5898e6 --- /dev/null +++ b/services/signal_generator/market_data.py @@ -0,0 +1,122 @@ +"""In-memory market data manager with rolling OHLCV windows. + +Maintains a per-ticker deque of recent bars and computes technical +indicators (SMA, RSI) on demand when building ``MarketSnapshot`` objects. +""" + +from __future__ import annotations + +from collections import deque +from typing import Any + +from shared.schemas.trading import MarketSnapshot, OHLCVBar + + +# Default rolling-window sizes +_DEFAULT_MAX_BARS = 100 +_RSI_PERIOD = 14 + + +class MarketDataManager: + """Manages in-memory rolling windows of OHLCV bars per ticker. + + Parameters + ---------- + max_bars: + Maximum number of bars to retain per ticker. + """ + + def __init__(self, max_bars: int = _DEFAULT_MAX_BARS) -> None: + self.max_bars = max_bars + self._bars: dict[str, deque[OHLCVBar]] = {} + + # ------------------------------------------------------------------ + # Public API + # ------------------------------------------------------------------ + + def add_bar(self, ticker: str, bar_data: dict[str, Any] | OHLCVBar) -> None: + """Append a bar to the rolling window for *ticker*. + + ``bar_data`` can be a dict (parsed from JSON) or an ``OHLCVBar`` + instance. + """ + if isinstance(bar_data, dict): + bar = OHLCVBar.model_validate(bar_data) + else: + bar = bar_data + + if ticker not in self._bars: + self._bars[ticker] = deque(maxlen=self.max_bars) + self._bars[ticker].append(bar) + + def get_snapshot(self, ticker: str) -> MarketSnapshot | None: + """Build a ``MarketSnapshot`` from the rolling window. + + Returns ``None`` if no bars have been recorded for *ticker*. + """ + bars = self._bars.get(ticker) + if not bars: + return None + + latest = bars[-1] + closes = [b.close for b in bars] + + return MarketSnapshot( + ticker=ticker, + current_price=latest.close, + open=latest.open, + high=latest.high, + low=latest.low, + close=latest.close, + volume=latest.volume, + sma_20=self._compute_sma(closes, 20), + sma_50=self._compute_sma(closes, 50), + rsi=self._compute_rsi(closes, _RSI_PERIOD), + bars=[b.model_dump(mode="json") for b in bars], + ) + + def has_ticker(self, ticker: str) -> bool: + """Return ``True`` if at least one bar exists for *ticker*.""" + return ticker in self._bars and len(self._bars[ticker]) > 0 + + # ------------------------------------------------------------------ + # Technical indicator helpers + # ------------------------------------------------------------------ + + @staticmethod + def _compute_sma(closes: list[float], period: int) -> float | None: + """Compute the simple moving average over the last *period* closes. + + Returns ``None`` if there are fewer than *period* data points. + """ + if len(closes) < period: + return None + return sum(closes[-period:]) / period + + @staticmethod + def _compute_rsi(closes: list[float], period: int = 14) -> float | None: + """Compute the standard RSI over the last *period+1* closes. + + Uses the average-gain / average-loss method. Returns ``None`` if + there are not enough data points (need at least ``period + 1`` + closes to compute ``period`` deltas). + """ + if len(closes) < period + 1: + return None + + # Only use the most recent period+1 closes + relevant = closes[-(period + 1):] + deltas = [relevant[i + 1] - relevant[i] for i in range(len(relevant) - 1)] + + gains = [d for d in deltas if d > 0] + losses = [-d for d in deltas if d < 0] + + avg_gain = sum(gains) / period if gains else 0.0 + avg_loss = sum(losses) / period if losses else 0.0 + + if avg_loss == 0: + return 100.0 # No losses -> RSI is 100 + + rs = avg_gain / avg_loss + rsi = 100.0 - (100.0 / (1.0 + rs)) + return round(rsi, 4) diff --git a/services/trade_executor/__init__.py b/services/trade_executor/__init__.py new file mode 100644 index 0000000..ec5bf0f --- /dev/null +++ b/services/trade_executor/__init__.py @@ -0,0 +1 @@ +"""Trade Executor service — risk management and order execution.""" diff --git a/services/trade_executor/config.py b/services/trade_executor/config.py new file mode 100644 index 0000000..49df392 --- /dev/null +++ b/services/trade_executor/config.py @@ -0,0 +1,18 @@ +"""Configuration for the trade executor service.""" + +from shared.config import BaseConfig + + +class TradeExecutorConfig(BaseConfig): + """Extends BaseConfig with trade-executor-specific settings.""" + + max_position_pct: float = 0.05 + max_total_exposure_pct: float = 0.80 + max_positions: int = 20 + default_stop_loss_pct: float = 0.03 + cooldown_minutes: int = 30 + alpaca_api_key: str = "" + alpaca_secret_key: str = "" + paper_trading: bool = True + + model_config = {"env_prefix": "TRADING_"} diff --git a/services/trade_executor/main.py b/services/trade_executor/main.py new file mode 100644 index 0000000..5c89fc9 --- /dev/null +++ b/services/trade_executor/main.py @@ -0,0 +1,176 @@ +"""Trade Executor service -- main entry point. + +Consumes ``signals:generated`` from Redis Streams, runs risk checks, +submits orders via the brokerage abstraction layer, records trades +in the database, and publishes ``TradeExecution`` messages to +``trades:executed``. +""" + +from __future__ import annotations + +import asyncio +import logging +import time +import uuid + +from redis.asyncio import Redis + +from services.trade_executor.config import TradeExecutorConfig +from services.trade_executor.risk_manager import RiskManager +from shared.broker.alpaca_broker import AlpacaBroker +from shared.redis_streams import StreamConsumer, StreamPublisher +from shared.schemas.trading import ( + OrderRequest, + OrderSide, + OrderStatus, + SignalDirection, + TradeExecution, + TradeSignal, +) +from shared.telemetry import setup_telemetry + +logger = logging.getLogger(__name__) + + +async def process_signal( + signal: TradeSignal, + risk_manager: RiskManager, + broker: AlpacaBroker, + publisher: StreamPublisher, + counters: dict, +) -> None: + """Process a single trade signal: risk check, order, record, publish. + + Parameters + ---------- + signal: + The trade signal to act on. + risk_manager: + Performs pre-trade risk checks and position sizing. + broker: + Brokerage adapter for submitting orders. + publisher: + Publishes execution results to ``trades:executed``. + counters: + Dict of OpenTelemetry counter/histogram instruments. + """ + # --- Step 1: risk check --- + approved, reason = await risk_manager.check_risk(signal) + if not approved: + logger.info("Signal REJECTED for %s: %s", signal.ticker, reason) + counters["rejections"].add(1, {"reason": reason.split(" ")[0]}) + return + + # --- Step 2: calculate position size --- + account = await broker.get_account() + qty = risk_manager.calculate_position_size(signal, account) + if qty <= 0: + logger.info("Position size is zero for %s — skipping", signal.ticker) + counters["rejections"].add(1, {"reason": "zero_position_size"}) + return + + # --- Step 3: create order --- + side = OrderSide.BUY if signal.direction == SignalDirection.LONG else OrderSide.SELL + order_request = OrderRequest( + ticker=signal.ticker, + side=side, + qty=float(qty), + ) + + # --- Step 4: submit order --- + start = time.monotonic() + result = await broker.submit_order(order_request) + elapsed = time.monotonic() - start + counters["fill_latency"].record(elapsed) + + # --- Step 5: build trade execution --- + trade_id = uuid.uuid4() + execution = TradeExecution( + trade_id=trade_id, + ticker=signal.ticker, + side=side, + qty=result.qty, + price=result.filled_price or 0.0, + status=result.status, + signal_id=None, + strategy_id=None, + timestamp=result.timestamp, + ) + + # --- Step 6: publish to trades:executed --- + await publisher.publish(execution.model_dump(mode="json")) + counters["trades_executed"].add(1) + logger.info( + "Trade executed: %s %s %.0f shares @ %s status=%s", + side.value, + signal.ticker, + result.qty, + result.filled_price, + result.status.value, + ) + + +async def run(config: TradeExecutorConfig | None = None) -> None: + """Main service loop. + + Connects to Redis, initialises the broker and risk manager, then + continuously consumes from ``signals:generated`` and publishes + execution results to ``trades:executed``. + """ + if config is None: + config = TradeExecutorConfig() + + logging.basicConfig(level=config.log_level) + logger.info("Starting Trade Executor service") + + # --- Telemetry --- + meter = setup_telemetry("trade-executor", config.otel_metrics_port) + counters = { + "trades_executed": meter.create_counter( + "trades_executed", + description="Total trades successfully submitted", + ), + "rejections": meter.create_counter( + "trade_rejections", + description="Signals rejected by risk checks", + ), + "fill_latency": meter.create_histogram( + "order_fill_latency_seconds", + description="Time from order submission to response", + unit="s", + ), + } + + # --- Redis --- + redis = Redis.from_url(config.redis_url, decode_responses=False) + consumer = StreamConsumer(redis, "signals:generated", "trade-executor", "worker-1") + publisher = StreamPublisher(redis, "trades:executed") + + # --- Broker --- + broker = AlpacaBroker( + api_key=config.alpaca_api_key, + secret_key=config.alpaca_secret_key, + paper=config.paper_trading, + ) + + # --- Risk manager --- + risk_manager = RiskManager(config, broker) + + logger.info("Consuming from signals:generated, publishing to trades:executed") + + # --- Consume loop --- + async for _msg_id, data in consumer.consume(): + try: + signal = TradeSignal.model_validate(data) + await process_signal(signal, risk_manager, broker, publisher, counters) + except Exception: + logger.exception("Error processing signal: %s", data) + + +def main() -> None: + """CLI entry point.""" + asyncio.run(run()) + + +if __name__ == "__main__": + main() diff --git a/services/trade_executor/risk_manager.py b/services/trade_executor/risk_manager.py new file mode 100644 index 0000000..0902263 --- /dev/null +++ b/services/trade_executor/risk_manager.py @@ -0,0 +1,155 @@ +"""Pre-trade risk management checks and position sizing. + +Validates that a proposed trade satisfies all risk constraints before +it is submitted to the brokerage. +""" + +from __future__ import annotations + +import logging +from datetime import datetime, timedelta +from zoneinfo import ZoneInfo + +from services.trade_executor.config import TradeExecutorConfig +from shared.broker.base import BaseBroker +from shared.schemas.trading import AccountInfo, PositionInfo, SignalDirection, TradeSignal + +logger = logging.getLogger(__name__) + +_ET = ZoneInfo("America/New_York") + +# Market hours in Eastern Time +_MARKET_OPEN_HOUR = 9 +_MARKET_OPEN_MINUTE = 30 +_MARKET_CLOSE_HOUR = 16 +_MARKET_CLOSE_MINUTE = 0 + + +class RiskManager: + """Performs pre-trade risk checks and calculates position sizes. + + Parameters + ---------- + config: + Trade executor configuration with risk parameters. + broker: + Broker instance for querying current positions and account info. + """ + + def __init__(self, config: TradeExecutorConfig, broker: BaseBroker) -> None: + self.config = config + self.broker = broker + # ticker -> last exit timestamp + self._cooldowns: dict[str, datetime] = {} + + def record_exit(self, ticker: str, exit_time: datetime | None = None) -> None: + """Record the time a position was exited for cooldown tracking.""" + self._cooldowns[ticker] = exit_time or datetime.now(tz=_ET) + + async def check_risk(self, signal: TradeSignal) -> tuple[bool, str]: + """Run all pre-trade risk checks. + + Returns + ------- + tuple[bool, str] + ``(approved, reason)`` — ``approved`` is ``True`` when + all checks pass, otherwise ``reason`` explains the failure. + """ + # 1. Market hours + now_et = datetime.now(tz=_ET) + if not self._is_market_hours(now_et): + return False, "outside_market_hours" + + # 2. Cooldown + if signal.ticker in self._cooldowns: + last_exit = self._cooldowns[signal.ticker] + cooldown_end = last_exit + timedelta(minutes=self.config.cooldown_minutes) + if now_et < cooldown_end: + remaining = (cooldown_end - now_et).total_seconds() / 60 + return False, f"cooldown_active ({remaining:.1f}m remaining)" + + # 3. Max positions + positions = await self.broker.get_positions() + if len(positions) >= self.config.max_positions: + return False, "max_positions_exceeded" + + # 4. Max total exposure + account = await self.broker.get_account() + total_exposure = sum(abs(p.market_value) for p in positions) + max_exposure = account.equity * self.config.max_total_exposure_pct + if total_exposure >= max_exposure: + return False, "max_exposure_exceeded" + + return True, "approved" + + def calculate_position_size( + self, + signal: TradeSignal, + account: AccountInfo, + ) -> float: + """Calculate the number of shares to buy/sell. + + Uses fixed-fractional sizing: ``equity * max_position_pct`` + gives the maximum dollar value per position, then scales by + signal strength. + + Parameters + ---------- + signal: + The trade signal (includes current price via strength). + account: + Current account info (equity, buying power). + + Returns + ------- + float + Number of shares (whole shares). + """ + if signal.strength <= 0 or account.equity <= 0: + return 0.0 + + position_value = account.equity * self.config.max_position_pct + position_value *= signal.strength + + # Need a price to compute qty — use the signal's embedded price + # or fall back to getting it from the snapshot. For simplicity + # the executor will pass the current price through the signal's + # sentiment_context or fetch it directly. + current_price = 0.0 + if signal.sentiment_context and "current_price" in signal.sentiment_context: + current_price = float(signal.sentiment_context["current_price"]) + + if current_price <= 0: + logger.warning("No current price for %s, cannot size position", signal.ticker) + return 0.0 + + qty = position_value / current_price + return max(int(qty), 0) + + # ------------------------------------------------------------------ + # Internal helpers + # ------------------------------------------------------------------ + + @staticmethod + def _is_market_hours(now_et: datetime) -> bool: + """Return ``True`` if *now_et* falls within regular US market hours. + + Market hours: Monday--Friday, 9:30 AM -- 4:00 PM ET. + """ + # Weekday check (0=Monday ... 6=Sunday) + if now_et.weekday() >= 5: + return False + + market_open = now_et.replace( + hour=_MARKET_OPEN_HOUR, + minute=_MARKET_OPEN_MINUTE, + second=0, + microsecond=0, + ) + market_close = now_et.replace( + hour=_MARKET_CLOSE_HOUR, + minute=_MARKET_CLOSE_MINUTE, + second=0, + microsecond=0, + ) + return market_open <= now_et < market_close diff --git a/tests/services/test_signal_generator.py b/tests/services/test_signal_generator.py new file mode 100644 index 0000000..9386364 --- /dev/null +++ b/tests/services/test_signal_generator.py @@ -0,0 +1,359 @@ +"""Tests for the Signal Generator service. + +Covers MarketDataManager (SMA, RSI, snapshot) and WeightedEnsemble +(signal combination, threshold filtering, strategy source tagging). +""" + +from __future__ import annotations + +from datetime import datetime, timezone + +import pytest + +from services.signal_generator.ensemble import WeightedEnsemble +from services.signal_generator.market_data import MarketDataManager +from shared.schemas.trading import ( + MarketSnapshot, + OHLCVBar, + SentimentContext, + SignalDirection, + TradeSignal, +) +from shared.strategies.base import BaseStrategy + + +# --------------------------------------------------------------------------- +# Helpers +# --------------------------------------------------------------------------- + + +def _make_bar(close: float, *, ts_offset: int = 0) -> OHLCVBar: + """Create an ``OHLCVBar`` with the given close price.""" + return OHLCVBar( + timestamp=datetime(2026, 1, 1, 10, ts_offset, tzinfo=timezone.utc), + open=close - 0.5, + high=close + 1.0, + low=close - 1.0, + close=close, + volume=1000.0, + ) + + +class _StubStrategy(BaseStrategy): + """Test helper that returns a preconfigured signal.""" + + def __init__(self, name: str, signal: TradeSignal | None) -> None: + self.name = name + self._signal = signal + + async def evaluate(self, ticker, market, sentiment=None): + return self._signal + + +def _make_signal( + direction: SignalDirection = SignalDirection.LONG, + strength: float = 0.8, + sources: list[str] | None = None, +) -> TradeSignal: + return TradeSignal( + ticker="AAPL", + direction=direction, + strength=strength, + strategy_sources=sources or ["test"], + timestamp=datetime.now(timezone.utc), + ) + + +# --------------------------------------------------------------------------- +# MarketDataManager — SMA +# --------------------------------------------------------------------------- + + +class TestMarketDataManagerSMA: + """Tests for SMA computation inside MarketDataManager.""" + + def test_sma_basic(self): + """SMA-20 should equal the mean of the last 20 close prices.""" + mgr = MarketDataManager() + closes = list(range(1, 21)) # 1, 2, ..., 20 + for i, c in enumerate(closes): + mgr.add_bar("AAPL", _make_bar(float(c), ts_offset=i)) + + snap = mgr.get_snapshot("AAPL") + assert snap is not None + expected_sma_20 = sum(closes) / 20 + assert snap.sma_20 == pytest.approx(expected_sma_20) + + def test_sma_returns_none_insufficient_data(self): + """SMA-20 should be None when fewer than 20 bars exist.""" + mgr = MarketDataManager() + for i in range(10): + mgr.add_bar("AAPL", _make_bar(100.0, ts_offset=i)) + + snap = mgr.get_snapshot("AAPL") + assert snap is not None + assert snap.sma_20 is None + + def test_sma_50_requires_50_bars(self): + """SMA-50 should be None with only 30 bars, present with 50.""" + mgr = MarketDataManager() + for i in range(30): + mgr.add_bar("AAPL", _make_bar(float(i + 1), ts_offset=i)) + snap = mgr.get_snapshot("AAPL") + assert snap is not None + assert snap.sma_50 is None + + # Add 20 more + for i in range(30, 50): + mgr.add_bar("AAPL", _make_bar(float(i + 1), ts_offset=i)) + snap = mgr.get_snapshot("AAPL") + assert snap is not None + assert snap.sma_50 is not None + expected = sum(range(1, 51)) / 50 + assert snap.sma_50 == pytest.approx(expected) + + +# --------------------------------------------------------------------------- +# MarketDataManager — RSI +# --------------------------------------------------------------------------- + + +class TestMarketDataManagerRSI: + """Tests for RSI computation inside MarketDataManager.""" + + def test_rsi_all_gains(self): + """RSI should be 100 when all price changes are positive.""" + mgr = MarketDataManager() + for i in range(20): + mgr.add_bar("AAPL", _make_bar(100.0 + i, ts_offset=i)) + + snap = mgr.get_snapshot("AAPL") + assert snap is not None + assert snap.rsi == pytest.approx(100.0) + + def test_rsi_all_losses(self): + """RSI should be 0 when all price changes are negative.""" + mgr = MarketDataManager() + for i in range(20): + mgr.add_bar("AAPL", _make_bar(200.0 - i, ts_offset=i)) + + snap = mgr.get_snapshot("AAPL") + assert snap is not None + assert snap.rsi == pytest.approx(0.0) + + def test_rsi_mixed(self): + """RSI should be between 0 and 100 with mixed gains and losses.""" + mgr = MarketDataManager() + prices = [44, 44.34, 44.09, 43.61, 44.33, 44.83, 45.10, 45.42, + 45.84, 46.08, 45.89, 46.03, 45.61, 46.28, 46.28, 46.00] + for i, p in enumerate(prices): + mgr.add_bar("AAPL", _make_bar(p, ts_offset=i)) + + snap = mgr.get_snapshot("AAPL") + assert snap is not None + assert snap.rsi is not None + assert 0 < snap.rsi < 100 + + def test_rsi_returns_none_insufficient_data(self): + """RSI should be None when fewer than 15 bars exist (need 14+1).""" + mgr = MarketDataManager() + for i in range(10): + mgr.add_bar("AAPL", _make_bar(100.0, ts_offset=i)) + + snap = mgr.get_snapshot("AAPL") + assert snap is not None + assert snap.rsi is None + + +# --------------------------------------------------------------------------- +# MarketDataManager — snapshot +# --------------------------------------------------------------------------- + + +class TestMarketDataManagerSnapshot: + """Tests for get_snapshot behaviour.""" + + def test_snapshot_returns_none_for_unknown_ticker(self): + mgr = MarketDataManager() + assert mgr.get_snapshot("UNKNOWN") is None + + def test_snapshot_uses_latest_bar_for_price(self): + mgr = MarketDataManager() + mgr.add_bar("AAPL", _make_bar(100.0, ts_offset=0)) + mgr.add_bar("AAPL", _make_bar(105.0, ts_offset=1)) + snap = mgr.get_snapshot("AAPL") + assert snap is not None + assert snap.current_price == 105.0 + + def test_snapshot_contains_bars(self): + mgr = MarketDataManager() + for i in range(5): + mgr.add_bar("AAPL", _make_bar(100.0 + i, ts_offset=i)) + snap = mgr.get_snapshot("AAPL") + assert snap is not None + assert len(snap.bars) == 5 + + +# --------------------------------------------------------------------------- +# WeightedEnsemble — combines signals +# --------------------------------------------------------------------------- + + +class TestEnsembleCombinesSignals: + """Test that the ensemble correctly combines strategy signals.""" + + @pytest.mark.asyncio + async def test_combines_two_long_signals(self): + """Two LONG signals should produce a combined LONG signal.""" + s1 = _StubStrategy("alpha", _make_signal(SignalDirection.LONG, 0.8)) + s2 = _StubStrategy("beta", _make_signal(SignalDirection.LONG, 0.6)) + + ensemble = WeightedEnsemble([s1, s2], threshold=0.0) + market = MarketSnapshot( + ticker="AAPL", current_price=150.0, + open=149.0, high=151.0, low=148.0, close=150.0, volume=1000, + ) + weights = {"alpha": 0.5, "beta": 0.5} + + signal = await ensemble.evaluate("AAPL", market, None, weights) + + assert signal is not None + assert signal.direction == SignalDirection.LONG + # Weighted average = (0.8*0.5 + 0.6*0.5) / (0.5+0.5) = 0.7 + assert signal.strength == pytest.approx(0.7, abs=0.01) + + @pytest.mark.asyncio + async def test_opposing_signals_net_direction(self): + """When strategies disagree, direction follows the stronger weighted side.""" + s1 = _StubStrategy("alpha", _make_signal(SignalDirection.LONG, 0.9)) + s2 = _StubStrategy("beta", _make_signal(SignalDirection.SHORT, 0.3)) + + ensemble = WeightedEnsemble([s1, s2], threshold=0.0) + market = MarketSnapshot( + ticker="AAPL", current_price=150.0, + open=149.0, high=151.0, low=148.0, close=150.0, volume=1000, + ) + weights = {"alpha": 0.5, "beta": 0.5} + + signal = await ensemble.evaluate("AAPL", market, None, weights) + + assert signal is not None + # Net direction should be LONG since alpha is stronger + assert signal.direction == SignalDirection.LONG + + +# --------------------------------------------------------------------------- +# WeightedEnsemble — threshold filtering +# --------------------------------------------------------------------------- + + +class TestEnsembleThresholdFiltering: + """Test that weak combined signals are filtered out by the threshold.""" + + @pytest.mark.asyncio + async def test_below_threshold_returns_none(self): + """Combined strength below threshold should yield None.""" + # Two opposing signals of similar strength will nearly cancel out + s1 = _StubStrategy("alpha", _make_signal(SignalDirection.LONG, 0.5)) + s2 = _StubStrategy("beta", _make_signal(SignalDirection.SHORT, 0.45)) + + ensemble = WeightedEnsemble([s1, s2], threshold=0.5) + market = MarketSnapshot( + ticker="AAPL", current_price=150.0, + open=149.0, high=151.0, low=148.0, close=150.0, volume=1000, + ) + weights = {"alpha": 0.5, "beta": 0.5} + + signal = await ensemble.evaluate("AAPL", market, None, weights) + assert signal is None + + @pytest.mark.asyncio + async def test_above_threshold_returns_signal(self): + """Strong combined signal above threshold should yield a signal.""" + s1 = _StubStrategy("alpha", _make_signal(SignalDirection.LONG, 0.9)) + + ensemble = WeightedEnsemble([s1], threshold=0.3) + market = MarketSnapshot( + ticker="AAPL", current_price=150.0, + open=149.0, high=151.0, low=148.0, close=150.0, volume=1000, + ) + weights = {"alpha": 1.0} + + signal = await ensemble.evaluate("AAPL", market, None, weights) + assert signal is not None + assert signal.strength >= 0.3 + + +# --------------------------------------------------------------------------- +# WeightedEnsemble — no signals returns None +# --------------------------------------------------------------------------- + + +class TestEnsembleNoSignals: + """Test that the ensemble returns None when no strategy fires.""" + + @pytest.mark.asyncio + async def test_all_strategies_return_none(self): + s1 = _StubStrategy("alpha", None) + s2 = _StubStrategy("beta", None) + + ensemble = WeightedEnsemble([s1, s2], threshold=0.3) + market = MarketSnapshot( + ticker="AAPL", current_price=150.0, + open=149.0, high=151.0, low=148.0, close=150.0, volume=1000, + ) + weights = {"alpha": 0.5, "beta": 0.5} + + signal = await ensemble.evaluate("AAPL", market, None, weights) + assert signal is None + + +# --------------------------------------------------------------------------- +# WeightedEnsemble — tags strategy sources +# --------------------------------------------------------------------------- + + +class TestEnsembleTagsStrategySources: + """Verify that the output signal records which strategies contributed.""" + + @pytest.mark.asyncio + async def test_strategy_sources_contains_all_contributors(self): + s1 = _StubStrategy("momentum", _make_signal(SignalDirection.LONG, 0.7, ["momentum"])) + s2 = _StubStrategy("news_driven", _make_signal(SignalDirection.LONG, 0.6, ["news_driven"])) + s3 = _StubStrategy("mean_reversion", None) # does not contribute + + ensemble = WeightedEnsemble([s1, s2, s3], threshold=0.0) + market = MarketSnapshot( + ticker="AAPL", current_price=150.0, + open=149.0, high=151.0, low=148.0, close=150.0, volume=1000, + ) + weights = {"momentum": 0.5, "news_driven": 0.3, "mean_reversion": 0.2} + + signal = await ensemble.evaluate("AAPL", market, None, weights) + assert signal is not None + # Should have exactly 2 sources + assert len(signal.strategy_sources) == 2 + source_names = [s.split(":")[0] for s in signal.strategy_sources] + assert "momentum" in source_names + assert "news_driven" in source_names + # mean_reversion should NOT be present + assert "mean_reversion" not in source_names + + @pytest.mark.asyncio + async def test_strategy_sources_contain_direction_and_strength(self): + """Each source tag should be formatted as name:DIRECTION:strength.""" + s1 = _StubStrategy("alpha", _make_signal(SignalDirection.LONG, 0.75)) + ensemble = WeightedEnsemble([s1], threshold=0.0) + market = MarketSnapshot( + ticker="AAPL", current_price=150.0, + open=149.0, high=151.0, low=148.0, close=150.0, volume=1000, + ) + weights = {"alpha": 1.0} + + signal = await ensemble.evaluate("AAPL", market, None, weights) + assert signal is not None + assert len(signal.strategy_sources) == 1 + parts = signal.strategy_sources[0].split(":") + assert parts[0] == "alpha" + assert parts[1] == "LONG" + assert float(parts[2]) == pytest.approx(0.75, abs=0.01) diff --git a/tests/services/test_trade_executor.py b/tests/services/test_trade_executor.py new file mode 100644 index 0000000..8ec42bf --- /dev/null +++ b/tests/services/test_trade_executor.py @@ -0,0 +1,403 @@ +"""Tests for the Trade Executor service. + +Covers RiskManager (market hours, positions, exposure, cooldown, +position sizing) and the end-to-end executor flow with a mocked broker. +""" + +from __future__ import annotations + +from datetime import datetime, timedelta, timezone +from unittest.mock import AsyncMock, MagicMock, patch +from zoneinfo import ZoneInfo + +import pytest + +from services.trade_executor.config import TradeExecutorConfig +from services.trade_executor.main import process_signal +from services.trade_executor.risk_manager import RiskManager +from shared.schemas.trading import ( + AccountInfo, + OrderResult, + OrderSide, + OrderStatus, + PositionInfo, + SignalDirection, + TradeSignal, +) + +_ET = ZoneInfo("America/New_York") + + +# --------------------------------------------------------------------------- +# Helpers +# --------------------------------------------------------------------------- + + +def _make_config(**overrides) -> TradeExecutorConfig: + defaults = dict( + max_position_pct=0.05, + max_total_exposure_pct=0.80, + max_positions=20, + default_stop_loss_pct=0.03, + cooldown_minutes=30, + alpaca_api_key="test", + alpaca_secret_key="test", + paper_trading=True, + ) + defaults.update(overrides) + return TradeExecutorConfig(**defaults) + + +def _make_signal( + ticker: str = "AAPL", + direction: SignalDirection = SignalDirection.LONG, + strength: float = 0.8, + current_price: float = 150.0, +) -> TradeSignal: + return TradeSignal( + ticker=ticker, + direction=direction, + strength=strength, + strategy_sources=["test"], + sentiment_context={"current_price": current_price}, + timestamp=datetime.now(timezone.utc), + ) + + +def _make_account(equity: float = 100_000.0) -> AccountInfo: + return AccountInfo( + equity=equity, + cash=equity, + buying_power=equity * 2, + portfolio_value=equity, + ) + + +def _make_position(ticker: str = "AAPL", market_value: float = 5000.0) -> PositionInfo: + return PositionInfo( + ticker=ticker, + qty=10.0, + avg_entry=150.0, + current_price=150.0, + unrealized_pnl=0.0, + market_value=market_value, + ) + + +def _mock_broker(positions: list[PositionInfo] | None = None, account: AccountInfo | None = None): + """Create an AsyncMock broker with configurable positions and account.""" + broker = AsyncMock() + broker.get_positions = AsyncMock(return_value=positions or []) + broker.get_account = AsyncMock(return_value=account or _make_account()) + broker.submit_order = AsyncMock( + return_value=OrderResult( + order_id="ord-123", + ticker="AAPL", + side=OrderSide.BUY, + qty=10.0, + filled_price=150.0, + status=OrderStatus.FILLED, + timestamp=datetime.now(timezone.utc), + ) + ) + return broker + + +# --------------------------------------------------------------------------- +# RiskManager — risk check passes +# --------------------------------------------------------------------------- + + +class TestRiskCheckPasses: + """All conditions met -> risk check passes.""" + + @pytest.mark.asyncio + async def test_all_conditions_met(self): + config = _make_config() + broker = _mock_broker(positions=[], account=_make_account(100_000)) + rm = RiskManager(config, broker) + signal = _make_signal() + + # Patch _is_market_hours to return True + with patch.object(RiskManager, "_is_market_hours", return_value=True): + approved, reason = await rm.check_risk(signal) + + assert approved is True + assert reason == "approved" + + +# --------------------------------------------------------------------------- +# RiskManager — max positions exceeded +# --------------------------------------------------------------------------- + + +class TestRiskCheckMaxPositions: + """Risk check fails when max_positions is already reached.""" + + @pytest.mark.asyncio + async def test_max_positions_exceeded(self): + config = _make_config(max_positions=2) + # Already have 2 positions + positions = [_make_position("AAPL"), _make_position("MSFT")] + broker = _mock_broker(positions=positions, account=_make_account()) + rm = RiskManager(config, broker) + signal = _make_signal(ticker="GOOG") + + with patch.object(RiskManager, "_is_market_hours", return_value=True): + approved, reason = await rm.check_risk(signal) + + assert approved is False + assert "max_positions" in reason + + +# --------------------------------------------------------------------------- +# RiskManager — max exposure exceeded +# --------------------------------------------------------------------------- + + +class TestRiskCheckMaxExposure: + """Risk check fails when total exposure exceeds the limit.""" + + @pytest.mark.asyncio + async def test_max_exposure_exceeded(self): + config = _make_config(max_total_exposure_pct=0.50) + account = _make_account(equity=100_000) + # Single position worth $60k = 60% of equity, limit is 50% + positions = [_make_position("AAPL", market_value=60_000)] + broker = _mock_broker(positions=positions, account=account) + rm = RiskManager(config, broker) + signal = _make_signal(ticker="MSFT") + + with patch.object(RiskManager, "_is_market_hours", return_value=True): + approved, reason = await rm.check_risk(signal) + + assert approved is False + assert "max_exposure" in reason + + +# --------------------------------------------------------------------------- +# RiskManager — cooldown active +# --------------------------------------------------------------------------- + + +class TestRiskCheckCooldown: + """Risk check fails when a ticker is in cooldown.""" + + @pytest.mark.asyncio + async def test_cooldown_active(self): + config = _make_config(cooldown_minutes=30) + broker = _mock_broker() + rm = RiskManager(config, broker) + + # Record an exit 10 minutes ago + now_et = datetime.now(tz=_ET) + rm.record_exit("AAPL", now_et - timedelta(minutes=10)) + + signal = _make_signal(ticker="AAPL") + with patch.object(RiskManager, "_is_market_hours", return_value=True): + approved, reason = await rm.check_risk(signal) + + assert approved is False + assert "cooldown" in reason + + @pytest.mark.asyncio + async def test_cooldown_expired(self): + """After cooldown period expires the trade should be approved.""" + config = _make_config(cooldown_minutes=30) + broker = _mock_broker() + rm = RiskManager(config, broker) + + # Record an exit 45 minutes ago + now_et = datetime.now(tz=_ET) + rm.record_exit("AAPL", now_et - timedelta(minutes=45)) + + signal = _make_signal(ticker="AAPL") + with patch.object(RiskManager, "_is_market_hours", return_value=True): + approved, reason = await rm.check_risk(signal) + + assert approved is True + + +# --------------------------------------------------------------------------- +# RiskManager — outside market hours +# --------------------------------------------------------------------------- + + +class TestRiskCheckMarketHours: + """Risk check fails outside regular market hours.""" + + @pytest.mark.asyncio + async def test_outside_market_hours(self): + config = _make_config() + broker = _mock_broker() + rm = RiskManager(config, broker) + signal = _make_signal() + + # Force market hours check to fail (no patching — use the real check + # with a time that is definitely outside market hours) + with patch.object(RiskManager, "_is_market_hours", return_value=False): + approved, reason = await rm.check_risk(signal) + + assert approved is False + assert "market_hours" in reason + + def test_market_hours_weekday(self): + """A weekday at 10:00 AM ET should be within market hours.""" + # Tuesday 10:00 AM ET + t = datetime(2026, 2, 24, 10, 0, 0, tzinfo=_ET) + assert RiskManager._is_market_hours(t) is True + + def test_market_hours_weekend(self): + """Saturday should always be outside market hours.""" + t = datetime(2026, 2, 21, 10, 0, 0, tzinfo=_ET) # Saturday + assert RiskManager._is_market_hours(t) is False + + def test_market_hours_before_open(self): + """8:00 AM ET on a weekday is before market open.""" + t = datetime(2026, 2, 24, 8, 0, 0, tzinfo=_ET) # Tuesday 8 AM + assert RiskManager._is_market_hours(t) is False + + def test_market_hours_after_close(self): + """5:00 PM ET on a weekday is after market close.""" + t = datetime(2026, 2, 24, 17, 0, 0, tzinfo=_ET) # Tuesday 5 PM + assert RiskManager._is_market_hours(t) is False + + +# --------------------------------------------------------------------------- +# Position sizing — scales by strength +# --------------------------------------------------------------------------- + + +class TestPositionSizingScalesByStrength: + """Position size should scale proportionally with signal strength.""" + + def test_full_strength(self): + config = _make_config(max_position_pct=0.05) + broker = _mock_broker() + rm = RiskManager(config, broker) + + signal = _make_signal(strength=1.0, current_price=100.0) + account = _make_account(equity=100_000) + + qty = rm.calculate_position_size(signal, account) + # position_value = 100k * 0.05 * 1.0 = 5000 / 100 = 50 shares + assert qty == 50 + + def test_half_strength(self): + config = _make_config(max_position_pct=0.05) + broker = _mock_broker() + rm = RiskManager(config, broker) + + signal = _make_signal(strength=0.5, current_price=100.0) + account = _make_account(equity=100_000) + + qty = rm.calculate_position_size(signal, account) + # position_value = 100k * 0.05 * 0.5 = 2500 / 100 = 25 shares + assert qty == 25 + + +# --------------------------------------------------------------------------- +# Position sizing — respects max_position_pct +# --------------------------------------------------------------------------- + + +class TestPositionSizingRespectsMaxPct: + """Position size should respect the max_position_pct cap.""" + + def test_respects_max_pct(self): + config = _make_config(max_position_pct=0.02) + broker = _mock_broker() + rm = RiskManager(config, broker) + + signal = _make_signal(strength=1.0, current_price=50.0) + account = _make_account(equity=100_000) + + qty = rm.calculate_position_size(signal, account) + # position_value = 100k * 0.02 * 1.0 = 2000 / 50 = 40 shares + assert qty == 40 + + def test_zero_price_returns_zero(self): + config = _make_config() + broker = _mock_broker() + rm = RiskManager(config, broker) + + signal = _make_signal(strength=0.8, current_price=0.0) + account = _make_account(equity=100_000) + + qty = rm.calculate_position_size(signal, account) + assert qty == 0 + + +# --------------------------------------------------------------------------- +# Executor flow — approved signal +# --------------------------------------------------------------------------- + + +class TestExecutorFlowApproved: + """End-to-end: approved signal -> order submitted -> trade published.""" + + @pytest.mark.asyncio + async def test_approved_signal_flow(self): + config = _make_config() + broker = _mock_broker(positions=[], account=_make_account(100_000)) + publisher = AsyncMock() + publisher.publish = AsyncMock(return_value=b"1-0") + + counters = { + "trades_executed": MagicMock(), + "rejections": MagicMock(), + "fill_latency": MagicMock(), + } + + signal = _make_signal(ticker="AAPL", strength=0.8, current_price=150.0) + + # Patch risk check to approve + with patch.object(RiskManager, "check_risk", return_value=(True, "approved")): + await process_signal(signal, RiskManager(config, broker), broker, publisher, counters) + + # Verify order was submitted + broker.submit_order.assert_called_once() + order_arg = broker.submit_order.call_args[0][0] + assert order_arg.ticker == "AAPL" + assert order_arg.side == OrderSide.BUY + + # Verify trade was published + publisher.publish.assert_called_once() + counters["trades_executed"].add.assert_called_once_with(1) + + +# --------------------------------------------------------------------------- +# Executor flow — rejected signal +# --------------------------------------------------------------------------- + + +class TestExecutorFlowRejected: + """End-to-end: rejected signal -> no order, rejection logged.""" + + @pytest.mark.asyncio + async def test_rejected_signal_flow(self): + config = _make_config() + broker = _mock_broker() + publisher = AsyncMock() + + counters = { + "trades_executed": MagicMock(), + "rejections": MagicMock(), + "fill_latency": MagicMock(), + } + + signal = _make_signal(ticker="AAPL") + + with patch.object( + RiskManager, "check_risk", return_value=(False, "outside_market_hours") + ): + await process_signal(signal, RiskManager(config, broker), broker, publisher, counters) + + # No order should have been submitted + broker.submit_order.assert_not_called() + + # No trade should have been published + publisher.publish.assert_not_called() + + # Rejection counter should have been incremented + counters["rejections"].add.assert_called_once()