- shared/schemas/trading.py: OrderRequest, OrderResult, PositionInfo, AccountInfo, TradeSignal, TradeExecution, MarketSnapshot, SentimentContext - shared/schemas/news.py: RawArticle, ScoredArticle - shared/schemas/learning.py: TradeOutcomeSchema, WeightAdjustment - shared/schemas/auth.py: RegisterRequest, LoginRequest, TokenResponse - 49 schema tests covering validation constraints, serialization round-trips, required fields, and range checks
586 lines
18 KiB
Python
586 lines
18 KiB
Python
"""Tests for Pydantic schemas — serialization round-trips and validation constraints."""
|
|
|
|
import uuid
|
|
from datetime import datetime, timezone
|
|
|
|
import pytest
|
|
from pydantic import ValidationError
|
|
|
|
from shared.schemas.trading import (
|
|
AccountInfo,
|
|
MarketSnapshot,
|
|
OrderRequest,
|
|
OrderResult,
|
|
OrderSide,
|
|
OrderStatus,
|
|
OrderType,
|
|
PositionInfo,
|
|
SentimentContext,
|
|
SignalDirection,
|
|
TradeExecution,
|
|
TradeSignal,
|
|
)
|
|
from shared.schemas.news import RawArticle, ScoredArticle
|
|
from shared.schemas.learning import TradeOutcomeSchema, WeightAdjustment
|
|
from shared.schemas.auth import LoginRequest, RegisterRequest, TokenResponse
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# Trading schemas
|
|
# ---------------------------------------------------------------------------
|
|
|
|
|
|
class TestOrderRequest:
|
|
def test_valid_market_order(self) -> None:
|
|
o = OrderRequest(ticker="AAPL", side=OrderSide.BUY, qty=10.0)
|
|
assert o.order_type == OrderType.MARKET
|
|
assert o.limit_price is None
|
|
|
|
def test_valid_limit_order(self) -> None:
|
|
o = OrderRequest(
|
|
ticker="TSLA",
|
|
side=OrderSide.SELL,
|
|
qty=5.0,
|
|
order_type=OrderType.LIMIT,
|
|
limit_price=250.50,
|
|
)
|
|
assert o.limit_price == 250.50
|
|
|
|
def test_qty_must_be_positive(self) -> None:
|
|
with pytest.raises(ValidationError):
|
|
OrderRequest(ticker="AAPL", side=OrderSide.BUY, qty=0)
|
|
|
|
def test_qty_must_not_be_negative(self) -> None:
|
|
with pytest.raises(ValidationError):
|
|
OrderRequest(ticker="AAPL", side=OrderSide.BUY, qty=-5)
|
|
|
|
def test_serialization_round_trip(self) -> None:
|
|
o = OrderRequest(
|
|
ticker="GOOG",
|
|
side=OrderSide.BUY,
|
|
qty=3.0,
|
|
order_type=OrderType.STOP,
|
|
stop_price=100.0,
|
|
)
|
|
data = o.model_dump()
|
|
restored = OrderRequest.model_validate(data)
|
|
assert restored == o
|
|
|
|
def test_json_round_trip(self) -> None:
|
|
o = OrderRequest(ticker="META", side=OrderSide.SELL, qty=1.5)
|
|
json_str = o.model_dump_json()
|
|
restored = OrderRequest.model_validate_json(json_str)
|
|
assert restored == o
|
|
|
|
|
|
class TestOrderResult:
|
|
def test_valid_result(self) -> None:
|
|
now = datetime.now(timezone.utc)
|
|
r = OrderResult(
|
|
order_id="ord-123",
|
|
ticker="AAPL",
|
|
side=OrderSide.BUY,
|
|
qty=10.0,
|
|
filled_price=150.25,
|
|
status=OrderStatus.FILLED,
|
|
timestamp=now,
|
|
)
|
|
assert r.filled_price == 150.25
|
|
assert r.status == OrderStatus.FILLED
|
|
|
|
def test_pending_no_fill(self) -> None:
|
|
now = datetime.now(timezone.utc)
|
|
r = OrderResult(
|
|
order_id="ord-456",
|
|
ticker="TSLA",
|
|
side=OrderSide.SELL,
|
|
qty=5.0,
|
|
status=OrderStatus.PENDING,
|
|
timestamp=now,
|
|
)
|
|
assert r.filled_price is None
|
|
|
|
|
|
class TestPositionInfo:
|
|
def test_valid_position(self) -> None:
|
|
p = PositionInfo(
|
|
ticker="NVDA",
|
|
qty=20.0,
|
|
avg_entry=800.0,
|
|
current_price=850.0,
|
|
unrealized_pnl=1000.0,
|
|
market_value=17000.0,
|
|
)
|
|
assert p.market_value == 17000.0
|
|
|
|
def test_serialization_round_trip(self) -> None:
|
|
p = PositionInfo(
|
|
ticker="AMZN",
|
|
qty=10.0,
|
|
avg_entry=180.0,
|
|
current_price=185.0,
|
|
unrealized_pnl=50.0,
|
|
market_value=1850.0,
|
|
)
|
|
restored = PositionInfo.model_validate(p.model_dump())
|
|
assert restored == p
|
|
|
|
|
|
class TestAccountInfo:
|
|
def test_valid_account(self) -> None:
|
|
a = AccountInfo(
|
|
equity=100_000.0,
|
|
cash=25_000.0,
|
|
buying_power=50_000.0,
|
|
portfolio_value=100_000.0,
|
|
)
|
|
assert a.equity == 100_000.0
|
|
|
|
|
|
class TestTradeSignal:
|
|
def test_valid_signal(self) -> None:
|
|
now = datetime.now(timezone.utc)
|
|
s = TradeSignal(
|
|
ticker="AAPL",
|
|
direction=SignalDirection.LONG,
|
|
strength=0.85,
|
|
strategy_sources=["momentum", "news_driven"],
|
|
timestamp=now,
|
|
)
|
|
assert s.strength == 0.85
|
|
assert s.sentiment_context is None
|
|
|
|
def test_strength_must_be_in_range(self) -> None:
|
|
now = datetime.now(timezone.utc)
|
|
with pytest.raises(ValidationError):
|
|
TradeSignal(
|
|
ticker="AAPL",
|
|
direction=SignalDirection.LONG,
|
|
strength=1.5,
|
|
strategy_sources=["momentum"],
|
|
timestamp=now,
|
|
)
|
|
|
|
def test_strength_lower_bound(self) -> None:
|
|
now = datetime.now(timezone.utc)
|
|
with pytest.raises(ValidationError):
|
|
TradeSignal(
|
|
ticker="AAPL",
|
|
direction=SignalDirection.SHORT,
|
|
strength=-0.1,
|
|
strategy_sources=["mean_reversion"],
|
|
timestamp=now,
|
|
)
|
|
|
|
def test_json_round_trip(self) -> None:
|
|
now = datetime.now(timezone.utc)
|
|
s = TradeSignal(
|
|
ticker="TSLA",
|
|
direction=SignalDirection.SHORT,
|
|
strength=0.6,
|
|
strategy_sources=["mean_reversion"],
|
|
sentiment_context={"avg_score": -0.4},
|
|
timestamp=now,
|
|
)
|
|
restored = TradeSignal.model_validate_json(s.model_dump_json())
|
|
assert restored == s
|
|
|
|
|
|
class TestTradeExecution:
|
|
def test_valid_execution(self) -> None:
|
|
now = datetime.now(timezone.utc)
|
|
tid = uuid.uuid4()
|
|
sid = uuid.uuid4()
|
|
e = TradeExecution(
|
|
trade_id=tid,
|
|
ticker="AAPL",
|
|
side=OrderSide.BUY,
|
|
qty=10.0,
|
|
price=150.0,
|
|
status=OrderStatus.FILLED,
|
|
signal_id=sid,
|
|
strategy_id=uuid.uuid4(),
|
|
timestamp=now,
|
|
)
|
|
assert e.trade_id == tid
|
|
assert e.signal_id == sid
|
|
|
|
def test_optional_ids(self) -> None:
|
|
now = datetime.now(timezone.utc)
|
|
e = TradeExecution(
|
|
trade_id=uuid.uuid4(),
|
|
ticker="GOOG",
|
|
side=OrderSide.SELL,
|
|
qty=5.0,
|
|
price=2800.0,
|
|
status=OrderStatus.FILLED,
|
|
timestamp=now,
|
|
)
|
|
assert e.signal_id is None
|
|
assert e.strategy_id is None
|
|
|
|
|
|
class TestMarketSnapshot:
|
|
def test_valid_snapshot(self) -> None:
|
|
ms = MarketSnapshot(
|
|
ticker="AAPL",
|
|
current_price=150.0,
|
|
open=148.0,
|
|
high=152.0,
|
|
low=147.0,
|
|
close=150.0,
|
|
volume=5_000_000.0,
|
|
sma_20=149.5,
|
|
rsi=55.0,
|
|
)
|
|
assert ms.sma_50 is None
|
|
assert ms.bars == []
|
|
|
|
def test_with_bars(self) -> None:
|
|
ms = MarketSnapshot(
|
|
ticker="TSLA",
|
|
current_price=250.0,
|
|
open=248.0,
|
|
high=255.0,
|
|
low=245.0,
|
|
close=250.0,
|
|
volume=10_000_000.0,
|
|
bars=[
|
|
{"timestamp": "2026-01-01T10:00:00Z", "open": 248, "close": 250}
|
|
],
|
|
)
|
|
assert len(ms.bars) == 1
|
|
|
|
|
|
class TestSentimentContext:
|
|
def test_valid_context(self) -> None:
|
|
sc = SentimentContext(
|
|
ticker="AAPL",
|
|
avg_score=0.65,
|
|
article_count=5,
|
|
recent_scores=[0.5, 0.7, 0.8],
|
|
avg_confidence=0.85,
|
|
)
|
|
assert sc.article_count == 5
|
|
|
|
def test_score_range_validation(self) -> None:
|
|
with pytest.raises(ValidationError):
|
|
SentimentContext(
|
|
ticker="AAPL",
|
|
avg_score=1.5, # Out of range
|
|
article_count=1,
|
|
avg_confidence=0.5,
|
|
)
|
|
|
|
def test_negative_score_in_range(self) -> None:
|
|
sc = SentimentContext(
|
|
ticker="TSLA",
|
|
avg_score=-0.8,
|
|
article_count=3,
|
|
avg_confidence=0.9,
|
|
)
|
|
assert sc.avg_score == -0.8
|
|
|
|
def test_confidence_range_validation(self) -> None:
|
|
with pytest.raises(ValidationError):
|
|
SentimentContext(
|
|
ticker="AAPL",
|
|
avg_score=0.5,
|
|
article_count=1,
|
|
avg_confidence=1.5, # Out of range
|
|
)
|
|
|
|
def test_article_count_non_negative(self) -> None:
|
|
with pytest.raises(ValidationError):
|
|
SentimentContext(
|
|
ticker="AAPL",
|
|
avg_score=0.0,
|
|
article_count=-1,
|
|
avg_confidence=0.5,
|
|
)
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# News schemas
|
|
# ---------------------------------------------------------------------------
|
|
|
|
|
|
class TestRawArticle:
|
|
def test_valid_article(self) -> None:
|
|
now = datetime.now(timezone.utc)
|
|
a = RawArticle(
|
|
source="reuters",
|
|
url="https://reuters.com/article/1",
|
|
title="Market Rally Continues",
|
|
content="Stocks rose sharply today...",
|
|
published_at=now,
|
|
fetched_at=now,
|
|
content_hash="sha256abcdef1234567890",
|
|
)
|
|
assert a.source == "reuters"
|
|
|
|
def test_published_at_optional(self) -> None:
|
|
now = datetime.now(timezone.utc)
|
|
a = RawArticle(
|
|
source="reddit",
|
|
url="https://reddit.com/r/stocks/1",
|
|
title="DD on TSLA",
|
|
content="Here is my analysis...",
|
|
fetched_at=now,
|
|
content_hash="hash123",
|
|
)
|
|
assert a.published_at is None
|
|
|
|
def test_json_round_trip(self) -> None:
|
|
now = datetime.now(timezone.utc)
|
|
a = RawArticle(
|
|
source="yahoo",
|
|
url="https://finance.yahoo.com/1",
|
|
title="Earnings Beat",
|
|
content="Apple beat earnings...",
|
|
published_at=now,
|
|
fetched_at=now,
|
|
content_hash="hash456",
|
|
)
|
|
restored = RawArticle.model_validate_json(a.model_dump_json())
|
|
assert restored == a
|
|
|
|
def test_required_fields(self) -> None:
|
|
with pytest.raises(ValidationError):
|
|
RawArticle(source="reuters") # type: ignore[call-arg]
|
|
|
|
|
|
class TestScoredArticle:
|
|
def test_valid_scored_article(self) -> None:
|
|
now = datetime.now(timezone.utc)
|
|
sa = ScoredArticle(
|
|
source="reuters",
|
|
url="https://reuters.com/article/1",
|
|
title="Apple Earnings Beat",
|
|
content="Apple reported...",
|
|
published_at=now,
|
|
fetched_at=now,
|
|
content_hash="hash789",
|
|
ticker="AAPL",
|
|
sentiment_score=0.85,
|
|
confidence=0.92,
|
|
model_used="finbert",
|
|
entities=["Apple Inc", "Tim Cook"],
|
|
)
|
|
assert sa.sentiment_score == 0.85
|
|
assert sa.entities == ["Apple Inc", "Tim Cook"]
|
|
|
|
def test_sentiment_score_range(self) -> None:
|
|
now = datetime.now(timezone.utc)
|
|
with pytest.raises(ValidationError):
|
|
ScoredArticle(
|
|
source="reuters",
|
|
url="https://reuters.com/1",
|
|
title="Test",
|
|
content="Test content",
|
|
fetched_at=now,
|
|
content_hash="hash",
|
|
ticker="AAPL",
|
|
sentiment_score=1.5, # Out of range
|
|
confidence=0.5,
|
|
model_used="finbert",
|
|
)
|
|
|
|
def test_confidence_range(self) -> None:
|
|
now = datetime.now(timezone.utc)
|
|
with pytest.raises(ValidationError):
|
|
ScoredArticle(
|
|
source="reuters",
|
|
url="https://reuters.com/1",
|
|
title="Test",
|
|
content="Test content",
|
|
fetched_at=now,
|
|
content_hash="hash",
|
|
ticker="AAPL",
|
|
sentiment_score=0.5,
|
|
confidence=-0.1, # Out of range
|
|
model_used="finbert",
|
|
)
|
|
|
|
def test_negative_sentiment(self) -> None:
|
|
now = datetime.now(timezone.utc)
|
|
sa = ScoredArticle(
|
|
source="reddit",
|
|
url="https://reddit.com/1",
|
|
title="Bad news",
|
|
content="Terrible quarter...",
|
|
fetched_at=now,
|
|
content_hash="hashN",
|
|
ticker="TSLA",
|
|
sentiment_score=-0.9,
|
|
confidence=0.8,
|
|
model_used="ollama",
|
|
)
|
|
assert sa.sentiment_score == -0.9
|
|
|
|
def test_json_round_trip(self) -> None:
|
|
now = datetime.now(timezone.utc)
|
|
sa = ScoredArticle(
|
|
source="yahoo",
|
|
url="https://yahoo.com/1",
|
|
title="Headline",
|
|
content="Body text",
|
|
fetched_at=now,
|
|
content_hash="hashRT",
|
|
ticker="NVDA",
|
|
sentiment_score=0.3,
|
|
confidence=0.7,
|
|
model_used="finbert",
|
|
entities=["NVIDIA"],
|
|
)
|
|
restored = ScoredArticle.model_validate_json(sa.model_dump_json())
|
|
assert restored == sa
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# Learning schemas
|
|
# ---------------------------------------------------------------------------
|
|
|
|
|
|
class TestTradeOutcomeSchema:
|
|
def test_valid_outcome(self) -> None:
|
|
o = TradeOutcomeSchema(
|
|
trade_id=uuid.uuid4(),
|
|
hold_duration_seconds=14400.0,
|
|
realized_pnl=250.50,
|
|
roi_pct=2.5,
|
|
was_profitable=True,
|
|
)
|
|
assert o.was_profitable is True
|
|
assert o.hold_duration_seconds == 14400.0
|
|
|
|
def test_hold_duration_non_negative(self) -> None:
|
|
with pytest.raises(ValidationError):
|
|
TradeOutcomeSchema(
|
|
trade_id=uuid.uuid4(),
|
|
hold_duration_seconds=-1.0,
|
|
realized_pnl=100.0,
|
|
roi_pct=1.0,
|
|
was_profitable=True,
|
|
)
|
|
|
|
def test_losing_trade(self) -> None:
|
|
o = TradeOutcomeSchema(
|
|
trade_id=uuid.uuid4(),
|
|
hold_duration_seconds=3600.0,
|
|
realized_pnl=-150.0,
|
|
roi_pct=-3.0,
|
|
was_profitable=False,
|
|
)
|
|
assert o.was_profitable is False
|
|
assert o.realized_pnl == -150.0
|
|
|
|
def test_json_round_trip(self) -> None:
|
|
o = TradeOutcomeSchema(
|
|
trade_id=uuid.uuid4(),
|
|
hold_duration_seconds=7200.0,
|
|
realized_pnl=500.0,
|
|
roi_pct=5.0,
|
|
was_profitable=True,
|
|
)
|
|
restored = TradeOutcomeSchema.model_validate_json(o.model_dump_json())
|
|
assert restored == o
|
|
|
|
|
|
class TestWeightAdjustment:
|
|
def test_valid_adjustment(self) -> None:
|
|
now = datetime.now(timezone.utc)
|
|
wa = WeightAdjustment(
|
|
strategy_id=uuid.uuid4(),
|
|
strategy_name="momentum",
|
|
old_weight=0.33,
|
|
new_weight=0.38,
|
|
reason="Positive reward signal from recent trades",
|
|
reward_signal=0.72,
|
|
timestamp=now,
|
|
)
|
|
assert wa.old_weight == 0.33
|
|
assert wa.new_weight == 0.38
|
|
|
|
def test_required_fields(self) -> None:
|
|
with pytest.raises(ValidationError):
|
|
WeightAdjustment(
|
|
strategy_id=uuid.uuid4(),
|
|
strategy_name="momentum",
|
|
) # type: ignore[call-arg]
|
|
|
|
def test_json_round_trip(self) -> None:
|
|
now = datetime.now(timezone.utc)
|
|
wa = WeightAdjustment(
|
|
strategy_id=uuid.uuid4(),
|
|
strategy_name="mean_reversion",
|
|
old_weight=0.30,
|
|
new_weight=0.25,
|
|
reason="Poor recent performance",
|
|
reward_signal=-0.4,
|
|
timestamp=now,
|
|
)
|
|
restored = WeightAdjustment.model_validate_json(wa.model_dump_json())
|
|
assert restored == wa
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# Auth schemas
|
|
# ---------------------------------------------------------------------------
|
|
|
|
|
|
class TestRegisterRequest:
|
|
def test_valid_registration(self) -> None:
|
|
r = RegisterRequest(username="trader1", display_name="Top Trader")
|
|
assert r.username == "trader1"
|
|
assert r.display_name == "Top Trader"
|
|
|
|
def test_display_name_optional(self) -> None:
|
|
r = RegisterRequest(username="trader2")
|
|
assert r.display_name is None
|
|
|
|
def test_username_required(self) -> None:
|
|
with pytest.raises(ValidationError):
|
|
RegisterRequest(username="") # min_length=1
|
|
|
|
def test_username_max_length(self) -> None:
|
|
with pytest.raises(ValidationError):
|
|
RegisterRequest(username="x" * 101) # max_length=100
|
|
|
|
|
|
class TestLoginRequest:
|
|
def test_valid_login(self) -> None:
|
|
l = LoginRequest(username="trader1")
|
|
assert l.username == "trader1"
|
|
|
|
def test_username_required(self) -> None:
|
|
with pytest.raises(ValidationError):
|
|
LoginRequest(username="")
|
|
|
|
|
|
class TestTokenResponse:
|
|
def test_valid_response(self) -> None:
|
|
t = TokenResponse(
|
|
access_token="eyJ...",
|
|
refresh_token="eyR...",
|
|
)
|
|
assert t.token_type == "bearer"
|
|
assert t.expires_in == 900 # default 15 min
|
|
|
|
def test_custom_expiry(self) -> None:
|
|
t = TokenResponse(
|
|
access_token="eyJ...",
|
|
refresh_token="eyR...",
|
|
expires_in=3600,
|
|
)
|
|
assert t.expires_in == 3600
|
|
|
|
def test_json_round_trip(self) -> None:
|
|
t = TokenResponse(
|
|
access_token="access123",
|
|
refresh_token="refresh456",
|
|
token_type="bearer",
|
|
expires_in=1800,
|
|
)
|
|
restored = TokenResponse.model_validate_json(t.model_dump_json())
|
|
assert restored == t
|