feat: pydantic schemas for all service message types
- shared/schemas/trading.py: OrderRequest, OrderResult, PositionInfo, AccountInfo, TradeSignal, TradeExecution, MarketSnapshot, SentimentContext - shared/schemas/news.py: RawArticle, ScoredArticle - shared/schemas/learning.py: TradeOutcomeSchema, WeightAdjustment - shared/schemas/auth.py: RegisterRequest, LoginRequest, TokenResponse - 49 schema tests covering validation constraints, serialization round-trips, required fields, and range checks
This commit is contained in:
parent
72cb1b6fe5
commit
c8277e301e
11 changed files with 889 additions and 0 deletions
586
tests/test_schemas.py
Normal file
586
tests/test_schemas.py
Normal file
|
|
@ -0,0 +1,586 @@
|
|||
"""Tests for Pydantic schemas — serialization round-trips and validation constraints."""
|
||||
|
||||
import uuid
|
||||
from datetime import datetime, timezone
|
||||
|
||||
import pytest
|
||||
from pydantic import ValidationError
|
||||
|
||||
from shared.schemas.trading import (
|
||||
AccountInfo,
|
||||
MarketSnapshot,
|
||||
OrderRequest,
|
||||
OrderResult,
|
||||
OrderSide,
|
||||
OrderStatus,
|
||||
OrderType,
|
||||
PositionInfo,
|
||||
SentimentContext,
|
||||
SignalDirection,
|
||||
TradeExecution,
|
||||
TradeSignal,
|
||||
)
|
||||
from shared.schemas.news import RawArticle, ScoredArticle
|
||||
from shared.schemas.learning import TradeOutcomeSchema, WeightAdjustment
|
||||
from shared.schemas.auth import LoginRequest, RegisterRequest, TokenResponse
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Trading schemas
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestOrderRequest:
|
||||
def test_valid_market_order(self) -> None:
|
||||
o = OrderRequest(ticker="AAPL", side=OrderSide.BUY, qty=10.0)
|
||||
assert o.order_type == OrderType.MARKET
|
||||
assert o.limit_price is None
|
||||
|
||||
def test_valid_limit_order(self) -> None:
|
||||
o = OrderRequest(
|
||||
ticker="TSLA",
|
||||
side=OrderSide.SELL,
|
||||
qty=5.0,
|
||||
order_type=OrderType.LIMIT,
|
||||
limit_price=250.50,
|
||||
)
|
||||
assert o.limit_price == 250.50
|
||||
|
||||
def test_qty_must_be_positive(self) -> None:
|
||||
with pytest.raises(ValidationError):
|
||||
OrderRequest(ticker="AAPL", side=OrderSide.BUY, qty=0)
|
||||
|
||||
def test_qty_must_not_be_negative(self) -> None:
|
||||
with pytest.raises(ValidationError):
|
||||
OrderRequest(ticker="AAPL", side=OrderSide.BUY, qty=-5)
|
||||
|
||||
def test_serialization_round_trip(self) -> None:
|
||||
o = OrderRequest(
|
||||
ticker="GOOG",
|
||||
side=OrderSide.BUY,
|
||||
qty=3.0,
|
||||
order_type=OrderType.STOP,
|
||||
stop_price=100.0,
|
||||
)
|
||||
data = o.model_dump()
|
||||
restored = OrderRequest.model_validate(data)
|
||||
assert restored == o
|
||||
|
||||
def test_json_round_trip(self) -> None:
|
||||
o = OrderRequest(ticker="META", side=OrderSide.SELL, qty=1.5)
|
||||
json_str = o.model_dump_json()
|
||||
restored = OrderRequest.model_validate_json(json_str)
|
||||
assert restored == o
|
||||
|
||||
|
||||
class TestOrderResult:
|
||||
def test_valid_result(self) -> None:
|
||||
now = datetime.now(timezone.utc)
|
||||
r = OrderResult(
|
||||
order_id="ord-123",
|
||||
ticker="AAPL",
|
||||
side=OrderSide.BUY,
|
||||
qty=10.0,
|
||||
filled_price=150.25,
|
||||
status=OrderStatus.FILLED,
|
||||
timestamp=now,
|
||||
)
|
||||
assert r.filled_price == 150.25
|
||||
assert r.status == OrderStatus.FILLED
|
||||
|
||||
def test_pending_no_fill(self) -> None:
|
||||
now = datetime.now(timezone.utc)
|
||||
r = OrderResult(
|
||||
order_id="ord-456",
|
||||
ticker="TSLA",
|
||||
side=OrderSide.SELL,
|
||||
qty=5.0,
|
||||
status=OrderStatus.PENDING,
|
||||
timestamp=now,
|
||||
)
|
||||
assert r.filled_price is None
|
||||
|
||||
|
||||
class TestPositionInfo:
|
||||
def test_valid_position(self) -> None:
|
||||
p = PositionInfo(
|
||||
ticker="NVDA",
|
||||
qty=20.0,
|
||||
avg_entry=800.0,
|
||||
current_price=850.0,
|
||||
unrealized_pnl=1000.0,
|
||||
market_value=17000.0,
|
||||
)
|
||||
assert p.market_value == 17000.0
|
||||
|
||||
def test_serialization_round_trip(self) -> None:
|
||||
p = PositionInfo(
|
||||
ticker="AMZN",
|
||||
qty=10.0,
|
||||
avg_entry=180.0,
|
||||
current_price=185.0,
|
||||
unrealized_pnl=50.0,
|
||||
market_value=1850.0,
|
||||
)
|
||||
restored = PositionInfo.model_validate(p.model_dump())
|
||||
assert restored == p
|
||||
|
||||
|
||||
class TestAccountInfo:
|
||||
def test_valid_account(self) -> None:
|
||||
a = AccountInfo(
|
||||
equity=100_000.0,
|
||||
cash=25_000.0,
|
||||
buying_power=50_000.0,
|
||||
portfolio_value=100_000.0,
|
||||
)
|
||||
assert a.equity == 100_000.0
|
||||
|
||||
|
||||
class TestTradeSignal:
|
||||
def test_valid_signal(self) -> None:
|
||||
now = datetime.now(timezone.utc)
|
||||
s = TradeSignal(
|
||||
ticker="AAPL",
|
||||
direction=SignalDirection.LONG,
|
||||
strength=0.85,
|
||||
strategy_sources=["momentum", "news_driven"],
|
||||
timestamp=now,
|
||||
)
|
||||
assert s.strength == 0.85
|
||||
assert s.sentiment_context is None
|
||||
|
||||
def test_strength_must_be_in_range(self) -> None:
|
||||
now = datetime.now(timezone.utc)
|
||||
with pytest.raises(ValidationError):
|
||||
TradeSignal(
|
||||
ticker="AAPL",
|
||||
direction=SignalDirection.LONG,
|
||||
strength=1.5,
|
||||
strategy_sources=["momentum"],
|
||||
timestamp=now,
|
||||
)
|
||||
|
||||
def test_strength_lower_bound(self) -> None:
|
||||
now = datetime.now(timezone.utc)
|
||||
with pytest.raises(ValidationError):
|
||||
TradeSignal(
|
||||
ticker="AAPL",
|
||||
direction=SignalDirection.SHORT,
|
||||
strength=-0.1,
|
||||
strategy_sources=["mean_reversion"],
|
||||
timestamp=now,
|
||||
)
|
||||
|
||||
def test_json_round_trip(self) -> None:
|
||||
now = datetime.now(timezone.utc)
|
||||
s = TradeSignal(
|
||||
ticker="TSLA",
|
||||
direction=SignalDirection.SHORT,
|
||||
strength=0.6,
|
||||
strategy_sources=["mean_reversion"],
|
||||
sentiment_context={"avg_score": -0.4},
|
||||
timestamp=now,
|
||||
)
|
||||
restored = TradeSignal.model_validate_json(s.model_dump_json())
|
||||
assert restored == s
|
||||
|
||||
|
||||
class TestTradeExecution:
|
||||
def test_valid_execution(self) -> None:
|
||||
now = datetime.now(timezone.utc)
|
||||
tid = uuid.uuid4()
|
||||
sid = uuid.uuid4()
|
||||
e = TradeExecution(
|
||||
trade_id=tid,
|
||||
ticker="AAPL",
|
||||
side=OrderSide.BUY,
|
||||
qty=10.0,
|
||||
price=150.0,
|
||||
status=OrderStatus.FILLED,
|
||||
signal_id=sid,
|
||||
strategy_id=uuid.uuid4(),
|
||||
timestamp=now,
|
||||
)
|
||||
assert e.trade_id == tid
|
||||
assert e.signal_id == sid
|
||||
|
||||
def test_optional_ids(self) -> None:
|
||||
now = datetime.now(timezone.utc)
|
||||
e = TradeExecution(
|
||||
trade_id=uuid.uuid4(),
|
||||
ticker="GOOG",
|
||||
side=OrderSide.SELL,
|
||||
qty=5.0,
|
||||
price=2800.0,
|
||||
status=OrderStatus.FILLED,
|
||||
timestamp=now,
|
||||
)
|
||||
assert e.signal_id is None
|
||||
assert e.strategy_id is None
|
||||
|
||||
|
||||
class TestMarketSnapshot:
|
||||
def test_valid_snapshot(self) -> None:
|
||||
ms = MarketSnapshot(
|
||||
ticker="AAPL",
|
||||
current_price=150.0,
|
||||
open=148.0,
|
||||
high=152.0,
|
||||
low=147.0,
|
||||
close=150.0,
|
||||
volume=5_000_000.0,
|
||||
sma_20=149.5,
|
||||
rsi=55.0,
|
||||
)
|
||||
assert ms.sma_50 is None
|
||||
assert ms.bars == []
|
||||
|
||||
def test_with_bars(self) -> None:
|
||||
ms = MarketSnapshot(
|
||||
ticker="TSLA",
|
||||
current_price=250.0,
|
||||
open=248.0,
|
||||
high=255.0,
|
||||
low=245.0,
|
||||
close=250.0,
|
||||
volume=10_000_000.0,
|
||||
bars=[
|
||||
{"timestamp": "2026-01-01T10:00:00Z", "open": 248, "close": 250}
|
||||
],
|
||||
)
|
||||
assert len(ms.bars) == 1
|
||||
|
||||
|
||||
class TestSentimentContext:
|
||||
def test_valid_context(self) -> None:
|
||||
sc = SentimentContext(
|
||||
ticker="AAPL",
|
||||
avg_score=0.65,
|
||||
article_count=5,
|
||||
recent_scores=[0.5, 0.7, 0.8],
|
||||
avg_confidence=0.85,
|
||||
)
|
||||
assert sc.article_count == 5
|
||||
|
||||
def test_score_range_validation(self) -> None:
|
||||
with pytest.raises(ValidationError):
|
||||
SentimentContext(
|
||||
ticker="AAPL",
|
||||
avg_score=1.5, # Out of range
|
||||
article_count=1,
|
||||
avg_confidence=0.5,
|
||||
)
|
||||
|
||||
def test_negative_score_in_range(self) -> None:
|
||||
sc = SentimentContext(
|
||||
ticker="TSLA",
|
||||
avg_score=-0.8,
|
||||
article_count=3,
|
||||
avg_confidence=0.9,
|
||||
)
|
||||
assert sc.avg_score == -0.8
|
||||
|
||||
def test_confidence_range_validation(self) -> None:
|
||||
with pytest.raises(ValidationError):
|
||||
SentimentContext(
|
||||
ticker="AAPL",
|
||||
avg_score=0.5,
|
||||
article_count=1,
|
||||
avg_confidence=1.5, # Out of range
|
||||
)
|
||||
|
||||
def test_article_count_non_negative(self) -> None:
|
||||
with pytest.raises(ValidationError):
|
||||
SentimentContext(
|
||||
ticker="AAPL",
|
||||
avg_score=0.0,
|
||||
article_count=-1,
|
||||
avg_confidence=0.5,
|
||||
)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# News schemas
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestRawArticle:
|
||||
def test_valid_article(self) -> None:
|
||||
now = datetime.now(timezone.utc)
|
||||
a = RawArticle(
|
||||
source="reuters",
|
||||
url="https://reuters.com/article/1",
|
||||
title="Market Rally Continues",
|
||||
content="Stocks rose sharply today...",
|
||||
published_at=now,
|
||||
fetched_at=now,
|
||||
content_hash="sha256abcdef1234567890",
|
||||
)
|
||||
assert a.source == "reuters"
|
||||
|
||||
def test_published_at_optional(self) -> None:
|
||||
now = datetime.now(timezone.utc)
|
||||
a = RawArticle(
|
||||
source="reddit",
|
||||
url="https://reddit.com/r/stocks/1",
|
||||
title="DD on TSLA",
|
||||
content="Here is my analysis...",
|
||||
fetched_at=now,
|
||||
content_hash="hash123",
|
||||
)
|
||||
assert a.published_at is None
|
||||
|
||||
def test_json_round_trip(self) -> None:
|
||||
now = datetime.now(timezone.utc)
|
||||
a = RawArticle(
|
||||
source="yahoo",
|
||||
url="https://finance.yahoo.com/1",
|
||||
title="Earnings Beat",
|
||||
content="Apple beat earnings...",
|
||||
published_at=now,
|
||||
fetched_at=now,
|
||||
content_hash="hash456",
|
||||
)
|
||||
restored = RawArticle.model_validate_json(a.model_dump_json())
|
||||
assert restored == a
|
||||
|
||||
def test_required_fields(self) -> None:
|
||||
with pytest.raises(ValidationError):
|
||||
RawArticle(source="reuters") # type: ignore[call-arg]
|
||||
|
||||
|
||||
class TestScoredArticle:
|
||||
def test_valid_scored_article(self) -> None:
|
||||
now = datetime.now(timezone.utc)
|
||||
sa = ScoredArticle(
|
||||
source="reuters",
|
||||
url="https://reuters.com/article/1",
|
||||
title="Apple Earnings Beat",
|
||||
content="Apple reported...",
|
||||
published_at=now,
|
||||
fetched_at=now,
|
||||
content_hash="hash789",
|
||||
ticker="AAPL",
|
||||
sentiment_score=0.85,
|
||||
confidence=0.92,
|
||||
model_used="finbert",
|
||||
entities=["Apple Inc", "Tim Cook"],
|
||||
)
|
||||
assert sa.sentiment_score == 0.85
|
||||
assert sa.entities == ["Apple Inc", "Tim Cook"]
|
||||
|
||||
def test_sentiment_score_range(self) -> None:
|
||||
now = datetime.now(timezone.utc)
|
||||
with pytest.raises(ValidationError):
|
||||
ScoredArticle(
|
||||
source="reuters",
|
||||
url="https://reuters.com/1",
|
||||
title="Test",
|
||||
content="Test content",
|
||||
fetched_at=now,
|
||||
content_hash="hash",
|
||||
ticker="AAPL",
|
||||
sentiment_score=1.5, # Out of range
|
||||
confidence=0.5,
|
||||
model_used="finbert",
|
||||
)
|
||||
|
||||
def test_confidence_range(self) -> None:
|
||||
now = datetime.now(timezone.utc)
|
||||
with pytest.raises(ValidationError):
|
||||
ScoredArticle(
|
||||
source="reuters",
|
||||
url="https://reuters.com/1",
|
||||
title="Test",
|
||||
content="Test content",
|
||||
fetched_at=now,
|
||||
content_hash="hash",
|
||||
ticker="AAPL",
|
||||
sentiment_score=0.5,
|
||||
confidence=-0.1, # Out of range
|
||||
model_used="finbert",
|
||||
)
|
||||
|
||||
def test_negative_sentiment(self) -> None:
|
||||
now = datetime.now(timezone.utc)
|
||||
sa = ScoredArticle(
|
||||
source="reddit",
|
||||
url="https://reddit.com/1",
|
||||
title="Bad news",
|
||||
content="Terrible quarter...",
|
||||
fetched_at=now,
|
||||
content_hash="hashN",
|
||||
ticker="TSLA",
|
||||
sentiment_score=-0.9,
|
||||
confidence=0.8,
|
||||
model_used="ollama",
|
||||
)
|
||||
assert sa.sentiment_score == -0.9
|
||||
|
||||
def test_json_round_trip(self) -> None:
|
||||
now = datetime.now(timezone.utc)
|
||||
sa = ScoredArticle(
|
||||
source="yahoo",
|
||||
url="https://yahoo.com/1",
|
||||
title="Headline",
|
||||
content="Body text",
|
||||
fetched_at=now,
|
||||
content_hash="hashRT",
|
||||
ticker="NVDA",
|
||||
sentiment_score=0.3,
|
||||
confidence=0.7,
|
||||
model_used="finbert",
|
||||
entities=["NVIDIA"],
|
||||
)
|
||||
restored = ScoredArticle.model_validate_json(sa.model_dump_json())
|
||||
assert restored == sa
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Learning schemas
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestTradeOutcomeSchema:
|
||||
def test_valid_outcome(self) -> None:
|
||||
o = TradeOutcomeSchema(
|
||||
trade_id=uuid.uuid4(),
|
||||
hold_duration_seconds=14400.0,
|
||||
realized_pnl=250.50,
|
||||
roi_pct=2.5,
|
||||
was_profitable=True,
|
||||
)
|
||||
assert o.was_profitable is True
|
||||
assert o.hold_duration_seconds == 14400.0
|
||||
|
||||
def test_hold_duration_non_negative(self) -> None:
|
||||
with pytest.raises(ValidationError):
|
||||
TradeOutcomeSchema(
|
||||
trade_id=uuid.uuid4(),
|
||||
hold_duration_seconds=-1.0,
|
||||
realized_pnl=100.0,
|
||||
roi_pct=1.0,
|
||||
was_profitable=True,
|
||||
)
|
||||
|
||||
def test_losing_trade(self) -> None:
|
||||
o = TradeOutcomeSchema(
|
||||
trade_id=uuid.uuid4(),
|
||||
hold_duration_seconds=3600.0,
|
||||
realized_pnl=-150.0,
|
||||
roi_pct=-3.0,
|
||||
was_profitable=False,
|
||||
)
|
||||
assert o.was_profitable is False
|
||||
assert o.realized_pnl == -150.0
|
||||
|
||||
def test_json_round_trip(self) -> None:
|
||||
o = TradeOutcomeSchema(
|
||||
trade_id=uuid.uuid4(),
|
||||
hold_duration_seconds=7200.0,
|
||||
realized_pnl=500.0,
|
||||
roi_pct=5.0,
|
||||
was_profitable=True,
|
||||
)
|
||||
restored = TradeOutcomeSchema.model_validate_json(o.model_dump_json())
|
||||
assert restored == o
|
||||
|
||||
|
||||
class TestWeightAdjustment:
|
||||
def test_valid_adjustment(self) -> None:
|
||||
now = datetime.now(timezone.utc)
|
||||
wa = WeightAdjustment(
|
||||
strategy_id=uuid.uuid4(),
|
||||
strategy_name="momentum",
|
||||
old_weight=0.33,
|
||||
new_weight=0.38,
|
||||
reason="Positive reward signal from recent trades",
|
||||
reward_signal=0.72,
|
||||
timestamp=now,
|
||||
)
|
||||
assert wa.old_weight == 0.33
|
||||
assert wa.new_weight == 0.38
|
||||
|
||||
def test_required_fields(self) -> None:
|
||||
with pytest.raises(ValidationError):
|
||||
WeightAdjustment(
|
||||
strategy_id=uuid.uuid4(),
|
||||
strategy_name="momentum",
|
||||
) # type: ignore[call-arg]
|
||||
|
||||
def test_json_round_trip(self) -> None:
|
||||
now = datetime.now(timezone.utc)
|
||||
wa = WeightAdjustment(
|
||||
strategy_id=uuid.uuid4(),
|
||||
strategy_name="mean_reversion",
|
||||
old_weight=0.30,
|
||||
new_weight=0.25,
|
||||
reason="Poor recent performance",
|
||||
reward_signal=-0.4,
|
||||
timestamp=now,
|
||||
)
|
||||
restored = WeightAdjustment.model_validate_json(wa.model_dump_json())
|
||||
assert restored == wa
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Auth schemas
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestRegisterRequest:
|
||||
def test_valid_registration(self) -> None:
|
||||
r = RegisterRequest(username="trader1", display_name="Top Trader")
|
||||
assert r.username == "trader1"
|
||||
assert r.display_name == "Top Trader"
|
||||
|
||||
def test_display_name_optional(self) -> None:
|
||||
r = RegisterRequest(username="trader2")
|
||||
assert r.display_name is None
|
||||
|
||||
def test_username_required(self) -> None:
|
||||
with pytest.raises(ValidationError):
|
||||
RegisterRequest(username="") # min_length=1
|
||||
|
||||
def test_username_max_length(self) -> None:
|
||||
with pytest.raises(ValidationError):
|
||||
RegisterRequest(username="x" * 101) # max_length=100
|
||||
|
||||
|
||||
class TestLoginRequest:
|
||||
def test_valid_login(self) -> None:
|
||||
l = LoginRequest(username="trader1")
|
||||
assert l.username == "trader1"
|
||||
|
||||
def test_username_required(self) -> None:
|
||||
with pytest.raises(ValidationError):
|
||||
LoginRequest(username="")
|
||||
|
||||
|
||||
class TestTokenResponse:
|
||||
def test_valid_response(self) -> None:
|
||||
t = TokenResponse(
|
||||
access_token="eyJ...",
|
||||
refresh_token="eyR...",
|
||||
)
|
||||
assert t.token_type == "bearer"
|
||||
assert t.expires_in == 900 # default 15 min
|
||||
|
||||
def test_custom_expiry(self) -> None:
|
||||
t = TokenResponse(
|
||||
access_token="eyJ...",
|
||||
refresh_token="eyR...",
|
||||
expires_in=3600,
|
||||
)
|
||||
assert t.expires_in == 3600
|
||||
|
||||
def test_json_round_trip(self) -> None:
|
||||
t = TokenResponse(
|
||||
access_token="access123",
|
||||
refresh_token="refresh456",
|
||||
token_type="bearer",
|
||||
expires_in=1800,
|
||||
)
|
||||
restored = TokenResponse.model_validate_json(t.model_dump_json())
|
||||
assert restored == t
|
||||
Loading…
Add table
Add a link
Reference in a new issue