- shared/db.py: async engine + session factory - shared/models/base.py: DeclarativeBase + TimestampMixin - shared/models/trading.py: Strategy, Signal, Trade, Position, StrategyWeightHistory - shared/models/news.py: Article, ArticleSentiment - shared/models/learning.py: TradeOutcome, LearningAdjustment - shared/models/auth.py: User, UserCredential - shared/models/timeseries.py: MarketData, PortfolioSnapshot, StrategyMetric - Alembic async env.py with initial migration including TimescaleDB hypertables - 21 model tests covering enums, instantiation, metadata registration
323 lines
9.2 KiB
Python
323 lines
9.2 KiB
Python
"""Tests for SQLAlchemy model instantiation, enums, and relationships."""
|
|
|
|
import uuid
|
|
from datetime import datetime, timedelta, timezone
|
|
|
|
import pytest
|
|
|
|
from shared.models import (
|
|
Base,
|
|
TimestampMixin,
|
|
# Trading
|
|
Strategy,
|
|
Signal,
|
|
SignalDirection,
|
|
Trade,
|
|
TradeSide,
|
|
TradeStatus,
|
|
Position,
|
|
StrategyWeightHistory,
|
|
# News
|
|
Article,
|
|
ArticleSentiment,
|
|
# Learning
|
|
TradeOutcome,
|
|
LearningAdjustment,
|
|
# Auth
|
|
User,
|
|
UserCredential,
|
|
# Timeseries
|
|
MarketData,
|
|
PortfolioSnapshot,
|
|
StrategyMetric,
|
|
)
|
|
from shared.db import create_db
|
|
from shared.config import BaseConfig
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# Enum tests
|
|
# ---------------------------------------------------------------------------
|
|
|
|
class TestEnums:
|
|
def test_trade_side_values(self) -> None:
|
|
assert TradeSide.BUY == "BUY"
|
|
assert TradeSide.SELL == "SELL"
|
|
assert set(TradeSide) == {TradeSide.BUY, TradeSide.SELL}
|
|
|
|
def test_trade_status_values(self) -> None:
|
|
assert TradeStatus.PENDING == "PENDING"
|
|
assert TradeStatus.FILLED == "FILLED"
|
|
assert TradeStatus.CANCELLED == "CANCELLED"
|
|
assert TradeStatus.REJECTED == "REJECTED"
|
|
assert len(TradeStatus) == 4
|
|
|
|
def test_signal_direction_values(self) -> None:
|
|
assert SignalDirection.LONG == "LONG"
|
|
assert SignalDirection.SHORT == "SHORT"
|
|
assert SignalDirection.NEUTRAL == "NEUTRAL"
|
|
assert len(SignalDirection) == 3
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# Model instantiation tests
|
|
# ---------------------------------------------------------------------------
|
|
|
|
class TestStrategy:
|
|
def test_create_strategy(self) -> None:
|
|
s = Strategy(
|
|
id=uuid.uuid4(),
|
|
name="momentum",
|
|
description="Trend-following strategy",
|
|
current_weight=0.5,
|
|
active=True,
|
|
)
|
|
assert s.name == "momentum"
|
|
assert s.current_weight == 0.5
|
|
assert s.active is True
|
|
|
|
def test_strategy_defaults(self) -> None:
|
|
"""Without a DB session, Python-level defaults are not applied by SQLAlchemy.
|
|
The column default is only used at INSERT time."""
|
|
s = Strategy(name="test")
|
|
assert s.description is None
|
|
# Column-level default=True is applied by the database at INSERT time,
|
|
# so in-memory the attribute is None until the row is flushed/refreshed.
|
|
assert s.active is None or s.active is True
|
|
|
|
|
|
class TestSignal:
|
|
def test_create_signal(self) -> None:
|
|
sig = Signal(
|
|
id=uuid.uuid4(),
|
|
ticker="AAPL",
|
|
direction=SignalDirection.LONG,
|
|
strength=0.85,
|
|
strategy_sources={"momentum": 0.9},
|
|
sentiment_score=0.7,
|
|
acted_on=False,
|
|
)
|
|
assert sig.ticker == "AAPL"
|
|
assert sig.direction == SignalDirection.LONG
|
|
assert sig.strength == 0.85
|
|
assert sig.acted_on is False
|
|
|
|
|
|
class TestTrade:
|
|
def test_create_trade(self) -> None:
|
|
t = Trade(
|
|
id=uuid.uuid4(),
|
|
ticker="TSLA",
|
|
side=TradeSide.BUY,
|
|
qty=10.0,
|
|
price=150.25,
|
|
status=TradeStatus.FILLED,
|
|
pnl=250.50,
|
|
)
|
|
assert t.ticker == "TSLA"
|
|
assert t.side == TradeSide.BUY
|
|
assert t.qty == 10.0
|
|
assert t.price == 150.25
|
|
assert t.status == TradeStatus.FILLED
|
|
assert t.pnl == 250.50
|
|
|
|
|
|
class TestPosition:
|
|
def test_create_position(self) -> None:
|
|
p = Position(
|
|
id=uuid.uuid4(),
|
|
ticker="GOOG",
|
|
qty=5.0,
|
|
avg_entry=2800.00,
|
|
unrealized_pnl=-50.0,
|
|
stop_loss=2750.0,
|
|
take_profit=3000.0,
|
|
)
|
|
assert p.ticker == "GOOG"
|
|
assert p.qty == 5.0
|
|
assert p.stop_loss == 2750.0
|
|
assert p.take_profit == 3000.0
|
|
|
|
|
|
class TestStrategyWeightHistory:
|
|
def test_create_weight_history(self) -> None:
|
|
sid = uuid.uuid4()
|
|
wh = StrategyWeightHistory(
|
|
id=uuid.uuid4(),
|
|
strategy_id=sid,
|
|
old_weight=0.33,
|
|
new_weight=0.40,
|
|
reason="Improved win rate",
|
|
)
|
|
assert wh.strategy_id == sid
|
|
assert wh.old_weight == 0.33
|
|
assert wh.new_weight == 0.40
|
|
|
|
|
|
class TestArticle:
|
|
def test_create_article(self) -> None:
|
|
now = datetime.now(timezone.utc)
|
|
a = Article(
|
|
id=uuid.uuid4(),
|
|
source="reuters",
|
|
url="https://reuters.com/article/1",
|
|
title="Market Rally",
|
|
published_at=now,
|
|
fetched_at=now,
|
|
content_hash="abc123def456",
|
|
)
|
|
assert a.source == "reuters"
|
|
assert a.content_hash == "abc123def456"
|
|
|
|
|
|
class TestArticleSentiment:
|
|
def test_create_sentiment(self) -> None:
|
|
asent = ArticleSentiment(
|
|
id=uuid.uuid4(),
|
|
article_id=uuid.uuid4(),
|
|
ticker="AAPL",
|
|
score=0.85,
|
|
confidence=0.92,
|
|
model_used="finbert",
|
|
)
|
|
assert asent.score == 0.85
|
|
assert asent.model_used == "finbert"
|
|
|
|
|
|
class TestTradeOutcome:
|
|
def test_create_outcome(self) -> None:
|
|
outcome = TradeOutcome(
|
|
id=uuid.uuid4(),
|
|
trade_id=uuid.uuid4(),
|
|
hold_duration=timedelta(hours=4, minutes=30),
|
|
realized_pnl=125.50,
|
|
roi_pct=2.5,
|
|
was_profitable=True,
|
|
)
|
|
assert outcome.realized_pnl == 125.50
|
|
assert outcome.was_profitable is True
|
|
assert outcome.hold_duration == timedelta(hours=4, minutes=30)
|
|
|
|
|
|
class TestLearningAdjustment:
|
|
def test_create_adjustment(self) -> None:
|
|
adj = LearningAdjustment(
|
|
id=uuid.uuid4(),
|
|
strategy_id=uuid.uuid4(),
|
|
old_weight=0.30,
|
|
new_weight=0.35,
|
|
reason="Positive reward signal",
|
|
reward_signal=0.72,
|
|
)
|
|
assert adj.reward_signal == 0.72
|
|
assert adj.reason == "Positive reward signal"
|
|
|
|
|
|
class TestUser:
|
|
def test_create_user(self) -> None:
|
|
u = User(
|
|
id=uuid.uuid4(),
|
|
username="trader1",
|
|
display_name="Top Trader",
|
|
)
|
|
assert u.username == "trader1"
|
|
assert u.display_name == "Top Trader"
|
|
|
|
|
|
class TestUserCredential:
|
|
def test_create_credential(self) -> None:
|
|
cred = UserCredential(
|
|
id=uuid.uuid4(),
|
|
user_id=uuid.uuid4(),
|
|
credential_id="cred-abc-123",
|
|
public_key=b"\x04abcdef",
|
|
sign_count=5,
|
|
)
|
|
assert cred.credential_id == "cred-abc-123"
|
|
assert cred.sign_count == 5
|
|
assert cred.public_key == b"\x04abcdef"
|
|
|
|
|
|
class TestMarketData:
|
|
def test_create_market_data(self) -> None:
|
|
now = datetime.now(timezone.utc)
|
|
md = MarketData(
|
|
timestamp=now,
|
|
ticker="AAPL",
|
|
open=150.0,
|
|
high=155.0,
|
|
low=149.0,
|
|
close=153.0,
|
|
volume=1_000_000.0,
|
|
)
|
|
assert md.ticker == "AAPL"
|
|
assert md.close == 153.0
|
|
|
|
|
|
class TestPortfolioSnapshot:
|
|
def test_create_snapshot(self) -> None:
|
|
now = datetime.now(timezone.utc)
|
|
snap = PortfolioSnapshot(
|
|
timestamp=now,
|
|
total_value=100_000.0,
|
|
cash=25_000.0,
|
|
positions_value=75_000.0,
|
|
daily_pnl=1_200.0,
|
|
)
|
|
assert snap.total_value == 100_000.0
|
|
assert snap.daily_pnl == 1_200.0
|
|
|
|
|
|
class TestStrategyMetric:
|
|
def test_create_metric(self) -> None:
|
|
now = datetime.now(timezone.utc)
|
|
sm = StrategyMetric(
|
|
timestamp=now,
|
|
strategy_id=uuid.uuid4(),
|
|
win_rate=0.65,
|
|
total_pnl=5_432.10,
|
|
trade_count=42,
|
|
sharpe_ratio=1.8,
|
|
)
|
|
assert sm.win_rate == 0.65
|
|
assert sm.trade_count == 42
|
|
assert sm.sharpe_ratio == 1.8
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# Metadata / Base tests
|
|
# ---------------------------------------------------------------------------
|
|
|
|
class TestMetadata:
|
|
def test_all_tables_registered(self) -> None:
|
|
table_names = set(Base.metadata.tables.keys())
|
|
expected = {
|
|
"strategies",
|
|
"signals",
|
|
"trades",
|
|
"positions",
|
|
"strategy_weight_history",
|
|
"articles",
|
|
"article_sentiments",
|
|
"trade_outcomes",
|
|
"learning_adjustments",
|
|
"users",
|
|
"user_credentials",
|
|
"market_data",
|
|
"portfolio_snapshots",
|
|
"strategy_metrics",
|
|
}
|
|
assert expected.issubset(table_names)
|
|
|
|
def test_timestamp_mixin_fields(self) -> None:
|
|
"""TimestampMixin should contribute created_at and updated_at columns."""
|
|
assert "created_at" in Strategy.__table__.columns
|
|
assert "updated_at" in Strategy.__table__.columns
|
|
|
|
|
|
class TestDbFactory:
|
|
def test_create_db_returns_engine_and_session(self) -> None:
|
|
config = BaseConfig()
|
|
engine, session_factory = create_db(config)
|
|
assert engine is not None
|
|
assert session_factory is not None
|