trading/scripts/seed_sample_data.py

399 lines
15 KiB
Python
Raw Normal View History

"""Seed the database with ~30 days of realistic sample data.
Populates portfolio snapshots, trades, signals, positions, news articles,
sentiments, strategy metrics, weight history, and trade outcomes.
Usage:
TRADING_DATABASE_URL=postgresql+asyncpg://trading:trading@localhost:5432/trading \
python -m scripts.seed_sample_data
"""
from __future__ import annotations
import asyncio
import hashlib
import logging
import random
import uuid
from datetime import datetime, timedelta, timezone
from sqlalchemy import select
from shared.config import BaseConfig
from shared.db import create_db
from shared.models.learning import TradeOutcome
from shared.models.news import Article, ArticleSentiment
from shared.models.timeseries import PortfolioSnapshot, StrategyMetric
from shared.models.trading import (
Position,
Signal,
SignalDirection,
Strategy,
StrategyWeightHistory,
Trade,
TradeSide,
TradeStatus,
)
logger = logging.getLogger(__name__)
TICKERS = ["AAPL", "TSLA", "NVDA", "MSFT", "GOOGL"]
STRATEGY_NAMES = ["momentum", "mean_reversion", "news_driven"]
# Realistic price ranges for tickers (approximate)
TICKER_PRICES = {
"AAPL": (170.0, 195.0),
"TSLA": (220.0, 280.0),
"NVDA": (700.0, 900.0),
"MSFT": (380.0, 430.0),
"GOOGL": (140.0, 170.0),
}
NEWS_HEADLINES = [
("AAPL", "Apple Reports Record Q4 Revenue Driven by iPhone Sales"),
("AAPL", "Apple Vision Pro Sees Slow Adoption Rates"),
("AAPL", "Apple Expands AI Features Across Product Line"),
("TSLA", "Tesla Deliveries Beat Expectations in Q3"),
("TSLA", "Tesla Cuts Prices Amid Growing EV Competition"),
("TSLA", "Tesla's Robotaxi Event Draws Mixed Reviews"),
("NVDA", "NVIDIA Reports Blowout Earnings on AI Chip Demand"),
("NVDA", "NVIDIA Blackwell GPUs Face Supply Constraints"),
("NVDA", "Data Center Revenue Drives NVIDIA to New Highs"),
("MSFT", "Microsoft Azure Growth Accelerates on AI Workloads"),
("MSFT", "Microsoft Copilot Adoption Grows Among Enterprise Clients"),
("MSFT", "Microsoft Invests $10B in AI Infrastructure"),
("GOOGL", "Google Search Revenue Beats Expectations"),
("GOOGL", "Alphabet Faces Antitrust Ruling on Search Monopoly"),
("GOOGL", "Google Cloud Turns Profitable for Third Consecutive Quarter"),
("AAPL", "Apple Supply Chain Diversifies Beyond China"),
("TSLA", "Tesla Semi Truck Enters Mass Production"),
("NVDA", "NVIDIA Partners with Healthcare Companies for AI Diagnostics"),
("MSFT", "Microsoft Teams Surpasses 300 Million Monthly Users"),
("GOOGL", "YouTube Ad Revenue Surges 15% Year-Over-Year"),
("AAPL", "Apple Services Revenue Hits All-Time High"),
("TSLA", "Tesla Energy Storage Deployments Double Year-Over-Year"),
("NVDA", "NVIDIA Stock Included in Dow Jones Industrial Average"),
("MSFT", "Microsoft Gaming Division Posts Strong Results"),
("GOOGL", "Google DeepMind Achieves Breakthrough in Protein Folding"),
("AAPL", "Apple Announces Stock Buyback Program"),
("TSLA", "Analysts Divided on Tesla Valuation After Rally"),
("NVDA", "NVIDIA Announces Next-Gen GPU Architecture"),
("MSFT", "Microsoft 365 Price Increase Draws Customer Pushback"),
("GOOGL", "Waymo Expands Autonomous Ride Service to New Cities"),
]
def _random_price(ticker: str) -> float:
lo, hi = TICKER_PRICES[ticker]
return round(random.uniform(lo, hi), 2)
def _content_hash(text: str) -> str:
return hashlib.sha256(text.encode()).hexdigest()
async def seed(database_url: str | None = None) -> None:
config = BaseConfig()
if database_url:
config.database_url = database_url
engine, session_factory = create_db(config)
async with session_factory() as session:
# ---------------------------------------------------------------
# 1. Ensure strategies exist (reuse from seed_strategies)
# ---------------------------------------------------------------
result = await session.execute(select(Strategy))
existing = {s.name: s for s in result.scalars().all()}
strategies: dict[str, Strategy] = {}
for name in STRATEGY_NAMES:
if name in existing:
strategies[name] = existing[name]
logger.info("Strategy '%s' already exists", name)
else:
s = Strategy(
name=name,
description=f"Auto-seeded {name} strategy",
current_weight=0.333,
active=True,
)
session.add(s)
strategies[name] = s
logger.info("Created strategy '%s'", name)
await session.flush()
# ---------------------------------------------------------------
# 2. Check if data already seeded (idempotency)
# ---------------------------------------------------------------
trade_count = (
await session.execute(select(Trade))
).scalars().first()
if trade_count is not None:
logger.info("Sample data already exists, skipping seed")
await engine.dispose()
return
now = datetime.now(timezone.utc)
random.seed(42) # reproducible data
# ---------------------------------------------------------------
# 3. Portfolio snapshots — 30 days of equity curve
# ---------------------------------------------------------------
equity = 100_000.0
snapshots = []
for day_offset in range(30, 0, -1):
ts = now - timedelta(days=day_offset)
daily_change = random.gauss(0.001, 0.008) # ~0.1% mean, 0.8% std
equity *= 1 + daily_change
positions_value = equity * random.uniform(0.3, 0.7)
cash = equity - positions_value
daily_pnl = equity * daily_change
snapshots.append(
PortfolioSnapshot(
timestamp=ts,
total_value=round(equity, 2),
cash=round(cash, 2),
positions_value=round(positions_value, 2),
daily_pnl=round(daily_pnl, 2),
)
)
session.add_all(snapshots)
logger.info("Added %d portfolio snapshots", len(snapshots))
# ---------------------------------------------------------------
# 4. Signals + Trades — ~50 trades spread across 30 days
# ---------------------------------------------------------------
strategy_list = list(strategies.values())
all_trades: list[Trade] = []
all_signals: list[Signal] = []
for i in range(50):
day_offset = random.randint(1, 30)
ts = now - timedelta(
days=day_offset,
hours=random.randint(9, 15),
minutes=random.randint(0, 59),
)
ticker = random.choice(TICKERS)
strat = random.choice(strategy_list)
price = _random_price(ticker)
side = random.choice([TradeSide.BUY, TradeSide.SELL])
direction = (
SignalDirection.LONG if side == TradeSide.BUY else SignalDirection.SHORT
)
strength = round(random.uniform(0.4, 0.95), 3)
qty = round(random.uniform(5, 50), 0)
# P&L: 60% of trades are profitable
is_profitable = random.random() < 0.6
pnl = round(
random.uniform(20, 800) * (1 if is_profitable else -1)
* (price / 200),
2,
)
signal = Signal(
ticker=ticker,
direction=direction,
strength=strength,
strategy_sources={strat.name: strength},
sentiment_score=round(random.uniform(-0.5, 0.8), 3),
acted_on=True,
strategy_id=strat.id,
created_at=ts,
updated_at=ts,
)
session.add(signal)
await session.flush()
trade = Trade(
ticker=ticker,
side=side,
qty=qty,
price=price,
status=TradeStatus.FILLED,
pnl=pnl,
strategy_id=strat.id,
signal_id=signal.id,
created_at=ts,
updated_at=ts,
)
session.add(trade)
all_trades.append(trade)
all_signals.append(signal)
await session.flush()
logger.info("Added %d signals and %d trades", len(all_signals), len(all_trades))
# ---------------------------------------------------------------
# 5. Trade outcomes — for all closed trades
# ---------------------------------------------------------------
outcomes = []
for trade in all_trades:
hold_hours = random.randint(1, 72)
roi_pct = round((trade.pnl or 0) / (trade.price * trade.qty) * 100, 2)
outcome = TradeOutcome(
trade_id=trade.id,
hold_duration=timedelta(hours=hold_hours),
realized_pnl=trade.pnl or 0.0,
roi_pct=roi_pct,
was_profitable=(trade.pnl or 0) > 0,
)
outcomes.append(outcome)
session.add_all(outcomes)
logger.info("Added %d trade outcomes", len(outcomes))
# ---------------------------------------------------------------
# 6. Open positions — 4 current positions
# ---------------------------------------------------------------
open_tickers = random.sample(TICKERS, 4)
positions = []
for ticker in open_tickers:
price = _random_price(ticker)
qty = round(random.uniform(10, 100), 0)
unrealized = round(random.gauss(0, price * qty * 0.03), 2)
positions.append(
Position(
ticker=ticker,
qty=qty,
avg_entry=price,
unrealized_pnl=unrealized,
stop_loss=round(price * 0.95, 2),
take_profit=round(price * 1.10, 2),
)
)
session.add_all(positions)
logger.info("Added %d open positions", len(positions))
# ---------------------------------------------------------------
# 7. News articles + sentiments
# ---------------------------------------------------------------
articles = []
sentiments = []
for idx, (ticker, headline) in enumerate(NEWS_HEADLINES):
day_offset = random.randint(1, 30)
ts = now - timedelta(
days=day_offset,
hours=random.randint(6, 20),
)
url = f"https://finance.example.com/article/{idx + 1}"
article = Article(
source="RSS",
url=url,
title=headline,
published_at=ts,
fetched_at=ts + timedelta(minutes=random.randint(1, 30)),
content_hash=_content_hash(f"{headline}-{idx}"),
created_at=ts,
updated_at=ts,
)
session.add(article)
await session.flush()
# Sentiment score correlated with headline tone
positive_words = {"record", "beat", "surge", "high", "growth", "strong", "profit", "expand", "breakthrough", "buyback"}
negative_words = {"slow", "cut", "face", "divided", "pushback", "mixed", "antitrust"}
headline_lower = headline.lower()
pos_count = sum(1 for w in positive_words if w in headline_lower)
neg_count = sum(1 for w in negative_words if w in headline_lower)
base_score = 0.3 * pos_count - 0.3 * neg_count
score = max(-1.0, min(1.0, base_score + random.gauss(0, 0.1)))
sentiment = ArticleSentiment(
article_id=article.id,
ticker=ticker,
score=round(score, 3),
confidence=round(random.uniform(0.6, 0.95), 3),
model_used="finbert",
created_at=ts,
updated_at=ts,
)
sentiments.append(sentiment)
articles.append(article)
session.add_all(sentiments)
logger.info(
"Added %d articles and %d sentiments", len(articles), len(sentiments)
)
# ---------------------------------------------------------------
# 8. Strategy metrics — daily metrics per strategy for 30 days
# ---------------------------------------------------------------
metrics = []
for strat in strategy_list:
cum_pnl = 0.0
trade_count = 0
for day_offset in range(30, 0, -1):
ts = now - timedelta(days=day_offset)
daily_trades = random.randint(0, 3)
trade_count += daily_trades
daily_pnl = round(random.gauss(50, 200), 2)
cum_pnl += daily_pnl
win_rate = round(random.uniform(0.35, 0.75), 4)
sharpe = round(random.gauss(1.2, 0.5), 2)
metrics.append(
StrategyMetric(
timestamp=ts,
strategy_id=strat.id,
win_rate=win_rate,
total_pnl=round(cum_pnl, 2),
trade_count=trade_count,
sharpe_ratio=sharpe,
)
)
session.add_all(metrics)
logger.info("Added %d strategy metric records", len(metrics))
# ---------------------------------------------------------------
# 9. Strategy weight history — a few adjustment records
# ---------------------------------------------------------------
weight_records = []
reasons = [
"Periodic performance review — increased weight due to positive Sharpe",
"Reduced weight after string of losses",
"Rebalanced weights to equal distribution",
"Increased weight — strong win rate last 7 days",
"Decreased weight — high drawdown detected",
]
for strat in strategy_list:
weight = 0.333
for adj_idx in range(random.randint(2, 4)):
day_offset = 30 - adj_idx * 7
if day_offset < 1:
break
ts = now - timedelta(days=day_offset)
old_weight = weight
weight = round(max(0.1, min(0.6, weight + random.gauss(0, 0.05))), 3)
weight_records.append(
StrategyWeightHistory(
strategy_id=strat.id,
old_weight=old_weight,
new_weight=weight,
reason=random.choice(reasons),
created_at=ts,
updated_at=ts,
)
)
session.add_all(weight_records)
logger.info("Added %d weight history records", len(weight_records))
# ---------------------------------------------------------------
# Commit all data
# ---------------------------------------------------------------
await session.commit()
logger.info("All sample data committed successfully")
await engine.dispose()
def main() -> None:
logging.basicConfig(level=logging.INFO, format="%(levelname)s: %(message)s")
asyncio.run(seed())
if __name__ == "__main__":
main()