trading/tests/test_models.py
Viktor Barzin 72cb1b6fe5
feat: database models and alembic migrations — all tables per design
- shared/db.py: async engine + session factory
- shared/models/base.py: DeclarativeBase + TimestampMixin
- shared/models/trading.py: Strategy, Signal, Trade, Position, StrategyWeightHistory
- shared/models/news.py: Article, ArticleSentiment
- shared/models/learning.py: TradeOutcome, LearningAdjustment
- shared/models/auth.py: User, UserCredential
- shared/models/timeseries.py: MarketData, PortfolioSnapshot, StrategyMetric
- Alembic async env.py with initial migration including TimescaleDB hypertables
- 21 model tests covering enums, instantiation, metadata registration
2026-02-22 15:17:07 +00:00

323 lines
9.2 KiB
Python

"""Tests for SQLAlchemy model instantiation, enums, and relationships."""
import uuid
from datetime import datetime, timedelta, timezone
import pytest
from shared.models import (
Base,
TimestampMixin,
# Trading
Strategy,
Signal,
SignalDirection,
Trade,
TradeSide,
TradeStatus,
Position,
StrategyWeightHistory,
# News
Article,
ArticleSentiment,
# Learning
TradeOutcome,
LearningAdjustment,
# Auth
User,
UserCredential,
# Timeseries
MarketData,
PortfolioSnapshot,
StrategyMetric,
)
from shared.db import create_db
from shared.config import BaseConfig
# ---------------------------------------------------------------------------
# Enum tests
# ---------------------------------------------------------------------------
class TestEnums:
def test_trade_side_values(self) -> None:
assert TradeSide.BUY == "BUY"
assert TradeSide.SELL == "SELL"
assert set(TradeSide) == {TradeSide.BUY, TradeSide.SELL}
def test_trade_status_values(self) -> None:
assert TradeStatus.PENDING == "PENDING"
assert TradeStatus.FILLED == "FILLED"
assert TradeStatus.CANCELLED == "CANCELLED"
assert TradeStatus.REJECTED == "REJECTED"
assert len(TradeStatus) == 4
def test_signal_direction_values(self) -> None:
assert SignalDirection.LONG == "LONG"
assert SignalDirection.SHORT == "SHORT"
assert SignalDirection.NEUTRAL == "NEUTRAL"
assert len(SignalDirection) == 3
# ---------------------------------------------------------------------------
# Model instantiation tests
# ---------------------------------------------------------------------------
class TestStrategy:
def test_create_strategy(self) -> None:
s = Strategy(
id=uuid.uuid4(),
name="momentum",
description="Trend-following strategy",
current_weight=0.5,
active=True,
)
assert s.name == "momentum"
assert s.current_weight == 0.5
assert s.active is True
def test_strategy_defaults(self) -> None:
"""Without a DB session, Python-level defaults are not applied by SQLAlchemy.
The column default is only used at INSERT time."""
s = Strategy(name="test")
assert s.description is None
# Column-level default=True is applied by the database at INSERT time,
# so in-memory the attribute is None until the row is flushed/refreshed.
assert s.active is None or s.active is True
class TestSignal:
def test_create_signal(self) -> None:
sig = Signal(
id=uuid.uuid4(),
ticker="AAPL",
direction=SignalDirection.LONG,
strength=0.85,
strategy_sources={"momentum": 0.9},
sentiment_score=0.7,
acted_on=False,
)
assert sig.ticker == "AAPL"
assert sig.direction == SignalDirection.LONG
assert sig.strength == 0.85
assert sig.acted_on is False
class TestTrade:
def test_create_trade(self) -> None:
t = Trade(
id=uuid.uuid4(),
ticker="TSLA",
side=TradeSide.BUY,
qty=10.0,
price=150.25,
status=TradeStatus.FILLED,
pnl=250.50,
)
assert t.ticker == "TSLA"
assert t.side == TradeSide.BUY
assert t.qty == 10.0
assert t.price == 150.25
assert t.status == TradeStatus.FILLED
assert t.pnl == 250.50
class TestPosition:
def test_create_position(self) -> None:
p = Position(
id=uuid.uuid4(),
ticker="GOOG",
qty=5.0,
avg_entry=2800.00,
unrealized_pnl=-50.0,
stop_loss=2750.0,
take_profit=3000.0,
)
assert p.ticker == "GOOG"
assert p.qty == 5.0
assert p.stop_loss == 2750.0
assert p.take_profit == 3000.0
class TestStrategyWeightHistory:
def test_create_weight_history(self) -> None:
sid = uuid.uuid4()
wh = StrategyWeightHistory(
id=uuid.uuid4(),
strategy_id=sid,
old_weight=0.33,
new_weight=0.40,
reason="Improved win rate",
)
assert wh.strategy_id == sid
assert wh.old_weight == 0.33
assert wh.new_weight == 0.40
class TestArticle:
def test_create_article(self) -> None:
now = datetime.now(timezone.utc)
a = Article(
id=uuid.uuid4(),
source="reuters",
url="https://reuters.com/article/1",
title="Market Rally",
published_at=now,
fetched_at=now,
content_hash="abc123def456",
)
assert a.source == "reuters"
assert a.content_hash == "abc123def456"
class TestArticleSentiment:
def test_create_sentiment(self) -> None:
asent = ArticleSentiment(
id=uuid.uuid4(),
article_id=uuid.uuid4(),
ticker="AAPL",
score=0.85,
confidence=0.92,
model_used="finbert",
)
assert asent.score == 0.85
assert asent.model_used == "finbert"
class TestTradeOutcome:
def test_create_outcome(self) -> None:
outcome = TradeOutcome(
id=uuid.uuid4(),
trade_id=uuid.uuid4(),
hold_duration=timedelta(hours=4, minutes=30),
realized_pnl=125.50,
roi_pct=2.5,
was_profitable=True,
)
assert outcome.realized_pnl == 125.50
assert outcome.was_profitable is True
assert outcome.hold_duration == timedelta(hours=4, minutes=30)
class TestLearningAdjustment:
def test_create_adjustment(self) -> None:
adj = LearningAdjustment(
id=uuid.uuid4(),
strategy_id=uuid.uuid4(),
old_weight=0.30,
new_weight=0.35,
reason="Positive reward signal",
reward_signal=0.72,
)
assert adj.reward_signal == 0.72
assert adj.reason == "Positive reward signal"
class TestUser:
def test_create_user(self) -> None:
u = User(
id=uuid.uuid4(),
username="trader1",
display_name="Top Trader",
)
assert u.username == "trader1"
assert u.display_name == "Top Trader"
class TestUserCredential:
def test_create_credential(self) -> None:
cred = UserCredential(
id=uuid.uuid4(),
user_id=uuid.uuid4(),
credential_id="cred-abc-123",
public_key=b"\x04abcdef",
sign_count=5,
)
assert cred.credential_id == "cred-abc-123"
assert cred.sign_count == 5
assert cred.public_key == b"\x04abcdef"
class TestMarketData:
def test_create_market_data(self) -> None:
now = datetime.now(timezone.utc)
md = MarketData(
timestamp=now,
ticker="AAPL",
open=150.0,
high=155.0,
low=149.0,
close=153.0,
volume=1_000_000.0,
)
assert md.ticker == "AAPL"
assert md.close == 153.0
class TestPortfolioSnapshot:
def test_create_snapshot(self) -> None:
now = datetime.now(timezone.utc)
snap = PortfolioSnapshot(
timestamp=now,
total_value=100_000.0,
cash=25_000.0,
positions_value=75_000.0,
daily_pnl=1_200.0,
)
assert snap.total_value == 100_000.0
assert snap.daily_pnl == 1_200.0
class TestStrategyMetric:
def test_create_metric(self) -> None:
now = datetime.now(timezone.utc)
sm = StrategyMetric(
timestamp=now,
strategy_id=uuid.uuid4(),
win_rate=0.65,
total_pnl=5_432.10,
trade_count=42,
sharpe_ratio=1.8,
)
assert sm.win_rate == 0.65
assert sm.trade_count == 42
assert sm.sharpe_ratio == 1.8
# ---------------------------------------------------------------------------
# Metadata / Base tests
# ---------------------------------------------------------------------------
class TestMetadata:
def test_all_tables_registered(self) -> None:
table_names = set(Base.metadata.tables.keys())
expected = {
"strategies",
"signals",
"trades",
"positions",
"strategy_weight_history",
"articles",
"article_sentiments",
"trade_outcomes",
"learning_adjustments",
"users",
"user_credentials",
"market_data",
"portfolio_snapshots",
"strategy_metrics",
}
assert expected.issubset(table_names)
def test_timestamp_mixin_fields(self) -> None:
"""TimestampMixin should contribute created_at and updated_at columns."""
assert "created_at" in Strategy.__table__.columns
assert "updated_at" in Strategy.__table__.columns
class TestDbFactory:
def test_create_db_returns_engine_and_session(self) -> None:
config = BaseConfig()
engine, session_factory = create_db(config)
assert engine is not None
assert session_factory is not None