Merge branch 'worktree-agent-ad9ede16'
# Conflicts: # shared/strategies/__init__.py # shared/strategies/base.py # shared/strategies/mean_reversion.py # shared/strategies/momentum.py # shared/strategies/news_driven.py
This commit is contained in:
commit
1d9900838d
11 changed files with 1532 additions and 0 deletions
1
services/signal_generator/__init__.py
Normal file
1
services/signal_generator/__init__.py
Normal file
|
|
@ -0,0 +1 @@
|
|||
"""Signal Generator service — weighted ensemble of trading strategies."""
|
||||
14
services/signal_generator/config.py
Normal file
14
services/signal_generator/config.py
Normal file
|
|
@ -0,0 +1,14 @@
|
|||
"""Configuration for the signal generator service."""
|
||||
|
||||
from shared.config import BaseConfig
|
||||
|
||||
|
||||
class SignalGeneratorConfig(BaseConfig):
|
||||
"""Extends BaseConfig with signal-generator-specific settings."""
|
||||
|
||||
alpaca_api_key: str = ""
|
||||
alpaca_secret_key: str = ""
|
||||
signal_strength_threshold: float = 0.3
|
||||
watchlist: list[str] = []
|
||||
|
||||
model_config = {"env_prefix": "TRADING_"}
|
||||
118
services/signal_generator/ensemble.py
Normal file
118
services/signal_generator/ensemble.py
Normal file
|
|
@ -0,0 +1,118 @@
|
|||
"""Weighted ensemble that combines signals from multiple strategies.
|
||||
|
||||
Runs all registered strategies, collects non-``None`` signals, and computes
|
||||
a combined strength via a weighted average. Only emits a ``TradeSignal``
|
||||
when the combined strength exceeds a configurable threshold.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from datetime import datetime, timezone
|
||||
|
||||
from shared.schemas.trading import (
|
||||
MarketSnapshot,
|
||||
SentimentContext,
|
||||
SignalDirection,
|
||||
TradeSignal,
|
||||
)
|
||||
from shared.strategies.base import BaseStrategy
|
||||
|
||||
|
||||
class WeightedEnsemble:
|
||||
"""Combine signals from multiple strategies using weighted averaging.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
strategies:
|
||||
The list of strategy instances to evaluate.
|
||||
threshold:
|
||||
Minimum combined strength required to emit a signal (default 0.3).
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
strategies: list[BaseStrategy],
|
||||
threshold: float = 0.3,
|
||||
) -> None:
|
||||
self.strategies = strategies
|
||||
self.threshold = threshold
|
||||
|
||||
async def evaluate(
|
||||
self,
|
||||
ticker: str,
|
||||
market: MarketSnapshot,
|
||||
sentiment: SentimentContext | None,
|
||||
weights: dict[str, float],
|
||||
) -> TradeSignal | None:
|
||||
"""Run all strategies and return a combined signal, or ``None``.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
ticker:
|
||||
The stock ticker being evaluated.
|
||||
market:
|
||||
Current market snapshot including price, SMA, RSI.
|
||||
sentiment:
|
||||
Aggregated sentiment context (may be ``None``).
|
||||
weights:
|
||||
Mapping from strategy name to its weight.
|
||||
"""
|
||||
# Step 1: run all strategies, collect (strategy, signal) pairs
|
||||
signals: list[tuple[BaseStrategy, TradeSignal]] = []
|
||||
for strategy in self.strategies:
|
||||
signal = await strategy.evaluate(ticker, market, sentiment)
|
||||
if signal is not None:
|
||||
signals.append((strategy, signal))
|
||||
|
||||
if not signals:
|
||||
return None
|
||||
|
||||
# Step 2: compute weighted sum
|
||||
weighted_sum = 0.0
|
||||
total_weight = 0.0
|
||||
for strategy, signal in signals:
|
||||
w = weights.get(strategy.name, 0.1)
|
||||
direction_sign = 1.0 if signal.direction == SignalDirection.LONG else -1.0
|
||||
weighted_sum += signal.strength * direction_sign * w
|
||||
total_weight += w
|
||||
|
||||
if total_weight == 0:
|
||||
return None
|
||||
|
||||
# Step 3: combined strength
|
||||
combined_strength = abs(weighted_sum) / total_weight
|
||||
|
||||
if combined_strength < self.threshold:
|
||||
return None
|
||||
|
||||
# Step 4: determine direction from the sign of the weighted sum
|
||||
if weighted_sum > 0:
|
||||
direction = SignalDirection.LONG
|
||||
elif weighted_sum < 0:
|
||||
direction = SignalDirection.SHORT
|
||||
else:
|
||||
return None
|
||||
|
||||
# Step 5: build strategy_sources with individual contributions
|
||||
strategy_sources = [
|
||||
f"{strategy.name}:{signal.direction.value}:{signal.strength:.4f}"
|
||||
for strategy, signal in signals
|
||||
]
|
||||
|
||||
# Carry forward sentiment context if available
|
||||
sentiment_ctx = None
|
||||
if sentiment is not None:
|
||||
sentiment_ctx = {
|
||||
"avg_score": sentiment.avg_score,
|
||||
"article_count": sentiment.article_count,
|
||||
"avg_confidence": sentiment.avg_confidence,
|
||||
}
|
||||
|
||||
return TradeSignal(
|
||||
ticker=ticker,
|
||||
direction=direction,
|
||||
strength=round(min(combined_strength, 1.0), 4),
|
||||
strategy_sources=strategy_sources,
|
||||
sentiment_context=sentiment_ctx,
|
||||
timestamp=datetime.now(timezone.utc),
|
||||
)
|
||||
165
services/signal_generator/main.py
Normal file
165
services/signal_generator/main.py
Normal file
|
|
@ -0,0 +1,165 @@
|
|||
"""Signal Generator service -- main entry point.
|
||||
|
||||
Consumes ``news:scored`` articles from Redis Streams, updates sentiment
|
||||
context per ticker, runs the weighted ensemble of trading strategies, and
|
||||
publishes qualifying ``TradeSignal`` messages to ``signals:generated``.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
from collections import defaultdict, deque
|
||||
|
||||
from redis.asyncio import Redis
|
||||
|
||||
from services.signal_generator.config import SignalGeneratorConfig
|
||||
from services.signal_generator.ensemble import WeightedEnsemble
|
||||
from services.signal_generator.market_data import MarketDataManager
|
||||
from shared.redis_streams import StreamConsumer, StreamPublisher
|
||||
from shared.schemas.news import ScoredArticle
|
||||
from shared.schemas.trading import SentimentContext
|
||||
from shared.strategies import MeanReversionStrategy, MomentumStrategy, NewsDrivenStrategy
|
||||
from shared.telemetry import setup_telemetry
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Maximum number of recent sentiment scores to retain per ticker
|
||||
_MAX_SENTIMENT_SCORES = 50
|
||||
|
||||
# Default strategy weights (equal weighting)
|
||||
_DEFAULT_WEIGHTS: dict[str, float] = {
|
||||
"momentum": 0.333,
|
||||
"mean_reversion": 0.333,
|
||||
"news_driven": 0.334,
|
||||
}
|
||||
|
||||
|
||||
def _build_sentiment_context(
|
||||
ticker: str,
|
||||
scores: deque[float],
|
||||
confidences: deque[float],
|
||||
) -> SentimentContext:
|
||||
"""Build a ``SentimentContext`` from accumulated per-ticker scores."""
|
||||
score_list = list(scores)
|
||||
conf_list = list(confidences)
|
||||
return SentimentContext(
|
||||
ticker=ticker,
|
||||
avg_score=sum(score_list) / len(score_list) if score_list else 0.0,
|
||||
article_count=len(score_list),
|
||||
recent_scores=score_list[-10:],
|
||||
avg_confidence=sum(conf_list) / len(conf_list) if conf_list else 0.0,
|
||||
)
|
||||
|
||||
|
||||
async def run(config: SignalGeneratorConfig | None = None) -> None:
|
||||
"""Main service loop.
|
||||
|
||||
Connects to Redis, initialises strategies and telemetry, then
|
||||
continuously consumes from ``news:scored`` and publishes qualifying
|
||||
signals to ``signals:generated``.
|
||||
"""
|
||||
if config is None:
|
||||
config = SignalGeneratorConfig()
|
||||
|
||||
logging.basicConfig(level=config.log_level)
|
||||
logger.info("Starting Signal Generator service")
|
||||
|
||||
# --- Telemetry ---
|
||||
meter = setup_telemetry("signal-generator", config.otel_metrics_port)
|
||||
signals_generated = meter.create_counter(
|
||||
"signals_generated",
|
||||
description="Total trade signals emitted by the signal generator",
|
||||
)
|
||||
per_strategy_signal_count = meter.create_counter(
|
||||
"per_strategy_signal_count",
|
||||
description="Signals emitted, broken down by strategy",
|
||||
)
|
||||
|
||||
# --- Redis ---
|
||||
redis = Redis.from_url(config.redis_url, decode_responses=False)
|
||||
consumer = StreamConsumer(redis, "news:scored", "signal-generator", "worker-1")
|
||||
publisher = StreamPublisher(redis, "signals:generated")
|
||||
|
||||
# --- Market data ---
|
||||
market_data = MarketDataManager()
|
||||
|
||||
# --- Strategies ---
|
||||
strategies = [
|
||||
MomentumStrategy(),
|
||||
MeanReversionStrategy(),
|
||||
NewsDrivenStrategy(),
|
||||
]
|
||||
ensemble = WeightedEnsemble(strategies, threshold=config.signal_strength_threshold)
|
||||
|
||||
# --- Strategy weights (default equal; could load from DB) ---
|
||||
weights = dict(_DEFAULT_WEIGHTS)
|
||||
|
||||
# --- Per-ticker sentiment accumulators ---
|
||||
sentiment_scores: dict[str, deque[float]] = defaultdict(lambda: deque(maxlen=_MAX_SENTIMENT_SCORES))
|
||||
sentiment_confidences: dict[str, deque[float]] = defaultdict(lambda: deque(maxlen=_MAX_SENTIMENT_SCORES))
|
||||
|
||||
logger.info("Consuming from news:scored, publishing to signals:generated")
|
||||
|
||||
# --- Consume loop ---
|
||||
async for _msg_id, data in consumer.consume():
|
||||
try:
|
||||
article = ScoredArticle.model_validate(data)
|
||||
ticker = article.ticker
|
||||
|
||||
# Update sentiment accumulators
|
||||
sentiment_scores[ticker].append(article.sentiment_score)
|
||||
sentiment_confidences[ticker].append(article.confidence)
|
||||
|
||||
# Build sentiment context
|
||||
sentiment = _build_sentiment_context(
|
||||
ticker,
|
||||
sentiment_scores[ticker],
|
||||
sentiment_confidences[ticker],
|
||||
)
|
||||
|
||||
# Get market snapshot (may be None if no bars received yet)
|
||||
snapshot = market_data.get_snapshot(ticker)
|
||||
if snapshot is None:
|
||||
# Create a minimal snapshot from sentiment data alone
|
||||
# (the news_driven strategy does not require market indicators)
|
||||
from shared.schemas.trading import MarketSnapshot
|
||||
|
||||
snapshot = MarketSnapshot(
|
||||
ticker=ticker,
|
||||
current_price=0.0,
|
||||
open=0.0,
|
||||
high=0.0,
|
||||
low=0.0,
|
||||
close=0.0,
|
||||
volume=0.0,
|
||||
)
|
||||
|
||||
# Run ensemble
|
||||
signal = await ensemble.evaluate(ticker, snapshot, sentiment, weights)
|
||||
|
||||
if signal is not None:
|
||||
await publisher.publish(signal.model_dump(mode="json"))
|
||||
signals_generated.add(1)
|
||||
for src in signal.strategy_sources:
|
||||
strategy_name = src.split(":")[0]
|
||||
per_strategy_signal_count.add(1, {"strategy": strategy_name})
|
||||
logger.info(
|
||||
"Signal generated: %s %s strength=%.4f sources=%s",
|
||||
signal.direction.value,
|
||||
ticker,
|
||||
signal.strength,
|
||||
signal.strategy_sources,
|
||||
)
|
||||
|
||||
except Exception:
|
||||
logger.exception("Error processing scored article: %s", data.get("title", "<unknown>"))
|
||||
|
||||
|
||||
def main() -> None:
|
||||
"""CLI entry point."""
|
||||
asyncio.run(run())
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
122
services/signal_generator/market_data.py
Normal file
122
services/signal_generator/market_data.py
Normal file
|
|
@ -0,0 +1,122 @@
|
|||
"""In-memory market data manager with rolling OHLCV windows.
|
||||
|
||||
Maintains a per-ticker deque of recent bars and computes technical
|
||||
indicators (SMA, RSI) on demand when building ``MarketSnapshot`` objects.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from collections import deque
|
||||
from typing import Any
|
||||
|
||||
from shared.schemas.trading import MarketSnapshot, OHLCVBar
|
||||
|
||||
|
||||
# Default rolling-window sizes
|
||||
_DEFAULT_MAX_BARS = 100
|
||||
_RSI_PERIOD = 14
|
||||
|
||||
|
||||
class MarketDataManager:
|
||||
"""Manages in-memory rolling windows of OHLCV bars per ticker.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
max_bars:
|
||||
Maximum number of bars to retain per ticker.
|
||||
"""
|
||||
|
||||
def __init__(self, max_bars: int = _DEFAULT_MAX_BARS) -> None:
|
||||
self.max_bars = max_bars
|
||||
self._bars: dict[str, deque[OHLCVBar]] = {}
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Public API
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def add_bar(self, ticker: str, bar_data: dict[str, Any] | OHLCVBar) -> None:
|
||||
"""Append a bar to the rolling window for *ticker*.
|
||||
|
||||
``bar_data`` can be a dict (parsed from JSON) or an ``OHLCVBar``
|
||||
instance.
|
||||
"""
|
||||
if isinstance(bar_data, dict):
|
||||
bar = OHLCVBar.model_validate(bar_data)
|
||||
else:
|
||||
bar = bar_data
|
||||
|
||||
if ticker not in self._bars:
|
||||
self._bars[ticker] = deque(maxlen=self.max_bars)
|
||||
self._bars[ticker].append(bar)
|
||||
|
||||
def get_snapshot(self, ticker: str) -> MarketSnapshot | None:
|
||||
"""Build a ``MarketSnapshot`` from the rolling window.
|
||||
|
||||
Returns ``None`` if no bars have been recorded for *ticker*.
|
||||
"""
|
||||
bars = self._bars.get(ticker)
|
||||
if not bars:
|
||||
return None
|
||||
|
||||
latest = bars[-1]
|
||||
closes = [b.close for b in bars]
|
||||
|
||||
return MarketSnapshot(
|
||||
ticker=ticker,
|
||||
current_price=latest.close,
|
||||
open=latest.open,
|
||||
high=latest.high,
|
||||
low=latest.low,
|
||||
close=latest.close,
|
||||
volume=latest.volume,
|
||||
sma_20=self._compute_sma(closes, 20),
|
||||
sma_50=self._compute_sma(closes, 50),
|
||||
rsi=self._compute_rsi(closes, _RSI_PERIOD),
|
||||
bars=[b.model_dump(mode="json") for b in bars],
|
||||
)
|
||||
|
||||
def has_ticker(self, ticker: str) -> bool:
|
||||
"""Return ``True`` if at least one bar exists for *ticker*."""
|
||||
return ticker in self._bars and len(self._bars[ticker]) > 0
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Technical indicator helpers
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
@staticmethod
|
||||
def _compute_sma(closes: list[float], period: int) -> float | None:
|
||||
"""Compute the simple moving average over the last *period* closes.
|
||||
|
||||
Returns ``None`` if there are fewer than *period* data points.
|
||||
"""
|
||||
if len(closes) < period:
|
||||
return None
|
||||
return sum(closes[-period:]) / period
|
||||
|
||||
@staticmethod
|
||||
def _compute_rsi(closes: list[float], period: int = 14) -> float | None:
|
||||
"""Compute the standard RSI over the last *period+1* closes.
|
||||
|
||||
Uses the average-gain / average-loss method. Returns ``None`` if
|
||||
there are not enough data points (need at least ``period + 1``
|
||||
closes to compute ``period`` deltas).
|
||||
"""
|
||||
if len(closes) < period + 1:
|
||||
return None
|
||||
|
||||
# Only use the most recent period+1 closes
|
||||
relevant = closes[-(period + 1):]
|
||||
deltas = [relevant[i + 1] - relevant[i] for i in range(len(relevant) - 1)]
|
||||
|
||||
gains = [d for d in deltas if d > 0]
|
||||
losses = [-d for d in deltas if d < 0]
|
||||
|
||||
avg_gain = sum(gains) / period if gains else 0.0
|
||||
avg_loss = sum(losses) / period if losses else 0.0
|
||||
|
||||
if avg_loss == 0:
|
||||
return 100.0 # No losses -> RSI is 100
|
||||
|
||||
rs = avg_gain / avg_loss
|
||||
rsi = 100.0 - (100.0 / (1.0 + rs))
|
||||
return round(rsi, 4)
|
||||
1
services/trade_executor/__init__.py
Normal file
1
services/trade_executor/__init__.py
Normal file
|
|
@ -0,0 +1 @@
|
|||
"""Trade Executor service — risk management and order execution."""
|
||||
18
services/trade_executor/config.py
Normal file
18
services/trade_executor/config.py
Normal file
|
|
@ -0,0 +1,18 @@
|
|||
"""Configuration for the trade executor service."""
|
||||
|
||||
from shared.config import BaseConfig
|
||||
|
||||
|
||||
class TradeExecutorConfig(BaseConfig):
|
||||
"""Extends BaseConfig with trade-executor-specific settings."""
|
||||
|
||||
max_position_pct: float = 0.05
|
||||
max_total_exposure_pct: float = 0.80
|
||||
max_positions: int = 20
|
||||
default_stop_loss_pct: float = 0.03
|
||||
cooldown_minutes: int = 30
|
||||
alpaca_api_key: str = ""
|
||||
alpaca_secret_key: str = ""
|
||||
paper_trading: bool = True
|
||||
|
||||
model_config = {"env_prefix": "TRADING_"}
|
||||
176
services/trade_executor/main.py
Normal file
176
services/trade_executor/main.py
Normal file
|
|
@ -0,0 +1,176 @@
|
|||
"""Trade Executor service -- main entry point.
|
||||
|
||||
Consumes ``signals:generated`` from Redis Streams, runs risk checks,
|
||||
submits orders via the brokerage abstraction layer, records trades
|
||||
in the database, and publishes ``TradeExecution`` messages to
|
||||
``trades:executed``.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
import time
|
||||
import uuid
|
||||
|
||||
from redis.asyncio import Redis
|
||||
|
||||
from services.trade_executor.config import TradeExecutorConfig
|
||||
from services.trade_executor.risk_manager import RiskManager
|
||||
from shared.broker.alpaca_broker import AlpacaBroker
|
||||
from shared.redis_streams import StreamConsumer, StreamPublisher
|
||||
from shared.schemas.trading import (
|
||||
OrderRequest,
|
||||
OrderSide,
|
||||
OrderStatus,
|
||||
SignalDirection,
|
||||
TradeExecution,
|
||||
TradeSignal,
|
||||
)
|
||||
from shared.telemetry import setup_telemetry
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
async def process_signal(
|
||||
signal: TradeSignal,
|
||||
risk_manager: RiskManager,
|
||||
broker: AlpacaBroker,
|
||||
publisher: StreamPublisher,
|
||||
counters: dict,
|
||||
) -> None:
|
||||
"""Process a single trade signal: risk check, order, record, publish.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
signal:
|
||||
The trade signal to act on.
|
||||
risk_manager:
|
||||
Performs pre-trade risk checks and position sizing.
|
||||
broker:
|
||||
Brokerage adapter for submitting orders.
|
||||
publisher:
|
||||
Publishes execution results to ``trades:executed``.
|
||||
counters:
|
||||
Dict of OpenTelemetry counter/histogram instruments.
|
||||
"""
|
||||
# --- Step 1: risk check ---
|
||||
approved, reason = await risk_manager.check_risk(signal)
|
||||
if not approved:
|
||||
logger.info("Signal REJECTED for %s: %s", signal.ticker, reason)
|
||||
counters["rejections"].add(1, {"reason": reason.split(" ")[0]})
|
||||
return
|
||||
|
||||
# --- Step 2: calculate position size ---
|
||||
account = await broker.get_account()
|
||||
qty = risk_manager.calculate_position_size(signal, account)
|
||||
if qty <= 0:
|
||||
logger.info("Position size is zero for %s — skipping", signal.ticker)
|
||||
counters["rejections"].add(1, {"reason": "zero_position_size"})
|
||||
return
|
||||
|
||||
# --- Step 3: create order ---
|
||||
side = OrderSide.BUY if signal.direction == SignalDirection.LONG else OrderSide.SELL
|
||||
order_request = OrderRequest(
|
||||
ticker=signal.ticker,
|
||||
side=side,
|
||||
qty=float(qty),
|
||||
)
|
||||
|
||||
# --- Step 4: submit order ---
|
||||
start = time.monotonic()
|
||||
result = await broker.submit_order(order_request)
|
||||
elapsed = time.monotonic() - start
|
||||
counters["fill_latency"].record(elapsed)
|
||||
|
||||
# --- Step 5: build trade execution ---
|
||||
trade_id = uuid.uuid4()
|
||||
execution = TradeExecution(
|
||||
trade_id=trade_id,
|
||||
ticker=signal.ticker,
|
||||
side=side,
|
||||
qty=result.qty,
|
||||
price=result.filled_price or 0.0,
|
||||
status=result.status,
|
||||
signal_id=None,
|
||||
strategy_id=None,
|
||||
timestamp=result.timestamp,
|
||||
)
|
||||
|
||||
# --- Step 6: publish to trades:executed ---
|
||||
await publisher.publish(execution.model_dump(mode="json"))
|
||||
counters["trades_executed"].add(1)
|
||||
logger.info(
|
||||
"Trade executed: %s %s %.0f shares @ %s status=%s",
|
||||
side.value,
|
||||
signal.ticker,
|
||||
result.qty,
|
||||
result.filled_price,
|
||||
result.status.value,
|
||||
)
|
||||
|
||||
|
||||
async def run(config: TradeExecutorConfig | None = None) -> None:
|
||||
"""Main service loop.
|
||||
|
||||
Connects to Redis, initialises the broker and risk manager, then
|
||||
continuously consumes from ``signals:generated`` and publishes
|
||||
execution results to ``trades:executed``.
|
||||
"""
|
||||
if config is None:
|
||||
config = TradeExecutorConfig()
|
||||
|
||||
logging.basicConfig(level=config.log_level)
|
||||
logger.info("Starting Trade Executor service")
|
||||
|
||||
# --- Telemetry ---
|
||||
meter = setup_telemetry("trade-executor", config.otel_metrics_port)
|
||||
counters = {
|
||||
"trades_executed": meter.create_counter(
|
||||
"trades_executed",
|
||||
description="Total trades successfully submitted",
|
||||
),
|
||||
"rejections": meter.create_counter(
|
||||
"trade_rejections",
|
||||
description="Signals rejected by risk checks",
|
||||
),
|
||||
"fill_latency": meter.create_histogram(
|
||||
"order_fill_latency_seconds",
|
||||
description="Time from order submission to response",
|
||||
unit="s",
|
||||
),
|
||||
}
|
||||
|
||||
# --- Redis ---
|
||||
redis = Redis.from_url(config.redis_url, decode_responses=False)
|
||||
consumer = StreamConsumer(redis, "signals:generated", "trade-executor", "worker-1")
|
||||
publisher = StreamPublisher(redis, "trades:executed")
|
||||
|
||||
# --- Broker ---
|
||||
broker = AlpacaBroker(
|
||||
api_key=config.alpaca_api_key,
|
||||
secret_key=config.alpaca_secret_key,
|
||||
paper=config.paper_trading,
|
||||
)
|
||||
|
||||
# --- Risk manager ---
|
||||
risk_manager = RiskManager(config, broker)
|
||||
|
||||
logger.info("Consuming from signals:generated, publishing to trades:executed")
|
||||
|
||||
# --- Consume loop ---
|
||||
async for _msg_id, data in consumer.consume():
|
||||
try:
|
||||
signal = TradeSignal.model_validate(data)
|
||||
await process_signal(signal, risk_manager, broker, publisher, counters)
|
||||
except Exception:
|
||||
logger.exception("Error processing signal: %s", data)
|
||||
|
||||
|
||||
def main() -> None:
|
||||
"""CLI entry point."""
|
||||
asyncio.run(run())
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
155
services/trade_executor/risk_manager.py
Normal file
155
services/trade_executor/risk_manager.py
Normal file
|
|
@ -0,0 +1,155 @@
|
|||
"""Pre-trade risk management checks and position sizing.
|
||||
|
||||
Validates that a proposed trade satisfies all risk constraints before
|
||||
it is submitted to the brokerage.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from datetime import datetime, timedelta
|
||||
from zoneinfo import ZoneInfo
|
||||
|
||||
from services.trade_executor.config import TradeExecutorConfig
|
||||
from shared.broker.base import BaseBroker
|
||||
from shared.schemas.trading import AccountInfo, PositionInfo, SignalDirection, TradeSignal
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
_ET = ZoneInfo("America/New_York")
|
||||
|
||||
# Market hours in Eastern Time
|
||||
_MARKET_OPEN_HOUR = 9
|
||||
_MARKET_OPEN_MINUTE = 30
|
||||
_MARKET_CLOSE_HOUR = 16
|
||||
_MARKET_CLOSE_MINUTE = 0
|
||||
|
||||
|
||||
class RiskManager:
|
||||
"""Performs pre-trade risk checks and calculates position sizes.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
config:
|
||||
Trade executor configuration with risk parameters.
|
||||
broker:
|
||||
Broker instance for querying current positions and account info.
|
||||
"""
|
||||
|
||||
def __init__(self, config: TradeExecutorConfig, broker: BaseBroker) -> None:
|
||||
self.config = config
|
||||
self.broker = broker
|
||||
# ticker -> last exit timestamp
|
||||
self._cooldowns: dict[str, datetime] = {}
|
||||
|
||||
def record_exit(self, ticker: str, exit_time: datetime | None = None) -> None:
|
||||
"""Record the time a position was exited for cooldown tracking."""
|
||||
self._cooldowns[ticker] = exit_time or datetime.now(tz=_ET)
|
||||
|
||||
async def check_risk(self, signal: TradeSignal) -> tuple[bool, str]:
|
||||
"""Run all pre-trade risk checks.
|
||||
|
||||
Returns
|
||||
-------
|
||||
tuple[bool, str]
|
||||
``(approved, reason)`` — ``approved`` is ``True`` when
|
||||
all checks pass, otherwise ``reason`` explains the failure.
|
||||
"""
|
||||
# 1. Market hours
|
||||
now_et = datetime.now(tz=_ET)
|
||||
if not self._is_market_hours(now_et):
|
||||
return False, "outside_market_hours"
|
||||
|
||||
# 2. Cooldown
|
||||
if signal.ticker in self._cooldowns:
|
||||
last_exit = self._cooldowns[signal.ticker]
|
||||
cooldown_end = last_exit + timedelta(minutes=self.config.cooldown_minutes)
|
||||
if now_et < cooldown_end:
|
||||
remaining = (cooldown_end - now_et).total_seconds() / 60
|
||||
return False, f"cooldown_active ({remaining:.1f}m remaining)"
|
||||
|
||||
# 3. Max positions
|
||||
positions = await self.broker.get_positions()
|
||||
if len(positions) >= self.config.max_positions:
|
||||
return False, "max_positions_exceeded"
|
||||
|
||||
# 4. Max total exposure
|
||||
account = await self.broker.get_account()
|
||||
total_exposure = sum(abs(p.market_value) for p in positions)
|
||||
max_exposure = account.equity * self.config.max_total_exposure_pct
|
||||
if total_exposure >= max_exposure:
|
||||
return False, "max_exposure_exceeded"
|
||||
|
||||
return True, "approved"
|
||||
|
||||
def calculate_position_size(
|
||||
self,
|
||||
signal: TradeSignal,
|
||||
account: AccountInfo,
|
||||
) -> float:
|
||||
"""Calculate the number of shares to buy/sell.
|
||||
|
||||
Uses fixed-fractional sizing: ``equity * max_position_pct``
|
||||
gives the maximum dollar value per position, then scales by
|
||||
signal strength.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
signal:
|
||||
The trade signal (includes current price via strength).
|
||||
account:
|
||||
Current account info (equity, buying power).
|
||||
|
||||
Returns
|
||||
-------
|
||||
float
|
||||
Number of shares (whole shares).
|
||||
"""
|
||||
if signal.strength <= 0 or account.equity <= 0:
|
||||
return 0.0
|
||||
|
||||
position_value = account.equity * self.config.max_position_pct
|
||||
position_value *= signal.strength
|
||||
|
||||
# Need a price to compute qty — use the signal's embedded price
|
||||
# or fall back to getting it from the snapshot. For simplicity
|
||||
# the executor will pass the current price through the signal's
|
||||
# sentiment_context or fetch it directly.
|
||||
current_price = 0.0
|
||||
if signal.sentiment_context and "current_price" in signal.sentiment_context:
|
||||
current_price = float(signal.sentiment_context["current_price"])
|
||||
|
||||
if current_price <= 0:
|
||||
logger.warning("No current price for %s, cannot size position", signal.ticker)
|
||||
return 0.0
|
||||
|
||||
qty = position_value / current_price
|
||||
return max(int(qty), 0)
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Internal helpers
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
@staticmethod
|
||||
def _is_market_hours(now_et: datetime) -> bool:
|
||||
"""Return ``True`` if *now_et* falls within regular US market hours.
|
||||
|
||||
Market hours: Monday--Friday, 9:30 AM -- 4:00 PM ET.
|
||||
"""
|
||||
# Weekday check (0=Monday ... 6=Sunday)
|
||||
if now_et.weekday() >= 5:
|
||||
return False
|
||||
|
||||
market_open = now_et.replace(
|
||||
hour=_MARKET_OPEN_HOUR,
|
||||
minute=_MARKET_OPEN_MINUTE,
|
||||
second=0,
|
||||
microsecond=0,
|
||||
)
|
||||
market_close = now_et.replace(
|
||||
hour=_MARKET_CLOSE_HOUR,
|
||||
minute=_MARKET_CLOSE_MINUTE,
|
||||
second=0,
|
||||
microsecond=0,
|
||||
)
|
||||
return market_open <= now_et < market_close
|
||||
359
tests/services/test_signal_generator.py
Normal file
359
tests/services/test_signal_generator.py
Normal file
|
|
@ -0,0 +1,359 @@
|
|||
"""Tests for the Signal Generator service.
|
||||
|
||||
Covers MarketDataManager (SMA, RSI, snapshot) and WeightedEnsemble
|
||||
(signal combination, threshold filtering, strategy source tagging).
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from datetime import datetime, timezone
|
||||
|
||||
import pytest
|
||||
|
||||
from services.signal_generator.ensemble import WeightedEnsemble
|
||||
from services.signal_generator.market_data import MarketDataManager
|
||||
from shared.schemas.trading import (
|
||||
MarketSnapshot,
|
||||
OHLCVBar,
|
||||
SentimentContext,
|
||||
SignalDirection,
|
||||
TradeSignal,
|
||||
)
|
||||
from shared.strategies.base import BaseStrategy
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Helpers
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def _make_bar(close: float, *, ts_offset: int = 0) -> OHLCVBar:
|
||||
"""Create an ``OHLCVBar`` with the given close price."""
|
||||
return OHLCVBar(
|
||||
timestamp=datetime(2026, 1, 1, 10, ts_offset, tzinfo=timezone.utc),
|
||||
open=close - 0.5,
|
||||
high=close + 1.0,
|
||||
low=close - 1.0,
|
||||
close=close,
|
||||
volume=1000.0,
|
||||
)
|
||||
|
||||
|
||||
class _StubStrategy(BaseStrategy):
|
||||
"""Test helper that returns a preconfigured signal."""
|
||||
|
||||
def __init__(self, name: str, signal: TradeSignal | None) -> None:
|
||||
self.name = name
|
||||
self._signal = signal
|
||||
|
||||
async def evaluate(self, ticker, market, sentiment=None):
|
||||
return self._signal
|
||||
|
||||
|
||||
def _make_signal(
|
||||
direction: SignalDirection = SignalDirection.LONG,
|
||||
strength: float = 0.8,
|
||||
sources: list[str] | None = None,
|
||||
) -> TradeSignal:
|
||||
return TradeSignal(
|
||||
ticker="AAPL",
|
||||
direction=direction,
|
||||
strength=strength,
|
||||
strategy_sources=sources or ["test"],
|
||||
timestamp=datetime.now(timezone.utc),
|
||||
)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# MarketDataManager — SMA
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestMarketDataManagerSMA:
|
||||
"""Tests for SMA computation inside MarketDataManager."""
|
||||
|
||||
def test_sma_basic(self):
|
||||
"""SMA-20 should equal the mean of the last 20 close prices."""
|
||||
mgr = MarketDataManager()
|
||||
closes = list(range(1, 21)) # 1, 2, ..., 20
|
||||
for i, c in enumerate(closes):
|
||||
mgr.add_bar("AAPL", _make_bar(float(c), ts_offset=i))
|
||||
|
||||
snap = mgr.get_snapshot("AAPL")
|
||||
assert snap is not None
|
||||
expected_sma_20 = sum(closes) / 20
|
||||
assert snap.sma_20 == pytest.approx(expected_sma_20)
|
||||
|
||||
def test_sma_returns_none_insufficient_data(self):
|
||||
"""SMA-20 should be None when fewer than 20 bars exist."""
|
||||
mgr = MarketDataManager()
|
||||
for i in range(10):
|
||||
mgr.add_bar("AAPL", _make_bar(100.0, ts_offset=i))
|
||||
|
||||
snap = mgr.get_snapshot("AAPL")
|
||||
assert snap is not None
|
||||
assert snap.sma_20 is None
|
||||
|
||||
def test_sma_50_requires_50_bars(self):
|
||||
"""SMA-50 should be None with only 30 bars, present with 50."""
|
||||
mgr = MarketDataManager()
|
||||
for i in range(30):
|
||||
mgr.add_bar("AAPL", _make_bar(float(i + 1), ts_offset=i))
|
||||
snap = mgr.get_snapshot("AAPL")
|
||||
assert snap is not None
|
||||
assert snap.sma_50 is None
|
||||
|
||||
# Add 20 more
|
||||
for i in range(30, 50):
|
||||
mgr.add_bar("AAPL", _make_bar(float(i + 1), ts_offset=i))
|
||||
snap = mgr.get_snapshot("AAPL")
|
||||
assert snap is not None
|
||||
assert snap.sma_50 is not None
|
||||
expected = sum(range(1, 51)) / 50
|
||||
assert snap.sma_50 == pytest.approx(expected)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# MarketDataManager — RSI
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestMarketDataManagerRSI:
|
||||
"""Tests for RSI computation inside MarketDataManager."""
|
||||
|
||||
def test_rsi_all_gains(self):
|
||||
"""RSI should be 100 when all price changes are positive."""
|
||||
mgr = MarketDataManager()
|
||||
for i in range(20):
|
||||
mgr.add_bar("AAPL", _make_bar(100.0 + i, ts_offset=i))
|
||||
|
||||
snap = mgr.get_snapshot("AAPL")
|
||||
assert snap is not None
|
||||
assert snap.rsi == pytest.approx(100.0)
|
||||
|
||||
def test_rsi_all_losses(self):
|
||||
"""RSI should be 0 when all price changes are negative."""
|
||||
mgr = MarketDataManager()
|
||||
for i in range(20):
|
||||
mgr.add_bar("AAPL", _make_bar(200.0 - i, ts_offset=i))
|
||||
|
||||
snap = mgr.get_snapshot("AAPL")
|
||||
assert snap is not None
|
||||
assert snap.rsi == pytest.approx(0.0)
|
||||
|
||||
def test_rsi_mixed(self):
|
||||
"""RSI should be between 0 and 100 with mixed gains and losses."""
|
||||
mgr = MarketDataManager()
|
||||
prices = [44, 44.34, 44.09, 43.61, 44.33, 44.83, 45.10, 45.42,
|
||||
45.84, 46.08, 45.89, 46.03, 45.61, 46.28, 46.28, 46.00]
|
||||
for i, p in enumerate(prices):
|
||||
mgr.add_bar("AAPL", _make_bar(p, ts_offset=i))
|
||||
|
||||
snap = mgr.get_snapshot("AAPL")
|
||||
assert snap is not None
|
||||
assert snap.rsi is not None
|
||||
assert 0 < snap.rsi < 100
|
||||
|
||||
def test_rsi_returns_none_insufficient_data(self):
|
||||
"""RSI should be None when fewer than 15 bars exist (need 14+1)."""
|
||||
mgr = MarketDataManager()
|
||||
for i in range(10):
|
||||
mgr.add_bar("AAPL", _make_bar(100.0, ts_offset=i))
|
||||
|
||||
snap = mgr.get_snapshot("AAPL")
|
||||
assert snap is not None
|
||||
assert snap.rsi is None
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# MarketDataManager — snapshot
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestMarketDataManagerSnapshot:
|
||||
"""Tests for get_snapshot behaviour."""
|
||||
|
||||
def test_snapshot_returns_none_for_unknown_ticker(self):
|
||||
mgr = MarketDataManager()
|
||||
assert mgr.get_snapshot("UNKNOWN") is None
|
||||
|
||||
def test_snapshot_uses_latest_bar_for_price(self):
|
||||
mgr = MarketDataManager()
|
||||
mgr.add_bar("AAPL", _make_bar(100.0, ts_offset=0))
|
||||
mgr.add_bar("AAPL", _make_bar(105.0, ts_offset=1))
|
||||
snap = mgr.get_snapshot("AAPL")
|
||||
assert snap is not None
|
||||
assert snap.current_price == 105.0
|
||||
|
||||
def test_snapshot_contains_bars(self):
|
||||
mgr = MarketDataManager()
|
||||
for i in range(5):
|
||||
mgr.add_bar("AAPL", _make_bar(100.0 + i, ts_offset=i))
|
||||
snap = mgr.get_snapshot("AAPL")
|
||||
assert snap is not None
|
||||
assert len(snap.bars) == 5
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# WeightedEnsemble — combines signals
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestEnsembleCombinesSignals:
|
||||
"""Test that the ensemble correctly combines strategy signals."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_combines_two_long_signals(self):
|
||||
"""Two LONG signals should produce a combined LONG signal."""
|
||||
s1 = _StubStrategy("alpha", _make_signal(SignalDirection.LONG, 0.8))
|
||||
s2 = _StubStrategy("beta", _make_signal(SignalDirection.LONG, 0.6))
|
||||
|
||||
ensemble = WeightedEnsemble([s1, s2], threshold=0.0)
|
||||
market = MarketSnapshot(
|
||||
ticker="AAPL", current_price=150.0,
|
||||
open=149.0, high=151.0, low=148.0, close=150.0, volume=1000,
|
||||
)
|
||||
weights = {"alpha": 0.5, "beta": 0.5}
|
||||
|
||||
signal = await ensemble.evaluate("AAPL", market, None, weights)
|
||||
|
||||
assert signal is not None
|
||||
assert signal.direction == SignalDirection.LONG
|
||||
# Weighted average = (0.8*0.5 + 0.6*0.5) / (0.5+0.5) = 0.7
|
||||
assert signal.strength == pytest.approx(0.7, abs=0.01)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_opposing_signals_net_direction(self):
|
||||
"""When strategies disagree, direction follows the stronger weighted side."""
|
||||
s1 = _StubStrategy("alpha", _make_signal(SignalDirection.LONG, 0.9))
|
||||
s2 = _StubStrategy("beta", _make_signal(SignalDirection.SHORT, 0.3))
|
||||
|
||||
ensemble = WeightedEnsemble([s1, s2], threshold=0.0)
|
||||
market = MarketSnapshot(
|
||||
ticker="AAPL", current_price=150.0,
|
||||
open=149.0, high=151.0, low=148.0, close=150.0, volume=1000,
|
||||
)
|
||||
weights = {"alpha": 0.5, "beta": 0.5}
|
||||
|
||||
signal = await ensemble.evaluate("AAPL", market, None, weights)
|
||||
|
||||
assert signal is not None
|
||||
# Net direction should be LONG since alpha is stronger
|
||||
assert signal.direction == SignalDirection.LONG
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# WeightedEnsemble — threshold filtering
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestEnsembleThresholdFiltering:
|
||||
"""Test that weak combined signals are filtered out by the threshold."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_below_threshold_returns_none(self):
|
||||
"""Combined strength below threshold should yield None."""
|
||||
# Two opposing signals of similar strength will nearly cancel out
|
||||
s1 = _StubStrategy("alpha", _make_signal(SignalDirection.LONG, 0.5))
|
||||
s2 = _StubStrategy("beta", _make_signal(SignalDirection.SHORT, 0.45))
|
||||
|
||||
ensemble = WeightedEnsemble([s1, s2], threshold=0.5)
|
||||
market = MarketSnapshot(
|
||||
ticker="AAPL", current_price=150.0,
|
||||
open=149.0, high=151.0, low=148.0, close=150.0, volume=1000,
|
||||
)
|
||||
weights = {"alpha": 0.5, "beta": 0.5}
|
||||
|
||||
signal = await ensemble.evaluate("AAPL", market, None, weights)
|
||||
assert signal is None
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_above_threshold_returns_signal(self):
|
||||
"""Strong combined signal above threshold should yield a signal."""
|
||||
s1 = _StubStrategy("alpha", _make_signal(SignalDirection.LONG, 0.9))
|
||||
|
||||
ensemble = WeightedEnsemble([s1], threshold=0.3)
|
||||
market = MarketSnapshot(
|
||||
ticker="AAPL", current_price=150.0,
|
||||
open=149.0, high=151.0, low=148.0, close=150.0, volume=1000,
|
||||
)
|
||||
weights = {"alpha": 1.0}
|
||||
|
||||
signal = await ensemble.evaluate("AAPL", market, None, weights)
|
||||
assert signal is not None
|
||||
assert signal.strength >= 0.3
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# WeightedEnsemble — no signals returns None
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestEnsembleNoSignals:
|
||||
"""Test that the ensemble returns None when no strategy fires."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_all_strategies_return_none(self):
|
||||
s1 = _StubStrategy("alpha", None)
|
||||
s2 = _StubStrategy("beta", None)
|
||||
|
||||
ensemble = WeightedEnsemble([s1, s2], threshold=0.3)
|
||||
market = MarketSnapshot(
|
||||
ticker="AAPL", current_price=150.0,
|
||||
open=149.0, high=151.0, low=148.0, close=150.0, volume=1000,
|
||||
)
|
||||
weights = {"alpha": 0.5, "beta": 0.5}
|
||||
|
||||
signal = await ensemble.evaluate("AAPL", market, None, weights)
|
||||
assert signal is None
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# WeightedEnsemble — tags strategy sources
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestEnsembleTagsStrategySources:
|
||||
"""Verify that the output signal records which strategies contributed."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_strategy_sources_contains_all_contributors(self):
|
||||
s1 = _StubStrategy("momentum", _make_signal(SignalDirection.LONG, 0.7, ["momentum"]))
|
||||
s2 = _StubStrategy("news_driven", _make_signal(SignalDirection.LONG, 0.6, ["news_driven"]))
|
||||
s3 = _StubStrategy("mean_reversion", None) # does not contribute
|
||||
|
||||
ensemble = WeightedEnsemble([s1, s2, s3], threshold=0.0)
|
||||
market = MarketSnapshot(
|
||||
ticker="AAPL", current_price=150.0,
|
||||
open=149.0, high=151.0, low=148.0, close=150.0, volume=1000,
|
||||
)
|
||||
weights = {"momentum": 0.5, "news_driven": 0.3, "mean_reversion": 0.2}
|
||||
|
||||
signal = await ensemble.evaluate("AAPL", market, None, weights)
|
||||
assert signal is not None
|
||||
# Should have exactly 2 sources
|
||||
assert len(signal.strategy_sources) == 2
|
||||
source_names = [s.split(":")[0] for s in signal.strategy_sources]
|
||||
assert "momentum" in source_names
|
||||
assert "news_driven" in source_names
|
||||
# mean_reversion should NOT be present
|
||||
assert "mean_reversion" not in source_names
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_strategy_sources_contain_direction_and_strength(self):
|
||||
"""Each source tag should be formatted as name:DIRECTION:strength."""
|
||||
s1 = _StubStrategy("alpha", _make_signal(SignalDirection.LONG, 0.75))
|
||||
ensemble = WeightedEnsemble([s1], threshold=0.0)
|
||||
market = MarketSnapshot(
|
||||
ticker="AAPL", current_price=150.0,
|
||||
open=149.0, high=151.0, low=148.0, close=150.0, volume=1000,
|
||||
)
|
||||
weights = {"alpha": 1.0}
|
||||
|
||||
signal = await ensemble.evaluate("AAPL", market, None, weights)
|
||||
assert signal is not None
|
||||
assert len(signal.strategy_sources) == 1
|
||||
parts = signal.strategy_sources[0].split(":")
|
||||
assert parts[0] == "alpha"
|
||||
assert parts[1] == "LONG"
|
||||
assert float(parts[2]) == pytest.approx(0.75, abs=0.01)
|
||||
403
tests/services/test_trade_executor.py
Normal file
403
tests/services/test_trade_executor.py
Normal file
|
|
@ -0,0 +1,403 @@
|
|||
"""Tests for the Trade Executor service.
|
||||
|
||||
Covers RiskManager (market hours, positions, exposure, cooldown,
|
||||
position sizing) and the end-to-end executor flow with a mocked broker.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from datetime import datetime, timedelta, timezone
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
from zoneinfo import ZoneInfo
|
||||
|
||||
import pytest
|
||||
|
||||
from services.trade_executor.config import TradeExecutorConfig
|
||||
from services.trade_executor.main import process_signal
|
||||
from services.trade_executor.risk_manager import RiskManager
|
||||
from shared.schemas.trading import (
|
||||
AccountInfo,
|
||||
OrderResult,
|
||||
OrderSide,
|
||||
OrderStatus,
|
||||
PositionInfo,
|
||||
SignalDirection,
|
||||
TradeSignal,
|
||||
)
|
||||
|
||||
_ET = ZoneInfo("America/New_York")
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Helpers
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def _make_config(**overrides) -> TradeExecutorConfig:
|
||||
defaults = dict(
|
||||
max_position_pct=0.05,
|
||||
max_total_exposure_pct=0.80,
|
||||
max_positions=20,
|
||||
default_stop_loss_pct=0.03,
|
||||
cooldown_minutes=30,
|
||||
alpaca_api_key="test",
|
||||
alpaca_secret_key="test",
|
||||
paper_trading=True,
|
||||
)
|
||||
defaults.update(overrides)
|
||||
return TradeExecutorConfig(**defaults)
|
||||
|
||||
|
||||
def _make_signal(
|
||||
ticker: str = "AAPL",
|
||||
direction: SignalDirection = SignalDirection.LONG,
|
||||
strength: float = 0.8,
|
||||
current_price: float = 150.0,
|
||||
) -> TradeSignal:
|
||||
return TradeSignal(
|
||||
ticker=ticker,
|
||||
direction=direction,
|
||||
strength=strength,
|
||||
strategy_sources=["test"],
|
||||
sentiment_context={"current_price": current_price},
|
||||
timestamp=datetime.now(timezone.utc),
|
||||
)
|
||||
|
||||
|
||||
def _make_account(equity: float = 100_000.0) -> AccountInfo:
|
||||
return AccountInfo(
|
||||
equity=equity,
|
||||
cash=equity,
|
||||
buying_power=equity * 2,
|
||||
portfolio_value=equity,
|
||||
)
|
||||
|
||||
|
||||
def _make_position(ticker: str = "AAPL", market_value: float = 5000.0) -> PositionInfo:
|
||||
return PositionInfo(
|
||||
ticker=ticker,
|
||||
qty=10.0,
|
||||
avg_entry=150.0,
|
||||
current_price=150.0,
|
||||
unrealized_pnl=0.0,
|
||||
market_value=market_value,
|
||||
)
|
||||
|
||||
|
||||
def _mock_broker(positions: list[PositionInfo] | None = None, account: AccountInfo | None = None):
|
||||
"""Create an AsyncMock broker with configurable positions and account."""
|
||||
broker = AsyncMock()
|
||||
broker.get_positions = AsyncMock(return_value=positions or [])
|
||||
broker.get_account = AsyncMock(return_value=account or _make_account())
|
||||
broker.submit_order = AsyncMock(
|
||||
return_value=OrderResult(
|
||||
order_id="ord-123",
|
||||
ticker="AAPL",
|
||||
side=OrderSide.BUY,
|
||||
qty=10.0,
|
||||
filled_price=150.0,
|
||||
status=OrderStatus.FILLED,
|
||||
timestamp=datetime.now(timezone.utc),
|
||||
)
|
||||
)
|
||||
return broker
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# RiskManager — risk check passes
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestRiskCheckPasses:
|
||||
"""All conditions met -> risk check passes."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_all_conditions_met(self):
|
||||
config = _make_config()
|
||||
broker = _mock_broker(positions=[], account=_make_account(100_000))
|
||||
rm = RiskManager(config, broker)
|
||||
signal = _make_signal()
|
||||
|
||||
# Patch _is_market_hours to return True
|
||||
with patch.object(RiskManager, "_is_market_hours", return_value=True):
|
||||
approved, reason = await rm.check_risk(signal)
|
||||
|
||||
assert approved is True
|
||||
assert reason == "approved"
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# RiskManager — max positions exceeded
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestRiskCheckMaxPositions:
|
||||
"""Risk check fails when max_positions is already reached."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_max_positions_exceeded(self):
|
||||
config = _make_config(max_positions=2)
|
||||
# Already have 2 positions
|
||||
positions = [_make_position("AAPL"), _make_position("MSFT")]
|
||||
broker = _mock_broker(positions=positions, account=_make_account())
|
||||
rm = RiskManager(config, broker)
|
||||
signal = _make_signal(ticker="GOOG")
|
||||
|
||||
with patch.object(RiskManager, "_is_market_hours", return_value=True):
|
||||
approved, reason = await rm.check_risk(signal)
|
||||
|
||||
assert approved is False
|
||||
assert "max_positions" in reason
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# RiskManager — max exposure exceeded
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestRiskCheckMaxExposure:
|
||||
"""Risk check fails when total exposure exceeds the limit."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_max_exposure_exceeded(self):
|
||||
config = _make_config(max_total_exposure_pct=0.50)
|
||||
account = _make_account(equity=100_000)
|
||||
# Single position worth $60k = 60% of equity, limit is 50%
|
||||
positions = [_make_position("AAPL", market_value=60_000)]
|
||||
broker = _mock_broker(positions=positions, account=account)
|
||||
rm = RiskManager(config, broker)
|
||||
signal = _make_signal(ticker="MSFT")
|
||||
|
||||
with patch.object(RiskManager, "_is_market_hours", return_value=True):
|
||||
approved, reason = await rm.check_risk(signal)
|
||||
|
||||
assert approved is False
|
||||
assert "max_exposure" in reason
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# RiskManager — cooldown active
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestRiskCheckCooldown:
|
||||
"""Risk check fails when a ticker is in cooldown."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_cooldown_active(self):
|
||||
config = _make_config(cooldown_minutes=30)
|
||||
broker = _mock_broker()
|
||||
rm = RiskManager(config, broker)
|
||||
|
||||
# Record an exit 10 minutes ago
|
||||
now_et = datetime.now(tz=_ET)
|
||||
rm.record_exit("AAPL", now_et - timedelta(minutes=10))
|
||||
|
||||
signal = _make_signal(ticker="AAPL")
|
||||
with patch.object(RiskManager, "_is_market_hours", return_value=True):
|
||||
approved, reason = await rm.check_risk(signal)
|
||||
|
||||
assert approved is False
|
||||
assert "cooldown" in reason
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_cooldown_expired(self):
|
||||
"""After cooldown period expires the trade should be approved."""
|
||||
config = _make_config(cooldown_minutes=30)
|
||||
broker = _mock_broker()
|
||||
rm = RiskManager(config, broker)
|
||||
|
||||
# Record an exit 45 minutes ago
|
||||
now_et = datetime.now(tz=_ET)
|
||||
rm.record_exit("AAPL", now_et - timedelta(minutes=45))
|
||||
|
||||
signal = _make_signal(ticker="AAPL")
|
||||
with patch.object(RiskManager, "_is_market_hours", return_value=True):
|
||||
approved, reason = await rm.check_risk(signal)
|
||||
|
||||
assert approved is True
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# RiskManager — outside market hours
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestRiskCheckMarketHours:
|
||||
"""Risk check fails outside regular market hours."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_outside_market_hours(self):
|
||||
config = _make_config()
|
||||
broker = _mock_broker()
|
||||
rm = RiskManager(config, broker)
|
||||
signal = _make_signal()
|
||||
|
||||
# Force market hours check to fail (no patching — use the real check
|
||||
# with a time that is definitely outside market hours)
|
||||
with patch.object(RiskManager, "_is_market_hours", return_value=False):
|
||||
approved, reason = await rm.check_risk(signal)
|
||||
|
||||
assert approved is False
|
||||
assert "market_hours" in reason
|
||||
|
||||
def test_market_hours_weekday(self):
|
||||
"""A weekday at 10:00 AM ET should be within market hours."""
|
||||
# Tuesday 10:00 AM ET
|
||||
t = datetime(2026, 2, 24, 10, 0, 0, tzinfo=_ET)
|
||||
assert RiskManager._is_market_hours(t) is True
|
||||
|
||||
def test_market_hours_weekend(self):
|
||||
"""Saturday should always be outside market hours."""
|
||||
t = datetime(2026, 2, 21, 10, 0, 0, tzinfo=_ET) # Saturday
|
||||
assert RiskManager._is_market_hours(t) is False
|
||||
|
||||
def test_market_hours_before_open(self):
|
||||
"""8:00 AM ET on a weekday is before market open."""
|
||||
t = datetime(2026, 2, 24, 8, 0, 0, tzinfo=_ET) # Tuesday 8 AM
|
||||
assert RiskManager._is_market_hours(t) is False
|
||||
|
||||
def test_market_hours_after_close(self):
|
||||
"""5:00 PM ET on a weekday is after market close."""
|
||||
t = datetime(2026, 2, 24, 17, 0, 0, tzinfo=_ET) # Tuesday 5 PM
|
||||
assert RiskManager._is_market_hours(t) is False
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Position sizing — scales by strength
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestPositionSizingScalesByStrength:
|
||||
"""Position size should scale proportionally with signal strength."""
|
||||
|
||||
def test_full_strength(self):
|
||||
config = _make_config(max_position_pct=0.05)
|
||||
broker = _mock_broker()
|
||||
rm = RiskManager(config, broker)
|
||||
|
||||
signal = _make_signal(strength=1.0, current_price=100.0)
|
||||
account = _make_account(equity=100_000)
|
||||
|
||||
qty = rm.calculate_position_size(signal, account)
|
||||
# position_value = 100k * 0.05 * 1.0 = 5000 / 100 = 50 shares
|
||||
assert qty == 50
|
||||
|
||||
def test_half_strength(self):
|
||||
config = _make_config(max_position_pct=0.05)
|
||||
broker = _mock_broker()
|
||||
rm = RiskManager(config, broker)
|
||||
|
||||
signal = _make_signal(strength=0.5, current_price=100.0)
|
||||
account = _make_account(equity=100_000)
|
||||
|
||||
qty = rm.calculate_position_size(signal, account)
|
||||
# position_value = 100k * 0.05 * 0.5 = 2500 / 100 = 25 shares
|
||||
assert qty == 25
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Position sizing — respects max_position_pct
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestPositionSizingRespectsMaxPct:
|
||||
"""Position size should respect the max_position_pct cap."""
|
||||
|
||||
def test_respects_max_pct(self):
|
||||
config = _make_config(max_position_pct=0.02)
|
||||
broker = _mock_broker()
|
||||
rm = RiskManager(config, broker)
|
||||
|
||||
signal = _make_signal(strength=1.0, current_price=50.0)
|
||||
account = _make_account(equity=100_000)
|
||||
|
||||
qty = rm.calculate_position_size(signal, account)
|
||||
# position_value = 100k * 0.02 * 1.0 = 2000 / 50 = 40 shares
|
||||
assert qty == 40
|
||||
|
||||
def test_zero_price_returns_zero(self):
|
||||
config = _make_config()
|
||||
broker = _mock_broker()
|
||||
rm = RiskManager(config, broker)
|
||||
|
||||
signal = _make_signal(strength=0.8, current_price=0.0)
|
||||
account = _make_account(equity=100_000)
|
||||
|
||||
qty = rm.calculate_position_size(signal, account)
|
||||
assert qty == 0
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Executor flow — approved signal
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestExecutorFlowApproved:
|
||||
"""End-to-end: approved signal -> order submitted -> trade published."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_approved_signal_flow(self):
|
||||
config = _make_config()
|
||||
broker = _mock_broker(positions=[], account=_make_account(100_000))
|
||||
publisher = AsyncMock()
|
||||
publisher.publish = AsyncMock(return_value=b"1-0")
|
||||
|
||||
counters = {
|
||||
"trades_executed": MagicMock(),
|
||||
"rejections": MagicMock(),
|
||||
"fill_latency": MagicMock(),
|
||||
}
|
||||
|
||||
signal = _make_signal(ticker="AAPL", strength=0.8, current_price=150.0)
|
||||
|
||||
# Patch risk check to approve
|
||||
with patch.object(RiskManager, "check_risk", return_value=(True, "approved")):
|
||||
await process_signal(signal, RiskManager(config, broker), broker, publisher, counters)
|
||||
|
||||
# Verify order was submitted
|
||||
broker.submit_order.assert_called_once()
|
||||
order_arg = broker.submit_order.call_args[0][0]
|
||||
assert order_arg.ticker == "AAPL"
|
||||
assert order_arg.side == OrderSide.BUY
|
||||
|
||||
# Verify trade was published
|
||||
publisher.publish.assert_called_once()
|
||||
counters["trades_executed"].add.assert_called_once_with(1)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Executor flow — rejected signal
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestExecutorFlowRejected:
|
||||
"""End-to-end: rejected signal -> no order, rejection logged."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_rejected_signal_flow(self):
|
||||
config = _make_config()
|
||||
broker = _mock_broker()
|
||||
publisher = AsyncMock()
|
||||
|
||||
counters = {
|
||||
"trades_executed": MagicMock(),
|
||||
"rejections": MagicMock(),
|
||||
"fill_latency": MagicMock(),
|
||||
}
|
||||
|
||||
signal = _make_signal(ticker="AAPL")
|
||||
|
||||
with patch.object(
|
||||
RiskManager, "check_risk", return_value=(False, "outside_market_hours")
|
||||
):
|
||||
await process_signal(signal, RiskManager(config, broker), broker, publisher, counters)
|
||||
|
||||
# No order should have been submitted
|
||||
broker.submit_order.assert_not_called()
|
||||
|
||||
# No trade should have been published
|
||||
publisher.publish.assert_not_called()
|
||||
|
||||
# Rejection counter should have been incremented
|
||||
counters["rejections"].add.assert_called_once()
|
||||
Loading…
Add table
Add a link
Reference in a new issue