trading/tests/test_models.py

369 lines
10 KiB
Python
Raw Permalink Normal View History

"""Tests for SQLAlchemy model instantiation, enums, and relationships."""
import uuid
from datetime import datetime, timedelta, timezone
import pytest
from shared.models import (
Base,
TimestampMixin,
# Trading
Strategy,
Signal,
SignalDirection,
Trade,
TradeSide,
TradeStatus,
Position,
StrategyWeightHistory,
# News
Article,
ArticleSentiment,
# Learning
TradeOutcome,
LearningAdjustment,
# Auth
User,
UserCredential,
# Timeseries
MarketData,
PortfolioSnapshot,
StrategyMetric,
# Fundamentals
Fundamentals,
)
from shared.db import create_db
from shared.config import BaseConfig
# ---------------------------------------------------------------------------
# Enum tests
# ---------------------------------------------------------------------------
class TestEnums:
def test_trade_side_values(self) -> None:
assert TradeSide.BUY == "BUY"
assert TradeSide.SELL == "SELL"
assert set(TradeSide) == {TradeSide.BUY, TradeSide.SELL}
def test_trade_status_values(self) -> None:
assert TradeStatus.PENDING == "PENDING"
assert TradeStatus.FILLED == "FILLED"
assert TradeStatus.CANCELLED == "CANCELLED"
assert TradeStatus.REJECTED == "REJECTED"
assert len(TradeStatus) == 4
def test_signal_direction_values(self) -> None:
assert SignalDirection.LONG == "LONG"
assert SignalDirection.SHORT == "SHORT"
assert SignalDirection.NEUTRAL == "NEUTRAL"
assert len(SignalDirection) == 3
# ---------------------------------------------------------------------------
# Model instantiation tests
# ---------------------------------------------------------------------------
class TestStrategy:
def test_create_strategy(self) -> None:
s = Strategy(
id=uuid.uuid4(),
name="momentum",
description="Trend-following strategy",
current_weight=0.5,
active=True,
)
assert s.name == "momentum"
assert s.current_weight == 0.5
assert s.active is True
def test_strategy_defaults(self) -> None:
"""Without a DB session, Python-level defaults are not applied by SQLAlchemy.
The column default is only used at INSERT time."""
s = Strategy(name="test")
assert s.description is None
# Column-level default=True is applied by the database at INSERT time,
# so in-memory the attribute is None until the row is flushed/refreshed.
assert s.active is None or s.active is True
class TestSignal:
def test_create_signal(self) -> None:
sig = Signal(
id=uuid.uuid4(),
ticker="AAPL",
direction=SignalDirection.LONG,
strength=0.85,
strategy_sources={"momentum": 0.9},
sentiment_score=0.7,
acted_on=False,
)
assert sig.ticker == "AAPL"
assert sig.direction == SignalDirection.LONG
assert sig.strength == 0.85
assert sig.acted_on is False
class TestTrade:
def test_create_trade(self) -> None:
t = Trade(
id=uuid.uuid4(),
ticker="TSLA",
side=TradeSide.BUY,
qty=10.0,
price=150.25,
status=TradeStatus.FILLED,
pnl=250.50,
)
assert t.ticker == "TSLA"
assert t.side == TradeSide.BUY
assert t.qty == 10.0
assert t.price == 150.25
assert t.status == TradeStatus.FILLED
assert t.pnl == 250.50
class TestPosition:
def test_create_position(self) -> None:
p = Position(
id=uuid.uuid4(),
ticker="GOOG",
qty=5.0,
avg_entry=2800.00,
unrealized_pnl=-50.0,
stop_loss=2750.0,
take_profit=3000.0,
)
assert p.ticker == "GOOG"
assert p.qty == 5.0
assert p.stop_loss == 2750.0
assert p.take_profit == 3000.0
class TestStrategyWeightHistory:
def test_create_weight_history(self) -> None:
sid = uuid.uuid4()
wh = StrategyWeightHistory(
id=uuid.uuid4(),
strategy_id=sid,
old_weight=0.33,
new_weight=0.40,
reason="Improved win rate",
)
assert wh.strategy_id == sid
assert wh.old_weight == 0.33
assert wh.new_weight == 0.40
class TestArticle:
def test_create_article(self) -> None:
now = datetime.now(timezone.utc)
a = Article(
id=uuid.uuid4(),
source="reuters",
url="https://reuters.com/article/1",
title="Market Rally",
published_at=now,
fetched_at=now,
content_hash="abc123def456",
)
assert a.source == "reuters"
assert a.content_hash == "abc123def456"
class TestArticleSentiment:
def test_create_sentiment(self) -> None:
asent = ArticleSentiment(
id=uuid.uuid4(),
article_id=uuid.uuid4(),
ticker="AAPL",
score=0.85,
confidence=0.92,
model_used="finbert",
)
assert asent.score == 0.85
assert asent.model_used == "finbert"
class TestTradeOutcome:
def test_create_outcome(self) -> None:
outcome = TradeOutcome(
id=uuid.uuid4(),
trade_id=uuid.uuid4(),
hold_duration=timedelta(hours=4, minutes=30),
realized_pnl=125.50,
roi_pct=2.5,
was_profitable=True,
)
assert outcome.realized_pnl == 125.50
assert outcome.was_profitable is True
assert outcome.hold_duration == timedelta(hours=4, minutes=30)
class TestLearningAdjustment:
def test_create_adjustment(self) -> None:
adj = LearningAdjustment(
id=uuid.uuid4(),
strategy_id=uuid.uuid4(),
old_weight=0.30,
new_weight=0.35,
reason="Positive reward signal",
reward_signal=0.72,
)
assert adj.reward_signal == 0.72
assert adj.reason == "Positive reward signal"
class TestUser:
def test_create_user(self) -> None:
u = User(
id=uuid.uuid4(),
username="trader1",
display_name="Top Trader",
)
assert u.username == "trader1"
assert u.display_name == "Top Trader"
class TestUserCredential:
def test_create_credential(self) -> None:
cred = UserCredential(
id=uuid.uuid4(),
user_id=uuid.uuid4(),
credential_id="cred-abc-123",
public_key=b"\x04abcdef",
sign_count=5,
)
assert cred.credential_id == "cred-abc-123"
assert cred.sign_count == 5
assert cred.public_key == b"\x04abcdef"
class TestMarketData:
def test_create_market_data(self) -> None:
now = datetime.now(timezone.utc)
md = MarketData(
timestamp=now,
ticker="AAPL",
open=150.0,
high=155.0,
low=149.0,
close=153.0,
volume=1_000_000.0,
)
assert md.ticker == "AAPL"
assert md.close == 153.0
class TestPortfolioSnapshot:
def test_create_snapshot(self) -> None:
now = datetime.now(timezone.utc)
snap = PortfolioSnapshot(
timestamp=now,
total_value=100_000.0,
cash=25_000.0,
positions_value=75_000.0,
daily_pnl=1_200.0,
)
assert snap.total_value == 100_000.0
assert snap.daily_pnl == 1_200.0
class TestStrategyMetric:
def test_create_metric(self) -> None:
now = datetime.now(timezone.utc)
sm = StrategyMetric(
timestamp=now,
strategy_id=uuid.uuid4(),
win_rate=0.65,
total_pnl=5_432.10,
trade_count=42,
sharpe_ratio=1.8,
)
assert sm.win_rate == 0.65
assert sm.trade_count == 42
assert sm.sharpe_ratio == 1.8
class TestFundamentals:
def test_create_fundamentals(self) -> None:
now = datetime.now(timezone.utc)
f = Fundamentals(
id=uuid.uuid4(),
ticker="AAPL",
eps_ttm=6.57,
pe_ratio=28.3,
peg_ratio=2.1,
revenue_growth_yoy=0.08,
profit_margin=0.26,
debt_to_equity=1.87,
market_cap=2_800_000_000_000.0,
fetched_at=now,
)
assert f.ticker == "AAPL"
assert f.eps_ttm == 6.57
assert f.pe_ratio == 28.3
assert f.peg_ratio == 2.1
assert f.revenue_growth_yoy == 0.08
assert f.profit_margin == 0.26
assert f.debt_to_equity == 1.87
assert f.market_cap == 2_800_000_000_000.0
assert f.fetched_at == now
def test_create_with_optional_fields_none(self) -> None:
now = datetime.now(timezone.utc)
f = Fundamentals(
id=uuid.uuid4(),
ticker="XYZ",
fetched_at=now,
)
assert f.ticker == "XYZ"
assert f.eps_ttm is None
assert f.pe_ratio is None
assert f.peg_ratio is None
assert f.revenue_growth_yoy is None
assert f.profit_margin is None
assert f.debt_to_equity is None
assert f.market_cap is None
# ---------------------------------------------------------------------------
# Metadata / Base tests
# ---------------------------------------------------------------------------
class TestMetadata:
def test_all_tables_registered(self) -> None:
table_names = set(Base.metadata.tables.keys())
expected = {
"strategies",
"signals",
"trades",
"positions",
"strategy_weight_history",
"articles",
"article_sentiments",
"trade_outcomes",
"learning_adjustments",
"users",
"user_credentials",
"market_data",
"portfolio_snapshots",
"strategy_metrics",
"fundamentals",
}
assert expected.issubset(table_names)
def test_timestamp_mixin_fields(self) -> None:
"""TimestampMixin should contribute created_at and updated_at columns."""
assert "created_at" in Strategy.__table__.columns
assert "updated_at" in Strategy.__table__.columns
class TestDbFactory:
def test_create_db_returns_engine_and_session(self) -> None:
config = BaseConfig()
engine, session_factory = create_db(config)
assert engine is not None
assert session_factory is not None