fix: resolve all remaining TODOs, add dev mode auth bypass
Some checks failed
ci/woodpecker/push/woodpecker Pipeline failed
Some checks failed
ci/woodpecker/push/woodpecker Pipeline failed
- Learning engine: expand default weights from 3 to all 9 strategies - Learning engine: resolve placeholder strategy_id with DB lookup - Learning engine: pass strategy_sources from trade execution - Trade executor: respect trading:paused Redis flag in RiskManager - Portfolio sync: compute actual daily P&L from day-start snapshot - Portfolio API: cumulative P&L from first snapshot, read pause flag - Portfolio metrics: compute max drawdown and avg hold duration - Add strategy_sources field to TradeExecution schema - Add dev_mode config (TRADING_DEV_MODE) to bypass auth for local dev - Dashboard: VITE_DEV_MODE bypasses ProtectedRoute and 401 redirects - Vite proxy target configurable via VITE_API_TARGET - Add top-level README.md and remaining-work-plan.md - Update CLAUDE.md with correct counts and remove stale TODOs - 404 tests passing Made-with: Cursor
This commit is contained in:
parent
4094e4b10f
commit
a3cdd0f1a5
16 changed files with 511 additions and 45 deletions
|
|
@ -5,8 +5,9 @@ Automated stock trading bot combining news sentiment analysis with technical str
|
|||
|
||||
## Architecture
|
||||
- **7 microservices**: news-fetcher, sentiment-analyzer, signal-generator, trade-executor, learning-engine, api-gateway, dashboard
|
||||
- **Shared libraries** in `shared/`: config, redis_streams, telemetry, db, models, schemas, broker abstraction, strategies
|
||||
- **Shared libraries** in `shared/`: config, redis_streams, telemetry, db, models, schemas, broker abstraction, 9 strategies, fundamentals providers
|
||||
- **Infra**: PostgreSQL+TimescaleDB, Redis Streams, Ollama (local LLM), Docker Compose
|
||||
- **CI/CD**: Woodpecker pipeline → Docker build → Kubernetes deploy
|
||||
- **Brokerage**: Alpaca (paper trading) via abstraction layer in `shared/broker/`
|
||||
|
||||
## Key Design Decisions
|
||||
|
|
@ -32,10 +33,11 @@ trading-bot/
|
|||
│ ├── redis_streams.py # StreamPublisher + StreamConsumer
|
||||
│ ├── telemetry.py # OpenTelemetry setup
|
||||
│ ├── db.py # Async engine + sessionmaker
|
||||
│ ├── models/ # SQLAlchemy models (14 tables)
|
||||
│ ├── models/ # SQLAlchemy models (16 tables)
|
||||
│ ├── schemas/ # Pydantic v2 schemas (message types)
|
||||
│ ├── broker/ # BaseBroker ABC + AlpacaBroker
|
||||
│ └── strategies/ # BaseStrategy + 3 implementations
|
||||
│ ├── fundamentals/ # Alpha Vantage, FMP, Yahoo providers + cache
|
||||
│ └── strategies/ # BaseStrategy + 9 implementations
|
||||
├── services/
|
||||
│ ├── news_fetcher/ # RSS + Reddit → news:raw
|
||||
│ ├── sentiment_analyzer/ # FinBERT + Ollama → news:scored
|
||||
|
|
@ -90,10 +92,8 @@ trading-bot/
|
|||
- Code review was started but not completed — should be done before production use
|
||||
- JWT test warnings about short HMAC keys (test-only, production keys should be 32+ bytes)
|
||||
- Integration tests mock FinBERT and Alpaca — need real integration tests with live services
|
||||
- No CI/CD pipeline yet
|
||||
- Twitter/X source not implemented in news fetcher (only RSS + Reddit)
|
||||
- Dashboard has no unit tests (only build verification)
|
||||
- Strategy conflict during Sprint 3 merge — both signal-generator and strategies agents created `shared/strategies/`; resolved by keeping the dedicated strategies agent's versions
|
||||
|
||||
## How This Was Built
|
||||
- Built in 6 sprints using parallel subagent worktrees
|
||||
|
|
|
|||
81
README.md
Normal file
81
README.md
Normal file
|
|
@ -0,0 +1,81 @@
|
|||
# Trading Bot
|
||||
|
||||
Automated stock trading bot combining news sentiment analysis with technical strategies. Built as event-driven Python microservices communicating via Redis Streams, with a React/TypeScript dashboard and Alpaca paper trading.
|
||||
|
||||
## Architecture
|
||||
|
||||
```
|
||||
RSS/Reddit ─→ news_fetcher ─→ [news:raw] ─→ sentiment_analyzer ─→ [news:scored] ┐
|
||||
│
|
||||
Alpaca OHLCV ─→ market_data ─→ [market:bars] ────────────────────────────────────┤
|
||||
│
|
||||
signal_generator ←──────────────────┘
|
||||
│
|
||||
[signals:generated]
|
||||
│
|
||||
trade_executor ─→ [trades:executed] ─→ learning_engine
|
||||
│ │
|
||||
Alpaca API Redis (weights)
|
||||
```
|
||||
|
||||
**Services**: news-fetcher, sentiment-analyzer, signal-generator, trade-executor, learning-engine, market-data, api-gateway, dashboard
|
||||
|
||||
**9 Trading Strategies**: Momentum, Mean Reversion, News-Driven, Value, MACD Crossover, Bollinger Breakout, VWAP, Liquidity, MA Stack — combined via weighted ensemble with multi-armed bandit weight adjustment.
|
||||
|
||||
## Tech Stack
|
||||
|
||||
- **Backend**: Python 3.12, FastAPI, SQLAlchemy 2.0 (async), Pydantic v2, alpaca-py
|
||||
- **Frontend**: React 19, TypeScript, Vite, Tailwind CSS, TanStack Query, TradingView lightweight-charts
|
||||
- **ML**: transformers (FinBERT), Ollama (local LLM fallback)
|
||||
- **Database**: PostgreSQL 16 + TimescaleDB, Alembic migrations (16 tables)
|
||||
- **Messaging**: Redis Streams + pub/sub
|
||||
- **Auth**: WebAuthn/Passkeys + JWT sessions
|
||||
- **Observability**: OpenTelemetry + Prometheus metrics
|
||||
- **CI/CD**: Woodpecker → Docker → Kubernetes
|
||||
|
||||
## Quick Start
|
||||
|
||||
```bash
|
||||
# Full stack with Docker Compose
|
||||
docker compose up -d
|
||||
|
||||
# Seed default strategies
|
||||
docker compose exec api-gateway python -m scripts.seed_strategies
|
||||
```
|
||||
|
||||
## Development
|
||||
|
||||
```bash
|
||||
# Create virtual environment
|
||||
python3 -m venv .venv && source .venv/bin/activate
|
||||
|
||||
# Install all dependencies
|
||||
pip install -e ".[api,news,sentiment,trading,backtester,dev]"
|
||||
|
||||
# Run unit tests (404 tests)
|
||||
python -m pytest tests/ -v -m "not integration"
|
||||
|
||||
# Run integration tests (requires Redis + PostgreSQL)
|
||||
python -m pytest tests/ -v -m integration
|
||||
|
||||
# Dashboard development
|
||||
cd dashboard && npm install && npm run dev
|
||||
```
|
||||
|
||||
## Project Structure
|
||||
|
||||
```
|
||||
trading-bot/
|
||||
├── shared/ # Shared libraries (config, DB, Redis, models, schemas, broker, strategies, fundamentals)
|
||||
├── services/ # 7 microservices (news_fetcher, sentiment_analyzer, signal_generator,
|
||||
│ # trade_executor, learning_engine, market_data, api_gateway)
|
||||
├── backtester/ # Historical replay engine with simulated broker
|
||||
├── dashboard/ # React 19 / TypeScript / Vite frontend
|
||||
├── docker/ # Dockerfiles and nginx configs
|
||||
├── scripts/ # Seed scripts and smoke tests
|
||||
├── tests/ # 404 unit + 9 integration tests
|
||||
├── alembic/ # Database migrations
|
||||
├── docker-compose.yml # Full stack orchestration
|
||||
├── .woodpecker.yml # CI/CD pipeline
|
||||
└── pyproject.toml # Python monorepo with optional dependency groups
|
||||
```
|
||||
|
|
@ -1,6 +1,7 @@
|
|||
import axios from 'axios';
|
||||
|
||||
const API_BASE_URL = import.meta.env.VITE_API_BASE_URL || '/api';
|
||||
const DEV_MODE = import.meta.env.VITE_DEV_MODE === 'true';
|
||||
|
||||
const client = axios.create({
|
||||
baseURL: API_BASE_URL,
|
||||
|
|
@ -44,7 +45,7 @@ client.interceptors.response.use(
|
|||
async (error) => {
|
||||
const originalRequest = error.config;
|
||||
|
||||
if (error.response?.status === 401 && !originalRequest._retry) {
|
||||
if (!DEV_MODE && error.response?.status === 401 && !originalRequest._retry) {
|
||||
if (isRefreshing) {
|
||||
return new Promise((resolve, reject) => {
|
||||
failedQueue.push({ resolve, reject });
|
||||
|
|
|
|||
|
|
@ -1,11 +1,17 @@
|
|||
import type { ReactNode } from 'react';
|
||||
import { Navigate } from 'react-router-dom';
|
||||
|
||||
const DEV_MODE = import.meta.env.VITE_DEV_MODE === 'true';
|
||||
|
||||
interface ProtectedRouteProps {
|
||||
children: ReactNode;
|
||||
}
|
||||
|
||||
export function ProtectedRoute({ children }: ProtectedRouteProps) {
|
||||
if (DEV_MODE) {
|
||||
return <>{children}</>;
|
||||
}
|
||||
|
||||
const token = localStorage.getItem('access_token');
|
||||
|
||||
if (!token) {
|
||||
|
|
|
|||
|
|
@ -2,16 +2,18 @@ import { defineConfig } from 'vite'
|
|||
import react from '@vitejs/plugin-react'
|
||||
import tailwindcss from '@tailwindcss/vite'
|
||||
|
||||
const API_TARGET = process.env.VITE_API_TARGET || 'http://localhost:8000'
|
||||
|
||||
export default defineConfig({
|
||||
plugins: [react(), tailwindcss()],
|
||||
server: {
|
||||
proxy: {
|
||||
'/api': {
|
||||
target: 'http://localhost:8000',
|
||||
target: API_TARGET,
|
||||
changeOrigin: true,
|
||||
},
|
||||
'/ws': {
|
||||
target: 'ws://localhost:8000',
|
||||
target: API_TARGET.replace('http', 'ws'),
|
||||
ws: true,
|
||||
},
|
||||
},
|
||||
|
|
|
|||
183
docs/plans/remaining-work-plan.md
Normal file
183
docs/plans/remaining-work-plan.md
Normal file
|
|
@ -0,0 +1,183 @@
|
|||
# Trading Bot — Remaining Work Plan
|
||||
|
||||
**Created**: 2026-02-25
|
||||
**Status**: Complete
|
||||
|
||||
## Overview
|
||||
|
||||
This plan addresses all remaining TODOs, placeholders, and incomplete features in the
|
||||
trading bot. Tasks are ordered by dependency — earlier tasks unblock later ones.
|
||||
|
||||
---
|
||||
|
||||
## Task 1: Learning engine — expand default weights to all 9 strategies
|
||||
|
||||
**Status**: [x] Complete
|
||||
**Files**: `services/learning_engine/main.py`
|
||||
|
||||
**Problem**: `_load_strategy_weights()` defaults to 3 strategies (momentum, mean_reversion,
|
||||
news_driven) but the signal generator uses all 9. When Redis has no cached weights, the
|
||||
learning engine operates on a stale subset.
|
||||
|
||||
**Fix**: Change the default fallback to all 9 strategies with equal weights (~0.111 each),
|
||||
matching `scripts/seed_strategies.py`.
|
||||
|
||||
---
|
||||
|
||||
## Task 2: Portfolio sync — compute actual daily P&L
|
||||
|
||||
**Status**: [x] Complete
|
||||
**Files**: `services/api_gateway/tasks/portfolio_sync.py`
|
||||
|
||||
**Problem**: `daily_pnl` is hardcoded to `0.0` in `_sync_once()` (line 60).
|
||||
|
||||
**Fix**: Before inserting the new snapshot, query the most recent prior snapshot for
|
||||
today. Compute `daily_pnl = current_total_value - day_start_total_value`. If no prior
|
||||
snapshot exists for today, the first snapshot's daily_pnl is 0.0.
|
||||
|
||||
---
|
||||
|
||||
## Task 3: Portfolio API — cumulative P&L from first snapshot
|
||||
|
||||
**Status**: [x] Complete
|
||||
**Files**: `services/api_gateway/routes/portfolio.py`
|
||||
|
||||
**Problem**: `total_pnl` on line 90 just returns `latest.daily_pnl` instead of the
|
||||
cumulative P&L since inception.
|
||||
|
||||
**Fix**: Query the earliest `PortfolioSnapshot` and compute:
|
||||
- `total_pnl = latest.total_value - earliest.total_value`
|
||||
- `total_pnl_pct = total_pnl / earliest.total_value * 100`
|
||||
|
||||
---
|
||||
|
||||
## Task 4: Portfolio API — read trading pause flag from Redis
|
||||
|
||||
**Status**: [x] Complete
|
||||
**Files**: `services/api_gateway/routes/portfolio.py`
|
||||
|
||||
**Problem**: `trading_active` is hardcoded to `True` (line 92).
|
||||
|
||||
**Fix**: Read the `trading:paused` key from Redis (same key used by
|
||||
`services/api_gateway/routes/controls.py`). Return `trading_active = not paused`.
|
||||
|
||||
---
|
||||
|
||||
## Task 5: Portfolio metrics — compute max drawdown from snapshots
|
||||
|
||||
**Status**: [x] Complete
|
||||
**Files**: `services/api_gateway/routes/portfolio.py`
|
||||
|
||||
**Problem**: `max_drawdown` is hardcoded to `0.0` (line 170).
|
||||
|
||||
**Fix**: Query all `PortfolioSnapshot` rows, compute the running peak and maximum
|
||||
percentage drop from peak. Return as a positive decimal (e.g. 0.12 = 12% drawdown).
|
||||
|
||||
---
|
||||
|
||||
## Task 6: Portfolio metrics — compute avg hold duration from trade outcomes
|
||||
|
||||
**Status**: [x] Complete
|
||||
**Files**: `services/api_gateway/routes/portfolio.py`
|
||||
|
||||
**Problem**: `avg_hold_duration` is hardcoded to `"0h"` (line 172).
|
||||
|
||||
**Fix**: Query `TradeOutcome` rows, average the `hold_duration_seconds` column,
|
||||
and format as a human-readable string (e.g. "4h 30m", "2d 6h").
|
||||
|
||||
---
|
||||
|
||||
## Task 7: Trade executor — respect the trading pause flag
|
||||
|
||||
**Status**: [x] Complete
|
||||
**Files**: `services/trade_executor/risk_manager.py`
|
||||
|
||||
**Problem**: The controls API sets/clears `trading:paused` in Redis, but the trade
|
||||
executor's `RiskManager.check_risk()` never checks it. Pausing has no effect.
|
||||
|
||||
**Fix**: Accept a Redis client in `RiskManager.__init__()`. Add a check at the top of
|
||||
`check_risk()` that reads `trading:paused` and rejects with `"trading_paused"` if set.
|
||||
Wire the Redis client through from `trade_executor/main.py`.
|
||||
|
||||
---
|
||||
|
||||
## Task 8: Learning engine — resolve placeholder strategy_id
|
||||
|
||||
**Status**: [x] Complete
|
||||
**Files**: `services/learning_engine/main.py`
|
||||
|
||||
**Problem**: `WeightAdjustment` on line 207 uses `strategy_id=UUID(int=0)` as a
|
||||
placeholder. This means weight history records can't be linked to the correct strategy.
|
||||
|
||||
**Fix**: On startup, load the `strategies` table into a `name -> UUID` lookup dict.
|
||||
Use this lookup when building `WeightAdjustment` records. Fall back to `UUID(int=0)`
|
||||
only if the strategy name isn't in the DB.
|
||||
|
||||
---
|
||||
|
||||
## Task 9: Pass strategy_sources from signal through to learning engine
|
||||
|
||||
**Status**: [x] Complete
|
||||
**Files**: `services/trade_executor/main.py`, `services/learning_engine/main.py`
|
||||
|
||||
**Problem**: `TradeExecution` doesn't carry `strategy_sources`. When the learning engine
|
||||
stores the opening trade (line 136), `strategy_sources` is always `[]`. This means credit
|
||||
attribution (`evaluator.attribute_credit()`) has no strategies to reward.
|
||||
|
||||
**Fix**:
|
||||
1. Add `strategy_sources: list[str] = []` to the `TradeExecution` schema.
|
||||
2. In the trade executor's `process_signal()`, copy `signal.strategy_sources` into the
|
||||
execution message.
|
||||
3. In the learning engine, read `trade.strategy_sources` (via the extended schema) and
|
||||
store them in the opening trade record.
|
||||
|
||||
---
|
||||
|
||||
## Task 10: Update CLAUDE.md and add top-level README
|
||||
|
||||
**Status**: [x] Complete
|
||||
**Files**: `.claude/CLAUDE.md`, `README.md`
|
||||
|
||||
**Fix**:
|
||||
- Remove "No CI/CD pipeline yet" from Known Issues (Woodpecker exists).
|
||||
- Update strategy count references (3 → 9) in Project Structure section.
|
||||
- Update model count (14 → 16 tables).
|
||||
- Create a concise top-level `README.md` with: project description, architecture diagram
|
||||
(text), quickstart (docker compose up), dev setup, and test commands.
|
||||
|
||||
---
|
||||
|
||||
## Execution Order
|
||||
|
||||
Tasks are grouped by dependency:
|
||||
|
||||
**Group A (independent, do first)**:
|
||||
- Task 1 (learning engine weights)
|
||||
- Task 7 (trade executor pause flag)
|
||||
- Task 9 (strategy_sources passthrough)
|
||||
|
||||
**Group B (depends on nothing, can parallel with A)**:
|
||||
- Task 2 (portfolio sync daily P&L)
|
||||
- Task 8 (learning engine strategy_id lookup)
|
||||
|
||||
**Group C (depends on Task 2)**:
|
||||
- Task 3 (cumulative P&L API)
|
||||
- Task 4 (trading_active flag API)
|
||||
- Task 5 (max drawdown API)
|
||||
- Task 6 (avg hold duration API)
|
||||
|
||||
**Group D (after all code changes)**:
|
||||
- Task 10 (documentation)
|
||||
|
||||
---
|
||||
|
||||
## Test Strategy
|
||||
|
||||
After each task, run the relevant unit tests:
|
||||
```bash
|
||||
python -m pytest tests/ -v -m "not integration" --tb=short
|
||||
```
|
||||
|
||||
Existing tests should continue to pass. New tests may be needed for:
|
||||
- Task 7: risk manager pause check
|
||||
- Task 9: strategy_sources in TradeExecution schema
|
||||
|
|
@ -24,6 +24,13 @@ def get_config() -> ApiGatewayConfig:
|
|||
return _config
|
||||
|
||||
|
||||
_DEV_USER = {
|
||||
"sub": "00000000-0000-0000-0000-000000000000",
|
||||
"username": "dev",
|
||||
"type": "access",
|
||||
}
|
||||
|
||||
|
||||
async def get_current_user(
|
||||
credentials: HTTPAuthorizationCredentials | None = Depends(security),
|
||||
config: ApiGatewayConfig = Depends(get_config),
|
||||
|
|
@ -33,7 +40,13 @@ async def get_current_user(
|
|||
Returns the decoded token payload (contains ``sub``, ``username``, etc.)
|
||||
on success. Raises a 401 ``HTTPException`` for missing, expired, or
|
||||
invalid tokens.
|
||||
|
||||
When ``config.dev_mode`` is ``True``, authentication is bypassed and a
|
||||
synthetic dev user is returned.
|
||||
"""
|
||||
if config.dev_mode:
|
||||
return _DEV_USER
|
||||
|
||||
if credentials is None:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
|
|
@ -57,7 +70,6 @@ async def get_current_user(
|
|||
headers={"WWW-Authenticate": "Bearer"},
|
||||
)
|
||||
|
||||
# Ensure it is an access token, not a refresh token
|
||||
if payload.get("type") != "access":
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
|
|
|
|||
|
|
@ -10,8 +10,11 @@ class ApiGatewayConfig(BaseConfig):
|
|||
prefixed with ``TRADING_``.
|
||||
"""
|
||||
|
||||
# Dev mode — bypasses authentication (set TRADING_DEV_MODE=true)
|
||||
dev_mode: bool = False
|
||||
|
||||
# JWT settings — TRADING_JWT_SECRET_KEY must be set in environment
|
||||
jwt_secret_key: str
|
||||
jwt_secret_key: str = "dev-secret-not-for-production"
|
||||
jwt_algorithm: str = "HS256"
|
||||
access_token_expire_minutes: int = 15
|
||||
refresh_token_expire_days: int = 7
|
||||
|
|
|
|||
|
|
@ -8,7 +8,7 @@ from enum import Enum
|
|||
from fastapi import APIRouter, Depends, Query, Request
|
||||
|
||||
from services.api_gateway.auth.middleware import get_current_user
|
||||
from sqlalchemy import select, desc
|
||||
from sqlalchemy import select, desc, asc, func
|
||||
|
||||
router = APIRouter(prefix="/api/portfolio", tags=["portfolio"])
|
||||
|
||||
|
|
@ -33,6 +33,41 @@ class HistoryPeriod(str, Enum):
|
|||
return None
|
||||
|
||||
|
||||
def _compute_max_drawdown(values: list[float]) -> float:
|
||||
"""Compute maximum percentage drawdown from a list of portfolio values.
|
||||
|
||||
Returns a positive decimal (e.g. 0.12 for a 12% drawdown).
|
||||
"""
|
||||
if len(values) < 2:
|
||||
return 0.0
|
||||
peak = values[0]
|
||||
max_dd = 0.0
|
||||
for v in values:
|
||||
if v > peak:
|
||||
peak = v
|
||||
if peak > 0:
|
||||
dd = (peak - v) / peak
|
||||
if dd > max_dd:
|
||||
max_dd = dd
|
||||
return max_dd
|
||||
|
||||
|
||||
def _format_duration(durations: list[timedelta]) -> str:
|
||||
"""Compute average of timedelta list and format as human-readable string."""
|
||||
if not durations:
|
||||
return "0h"
|
||||
total_seconds = sum(d.total_seconds() for d in durations) / len(durations)
|
||||
if total_seconds < 0:
|
||||
total_seconds = 0
|
||||
hours = int(total_seconds // 3600)
|
||||
minutes = int((total_seconds % 3600) // 60)
|
||||
if hours >= 24:
|
||||
days = hours // 24
|
||||
remaining_hours = hours % 24
|
||||
return f"{days}d {remaining_hours}h"
|
||||
return f"{hours}h {minutes}m"
|
||||
|
||||
|
||||
def _period_to_timedelta(period: HistoryPeriod) -> timedelta:
|
||||
"""Convert a period enum value to a timedelta."""
|
||||
mapping = {
|
||||
|
|
@ -47,6 +82,9 @@ def _period_to_timedelta(period: HistoryPeriod) -> timedelta:
|
|||
return mapping[period]
|
||||
|
||||
|
||||
TRADING_PAUSED_KEY = "trading:paused"
|
||||
|
||||
|
||||
@router.get("")
|
||||
async def get_portfolio(
|
||||
request: Request,
|
||||
|
|
@ -77,20 +115,38 @@ async def get_portfolio(
|
|||
"trading_active": True,
|
||||
}
|
||||
|
||||
# Compute percentage fields from snapshot data
|
||||
# Cumulative P&L: difference between latest and earliest snapshot
|
||||
earliest = (
|
||||
await session.execute(
|
||||
select(PortfolioSnapshot)
|
||||
.order_by(asc(PortfolioSnapshot.timestamp))
|
||||
.limit(1)
|
||||
)
|
||||
).scalar_one_or_none()
|
||||
|
||||
total_pnl = 0.0
|
||||
total_pnl_pct = 0.0
|
||||
if earliest is not None and earliest.total_value > 0:
|
||||
total_pnl = latest.total_value - earliest.total_value
|
||||
total_pnl_pct = total_pnl / earliest.total_value * 100.0
|
||||
|
||||
daily_pnl_pct = (latest.daily_pnl / (latest.total_value - latest.daily_pnl) * 100.0
|
||||
if latest.total_value != latest.daily_pnl else 0.0)
|
||||
|
||||
return {
|
||||
"total_value": latest.total_value,
|
||||
"cash": latest.cash,
|
||||
"buying_power": latest.cash,
|
||||
"daily_pnl": latest.daily_pnl,
|
||||
"daily_pnl_pct": round(daily_pnl_pct, 2),
|
||||
"total_pnl": latest.daily_pnl, # TODO: compute cumulative P&L from first snapshot
|
||||
"total_pnl_pct": round(daily_pnl_pct, 2),
|
||||
"trading_active": True, # TODO: read from Redis trading pause flag
|
||||
}
|
||||
# Read trading pause flag from Redis
|
||||
redis = request.app.state.redis
|
||||
paused = await redis.get(TRADING_PAUSED_KEY) if redis else None
|
||||
|
||||
return {
|
||||
"total_value": latest.total_value,
|
||||
"cash": latest.cash,
|
||||
"buying_power": latest.cash,
|
||||
"daily_pnl": latest.daily_pnl,
|
||||
"daily_pnl_pct": round(daily_pnl_pct, 2),
|
||||
"total_pnl": round(total_pnl, 2),
|
||||
"total_pnl_pct": round(total_pnl_pct, 2),
|
||||
"trading_active": not bool(paused),
|
||||
}
|
||||
|
||||
|
||||
@router.get("/positions")
|
||||
|
|
@ -132,18 +188,17 @@ async def get_portfolio_metrics(
|
|||
_user: dict = Depends(get_current_user),
|
||||
) -> dict:
|
||||
"""Aggregate portfolio performance metrics — ROI, Sharpe, win rate, drawdown."""
|
||||
from shared.models.learning import TradeOutcome
|
||||
from shared.models.trading import Trade, TradeStatus
|
||||
from shared.models.timeseries import StrategyMetric
|
||||
from shared.models.timeseries import PortfolioSnapshot, StrategyMetric
|
||||
|
||||
db = request.app.state.db_session_factory
|
||||
async with db() as session:
|
||||
# Total trades and win rate from trades table
|
||||
trades_result = await session.execute(
|
||||
select(Trade).where(Trade.status == TradeStatus.FILLED)
|
||||
)
|
||||
trades = trades_result.scalars().all()
|
||||
|
||||
# Latest strategy metrics for Sharpe
|
||||
metrics_result = await session.execute(
|
||||
select(StrategyMetric)
|
||||
.order_by(desc(StrategyMetric.timestamp))
|
||||
|
|
@ -156,20 +211,34 @@ async def get_portfolio_metrics(
|
|||
win_rate = winning / total_trades if total_trades > 0 else 0.0
|
||||
|
||||
total_pnl = sum(t.pnl for t in trades if t.pnl is not None)
|
||||
# Approximate ROI from P&L (rough — proper calculation needs initial capital)
|
||||
roi = total_pnl / 100_000.0 * 100.0 # assumes 100k starting capital
|
||||
roi = total_pnl / 100_000.0 * 100.0
|
||||
|
||||
# Average Sharpe from strategy metrics
|
||||
sharpe_values = [m.sharpe_ratio for m in strategy_metrics if m.sharpe_ratio is not None]
|
||||
avg_sharpe = sum(sharpe_values) / len(sharpe_values) if sharpe_values else 0.0
|
||||
|
||||
# Max drawdown from portfolio snapshots (peak-to-trough)
|
||||
snapshots_result = await session.execute(
|
||||
select(PortfolioSnapshot.total_value)
|
||||
.order_by(PortfolioSnapshot.timestamp)
|
||||
)
|
||||
values = [row[0] for row in snapshots_result.all()]
|
||||
max_drawdown = _compute_max_drawdown(values)
|
||||
|
||||
# Average hold duration from trade outcomes
|
||||
outcomes_result = await session.execute(
|
||||
select(TradeOutcome.hold_duration)
|
||||
.where(TradeOutcome.hold_duration.isnot(None))
|
||||
)
|
||||
durations = [row[0] for row in outcomes_result.all()]
|
||||
avg_hold = _format_duration(durations)
|
||||
|
||||
return {
|
||||
"roi": round(roi, 4),
|
||||
"sharpe": round(avg_sharpe, 2),
|
||||
"win_rate": round(win_rate, 4),
|
||||
"max_drawdown": 0.0, # TODO: compute from portfolio snapshots
|
||||
"max_drawdown": round(max_drawdown, 4),
|
||||
"total_trades": total_trades,
|
||||
"avg_hold_duration": "0h", # TODO: compute from trade outcomes
|
||||
"avg_hold_duration": avg_hold,
|
||||
}
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -52,15 +52,30 @@ async def _sync_once(
|
|||
# 1. Snapshot account state
|
||||
account = await broker.get_account()
|
||||
|
||||
# 2. Compute daily P&L: difference from the first snapshot today
|
||||
today_start = now.replace(hour=0, minute=0, second=0, microsecond=0)
|
||||
daily_pnl = 0.0
|
||||
async with session_factory() as read_session:
|
||||
day_start_snapshot = (
|
||||
await read_session.execute(
|
||||
select(PortfolioSnapshot)
|
||||
.where(PortfolioSnapshot.timestamp >= today_start)
|
||||
.order_by(PortfolioSnapshot.timestamp)
|
||||
.limit(1)
|
||||
)
|
||||
).scalar_one_or_none()
|
||||
if day_start_snapshot is not None:
|
||||
daily_pnl = account.portfolio_value - day_start_snapshot.total_value
|
||||
|
||||
snapshot = PortfolioSnapshot(
|
||||
timestamp=now,
|
||||
total_value=account.portfolio_value,
|
||||
cash=account.cash,
|
||||
positions_value=account.portfolio_value - account.cash,
|
||||
daily_pnl=0.0,
|
||||
daily_pnl=daily_pnl,
|
||||
)
|
||||
|
||||
# 2. Fetch broker positions
|
||||
# 3. Fetch broker positions
|
||||
broker_positions = await broker.get_positions()
|
||||
broker_tickers = {p.ticker for p in broker_positions}
|
||||
|
||||
|
|
@ -91,7 +106,7 @@ async def _sync_once(
|
|||
)
|
||||
session.add(new_pos)
|
||||
|
||||
# 3. Remove positions that are no longer held at the broker
|
||||
# 4. Remove positions that are no longer held at the broker
|
||||
if broker_tickers:
|
||||
await session.execute(
|
||||
delete(Position).where(Position.ticker.notin_(broker_tickers))
|
||||
|
|
|
|||
|
|
@ -15,10 +15,13 @@ from datetime import datetime, timezone
|
|||
from uuid import UUID
|
||||
|
||||
from redis.asyncio import Redis
|
||||
from sqlalchemy import select
|
||||
|
||||
from services.learning_engine.config import LearningEngineConfig
|
||||
from services.learning_engine.evaluator import TradeEvaluator
|
||||
from services.learning_engine.weight_adjuster import WeightAdjuster
|
||||
from shared.db import create_db
|
||||
from shared.models.trading import Strategy
|
||||
from shared.redis_streams import StreamConsumer
|
||||
from shared.schemas.learning import TradeOutcomeSchema, WeightAdjustment
|
||||
from shared.schemas.trading import OrderSide, TradeExecution
|
||||
|
|
@ -39,11 +42,17 @@ async def _load_strategy_weights(redis: Redis) -> dict[str, float]:
|
|||
raw = await redis.get(_STRATEGY_WEIGHTS_KEY)
|
||||
if raw:
|
||||
return json.loads(raw)
|
||||
# Default equal weights
|
||||
# Default equal weights for all 9 strategies (matches seed_strategies.py)
|
||||
return {
|
||||
"momentum": 0.333,
|
||||
"mean_reversion": 0.333,
|
||||
"news_driven": 0.334,
|
||||
"momentum": 0.111,
|
||||
"mean_reversion": 0.111,
|
||||
"news_driven": 0.111,
|
||||
"value": 0.111,
|
||||
"macd_crossover": 0.111,
|
||||
"bollinger_breakout": 0.111,
|
||||
"vwap": 0.111,
|
||||
"liquidity": 0.112,
|
||||
"ma_stack": 0.111,
|
||||
}
|
||||
|
||||
|
||||
|
|
@ -104,6 +113,7 @@ async def process_trade(
|
|||
evaluator: TradeEvaluator,
|
||||
adjuster: WeightAdjuster,
|
||||
counters: dict,
|
||||
strategy_id_lookup: dict[str, UUID] | None = None,
|
||||
) -> list[WeightAdjustment]:
|
||||
"""Process a single trade execution.
|
||||
|
||||
|
|
@ -132,7 +142,7 @@ async def process_trade(
|
|||
"price": trade.price,
|
||||
"qty": trade.qty,
|
||||
"timestamp": trade.timestamp.isoformat(),
|
||||
"strategy_sources": [], # would come from signal
|
||||
"strategy_sources": trade.strategy_sources,
|
||||
},
|
||||
)
|
||||
return adjustments
|
||||
|
|
@ -203,8 +213,9 @@ async def process_trade(
|
|||
weights[strategy_name] = new_weight
|
||||
any_adjusted = True
|
||||
|
||||
sid = (strategy_id_lookup or {}).get(strategy_name, UUID(int=0))
|
||||
adjustment = WeightAdjustment(
|
||||
strategy_id=UUID(int=0), # placeholder -- DB would assign real ID
|
||||
strategy_id=sid,
|
||||
strategy_name=strategy_name,
|
||||
old_weight=old_weight,
|
||||
new_weight=new_weight,
|
||||
|
|
@ -274,6 +285,19 @@ async def run(config: LearningEngineConfig | None = None) -> None:
|
|||
evaluator = TradeEvaluator()
|
||||
adjuster = WeightAdjuster(config)
|
||||
|
||||
# --- Load strategy name -> UUID lookup from DB ---
|
||||
strategy_id_lookup: dict[str, UUID] = {}
|
||||
try:
|
||||
_engine, session_factory = create_db(config)
|
||||
async with session_factory() as session:
|
||||
result = await session.execute(select(Strategy))
|
||||
for s in result.scalars().all():
|
||||
strategy_id_lookup[s.name] = s.id
|
||||
await _engine.dispose()
|
||||
logger.info("Loaded %d strategy IDs from DB", len(strategy_id_lookup))
|
||||
except Exception:
|
||||
logger.exception("Failed to load strategy IDs — using fallback UUID(int=0)")
|
||||
|
||||
logger.info("Consuming from trades:executed")
|
||||
|
||||
# Graceful shutdown on SIGTERM/SIGINT
|
||||
|
|
@ -294,7 +318,9 @@ async def run(config: LearningEngineConfig | None = None) -> None:
|
|||
logger.debug("Skipping non-filled trade: %s", trade.status.value)
|
||||
continue
|
||||
|
||||
adjustments = await process_trade(trade, redis, evaluator, adjuster, counters)
|
||||
adjustments = await process_trade(
|
||||
trade, redis, evaluator, adjuster, counters, strategy_id_lookup
|
||||
)
|
||||
if adjustments:
|
||||
logger.info(
|
||||
"Made %d weight adjustment(s) for %s",
|
||||
|
|
|
|||
|
|
@ -103,6 +103,7 @@ async def process_signal(
|
|||
status=result.status,
|
||||
signal_id=signal.signal_id,
|
||||
strategy_id=None,
|
||||
strategy_sources=signal.strategy_sources,
|
||||
timestamp=result.timestamp,
|
||||
)
|
||||
|
||||
|
|
@ -193,7 +194,7 @@ async def run(config: TradeExecutorConfig | None = None) -> None:
|
|||
)
|
||||
|
||||
# --- Risk manager ---
|
||||
risk_manager = RiskManager(config, broker)
|
||||
risk_manager = RiskManager(config, broker, redis=redis)
|
||||
|
||||
# --- Database (for persisting trades) ---
|
||||
db_session_factory = None
|
||||
|
|
|
|||
|
|
@ -10,6 +10,8 @@ import logging
|
|||
from datetime import datetime, timedelta
|
||||
from zoneinfo import ZoneInfo
|
||||
|
||||
from redis.asyncio import Redis
|
||||
|
||||
from services.trade_executor.config import TradeExecutorConfig
|
||||
from shared.broker.base import BaseBroker
|
||||
from shared.schemas.trading import AccountInfo, PositionInfo, SignalDirection, TradeSignal
|
||||
|
|
@ -24,6 +26,8 @@ _MARKET_OPEN_MINUTE = 30
|
|||
_MARKET_CLOSE_HOUR = 16
|
||||
_MARKET_CLOSE_MINUTE = 0
|
||||
|
||||
TRADING_PAUSED_KEY = "trading:paused"
|
||||
|
||||
|
||||
class RiskManager:
|
||||
"""Performs pre-trade risk checks and calculates position sizes.
|
||||
|
|
@ -34,11 +38,19 @@ class RiskManager:
|
|||
Trade executor configuration with risk parameters.
|
||||
broker:
|
||||
Broker instance for querying current positions and account info.
|
||||
redis:
|
||||
Redis client for checking the trading pause flag.
|
||||
"""
|
||||
|
||||
def __init__(self, config: TradeExecutorConfig, broker: BaseBroker) -> None:
|
||||
def __init__(
|
||||
self,
|
||||
config: TradeExecutorConfig,
|
||||
broker: BaseBroker,
|
||||
redis: Redis | None = None,
|
||||
) -> None:
|
||||
self.config = config
|
||||
self.broker = broker
|
||||
self.redis = redis
|
||||
# ticker -> last exit timestamp
|
||||
self._cooldowns: dict[str, datetime] = {}
|
||||
|
||||
|
|
@ -55,6 +67,12 @@ class RiskManager:
|
|||
``(approved, reason)`` — ``approved`` is ``True`` when
|
||||
all checks pass, otherwise ``reason`` explains the failure.
|
||||
"""
|
||||
# 0. Trading pause flag
|
||||
if self.redis is not None:
|
||||
paused = await self.redis.get(TRADING_PAUSED_KEY)
|
||||
if paused:
|
||||
return False, "trading_paused"
|
||||
|
||||
# 1. Market hours
|
||||
now_et = datetime.now(tz=_ET)
|
||||
if not self._is_market_hours(now_et):
|
||||
|
|
|
|||
|
|
@ -118,6 +118,7 @@ class TradeExecution(BaseModel):
|
|||
status: OrderStatus
|
||||
signal_id: UUID | None = None
|
||||
strategy_id: UUID | None = None
|
||||
strategy_sources: list[str] = Field(default_factory=list)
|
||||
timestamp: datetime
|
||||
|
||||
model_config = {"from_attributes": True}
|
||||
|
|
|
|||
|
|
@ -231,15 +231,18 @@ class TestSyncOnce:
|
|||
existing_aapl.qty = 5.0 # old qty
|
||||
existing_aapl.avg_entry = 140.0 # old entry
|
||||
|
||||
# Day-start snapshot query (returns None = first snapshot today)
|
||||
result_day_start = MagicMock()
|
||||
result_day_start.scalar_one_or_none.return_value = None
|
||||
|
||||
result_aapl = MagicMock()
|
||||
result_aapl.scalar_one_or_none.return_value = existing_aapl
|
||||
result_msft = MagicMock()
|
||||
result_msft.scalar_one_or_none.return_value = None
|
||||
|
||||
# First execute call is for the delete of stale positions;
|
||||
# but within the loop, select calls come first
|
||||
# Execute calls: day-start snapshot, AAPL lookup, MSFT lookup, DELETE
|
||||
mock_session.execute = AsyncMock(
|
||||
side_effect=[result_aapl, result_msft, MagicMock()]
|
||||
side_effect=[result_day_start, result_aapl, result_msft, MagicMock()]
|
||||
)
|
||||
|
||||
await _sync_once(mock_broker, mock_session_factory)
|
||||
|
|
|
|||
|
|
@ -126,6 +126,51 @@ class TestRiskCheckPasses:
|
|||
assert reason == "approved"
|
||||
|
||||
|
||||
class TestRiskCheckTradingPaused:
|
||||
"""Risk check fails when trading is paused via Redis flag."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_paused_flag_rejects(self):
|
||||
config = _make_config()
|
||||
broker = _mock_broker()
|
||||
redis_mock = AsyncMock()
|
||||
redis_mock.get = AsyncMock(return_value=b"1")
|
||||
rm = RiskManager(config, broker, redis=redis_mock)
|
||||
signal = _make_signal()
|
||||
|
||||
with patch.object(RiskManager, "_is_market_hours", return_value=True):
|
||||
approved, reason = await rm.check_risk(signal)
|
||||
|
||||
assert approved is False
|
||||
assert reason == "trading_paused"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_no_pause_flag_passes_through(self):
|
||||
config = _make_config()
|
||||
broker = _mock_broker(positions=[], account=_make_account(100_000))
|
||||
redis_mock = AsyncMock()
|
||||
redis_mock.get = AsyncMock(return_value=None)
|
||||
rm = RiskManager(config, broker, redis=redis_mock)
|
||||
signal = _make_signal()
|
||||
|
||||
with patch.object(RiskManager, "_is_market_hours", return_value=True):
|
||||
approved, reason = await rm.check_risk(signal)
|
||||
|
||||
assert approved is True
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_no_redis_skips_pause_check(self):
|
||||
config = _make_config()
|
||||
broker = _mock_broker(positions=[], account=_make_account(100_000))
|
||||
rm = RiskManager(config, broker, redis=None)
|
||||
signal = _make_signal()
|
||||
|
||||
with patch.object(RiskManager, "_is_market_hours", return_value=True):
|
||||
approved, reason = await rm.check_risk(signal)
|
||||
|
||||
assert approved is True
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# RiskManager — max positions exceeded
|
||||
# ---------------------------------------------------------------------------
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue