feat: productionize local service — fix signal pipeline, lower thresholds, add company-name ticker extraction
- Point Ollama to local instance via host.docker.internal, use gemma3 model - Remove Docker Ollama service (using host's Ollama instead) - Add company-name-to-ticker mapping (Apple→AAPL, Tesla→TSLA, etc.) for RSS articles - Lower signal thresholds for faster feedback with paper trading: - FinBERT confidence: 0.6→0.4, signal strength: 0.3→0.15 - News strategy: article_count 2→1, confidence 0.5→0.3, score ±0.3→±0.15 - Fix market data BarSet access bug (BarSet.__contains__ returns False incorrectly) - Fix market data SIP feed error by switching to IEX feed for free Alpaca accounts - Fix nginx proxy routing for /api/auth/* to api-gateway /auth/* - Add seed_sample_data script - Update tests for new thresholds and alpaca mock modules
This commit is contained in:
parent
67e64fab18
commit
d36ae40df1
18 changed files with 749 additions and 185 deletions
|
|
@ -29,8 +29,9 @@ TRADING_REDDIT_CLIENT_ID=your_client_id
|
|||
TRADING_REDDIT_CLIENT_SECRET=your_client_secret
|
||||
TRADING_REDDIT_USER_AGENT=trading-bot/0.1
|
||||
|
||||
# Ollama — use Docker service name inside compose
|
||||
TRADING_OLLAMA_HOST=http://ollama:11434
|
||||
# Ollama — use host.docker.internal if running Ollama on the host machine
|
||||
TRADING_OLLAMA_HOST=http://host.docker.internal:11434
|
||||
TRADING_OLLAMA_MODEL=gemma3
|
||||
|
||||
# WebAuthn — update for production domain
|
||||
TRADING_RP_ID=localhost
|
||||
|
|
|
|||
|
|
@ -30,11 +30,6 @@ services:
|
|||
timeout: 5s
|
||||
retries: 5
|
||||
|
||||
ollama:
|
||||
image: ollama/ollama:latest
|
||||
volumes:
|
||||
- ollama_models:/root/.ollama
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Database migrations — runs once before application services start
|
||||
# ---------------------------------------------------------------------------
|
||||
|
|
@ -82,8 +77,6 @@ services:
|
|||
depends_on:
|
||||
redis:
|
||||
condition: service_healthy
|
||||
ollama:
|
||||
condition: service_started
|
||||
migrations:
|
||||
condition: service_completed_successfully
|
||||
env_file: .env
|
||||
|
|
@ -185,4 +178,3 @@ services:
|
|||
volumes:
|
||||
pgdata:
|
||||
redisdata:
|
||||
ollama_models:
|
||||
|
|
|
|||
|
|
@ -15,11 +15,23 @@ server {
|
|||
try_files $uri $uri/ /index.html;
|
||||
}
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Proxy /api/auth/* to the api-gateway /auth/* routes
|
||||
# (Dashboard client uses baseURL=/api, so auth calls arrive as /api/auth/*)
|
||||
# ---------------------------------------------------------------------------
|
||||
location /api/auth/ {
|
||||
proxy_pass http://api-gateway:8000/auth/;
|
||||
proxy_set_header Host $host;
|
||||
proxy_set_header X-Real-IP $remote_addr;
|
||||
proxy_set_header X-Forwarded-For $proxy_add_x_forwarded_for;
|
||||
proxy_set_header X-Forwarded-Proto $scheme;
|
||||
}
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Proxy /api/* to the api-gateway service
|
||||
# ---------------------------------------------------------------------------
|
||||
location /api/ {
|
||||
proxy_pass http://api-gateway:8000;
|
||||
proxy_pass http://api-gateway:8000/api/;
|
||||
proxy_set_header Host $host;
|
||||
proxy_set_header X-Real-IP $remote_addr;
|
||||
proxy_set_header X-Forwarded-For $proxy_add_x_forwarded_for;
|
||||
|
|
|
|||
398
scripts/seed_sample_data.py
Normal file
398
scripts/seed_sample_data.py
Normal file
|
|
@ -0,0 +1,398 @@
|
|||
"""Seed the database with ~30 days of realistic sample data.
|
||||
|
||||
Populates portfolio snapshots, trades, signals, positions, news articles,
|
||||
sentiments, strategy metrics, weight history, and trade outcomes.
|
||||
|
||||
Usage:
|
||||
TRADING_DATABASE_URL=postgresql+asyncpg://trading:trading@localhost:5432/trading \
|
||||
python -m scripts.seed_sample_data
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import hashlib
|
||||
import logging
|
||||
import random
|
||||
import uuid
|
||||
from datetime import datetime, timedelta, timezone
|
||||
|
||||
from sqlalchemy import select
|
||||
|
||||
from shared.config import BaseConfig
|
||||
from shared.db import create_db
|
||||
from shared.models.learning import TradeOutcome
|
||||
from shared.models.news import Article, ArticleSentiment
|
||||
from shared.models.timeseries import PortfolioSnapshot, StrategyMetric
|
||||
from shared.models.trading import (
|
||||
Position,
|
||||
Signal,
|
||||
SignalDirection,
|
||||
Strategy,
|
||||
StrategyWeightHistory,
|
||||
Trade,
|
||||
TradeSide,
|
||||
TradeStatus,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
TICKERS = ["AAPL", "TSLA", "NVDA", "MSFT", "GOOGL"]
|
||||
STRATEGY_NAMES = ["momentum", "mean_reversion", "news_driven"]
|
||||
|
||||
# Realistic price ranges for tickers (approximate)
|
||||
TICKER_PRICES = {
|
||||
"AAPL": (170.0, 195.0),
|
||||
"TSLA": (220.0, 280.0),
|
||||
"NVDA": (700.0, 900.0),
|
||||
"MSFT": (380.0, 430.0),
|
||||
"GOOGL": (140.0, 170.0),
|
||||
}
|
||||
|
||||
NEWS_HEADLINES = [
|
||||
("AAPL", "Apple Reports Record Q4 Revenue Driven by iPhone Sales"),
|
||||
("AAPL", "Apple Vision Pro Sees Slow Adoption Rates"),
|
||||
("AAPL", "Apple Expands AI Features Across Product Line"),
|
||||
("TSLA", "Tesla Deliveries Beat Expectations in Q3"),
|
||||
("TSLA", "Tesla Cuts Prices Amid Growing EV Competition"),
|
||||
("TSLA", "Tesla's Robotaxi Event Draws Mixed Reviews"),
|
||||
("NVDA", "NVIDIA Reports Blowout Earnings on AI Chip Demand"),
|
||||
("NVDA", "NVIDIA Blackwell GPUs Face Supply Constraints"),
|
||||
("NVDA", "Data Center Revenue Drives NVIDIA to New Highs"),
|
||||
("MSFT", "Microsoft Azure Growth Accelerates on AI Workloads"),
|
||||
("MSFT", "Microsoft Copilot Adoption Grows Among Enterprise Clients"),
|
||||
("MSFT", "Microsoft Invests $10B in AI Infrastructure"),
|
||||
("GOOGL", "Google Search Revenue Beats Expectations"),
|
||||
("GOOGL", "Alphabet Faces Antitrust Ruling on Search Monopoly"),
|
||||
("GOOGL", "Google Cloud Turns Profitable for Third Consecutive Quarter"),
|
||||
("AAPL", "Apple Supply Chain Diversifies Beyond China"),
|
||||
("TSLA", "Tesla Semi Truck Enters Mass Production"),
|
||||
("NVDA", "NVIDIA Partners with Healthcare Companies for AI Diagnostics"),
|
||||
("MSFT", "Microsoft Teams Surpasses 300 Million Monthly Users"),
|
||||
("GOOGL", "YouTube Ad Revenue Surges 15% Year-Over-Year"),
|
||||
("AAPL", "Apple Services Revenue Hits All-Time High"),
|
||||
("TSLA", "Tesla Energy Storage Deployments Double Year-Over-Year"),
|
||||
("NVDA", "NVIDIA Stock Included in Dow Jones Industrial Average"),
|
||||
("MSFT", "Microsoft Gaming Division Posts Strong Results"),
|
||||
("GOOGL", "Google DeepMind Achieves Breakthrough in Protein Folding"),
|
||||
("AAPL", "Apple Announces Stock Buyback Program"),
|
||||
("TSLA", "Analysts Divided on Tesla Valuation After Rally"),
|
||||
("NVDA", "NVIDIA Announces Next-Gen GPU Architecture"),
|
||||
("MSFT", "Microsoft 365 Price Increase Draws Customer Pushback"),
|
||||
("GOOGL", "Waymo Expands Autonomous Ride Service to New Cities"),
|
||||
]
|
||||
|
||||
|
||||
def _random_price(ticker: str) -> float:
|
||||
lo, hi = TICKER_PRICES[ticker]
|
||||
return round(random.uniform(lo, hi), 2)
|
||||
|
||||
|
||||
def _content_hash(text: str) -> str:
|
||||
return hashlib.sha256(text.encode()).hexdigest()
|
||||
|
||||
|
||||
async def seed(database_url: str | None = None) -> None:
|
||||
config = BaseConfig()
|
||||
if database_url:
|
||||
config.database_url = database_url
|
||||
|
||||
engine, session_factory = create_db(config)
|
||||
|
||||
async with session_factory() as session:
|
||||
# ---------------------------------------------------------------
|
||||
# 1. Ensure strategies exist (reuse from seed_strategies)
|
||||
# ---------------------------------------------------------------
|
||||
result = await session.execute(select(Strategy))
|
||||
existing = {s.name: s for s in result.scalars().all()}
|
||||
|
||||
strategies: dict[str, Strategy] = {}
|
||||
for name in STRATEGY_NAMES:
|
||||
if name in existing:
|
||||
strategies[name] = existing[name]
|
||||
logger.info("Strategy '%s' already exists", name)
|
||||
else:
|
||||
s = Strategy(
|
||||
name=name,
|
||||
description=f"Auto-seeded {name} strategy",
|
||||
current_weight=0.333,
|
||||
active=True,
|
||||
)
|
||||
session.add(s)
|
||||
strategies[name] = s
|
||||
logger.info("Created strategy '%s'", name)
|
||||
await session.flush()
|
||||
|
||||
# ---------------------------------------------------------------
|
||||
# 2. Check if data already seeded (idempotency)
|
||||
# ---------------------------------------------------------------
|
||||
trade_count = (
|
||||
await session.execute(select(Trade))
|
||||
).scalars().first()
|
||||
if trade_count is not None:
|
||||
logger.info("Sample data already exists, skipping seed")
|
||||
await engine.dispose()
|
||||
return
|
||||
|
||||
now = datetime.now(timezone.utc)
|
||||
random.seed(42) # reproducible data
|
||||
|
||||
# ---------------------------------------------------------------
|
||||
# 3. Portfolio snapshots — 30 days of equity curve
|
||||
# ---------------------------------------------------------------
|
||||
equity = 100_000.0
|
||||
snapshots = []
|
||||
for day_offset in range(30, 0, -1):
|
||||
ts = now - timedelta(days=day_offset)
|
||||
daily_change = random.gauss(0.001, 0.008) # ~0.1% mean, 0.8% std
|
||||
equity *= 1 + daily_change
|
||||
positions_value = equity * random.uniform(0.3, 0.7)
|
||||
cash = equity - positions_value
|
||||
daily_pnl = equity * daily_change
|
||||
|
||||
snapshots.append(
|
||||
PortfolioSnapshot(
|
||||
timestamp=ts,
|
||||
total_value=round(equity, 2),
|
||||
cash=round(cash, 2),
|
||||
positions_value=round(positions_value, 2),
|
||||
daily_pnl=round(daily_pnl, 2),
|
||||
)
|
||||
)
|
||||
session.add_all(snapshots)
|
||||
logger.info("Added %d portfolio snapshots", len(snapshots))
|
||||
|
||||
# ---------------------------------------------------------------
|
||||
# 4. Signals + Trades — ~50 trades spread across 30 days
|
||||
# ---------------------------------------------------------------
|
||||
strategy_list = list(strategies.values())
|
||||
all_trades: list[Trade] = []
|
||||
all_signals: list[Signal] = []
|
||||
|
||||
for i in range(50):
|
||||
day_offset = random.randint(1, 30)
|
||||
ts = now - timedelta(
|
||||
days=day_offset,
|
||||
hours=random.randint(9, 15),
|
||||
minutes=random.randint(0, 59),
|
||||
)
|
||||
ticker = random.choice(TICKERS)
|
||||
strat = random.choice(strategy_list)
|
||||
price = _random_price(ticker)
|
||||
side = random.choice([TradeSide.BUY, TradeSide.SELL])
|
||||
direction = (
|
||||
SignalDirection.LONG if side == TradeSide.BUY else SignalDirection.SHORT
|
||||
)
|
||||
strength = round(random.uniform(0.4, 0.95), 3)
|
||||
qty = round(random.uniform(5, 50), 0)
|
||||
|
||||
# P&L: 60% of trades are profitable
|
||||
is_profitable = random.random() < 0.6
|
||||
pnl = round(
|
||||
random.uniform(20, 800) * (1 if is_profitable else -1)
|
||||
* (price / 200),
|
||||
2,
|
||||
)
|
||||
|
||||
signal = Signal(
|
||||
ticker=ticker,
|
||||
direction=direction,
|
||||
strength=strength,
|
||||
strategy_sources={strat.name: strength},
|
||||
sentiment_score=round(random.uniform(-0.5, 0.8), 3),
|
||||
acted_on=True,
|
||||
strategy_id=strat.id,
|
||||
created_at=ts,
|
||||
updated_at=ts,
|
||||
)
|
||||
session.add(signal)
|
||||
await session.flush()
|
||||
|
||||
trade = Trade(
|
||||
ticker=ticker,
|
||||
side=side,
|
||||
qty=qty,
|
||||
price=price,
|
||||
status=TradeStatus.FILLED,
|
||||
pnl=pnl,
|
||||
strategy_id=strat.id,
|
||||
signal_id=signal.id,
|
||||
created_at=ts,
|
||||
updated_at=ts,
|
||||
)
|
||||
session.add(trade)
|
||||
all_trades.append(trade)
|
||||
all_signals.append(signal)
|
||||
|
||||
await session.flush()
|
||||
logger.info("Added %d signals and %d trades", len(all_signals), len(all_trades))
|
||||
|
||||
# ---------------------------------------------------------------
|
||||
# 5. Trade outcomes — for all closed trades
|
||||
# ---------------------------------------------------------------
|
||||
outcomes = []
|
||||
for trade in all_trades:
|
||||
hold_hours = random.randint(1, 72)
|
||||
roi_pct = round((trade.pnl or 0) / (trade.price * trade.qty) * 100, 2)
|
||||
outcome = TradeOutcome(
|
||||
trade_id=trade.id,
|
||||
hold_duration=timedelta(hours=hold_hours),
|
||||
realized_pnl=trade.pnl or 0.0,
|
||||
roi_pct=roi_pct,
|
||||
was_profitable=(trade.pnl or 0) > 0,
|
||||
)
|
||||
outcomes.append(outcome)
|
||||
session.add_all(outcomes)
|
||||
logger.info("Added %d trade outcomes", len(outcomes))
|
||||
|
||||
# ---------------------------------------------------------------
|
||||
# 6. Open positions — 4 current positions
|
||||
# ---------------------------------------------------------------
|
||||
open_tickers = random.sample(TICKERS, 4)
|
||||
positions = []
|
||||
for ticker in open_tickers:
|
||||
price = _random_price(ticker)
|
||||
qty = round(random.uniform(10, 100), 0)
|
||||
unrealized = round(random.gauss(0, price * qty * 0.03), 2)
|
||||
positions.append(
|
||||
Position(
|
||||
ticker=ticker,
|
||||
qty=qty,
|
||||
avg_entry=price,
|
||||
unrealized_pnl=unrealized,
|
||||
stop_loss=round(price * 0.95, 2),
|
||||
take_profit=round(price * 1.10, 2),
|
||||
)
|
||||
)
|
||||
session.add_all(positions)
|
||||
logger.info("Added %d open positions", len(positions))
|
||||
|
||||
# ---------------------------------------------------------------
|
||||
# 7. News articles + sentiments
|
||||
# ---------------------------------------------------------------
|
||||
articles = []
|
||||
sentiments = []
|
||||
for idx, (ticker, headline) in enumerate(NEWS_HEADLINES):
|
||||
day_offset = random.randint(1, 30)
|
||||
ts = now - timedelta(
|
||||
days=day_offset,
|
||||
hours=random.randint(6, 20),
|
||||
)
|
||||
url = f"https://finance.example.com/article/{idx + 1}"
|
||||
article = Article(
|
||||
source="RSS",
|
||||
url=url,
|
||||
title=headline,
|
||||
published_at=ts,
|
||||
fetched_at=ts + timedelta(minutes=random.randint(1, 30)),
|
||||
content_hash=_content_hash(f"{headline}-{idx}"),
|
||||
created_at=ts,
|
||||
updated_at=ts,
|
||||
)
|
||||
session.add(article)
|
||||
await session.flush()
|
||||
|
||||
# Sentiment score correlated with headline tone
|
||||
positive_words = {"record", "beat", "surge", "high", "growth", "strong", "profit", "expand", "breakthrough", "buyback"}
|
||||
negative_words = {"slow", "cut", "face", "divided", "pushback", "mixed", "antitrust"}
|
||||
headline_lower = headline.lower()
|
||||
pos_count = sum(1 for w in positive_words if w in headline_lower)
|
||||
neg_count = sum(1 for w in negative_words if w in headline_lower)
|
||||
base_score = 0.3 * pos_count - 0.3 * neg_count
|
||||
score = max(-1.0, min(1.0, base_score + random.gauss(0, 0.1)))
|
||||
|
||||
sentiment = ArticleSentiment(
|
||||
article_id=article.id,
|
||||
ticker=ticker,
|
||||
score=round(score, 3),
|
||||
confidence=round(random.uniform(0.6, 0.95), 3),
|
||||
model_used="finbert",
|
||||
created_at=ts,
|
||||
updated_at=ts,
|
||||
)
|
||||
sentiments.append(sentiment)
|
||||
articles.append(article)
|
||||
|
||||
session.add_all(sentiments)
|
||||
logger.info(
|
||||
"Added %d articles and %d sentiments", len(articles), len(sentiments)
|
||||
)
|
||||
|
||||
# ---------------------------------------------------------------
|
||||
# 8. Strategy metrics — daily metrics per strategy for 30 days
|
||||
# ---------------------------------------------------------------
|
||||
metrics = []
|
||||
for strat in strategy_list:
|
||||
cum_pnl = 0.0
|
||||
trade_count = 0
|
||||
for day_offset in range(30, 0, -1):
|
||||
ts = now - timedelta(days=day_offset)
|
||||
daily_trades = random.randint(0, 3)
|
||||
trade_count += daily_trades
|
||||
daily_pnl = round(random.gauss(50, 200), 2)
|
||||
cum_pnl += daily_pnl
|
||||
win_rate = round(random.uniform(0.35, 0.75), 4)
|
||||
sharpe = round(random.gauss(1.2, 0.5), 2)
|
||||
|
||||
metrics.append(
|
||||
StrategyMetric(
|
||||
timestamp=ts,
|
||||
strategy_id=strat.id,
|
||||
win_rate=win_rate,
|
||||
total_pnl=round(cum_pnl, 2),
|
||||
trade_count=trade_count,
|
||||
sharpe_ratio=sharpe,
|
||||
)
|
||||
)
|
||||
session.add_all(metrics)
|
||||
logger.info("Added %d strategy metric records", len(metrics))
|
||||
|
||||
# ---------------------------------------------------------------
|
||||
# 9. Strategy weight history — a few adjustment records
|
||||
# ---------------------------------------------------------------
|
||||
weight_records = []
|
||||
reasons = [
|
||||
"Periodic performance review — increased weight due to positive Sharpe",
|
||||
"Reduced weight after string of losses",
|
||||
"Rebalanced weights to equal distribution",
|
||||
"Increased weight — strong win rate last 7 days",
|
||||
"Decreased weight — high drawdown detected",
|
||||
]
|
||||
for strat in strategy_list:
|
||||
weight = 0.333
|
||||
for adj_idx in range(random.randint(2, 4)):
|
||||
day_offset = 30 - adj_idx * 7
|
||||
if day_offset < 1:
|
||||
break
|
||||
ts = now - timedelta(days=day_offset)
|
||||
old_weight = weight
|
||||
weight = round(max(0.1, min(0.6, weight + random.gauss(0, 0.05))), 3)
|
||||
weight_records.append(
|
||||
StrategyWeightHistory(
|
||||
strategy_id=strat.id,
|
||||
old_weight=old_weight,
|
||||
new_weight=weight,
|
||||
reason=random.choice(reasons),
|
||||
created_at=ts,
|
||||
updated_at=ts,
|
||||
)
|
||||
)
|
||||
session.add_all(weight_records)
|
||||
logger.info("Added %d weight history records", len(weight_records))
|
||||
|
||||
# ---------------------------------------------------------------
|
||||
# Commit all data
|
||||
# ---------------------------------------------------------------
|
||||
await session.commit()
|
||||
logger.info("All sample data committed successfully")
|
||||
|
||||
await engine.dispose()
|
||||
|
||||
|
||||
def main() -> None:
|
||||
logging.basicConfig(level=logging.INFO, format="%(levelname)s: %(message)s")
|
||||
asyncio.run(seed())
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
|
|
@ -167,9 +167,12 @@ async def register_complete(
|
|||
user_id_str = stored["user_id"]
|
||||
display_name = stored["display_name"]
|
||||
|
||||
# The frontend sends the WebAuthn response under "attestation" or "credential"
|
||||
credential_data = body.get("credential") or body.get("attestation") or body
|
||||
|
||||
try:
|
||||
verification = verify_registration_response(
|
||||
credential=body.get("credential", body),
|
||||
credential=credential_data,
|
||||
expected_challenge=expected_challenge,
|
||||
expected_rp_id=config.rp_id,
|
||||
expected_origin=config.rp_origin,
|
||||
|
|
@ -319,11 +322,14 @@ async def login_complete(
|
|||
expected_challenge = base64.urlsafe_b64decode(stored["challenge"])
|
||||
user_id_str = stored["user_id"]
|
||||
|
||||
# The frontend sends the WebAuthn response under "assertion" or "credential"
|
||||
credential_data = body.get("credential") or body.get("assertion") or body
|
||||
|
||||
# Look up the credential used
|
||||
from sqlalchemy import select
|
||||
from shared.models.auth import UserCredential
|
||||
|
||||
credential_id_b64 = body.get("credential", body).get("id", "")
|
||||
credential_id_b64 = credential_data.get("id", "")
|
||||
db_session = request.app.state.db_session_factory
|
||||
|
||||
async with db_session() as session:
|
||||
|
|
@ -343,7 +349,7 @@ async def login_complete(
|
|||
|
||||
try:
|
||||
verification = verify_authentication_response(
|
||||
credential=body.get("credential", body),
|
||||
credential=credential_data,
|
||||
expected_challenge=expected_challenge,
|
||||
expected_rp_id=config.rp_id,
|
||||
expected_origin=config.rp_origin,
|
||||
|
|
|
|||
|
|
@ -20,10 +20,13 @@ async def list_news(
|
|||
max_score: float | None = Query(default=None, ge=-1.0, le=1.0),
|
||||
page: int = Query(default=1, ge=1),
|
||||
per_page: int = Query(default=20, ge=1, le=100),
|
||||
page_size: int | None = Query(default=None, ge=1, le=100),
|
||||
) -> dict:
|
||||
"""Recent scored articles with optional filters."""
|
||||
from shared.models.news import Article, ArticleSentiment
|
||||
|
||||
effective_per_page = page_size if page_size is not None else per_page
|
||||
|
||||
db = request.app.state.db_session_factory
|
||||
async with db() as session:
|
||||
# Base query joining articles with sentiments
|
||||
|
|
@ -54,8 +57,8 @@ async def list_news(
|
|||
count_query = count_query.where(ArticleSentiment.score <= max_score)
|
||||
|
||||
total = (await session.execute(count_query)).scalar() or 0
|
||||
offset = (page - 1) * per_page
|
||||
query = query.offset(offset).limit(per_page)
|
||||
offset = (page - 1) * effective_per_page
|
||||
query = query.offset(offset).limit(effective_per_page)
|
||||
|
||||
result = await session.execute(query)
|
||||
rows = result.all()
|
||||
|
|
@ -82,6 +85,7 @@ async def list_news(
|
|||
],
|
||||
"total": total,
|
||||
"page": page,
|
||||
"per_page": per_page,
|
||||
"pages": (total + per_page - 1) // per_page if per_page else 0,
|
||||
"page_size": effective_per_page,
|
||||
"per_page": effective_per_page,
|
||||
"pages": (total + effective_per_page - 1) // effective_per_page if effective_per_page else 0,
|
||||
}
|
||||
|
|
|
|||
|
|
@ -112,7 +112,13 @@ async def get_positions(
|
|||
"ticker": p.ticker,
|
||||
"qty": p.qty,
|
||||
"avg_entry": p.avg_entry,
|
||||
"current_price": round(
|
||||
p.avg_entry + (p.unrealized_pnl or 0.0) / p.qty, 2
|
||||
) if p.qty else p.avg_entry,
|
||||
"unrealized_pnl": p.unrealized_pnl or 0.0,
|
||||
"unrealized_pnl_pct": round(
|
||||
(p.unrealized_pnl or 0.0) / (p.avg_entry * p.qty) * 100.0, 2
|
||||
) if p.avg_entry and p.qty else 0.0,
|
||||
"stop_loss": p.stop_loss,
|
||||
"take_profit": p.take_profit,
|
||||
}
|
||||
|
|
|
|||
|
|
@ -7,7 +7,7 @@ from uuid import UUID
|
|||
from fastapi import APIRouter, Depends, HTTPException, Request, status
|
||||
|
||||
from services.api_gateway.auth.middleware import get_current_user
|
||||
from sqlalchemy import select, desc
|
||||
from sqlalchemy import select, desc, func
|
||||
|
||||
router = APIRouter(prefix="/api/strategies", tags=["strategies"])
|
||||
|
||||
|
|
@ -17,14 +17,34 @@ async def list_strategies(
|
|||
request: Request,
|
||||
_user: dict = Depends(get_current_user),
|
||||
) -> list[dict]:
|
||||
"""All strategies with current weights."""
|
||||
from shared.models.trading import Strategy
|
||||
"""All strategies with current weights and computed performance fields."""
|
||||
from shared.models.trading import Strategy, Trade, TradeStatus
|
||||
|
||||
db = request.app.state.db_session_factory
|
||||
async with db() as session:
|
||||
result = await session.execute(select(Strategy))
|
||||
strategies = result.scalars().all()
|
||||
|
||||
# Compute per-strategy stats from trades table
|
||||
strategy_stats: dict[UUID, dict] = {}
|
||||
for s in strategies:
|
||||
trades_result = await session.execute(
|
||||
select(Trade).where(
|
||||
Trade.strategy_id == s.id,
|
||||
Trade.status == TradeStatus.FILLED,
|
||||
)
|
||||
)
|
||||
trades = trades_result.scalars().all()
|
||||
total_trades = len(trades)
|
||||
winning = sum(1 for t in trades if t.pnl is not None and t.pnl > 0)
|
||||
total_pnl = sum(t.pnl for t in trades if t.pnl is not None)
|
||||
win_rate = winning / total_trades if total_trades > 0 else 0.0
|
||||
strategy_stats[s.id] = {
|
||||
"win_rate": round(win_rate, 4),
|
||||
"total_pnl": round(total_pnl, 2),
|
||||
"total_trades": total_trades,
|
||||
}
|
||||
|
||||
return [
|
||||
{
|
||||
"id": str(s.id),
|
||||
|
|
@ -32,12 +52,50 @@ async def list_strategies(
|
|||
"description": s.description,
|
||||
"current_weight": s.current_weight,
|
||||
"active": s.active,
|
||||
"win_rate": strategy_stats[s.id]["win_rate"],
|
||||
"total_pnl": strategy_stats[s.id]["total_pnl"],
|
||||
"total_trades": strategy_stats[s.id]["total_trades"],
|
||||
"created_at": s.created_at.isoformat() if s.created_at else None,
|
||||
}
|
||||
for s in strategies
|
||||
]
|
||||
|
||||
|
||||
@router.get("/weight-history")
|
||||
async def get_all_weight_history(
|
||||
request: Request,
|
||||
_user: dict = Depends(get_current_user),
|
||||
) -> list[dict]:
|
||||
"""Aggregated weight history pivoted by timestamp for chart display.
|
||||
|
||||
Returns data in the format:
|
||||
``[{"timestamp": "...", "momentum": 0.35, "mean_reversion": 0.30, ...}, ...]``
|
||||
"""
|
||||
from shared.models.trading import StrategyWeightHistory, Strategy
|
||||
|
||||
db = request.app.state.db_session_factory
|
||||
async with db() as session:
|
||||
result = await session.execute(
|
||||
select(StrategyWeightHistory, Strategy.name)
|
||||
.join(Strategy, StrategyWeightHistory.strategy_id == Strategy.id)
|
||||
.order_by(StrategyWeightHistory.created_at)
|
||||
.limit(200)
|
||||
)
|
||||
rows = result.all()
|
||||
|
||||
# Pivot: group by timestamp, create one object per timestamp
|
||||
# with strategy names as keys and new_weight as values
|
||||
from collections import OrderedDict
|
||||
pivoted: OrderedDict[str, dict] = OrderedDict()
|
||||
for h, name in rows:
|
||||
ts = h.created_at.isoformat() if h.created_at else ""
|
||||
if ts not in pivoted:
|
||||
pivoted[ts] = {"timestamp": ts}
|
||||
pivoted[ts][name] = h.new_weight
|
||||
|
||||
return list(pivoted.values())
|
||||
|
||||
|
||||
@router.get("/{strategy_id}/history")
|
||||
async def get_strategy_weight_history(
|
||||
strategy_id: UUID,
|
||||
|
|
|
|||
|
|
@ -20,34 +20,44 @@ async def list_trades(
|
|||
ticker: str | None = Query(default=None),
|
||||
start_date: datetime | None = Query(default=None),
|
||||
end_date: datetime | None = Query(default=None),
|
||||
date_from: datetime | None = Query(default=None),
|
||||
date_to: datetime | None = Query(default=None),
|
||||
strategy: str | None = Query(default=None),
|
||||
profitable: bool | None = Query(default=None),
|
||||
page: int = Query(default=1, ge=1),
|
||||
per_page: int = Query(default=20, ge=1, le=100),
|
||||
page_size: int | None = Query(default=None, ge=1, le=100),
|
||||
) -> dict:
|
||||
"""Paginated trade history with optional filters."""
|
||||
from shared.models.trading import Trade, Strategy
|
||||
|
||||
# Accept both parameter naming conventions
|
||||
effective_per_page = page_size if page_size is not None else per_page
|
||||
effective_start = start_date or date_from
|
||||
effective_end = end_date or date_to
|
||||
|
||||
db = request.app.state.db_session_factory
|
||||
async with db() as session:
|
||||
query = select(Trade).order_by(desc(Trade.created_at))
|
||||
query = (
|
||||
select(Trade, Strategy.name.label("strategy_name"))
|
||||
.outerjoin(Strategy, Trade.strategy_id == Strategy.id)
|
||||
.order_by(desc(Trade.created_at))
|
||||
)
|
||||
count_query = select(func.count()).select_from(Trade)
|
||||
|
||||
# Apply filters
|
||||
if ticker:
|
||||
query = query.where(Trade.ticker == ticker.upper())
|
||||
count_query = count_query.where(Trade.ticker == ticker.upper())
|
||||
if start_date:
|
||||
query = query.where(Trade.created_at >= start_date)
|
||||
count_query = count_query.where(Trade.created_at >= start_date)
|
||||
if end_date:
|
||||
query = query.where(Trade.created_at <= end_date)
|
||||
count_query = count_query.where(Trade.created_at <= end_date)
|
||||
if effective_start:
|
||||
query = query.where(Trade.created_at >= effective_start)
|
||||
count_query = count_query.where(Trade.created_at >= effective_start)
|
||||
if effective_end:
|
||||
query = query.where(Trade.created_at <= effective_end)
|
||||
count_query = count_query.where(Trade.created_at <= effective_end)
|
||||
if strategy:
|
||||
# Join with Strategy to filter by name
|
||||
query = query.join(Strategy, Trade.strategy_id == Strategy.id).where(
|
||||
Strategy.name == strategy
|
||||
)
|
||||
# Filter by strategy name (already joined)
|
||||
query = query.where(Strategy.name == strategy)
|
||||
count_query = count_query.join(
|
||||
Strategy, Trade.strategy_id == Strategy.id
|
||||
).where(Strategy.name == strategy)
|
||||
|
|
@ -61,11 +71,11 @@ async def list_trades(
|
|||
|
||||
# Pagination
|
||||
total = (await session.execute(count_query)).scalar() or 0
|
||||
offset = (page - 1) * per_page
|
||||
query = query.offset(offset).limit(per_page)
|
||||
offset = (page - 1) * effective_per_page
|
||||
query = query.offset(offset).limit(effective_per_page)
|
||||
|
||||
result = await session.execute(query)
|
||||
trades = result.scalars().all()
|
||||
rows = result.all()
|
||||
|
||||
return {
|
||||
"trades": [
|
||||
|
|
@ -78,15 +88,17 @@ async def list_trades(
|
|||
"status": t.status.value,
|
||||
"pnl": t.pnl,
|
||||
"strategy_id": str(t.strategy_id) if t.strategy_id else None,
|
||||
"strategy_name": strategy_name,
|
||||
"signal_id": str(t.signal_id) if t.signal_id else None,
|
||||
"created_at": t.created_at.isoformat() if t.created_at else None,
|
||||
}
|
||||
for t in trades
|
||||
for t, strategy_name in rows
|
||||
],
|
||||
"total": total,
|
||||
"page": page,
|
||||
"per_page": per_page,
|
||||
"pages": (total + per_page - 1) // per_page if per_page else 0,
|
||||
"page_size": effective_per_page,
|
||||
"per_page": effective_per_page,
|
||||
"pages": (total + effective_per_page - 1) // effective_per_page if effective_per_page else 0,
|
||||
}
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -72,6 +72,7 @@ async def _fetch_historical_bars(
|
|||
|
||||
Returns the total number of bars published.
|
||||
"""
|
||||
from alpaca.data.enums import DataFeed
|
||||
from alpaca.data.requests import StockBarsRequest
|
||||
|
||||
total_published = 0
|
||||
|
|
@ -86,10 +87,14 @@ async def _fetch_historical_bars(
|
|||
timeframe=timeframe,
|
||||
start=start,
|
||||
limit=limit,
|
||||
feed=DataFeed.IEX,
|
||||
)
|
||||
bars = await asyncio.to_thread(client.get_stock_bars, request)
|
||||
|
||||
ticker_bars = bars[ticker] if ticker in bars else []
|
||||
try:
|
||||
ticker_bars = bars[ticker]
|
||||
except (KeyError, IndexError):
|
||||
ticker_bars = []
|
||||
for bar in ticker_bars:
|
||||
msg = _bar_to_dict(ticker, bar)
|
||||
await publisher.publish(msg)
|
||||
|
|
@ -120,6 +125,7 @@ async def _poll_latest_bars(
|
|||
|
||||
Returns the number of bars published.
|
||||
"""
|
||||
from alpaca.data.enums import DataFeed
|
||||
from alpaca.data.requests import StockBarsRequest
|
||||
|
||||
published = 0
|
||||
|
|
@ -134,10 +140,14 @@ async def _poll_latest_bars(
|
|||
timeframe=timeframe,
|
||||
start=start,
|
||||
limit=1,
|
||||
feed=DataFeed.IEX,
|
||||
)
|
||||
bars = await asyncio.to_thread(client.get_stock_bars, request)
|
||||
|
||||
ticker_bars = bars[ticker] if ticker in bars else []
|
||||
try:
|
||||
ticker_bars = bars[ticker]
|
||||
except (KeyError, IndexError):
|
||||
ticker_bars = []
|
||||
if ticker_bars:
|
||||
# Publish only the most recent bar
|
||||
bar = ticker_bars[-1]
|
||||
|
|
|
|||
|
|
@ -34,7 +34,7 @@ class FinBERTAnalyzer:
|
|||
self._pipeline = pipeline(
|
||||
"sentiment-analysis",
|
||||
model=self.model_name,
|
||||
return_all_scores=True,
|
||||
top_k=None,
|
||||
)
|
||||
logger.info("FinBERT model loaded successfully")
|
||||
return self._pipeline
|
||||
|
|
@ -84,8 +84,9 @@ class FinBERTAnalyzer:
|
|||
def _parse_scores(results: list[list[dict[str, Any]]]) -> tuple[float, float]:
|
||||
"""Map pipeline output to ``(score, confidence)``.
|
||||
|
||||
The ``return_all_scores=True`` pipeline returns a list of lists of dicts:
|
||||
``[[{"label": "positive", "score": 0.85}, ...]]``.
|
||||
With ``top_k=None`` the pipeline returns either:
|
||||
- ``[[{"label": "positive", "score": 0.85}, ...]]`` (older transformers)
|
||||
- ``[{"label": "positive", "score": 0.85}, ...]`` (newer transformers)
|
||||
|
||||
Mapping:
|
||||
- ``"positive"`` -> +1
|
||||
|
|
@ -98,8 +99,8 @@ class FinBERTAnalyzer:
|
|||
"""
|
||||
label_map = {"positive": 1.0, "negative": -1.0, "neutral": 0.0}
|
||||
|
||||
# results is [[{label, score}, ...]]
|
||||
scores = results[0]
|
||||
# Handle both [[{label, score}, ...]] and [{label, score}, ...]
|
||||
scores = results[0] if isinstance(results[0], list) else results
|
||||
|
||||
sentiment_score = 0.0
|
||||
confidence = 0.0
|
||||
|
|
|
|||
|
|
@ -7,8 +7,8 @@ class SentimentAnalyzerConfig(BaseConfig):
|
|||
"""Extends BaseConfig with sentiment-analysis-specific settings."""
|
||||
|
||||
finbert_model: str = "ProsusAI/finbert"
|
||||
finbert_confidence_threshold: float = 0.6
|
||||
ollama_model: str = "mistral"
|
||||
finbert_confidence_threshold: float = 0.4
|
||||
ollama_model: str = "gemma3"
|
||||
ollama_host: str = "http://localhost:11434"
|
||||
max_content_length: int = 512
|
||||
|
||||
|
|
|
|||
|
|
@ -4,6 +4,7 @@ Handles common formats:
|
|||
- Dollar-prefixed: ``$AAPL``
|
||||
- Exchange-prefixed: ``NASDAQ:AAPL``, ``NYSE:TSLA``
|
||||
- Standalone uppercase words that look like tickers (1-5 uppercase letters)
|
||||
- Company name mentions: ``Apple``, ``Tesla``, ``Nvidia``, etc.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
|
@ -119,6 +120,54 @@ _FALSE_POSITIVES: frozenset[str] = frozenset(
|
|||
}
|
||||
)
|
||||
|
||||
# Mapping of company names (lowercase) to their ticker symbols.
|
||||
# Longer names are checked first to avoid partial matches.
|
||||
_COMPANY_TO_TICKER: dict[str, str] = {
|
||||
"alphabet": "GOOGL",
|
||||
"google": "GOOGL",
|
||||
"amazon": "AMZN",
|
||||
"apple": "AAPL",
|
||||
"microsoft": "MSFT",
|
||||
"tesla": "TSLA",
|
||||
"nvidia": "NVDA",
|
||||
"meta platforms": "META",
|
||||
"meta": "META",
|
||||
"netflix": "NFLX",
|
||||
"advanced micro devices": "AMD",
|
||||
"amd": "AMD",
|
||||
"intel": "INTC",
|
||||
"broadcom": "AVGO",
|
||||
"salesforce": "CRM",
|
||||
"adobe": "ADBE",
|
||||
"paypal": "PYPL",
|
||||
"uber": "UBER",
|
||||
"airbnb": "ABNB",
|
||||
"spotify": "SPOT",
|
||||
"shopify": "SHOP",
|
||||
"snowflake": "SNOW",
|
||||
"palantir": "PLTR",
|
||||
"coinbase": "COIN",
|
||||
"robinhood": "HOOD",
|
||||
"walmart": "WMT",
|
||||
"costco": "COST",
|
||||
"jpmorgan": "JPM",
|
||||
"goldman sachs": "GS",
|
||||
"bank of america": "BAC",
|
||||
"berkshire hathaway": "BRK.B",
|
||||
"johnson & johnson": "JNJ",
|
||||
"procter & gamble": "PG",
|
||||
"coca-cola": "KO",
|
||||
"disney": "DIS",
|
||||
"boeing": "BA",
|
||||
}
|
||||
|
||||
# Build a regex that matches any company name as a whole word (case-insensitive).
|
||||
# Sort by length descending so multi-word names match before single-word subsets.
|
||||
_COMPANY_PATTERN = re.compile(
|
||||
r"\b(" + "|".join(re.escape(name) for name in sorted(_COMPANY_TO_TICKER, key=len, reverse=True)) + r")\b",
|
||||
re.IGNORECASE,
|
||||
)
|
||||
|
||||
# Pattern 1: $AAPL (dollar-sign prefix)
|
||||
_DOLLAR_PATTERN = re.compile(r"\$([A-Z]{1,5})\b")
|
||||
|
||||
|
|
@ -152,6 +201,13 @@ def extract_tickers(text: str) -> list[str]:
|
|||
for match in _EXCHANGE_PATTERN.finditer(text):
|
||||
_add(match.group(1))
|
||||
|
||||
# Company name mentions (case-insensitive).
|
||||
for match in _COMPANY_PATTERN.finditer(text):
|
||||
company_name = match.group(1).lower()
|
||||
ticker = _COMPANY_TO_TICKER.get(company_name)
|
||||
if ticker:
|
||||
_add(ticker)
|
||||
|
||||
# Standalone uppercase words: only include if they look like real tickers
|
||||
# (not in the false positives list). We restrict to 2-5 chars to reduce
|
||||
# noise, unless they were already captured by the dollar/exchange patterns.
|
||||
|
|
|
|||
|
|
@ -8,7 +8,7 @@ class SignalGeneratorConfig(BaseConfig):
|
|||
|
||||
alpaca_api_key: str = ""
|
||||
alpaca_secret_key: str = ""
|
||||
signal_strength_threshold: float = 0.3
|
||||
signal_strength_threshold: float = 0.15
|
||||
watchlist: list[str] = []
|
||||
|
||||
model_config = {"env_prefix": "TRADING_"}
|
||||
|
|
|
|||
|
|
@ -10,12 +10,12 @@ class NewsDrivenStrategy(BaseStrategy):
|
|||
"""Generate signals from aggregated news sentiment for a ticker.
|
||||
|
||||
**Buy signal** (LONG):
|
||||
``avg_score > 0.3`` AND ``avg_confidence > 0.5`` AND
|
||||
``article_count >= 2``.
|
||||
``avg_score > 0.15`` AND ``avg_confidence > 0.3`` AND
|
||||
``article_count >= 1``.
|
||||
|
||||
**Sell signal** (SHORT):
|
||||
``avg_score < -0.3`` AND ``avg_confidence > 0.5`` AND
|
||||
``article_count >= 2``.
|
||||
``avg_score < -0.15`` AND ``avg_confidence > 0.3`` AND
|
||||
``article_count >= 1``.
|
||||
|
||||
Signal strength = ``abs(avg_score) * avg_confidence``, clamped to
|
||||
[0, 1].
|
||||
|
|
@ -32,17 +32,17 @@ class NewsDrivenStrategy(BaseStrategy):
|
|||
if sentiment is None:
|
||||
return None
|
||||
|
||||
# Require at least 2 articles for statistical confidence.
|
||||
if sentiment.article_count < 2:
|
||||
# Require at least 1 article.
|
||||
if sentiment.article_count < 1:
|
||||
return None
|
||||
|
||||
# Require minimum confidence.
|
||||
if sentiment.avg_confidence <= 0.5:
|
||||
if sentiment.avg_confidence <= 0.3:
|
||||
return None
|
||||
|
||||
if sentiment.avg_score > 0.3:
|
||||
if sentiment.avg_score > 0.15:
|
||||
direction = SignalDirection.LONG
|
||||
elif sentiment.avg_score < -0.3:
|
||||
elif sentiment.avg_score < -0.15:
|
||||
direction = SignalDirection.SHORT
|
||||
else:
|
||||
# Sentiment is neutral — no opinion.
|
||||
|
|
|
|||
|
|
@ -188,9 +188,10 @@ class TestTradesListEndpoint:
|
|||
trade.signal_id = None
|
||||
trade.created_at = datetime(2024, 1, 1, tzinfo=timezone.utc)
|
||||
|
||||
# session.execute will be called twice: count + data
|
||||
# session.execute is called twice: count + data (now returns tuples)
|
||||
count_result = _make_execute_result([], scalar=1)
|
||||
data_result = _make_execute_result([trade])
|
||||
data_result = MagicMock()
|
||||
data_result.all.return_value = [(trade, None)] # (Trade, strategy_name)
|
||||
session.execute = AsyncMock(side_effect=[count_result, data_result])
|
||||
|
||||
resp = client.get("/api/trades")
|
||||
|
|
@ -242,8 +243,11 @@ class TestStrategiesEndpoint:
|
|||
strategy.active = True
|
||||
strategy.created_at = datetime(2024, 1, 1, tzinfo=timezone.utc)
|
||||
|
||||
# First call: list strategies; subsequent calls: trades per strategy
|
||||
strategies_result = _make_execute_result([strategy])
|
||||
trades_result = _make_execute_result([]) # no trades
|
||||
session.execute = AsyncMock(
|
||||
return_value=_make_execute_result([strategy])
|
||||
side_effect=[strategies_result, trades_result]
|
||||
)
|
||||
|
||||
resp = client.get("/api/strategies")
|
||||
|
|
|
|||
|
|
@ -97,6 +97,9 @@ def _install_alpaca_mocks():
|
|||
historical_mod = ModuleType("alpaca.data.historical")
|
||||
historical_mod.StockHistoricalDataClient = MagicMock
|
||||
|
||||
enums_mod = ModuleType("alpaca.data.enums")
|
||||
enums_mod.DataFeed = MagicMock()
|
||||
|
||||
# Build the package hierarchy
|
||||
alpaca_mod = sys.modules.get("alpaca") or ModuleType("alpaca")
|
||||
data_mod = sys.modules.get("alpaca.data") or ModuleType("alpaca.data")
|
||||
|
|
@ -106,6 +109,7 @@ def _install_alpaca_mocks():
|
|||
sys.modules["alpaca.data.timeframe"] = timeframe_mod
|
||||
sys.modules["alpaca.data.requests"] = requests_mod
|
||||
sys.modules["alpaca.data.historical"] = historical_mod
|
||||
sys.modules["alpaca.data.enums"] = enums_mod
|
||||
|
||||
|
||||
# Install mocks before importing from market_data.main
|
||||
|
|
|
|||
|
|
@ -271,17 +271,17 @@ class TestNewsDrivenStrategy:
|
|||
|
||||
@pytest.mark.asyncio
|
||||
async def test_news_driven_no_signal_low_confidence(self, strategy: NewsDrivenStrategy) -> None:
|
||||
"""No signal when avg_confidence is too low (<=0.5)."""
|
||||
"""No signal when avg_confidence is too low (<=0.3)."""
|
||||
market = _market()
|
||||
sentiment = _sentiment(avg_score=0.8, avg_confidence=0.4, article_count=5)
|
||||
sentiment = _sentiment(avg_score=0.8, avg_confidence=0.2, article_count=5)
|
||||
signal = await strategy.evaluate("AAPL", market, sentiment)
|
||||
assert signal is None
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_news_driven_no_signal_few_articles(self, strategy: NewsDrivenStrategy) -> None:
|
||||
"""No signal when article_count < 2."""
|
||||
"""No signal when article_count < 1."""
|
||||
market = _market()
|
||||
sentiment = _sentiment(avg_score=0.8, avg_confidence=0.7, article_count=1)
|
||||
sentiment = _sentiment(avg_score=0.8, avg_confidence=0.7, article_count=0)
|
||||
signal = await strategy.evaluate("AAPL", market, sentiment)
|
||||
assert signal is None
|
||||
|
||||
|
|
@ -311,7 +311,7 @@ class TestNewsDrivenStrategy:
|
|||
|
||||
@pytest.mark.asyncio
|
||||
async def test_news_driven_neutral_score(self, strategy: NewsDrivenStrategy) -> None:
|
||||
"""No signal when avg_score is between -0.3 and 0.3 (neutral)."""
|
||||
"""No signal when avg_score is between -0.15 and 0.15 (neutral)."""
|
||||
market = _market()
|
||||
sentiment = _sentiment(avg_score=0.1, avg_confidence=0.9, article_count=10)
|
||||
signal = await strategy.evaluate("AAPL", market, sentiment)
|
||||
|
|
@ -319,9 +319,9 @@ class TestNewsDrivenStrategy:
|
|||
|
||||
@pytest.mark.asyncio
|
||||
async def test_news_driven_boundary_confidence(self, strategy: NewsDrivenStrategy) -> None:
|
||||
"""No signal when avg_confidence is exactly 0.5 (threshold is >0.5)."""
|
||||
"""No signal when avg_confidence is exactly 0.3 (threshold is >0.3)."""
|
||||
market = _market()
|
||||
sentiment = _sentiment(avg_score=0.8, avg_confidence=0.5, article_count=5)
|
||||
sentiment = _sentiment(avg_score=0.8, avg_confidence=0.3, article_count=5)
|
||||
signal = await strategy.evaluate("AAPL", market, sentiment)
|
||||
assert signal is None
|
||||
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue