trading/tests/test_schemas.py
Viktor Barzin c8277e301e
feat: pydantic schemas for all service message types
- shared/schemas/trading.py: OrderRequest, OrderResult, PositionInfo,
  AccountInfo, TradeSignal, TradeExecution, MarketSnapshot, SentimentContext
- shared/schemas/news.py: RawArticle, ScoredArticle
- shared/schemas/learning.py: TradeOutcomeSchema, WeightAdjustment
- shared/schemas/auth.py: RegisterRequest, LoginRequest, TokenResponse
- 49 schema tests covering validation constraints, serialization round-trips,
  required fields, and range checks
2026-02-22 15:19:00 +00:00

586 lines
18 KiB
Python

"""Tests for Pydantic schemas — serialization round-trips and validation constraints."""
import uuid
from datetime import datetime, timezone
import pytest
from pydantic import ValidationError
from shared.schemas.trading import (
AccountInfo,
MarketSnapshot,
OrderRequest,
OrderResult,
OrderSide,
OrderStatus,
OrderType,
PositionInfo,
SentimentContext,
SignalDirection,
TradeExecution,
TradeSignal,
)
from shared.schemas.news import RawArticle, ScoredArticle
from shared.schemas.learning import TradeOutcomeSchema, WeightAdjustment
from shared.schemas.auth import LoginRequest, RegisterRequest, TokenResponse
# ---------------------------------------------------------------------------
# Trading schemas
# ---------------------------------------------------------------------------
class TestOrderRequest:
def test_valid_market_order(self) -> None:
o = OrderRequest(ticker="AAPL", side=OrderSide.BUY, qty=10.0)
assert o.order_type == OrderType.MARKET
assert o.limit_price is None
def test_valid_limit_order(self) -> None:
o = OrderRequest(
ticker="TSLA",
side=OrderSide.SELL,
qty=5.0,
order_type=OrderType.LIMIT,
limit_price=250.50,
)
assert o.limit_price == 250.50
def test_qty_must_be_positive(self) -> None:
with pytest.raises(ValidationError):
OrderRequest(ticker="AAPL", side=OrderSide.BUY, qty=0)
def test_qty_must_not_be_negative(self) -> None:
with pytest.raises(ValidationError):
OrderRequest(ticker="AAPL", side=OrderSide.BUY, qty=-5)
def test_serialization_round_trip(self) -> None:
o = OrderRequest(
ticker="GOOG",
side=OrderSide.BUY,
qty=3.0,
order_type=OrderType.STOP,
stop_price=100.0,
)
data = o.model_dump()
restored = OrderRequest.model_validate(data)
assert restored == o
def test_json_round_trip(self) -> None:
o = OrderRequest(ticker="META", side=OrderSide.SELL, qty=1.5)
json_str = o.model_dump_json()
restored = OrderRequest.model_validate_json(json_str)
assert restored == o
class TestOrderResult:
def test_valid_result(self) -> None:
now = datetime.now(timezone.utc)
r = OrderResult(
order_id="ord-123",
ticker="AAPL",
side=OrderSide.BUY,
qty=10.0,
filled_price=150.25,
status=OrderStatus.FILLED,
timestamp=now,
)
assert r.filled_price == 150.25
assert r.status == OrderStatus.FILLED
def test_pending_no_fill(self) -> None:
now = datetime.now(timezone.utc)
r = OrderResult(
order_id="ord-456",
ticker="TSLA",
side=OrderSide.SELL,
qty=5.0,
status=OrderStatus.PENDING,
timestamp=now,
)
assert r.filled_price is None
class TestPositionInfo:
def test_valid_position(self) -> None:
p = PositionInfo(
ticker="NVDA",
qty=20.0,
avg_entry=800.0,
current_price=850.0,
unrealized_pnl=1000.0,
market_value=17000.0,
)
assert p.market_value == 17000.0
def test_serialization_round_trip(self) -> None:
p = PositionInfo(
ticker="AMZN",
qty=10.0,
avg_entry=180.0,
current_price=185.0,
unrealized_pnl=50.0,
market_value=1850.0,
)
restored = PositionInfo.model_validate(p.model_dump())
assert restored == p
class TestAccountInfo:
def test_valid_account(self) -> None:
a = AccountInfo(
equity=100_000.0,
cash=25_000.0,
buying_power=50_000.0,
portfolio_value=100_000.0,
)
assert a.equity == 100_000.0
class TestTradeSignal:
def test_valid_signal(self) -> None:
now = datetime.now(timezone.utc)
s = TradeSignal(
ticker="AAPL",
direction=SignalDirection.LONG,
strength=0.85,
strategy_sources=["momentum", "news_driven"],
timestamp=now,
)
assert s.strength == 0.85
assert s.sentiment_context is None
def test_strength_must_be_in_range(self) -> None:
now = datetime.now(timezone.utc)
with pytest.raises(ValidationError):
TradeSignal(
ticker="AAPL",
direction=SignalDirection.LONG,
strength=1.5,
strategy_sources=["momentum"],
timestamp=now,
)
def test_strength_lower_bound(self) -> None:
now = datetime.now(timezone.utc)
with pytest.raises(ValidationError):
TradeSignal(
ticker="AAPL",
direction=SignalDirection.SHORT,
strength=-0.1,
strategy_sources=["mean_reversion"],
timestamp=now,
)
def test_json_round_trip(self) -> None:
now = datetime.now(timezone.utc)
s = TradeSignal(
ticker="TSLA",
direction=SignalDirection.SHORT,
strength=0.6,
strategy_sources=["mean_reversion"],
sentiment_context={"avg_score": -0.4},
timestamp=now,
)
restored = TradeSignal.model_validate_json(s.model_dump_json())
assert restored == s
class TestTradeExecution:
def test_valid_execution(self) -> None:
now = datetime.now(timezone.utc)
tid = uuid.uuid4()
sid = uuid.uuid4()
e = TradeExecution(
trade_id=tid,
ticker="AAPL",
side=OrderSide.BUY,
qty=10.0,
price=150.0,
status=OrderStatus.FILLED,
signal_id=sid,
strategy_id=uuid.uuid4(),
timestamp=now,
)
assert e.trade_id == tid
assert e.signal_id == sid
def test_optional_ids(self) -> None:
now = datetime.now(timezone.utc)
e = TradeExecution(
trade_id=uuid.uuid4(),
ticker="GOOG",
side=OrderSide.SELL,
qty=5.0,
price=2800.0,
status=OrderStatus.FILLED,
timestamp=now,
)
assert e.signal_id is None
assert e.strategy_id is None
class TestMarketSnapshot:
def test_valid_snapshot(self) -> None:
ms = MarketSnapshot(
ticker="AAPL",
current_price=150.0,
open=148.0,
high=152.0,
low=147.0,
close=150.0,
volume=5_000_000.0,
sma_20=149.5,
rsi=55.0,
)
assert ms.sma_50 is None
assert ms.bars == []
def test_with_bars(self) -> None:
ms = MarketSnapshot(
ticker="TSLA",
current_price=250.0,
open=248.0,
high=255.0,
low=245.0,
close=250.0,
volume=10_000_000.0,
bars=[
{"timestamp": "2026-01-01T10:00:00Z", "open": 248, "close": 250}
],
)
assert len(ms.bars) == 1
class TestSentimentContext:
def test_valid_context(self) -> None:
sc = SentimentContext(
ticker="AAPL",
avg_score=0.65,
article_count=5,
recent_scores=[0.5, 0.7, 0.8],
avg_confidence=0.85,
)
assert sc.article_count == 5
def test_score_range_validation(self) -> None:
with pytest.raises(ValidationError):
SentimentContext(
ticker="AAPL",
avg_score=1.5, # Out of range
article_count=1,
avg_confidence=0.5,
)
def test_negative_score_in_range(self) -> None:
sc = SentimentContext(
ticker="TSLA",
avg_score=-0.8,
article_count=3,
avg_confidence=0.9,
)
assert sc.avg_score == -0.8
def test_confidence_range_validation(self) -> None:
with pytest.raises(ValidationError):
SentimentContext(
ticker="AAPL",
avg_score=0.5,
article_count=1,
avg_confidence=1.5, # Out of range
)
def test_article_count_non_negative(self) -> None:
with pytest.raises(ValidationError):
SentimentContext(
ticker="AAPL",
avg_score=0.0,
article_count=-1,
avg_confidence=0.5,
)
# ---------------------------------------------------------------------------
# News schemas
# ---------------------------------------------------------------------------
class TestRawArticle:
def test_valid_article(self) -> None:
now = datetime.now(timezone.utc)
a = RawArticle(
source="reuters",
url="https://reuters.com/article/1",
title="Market Rally Continues",
content="Stocks rose sharply today...",
published_at=now,
fetched_at=now,
content_hash="sha256abcdef1234567890",
)
assert a.source == "reuters"
def test_published_at_optional(self) -> None:
now = datetime.now(timezone.utc)
a = RawArticle(
source="reddit",
url="https://reddit.com/r/stocks/1",
title="DD on TSLA",
content="Here is my analysis...",
fetched_at=now,
content_hash="hash123",
)
assert a.published_at is None
def test_json_round_trip(self) -> None:
now = datetime.now(timezone.utc)
a = RawArticle(
source="yahoo",
url="https://finance.yahoo.com/1",
title="Earnings Beat",
content="Apple beat earnings...",
published_at=now,
fetched_at=now,
content_hash="hash456",
)
restored = RawArticle.model_validate_json(a.model_dump_json())
assert restored == a
def test_required_fields(self) -> None:
with pytest.raises(ValidationError):
RawArticle(source="reuters") # type: ignore[call-arg]
class TestScoredArticle:
def test_valid_scored_article(self) -> None:
now = datetime.now(timezone.utc)
sa = ScoredArticle(
source="reuters",
url="https://reuters.com/article/1",
title="Apple Earnings Beat",
content="Apple reported...",
published_at=now,
fetched_at=now,
content_hash="hash789",
ticker="AAPL",
sentiment_score=0.85,
confidence=0.92,
model_used="finbert",
entities=["Apple Inc", "Tim Cook"],
)
assert sa.sentiment_score == 0.85
assert sa.entities == ["Apple Inc", "Tim Cook"]
def test_sentiment_score_range(self) -> None:
now = datetime.now(timezone.utc)
with pytest.raises(ValidationError):
ScoredArticle(
source="reuters",
url="https://reuters.com/1",
title="Test",
content="Test content",
fetched_at=now,
content_hash="hash",
ticker="AAPL",
sentiment_score=1.5, # Out of range
confidence=0.5,
model_used="finbert",
)
def test_confidence_range(self) -> None:
now = datetime.now(timezone.utc)
with pytest.raises(ValidationError):
ScoredArticle(
source="reuters",
url="https://reuters.com/1",
title="Test",
content="Test content",
fetched_at=now,
content_hash="hash",
ticker="AAPL",
sentiment_score=0.5,
confidence=-0.1, # Out of range
model_used="finbert",
)
def test_negative_sentiment(self) -> None:
now = datetime.now(timezone.utc)
sa = ScoredArticle(
source="reddit",
url="https://reddit.com/1",
title="Bad news",
content="Terrible quarter...",
fetched_at=now,
content_hash="hashN",
ticker="TSLA",
sentiment_score=-0.9,
confidence=0.8,
model_used="ollama",
)
assert sa.sentiment_score == -0.9
def test_json_round_trip(self) -> None:
now = datetime.now(timezone.utc)
sa = ScoredArticle(
source="yahoo",
url="https://yahoo.com/1",
title="Headline",
content="Body text",
fetched_at=now,
content_hash="hashRT",
ticker="NVDA",
sentiment_score=0.3,
confidence=0.7,
model_used="finbert",
entities=["NVIDIA"],
)
restored = ScoredArticle.model_validate_json(sa.model_dump_json())
assert restored == sa
# ---------------------------------------------------------------------------
# Learning schemas
# ---------------------------------------------------------------------------
class TestTradeOutcomeSchema:
def test_valid_outcome(self) -> None:
o = TradeOutcomeSchema(
trade_id=uuid.uuid4(),
hold_duration_seconds=14400.0,
realized_pnl=250.50,
roi_pct=2.5,
was_profitable=True,
)
assert o.was_profitable is True
assert o.hold_duration_seconds == 14400.0
def test_hold_duration_non_negative(self) -> None:
with pytest.raises(ValidationError):
TradeOutcomeSchema(
trade_id=uuid.uuid4(),
hold_duration_seconds=-1.0,
realized_pnl=100.0,
roi_pct=1.0,
was_profitable=True,
)
def test_losing_trade(self) -> None:
o = TradeOutcomeSchema(
trade_id=uuid.uuid4(),
hold_duration_seconds=3600.0,
realized_pnl=-150.0,
roi_pct=-3.0,
was_profitable=False,
)
assert o.was_profitable is False
assert o.realized_pnl == -150.0
def test_json_round_trip(self) -> None:
o = TradeOutcomeSchema(
trade_id=uuid.uuid4(),
hold_duration_seconds=7200.0,
realized_pnl=500.0,
roi_pct=5.0,
was_profitable=True,
)
restored = TradeOutcomeSchema.model_validate_json(o.model_dump_json())
assert restored == o
class TestWeightAdjustment:
def test_valid_adjustment(self) -> None:
now = datetime.now(timezone.utc)
wa = WeightAdjustment(
strategy_id=uuid.uuid4(),
strategy_name="momentum",
old_weight=0.33,
new_weight=0.38,
reason="Positive reward signal from recent trades",
reward_signal=0.72,
timestamp=now,
)
assert wa.old_weight == 0.33
assert wa.new_weight == 0.38
def test_required_fields(self) -> None:
with pytest.raises(ValidationError):
WeightAdjustment(
strategy_id=uuid.uuid4(),
strategy_name="momentum",
) # type: ignore[call-arg]
def test_json_round_trip(self) -> None:
now = datetime.now(timezone.utc)
wa = WeightAdjustment(
strategy_id=uuid.uuid4(),
strategy_name="mean_reversion",
old_weight=0.30,
new_weight=0.25,
reason="Poor recent performance",
reward_signal=-0.4,
timestamp=now,
)
restored = WeightAdjustment.model_validate_json(wa.model_dump_json())
assert restored == wa
# ---------------------------------------------------------------------------
# Auth schemas
# ---------------------------------------------------------------------------
class TestRegisterRequest:
def test_valid_registration(self) -> None:
r = RegisterRequest(username="trader1", display_name="Top Trader")
assert r.username == "trader1"
assert r.display_name == "Top Trader"
def test_display_name_optional(self) -> None:
r = RegisterRequest(username="trader2")
assert r.display_name is None
def test_username_required(self) -> None:
with pytest.raises(ValidationError):
RegisterRequest(username="") # min_length=1
def test_username_max_length(self) -> None:
with pytest.raises(ValidationError):
RegisterRequest(username="x" * 101) # max_length=100
class TestLoginRequest:
def test_valid_login(self) -> None:
l = LoginRequest(username="trader1")
assert l.username == "trader1"
def test_username_required(self) -> None:
with pytest.raises(ValidationError):
LoginRequest(username="")
class TestTokenResponse:
def test_valid_response(self) -> None:
t = TokenResponse(
access_token="eyJ...",
refresh_token="eyR...",
)
assert t.token_type == "bearer"
assert t.expires_in == 900 # default 15 min
def test_custom_expiry(self) -> None:
t = TokenResponse(
access_token="eyJ...",
refresh_token="eyR...",
expires_in=3600,
)
assert t.expires_in == 3600
def test_json_round_trip(self) -> None:
t = TokenResponse(
access_token="access123",
refresh_token="refresh456",
token_type="bearer",
expires_in=1800,
)
restored = TokenResponse.model_validate_json(t.model_dump_json())
assert restored == t