feat: pydantic schemas for all service message types

- shared/schemas/trading.py: OrderRequest, OrderResult, PositionInfo,
  AccountInfo, TradeSignal, TradeExecution, MarketSnapshot, SentimentContext
- shared/schemas/news.py: RawArticle, ScoredArticle
- shared/schemas/learning.py: TradeOutcomeSchema, WeightAdjustment
- shared/schemas/auth.py: RegisterRequest, LoginRequest, TokenResponse
- 49 schema tests covering validation constraints, serialization round-trips,
  required fields, and range checks
This commit is contained in:
Viktor Barzin 2026-02-22 15:19:00 +00:00
parent 72cb1b6fe5
commit c8277e301e
No known key found for this signature in database
GPG key ID: 0EB088298288D958
11 changed files with 889 additions and 0 deletions

View file

@ -0,0 +1,37 @@
"""Pydantic v2 schemas for all service message types."""
from shared.schemas.trading import (
AccountInfo,
MarketSnapshot,
OrderRequest,
OrderResult,
PositionInfo,
SentimentContext,
TradeExecution,
TradeSignal,
)
from shared.schemas.news import RawArticle, ScoredArticle
from shared.schemas.learning import TradeOutcomeSchema, WeightAdjustment
from shared.schemas.auth import LoginRequest, RegisterRequest, TokenResponse
__all__ = [
# Trading
"OrderRequest",
"OrderResult",
"PositionInfo",
"AccountInfo",
"TradeSignal",
"TradeExecution",
"MarketSnapshot",
"SentimentContext",
# News
"RawArticle",
"ScoredArticle",
# Learning
"TradeOutcomeSchema",
"WeightAdjustment",
# Auth
"RegisterRequest",
"LoginRequest",
"TokenResponse",
]

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

27
shared/schemas/auth.py Normal file
View file

@ -0,0 +1,27 @@
"""Authentication Pydantic schemas for API request/response payloads."""
from pydantic import BaseModel, Field
class RegisterRequest(BaseModel):
"""Sent by the dashboard to begin passkey registration."""
username: str = Field(min_length=1, max_length=100)
display_name: str | None = None
class LoginRequest(BaseModel):
"""Sent by the dashboard to begin passkey authentication."""
username: str = Field(min_length=1, max_length=100)
class TokenResponse(BaseModel):
"""Returned after successful authentication."""
access_token: str
refresh_token: str
token_type: str = "bearer"
expires_in: int = Field(
description="Access token lifetime in seconds", default=900
)

View file

@ -0,0 +1,32 @@
"""Learning domain Pydantic schemas."""
from datetime import datetime
from uuid import UUID
from pydantic import BaseModel, Field
class TradeOutcomeSchema(BaseModel):
"""Represents the evaluated outcome of a closed trade."""
trade_id: UUID
hold_duration_seconds: float = Field(ge=0)
realized_pnl: float
roi_pct: float
was_profitable: bool
model_config = {"from_attributes": True}
class WeightAdjustment(BaseModel):
"""Represents a strategy weight change made by the learning engine."""
strategy_id: UUID
strategy_name: str
old_weight: float
new_weight: float
reason: str
reward_signal: float
timestamp: datetime
model_config = {"from_attributes": True}

44
shared/schemas/news.py Normal file
View file

@ -0,0 +1,44 @@
"""News article Pydantic schemas for Redis Stream messages."""
from datetime import datetime
from pydantic import BaseModel, Field
class RawArticle(BaseModel):
"""Published to ``news:raw`` by the news fetcher."""
source: str
url: str
title: str
content: str
published_at: datetime | None = None
fetched_at: datetime
content_hash: str
model_config = {"from_attributes": True}
class ScoredArticle(BaseModel):
"""Published to ``news:scored`` by the sentiment analyzer.
Inherits all fields from RawArticle conceptually plus scoring metadata.
"""
# Original article fields
source: str
url: str
title: str
content: str
published_at: datetime | None = None
fetched_at: datetime
content_hash: str
# Scoring fields
ticker: str
sentiment_score: float = Field(ge=-1.0, le=1.0)
confidence: float = Field(ge=0.0, le=1.0)
model_used: str
entities: list[str] = Field(default_factory=list)
model_config = {"from_attributes": True}

163
shared/schemas/trading.py Normal file
View file

@ -0,0 +1,163 @@
"""Trading-related Pydantic schemas for Redis Streams messages and API payloads."""
from datetime import datetime
from enum import Enum
from typing import Any
from uuid import UUID
from pydantic import BaseModel, Field
class OrderType(str, Enum):
MARKET = "market"
LIMIT = "limit"
STOP = "stop"
class OrderSide(str, Enum):
BUY = "BUY"
SELL = "SELL"
class OrderStatus(str, Enum):
PENDING = "PENDING"
FILLED = "FILLED"
CANCELLED = "CANCELLED"
REJECTED = "REJECTED"
class SignalDirection(str, Enum):
LONG = "LONG"
SHORT = "SHORT"
NEUTRAL = "NEUTRAL"
# ---------------------------------------------------------------------------
# API request / response schemas
# ---------------------------------------------------------------------------
class OrderRequest(BaseModel):
"""Submitted by the trade executor or the API to place an order."""
ticker: str
side: OrderSide
qty: float = Field(gt=0)
order_type: OrderType = OrderType.MARKET
limit_price: float | None = None
stop_price: float | None = None
model_config = {"from_attributes": True}
class OrderResult(BaseModel):
"""Returned after order submission or status query."""
order_id: str
ticker: str
side: OrderSide
qty: float
filled_price: float | None = None
status: OrderStatus
timestamp: datetime
model_config = {"from_attributes": True}
class PositionInfo(BaseModel):
"""Current position state — used in API responses and portfolio views."""
ticker: str
qty: float
avg_entry: float
current_price: float
unrealized_pnl: float
market_value: float
model_config = {"from_attributes": True}
class AccountInfo(BaseModel):
"""Account-level summary from the brokerage."""
equity: float
cash: float
buying_power: float
portfolio_value: float
model_config = {"from_attributes": True}
# ---------------------------------------------------------------------------
# Redis Stream message schemas
# ---------------------------------------------------------------------------
class TradeSignal(BaseModel):
"""Published to ``signals:generated`` by the signal generator."""
ticker: str
direction: SignalDirection
strength: float = Field(ge=0.0, le=1.0)
strategy_sources: list[str]
sentiment_context: dict[str, Any] | None = None
timestamp: datetime
model_config = {"from_attributes": True}
class TradeExecution(BaseModel):
"""Published to ``trades:executed`` by the trade executor."""
trade_id: UUID
ticker: str
side: OrderSide
qty: float
price: float
status: OrderStatus
signal_id: UUID | None = None
strategy_id: UUID | None = None
timestamp: datetime
model_config = {"from_attributes": True}
class OHLCVBar(BaseModel):
"""Single OHLCV bar."""
timestamp: datetime
open: float
high: float
low: float
close: float
volume: float
class MarketSnapshot(BaseModel):
"""Snapshot of market data for a single ticker — used by strategies."""
ticker: str
current_price: float
open: float
high: float
low: float
close: float
volume: float
sma_20: float | None = None
sma_50: float | None = None
rsi: float | None = None
bars: list[dict[str, Any]] = Field(default_factory=list)
model_config = {"from_attributes": True}
class SentimentContext(BaseModel):
"""Aggregated sentiment for a ticker — passed to strategies."""
ticker: str
avg_score: float = Field(ge=-1.0, le=1.0)
article_count: int = Field(ge=0)
recent_scores: list[float] = Field(default_factory=list)
avg_confidence: float = Field(ge=0.0, le=1.0)
model_config = {"from_attributes": True}

586
tests/test_schemas.py Normal file
View file

@ -0,0 +1,586 @@
"""Tests for Pydantic schemas — serialization round-trips and validation constraints."""
import uuid
from datetime import datetime, timezone
import pytest
from pydantic import ValidationError
from shared.schemas.trading import (
AccountInfo,
MarketSnapshot,
OrderRequest,
OrderResult,
OrderSide,
OrderStatus,
OrderType,
PositionInfo,
SentimentContext,
SignalDirection,
TradeExecution,
TradeSignal,
)
from shared.schemas.news import RawArticle, ScoredArticle
from shared.schemas.learning import TradeOutcomeSchema, WeightAdjustment
from shared.schemas.auth import LoginRequest, RegisterRequest, TokenResponse
# ---------------------------------------------------------------------------
# Trading schemas
# ---------------------------------------------------------------------------
class TestOrderRequest:
def test_valid_market_order(self) -> None:
o = OrderRequest(ticker="AAPL", side=OrderSide.BUY, qty=10.0)
assert o.order_type == OrderType.MARKET
assert o.limit_price is None
def test_valid_limit_order(self) -> None:
o = OrderRequest(
ticker="TSLA",
side=OrderSide.SELL,
qty=5.0,
order_type=OrderType.LIMIT,
limit_price=250.50,
)
assert o.limit_price == 250.50
def test_qty_must_be_positive(self) -> None:
with pytest.raises(ValidationError):
OrderRequest(ticker="AAPL", side=OrderSide.BUY, qty=0)
def test_qty_must_not_be_negative(self) -> None:
with pytest.raises(ValidationError):
OrderRequest(ticker="AAPL", side=OrderSide.BUY, qty=-5)
def test_serialization_round_trip(self) -> None:
o = OrderRequest(
ticker="GOOG",
side=OrderSide.BUY,
qty=3.0,
order_type=OrderType.STOP,
stop_price=100.0,
)
data = o.model_dump()
restored = OrderRequest.model_validate(data)
assert restored == o
def test_json_round_trip(self) -> None:
o = OrderRequest(ticker="META", side=OrderSide.SELL, qty=1.5)
json_str = o.model_dump_json()
restored = OrderRequest.model_validate_json(json_str)
assert restored == o
class TestOrderResult:
def test_valid_result(self) -> None:
now = datetime.now(timezone.utc)
r = OrderResult(
order_id="ord-123",
ticker="AAPL",
side=OrderSide.BUY,
qty=10.0,
filled_price=150.25,
status=OrderStatus.FILLED,
timestamp=now,
)
assert r.filled_price == 150.25
assert r.status == OrderStatus.FILLED
def test_pending_no_fill(self) -> None:
now = datetime.now(timezone.utc)
r = OrderResult(
order_id="ord-456",
ticker="TSLA",
side=OrderSide.SELL,
qty=5.0,
status=OrderStatus.PENDING,
timestamp=now,
)
assert r.filled_price is None
class TestPositionInfo:
def test_valid_position(self) -> None:
p = PositionInfo(
ticker="NVDA",
qty=20.0,
avg_entry=800.0,
current_price=850.0,
unrealized_pnl=1000.0,
market_value=17000.0,
)
assert p.market_value == 17000.0
def test_serialization_round_trip(self) -> None:
p = PositionInfo(
ticker="AMZN",
qty=10.0,
avg_entry=180.0,
current_price=185.0,
unrealized_pnl=50.0,
market_value=1850.0,
)
restored = PositionInfo.model_validate(p.model_dump())
assert restored == p
class TestAccountInfo:
def test_valid_account(self) -> None:
a = AccountInfo(
equity=100_000.0,
cash=25_000.0,
buying_power=50_000.0,
portfolio_value=100_000.0,
)
assert a.equity == 100_000.0
class TestTradeSignal:
def test_valid_signal(self) -> None:
now = datetime.now(timezone.utc)
s = TradeSignal(
ticker="AAPL",
direction=SignalDirection.LONG,
strength=0.85,
strategy_sources=["momentum", "news_driven"],
timestamp=now,
)
assert s.strength == 0.85
assert s.sentiment_context is None
def test_strength_must_be_in_range(self) -> None:
now = datetime.now(timezone.utc)
with pytest.raises(ValidationError):
TradeSignal(
ticker="AAPL",
direction=SignalDirection.LONG,
strength=1.5,
strategy_sources=["momentum"],
timestamp=now,
)
def test_strength_lower_bound(self) -> None:
now = datetime.now(timezone.utc)
with pytest.raises(ValidationError):
TradeSignal(
ticker="AAPL",
direction=SignalDirection.SHORT,
strength=-0.1,
strategy_sources=["mean_reversion"],
timestamp=now,
)
def test_json_round_trip(self) -> None:
now = datetime.now(timezone.utc)
s = TradeSignal(
ticker="TSLA",
direction=SignalDirection.SHORT,
strength=0.6,
strategy_sources=["mean_reversion"],
sentiment_context={"avg_score": -0.4},
timestamp=now,
)
restored = TradeSignal.model_validate_json(s.model_dump_json())
assert restored == s
class TestTradeExecution:
def test_valid_execution(self) -> None:
now = datetime.now(timezone.utc)
tid = uuid.uuid4()
sid = uuid.uuid4()
e = TradeExecution(
trade_id=tid,
ticker="AAPL",
side=OrderSide.BUY,
qty=10.0,
price=150.0,
status=OrderStatus.FILLED,
signal_id=sid,
strategy_id=uuid.uuid4(),
timestamp=now,
)
assert e.trade_id == tid
assert e.signal_id == sid
def test_optional_ids(self) -> None:
now = datetime.now(timezone.utc)
e = TradeExecution(
trade_id=uuid.uuid4(),
ticker="GOOG",
side=OrderSide.SELL,
qty=5.0,
price=2800.0,
status=OrderStatus.FILLED,
timestamp=now,
)
assert e.signal_id is None
assert e.strategy_id is None
class TestMarketSnapshot:
def test_valid_snapshot(self) -> None:
ms = MarketSnapshot(
ticker="AAPL",
current_price=150.0,
open=148.0,
high=152.0,
low=147.0,
close=150.0,
volume=5_000_000.0,
sma_20=149.5,
rsi=55.0,
)
assert ms.sma_50 is None
assert ms.bars == []
def test_with_bars(self) -> None:
ms = MarketSnapshot(
ticker="TSLA",
current_price=250.0,
open=248.0,
high=255.0,
low=245.0,
close=250.0,
volume=10_000_000.0,
bars=[
{"timestamp": "2026-01-01T10:00:00Z", "open": 248, "close": 250}
],
)
assert len(ms.bars) == 1
class TestSentimentContext:
def test_valid_context(self) -> None:
sc = SentimentContext(
ticker="AAPL",
avg_score=0.65,
article_count=5,
recent_scores=[0.5, 0.7, 0.8],
avg_confidence=0.85,
)
assert sc.article_count == 5
def test_score_range_validation(self) -> None:
with pytest.raises(ValidationError):
SentimentContext(
ticker="AAPL",
avg_score=1.5, # Out of range
article_count=1,
avg_confidence=0.5,
)
def test_negative_score_in_range(self) -> None:
sc = SentimentContext(
ticker="TSLA",
avg_score=-0.8,
article_count=3,
avg_confidence=0.9,
)
assert sc.avg_score == -0.8
def test_confidence_range_validation(self) -> None:
with pytest.raises(ValidationError):
SentimentContext(
ticker="AAPL",
avg_score=0.5,
article_count=1,
avg_confidence=1.5, # Out of range
)
def test_article_count_non_negative(self) -> None:
with pytest.raises(ValidationError):
SentimentContext(
ticker="AAPL",
avg_score=0.0,
article_count=-1,
avg_confidence=0.5,
)
# ---------------------------------------------------------------------------
# News schemas
# ---------------------------------------------------------------------------
class TestRawArticle:
def test_valid_article(self) -> None:
now = datetime.now(timezone.utc)
a = RawArticle(
source="reuters",
url="https://reuters.com/article/1",
title="Market Rally Continues",
content="Stocks rose sharply today...",
published_at=now,
fetched_at=now,
content_hash="sha256abcdef1234567890",
)
assert a.source == "reuters"
def test_published_at_optional(self) -> None:
now = datetime.now(timezone.utc)
a = RawArticle(
source="reddit",
url="https://reddit.com/r/stocks/1",
title="DD on TSLA",
content="Here is my analysis...",
fetched_at=now,
content_hash="hash123",
)
assert a.published_at is None
def test_json_round_trip(self) -> None:
now = datetime.now(timezone.utc)
a = RawArticle(
source="yahoo",
url="https://finance.yahoo.com/1",
title="Earnings Beat",
content="Apple beat earnings...",
published_at=now,
fetched_at=now,
content_hash="hash456",
)
restored = RawArticle.model_validate_json(a.model_dump_json())
assert restored == a
def test_required_fields(self) -> None:
with pytest.raises(ValidationError):
RawArticle(source="reuters") # type: ignore[call-arg]
class TestScoredArticle:
def test_valid_scored_article(self) -> None:
now = datetime.now(timezone.utc)
sa = ScoredArticle(
source="reuters",
url="https://reuters.com/article/1",
title="Apple Earnings Beat",
content="Apple reported...",
published_at=now,
fetched_at=now,
content_hash="hash789",
ticker="AAPL",
sentiment_score=0.85,
confidence=0.92,
model_used="finbert",
entities=["Apple Inc", "Tim Cook"],
)
assert sa.sentiment_score == 0.85
assert sa.entities == ["Apple Inc", "Tim Cook"]
def test_sentiment_score_range(self) -> None:
now = datetime.now(timezone.utc)
with pytest.raises(ValidationError):
ScoredArticle(
source="reuters",
url="https://reuters.com/1",
title="Test",
content="Test content",
fetched_at=now,
content_hash="hash",
ticker="AAPL",
sentiment_score=1.5, # Out of range
confidence=0.5,
model_used="finbert",
)
def test_confidence_range(self) -> None:
now = datetime.now(timezone.utc)
with pytest.raises(ValidationError):
ScoredArticle(
source="reuters",
url="https://reuters.com/1",
title="Test",
content="Test content",
fetched_at=now,
content_hash="hash",
ticker="AAPL",
sentiment_score=0.5,
confidence=-0.1, # Out of range
model_used="finbert",
)
def test_negative_sentiment(self) -> None:
now = datetime.now(timezone.utc)
sa = ScoredArticle(
source="reddit",
url="https://reddit.com/1",
title="Bad news",
content="Terrible quarter...",
fetched_at=now,
content_hash="hashN",
ticker="TSLA",
sentiment_score=-0.9,
confidence=0.8,
model_used="ollama",
)
assert sa.sentiment_score == -0.9
def test_json_round_trip(self) -> None:
now = datetime.now(timezone.utc)
sa = ScoredArticle(
source="yahoo",
url="https://yahoo.com/1",
title="Headline",
content="Body text",
fetched_at=now,
content_hash="hashRT",
ticker="NVDA",
sentiment_score=0.3,
confidence=0.7,
model_used="finbert",
entities=["NVIDIA"],
)
restored = ScoredArticle.model_validate_json(sa.model_dump_json())
assert restored == sa
# ---------------------------------------------------------------------------
# Learning schemas
# ---------------------------------------------------------------------------
class TestTradeOutcomeSchema:
def test_valid_outcome(self) -> None:
o = TradeOutcomeSchema(
trade_id=uuid.uuid4(),
hold_duration_seconds=14400.0,
realized_pnl=250.50,
roi_pct=2.5,
was_profitable=True,
)
assert o.was_profitable is True
assert o.hold_duration_seconds == 14400.0
def test_hold_duration_non_negative(self) -> None:
with pytest.raises(ValidationError):
TradeOutcomeSchema(
trade_id=uuid.uuid4(),
hold_duration_seconds=-1.0,
realized_pnl=100.0,
roi_pct=1.0,
was_profitable=True,
)
def test_losing_trade(self) -> None:
o = TradeOutcomeSchema(
trade_id=uuid.uuid4(),
hold_duration_seconds=3600.0,
realized_pnl=-150.0,
roi_pct=-3.0,
was_profitable=False,
)
assert o.was_profitable is False
assert o.realized_pnl == -150.0
def test_json_round_trip(self) -> None:
o = TradeOutcomeSchema(
trade_id=uuid.uuid4(),
hold_duration_seconds=7200.0,
realized_pnl=500.0,
roi_pct=5.0,
was_profitable=True,
)
restored = TradeOutcomeSchema.model_validate_json(o.model_dump_json())
assert restored == o
class TestWeightAdjustment:
def test_valid_adjustment(self) -> None:
now = datetime.now(timezone.utc)
wa = WeightAdjustment(
strategy_id=uuid.uuid4(),
strategy_name="momentum",
old_weight=0.33,
new_weight=0.38,
reason="Positive reward signal from recent trades",
reward_signal=0.72,
timestamp=now,
)
assert wa.old_weight == 0.33
assert wa.new_weight == 0.38
def test_required_fields(self) -> None:
with pytest.raises(ValidationError):
WeightAdjustment(
strategy_id=uuid.uuid4(),
strategy_name="momentum",
) # type: ignore[call-arg]
def test_json_round_trip(self) -> None:
now = datetime.now(timezone.utc)
wa = WeightAdjustment(
strategy_id=uuid.uuid4(),
strategy_name="mean_reversion",
old_weight=0.30,
new_weight=0.25,
reason="Poor recent performance",
reward_signal=-0.4,
timestamp=now,
)
restored = WeightAdjustment.model_validate_json(wa.model_dump_json())
assert restored == wa
# ---------------------------------------------------------------------------
# Auth schemas
# ---------------------------------------------------------------------------
class TestRegisterRequest:
def test_valid_registration(self) -> None:
r = RegisterRequest(username="trader1", display_name="Top Trader")
assert r.username == "trader1"
assert r.display_name == "Top Trader"
def test_display_name_optional(self) -> None:
r = RegisterRequest(username="trader2")
assert r.display_name is None
def test_username_required(self) -> None:
with pytest.raises(ValidationError):
RegisterRequest(username="") # min_length=1
def test_username_max_length(self) -> None:
with pytest.raises(ValidationError):
RegisterRequest(username="x" * 101) # max_length=100
class TestLoginRequest:
def test_valid_login(self) -> None:
l = LoginRequest(username="trader1")
assert l.username == "trader1"
def test_username_required(self) -> None:
with pytest.raises(ValidationError):
LoginRequest(username="")
class TestTokenResponse:
def test_valid_response(self) -> None:
t = TokenResponse(
access_token="eyJ...",
refresh_token="eyR...",
)
assert t.token_type == "bearer"
assert t.expires_in == 900 # default 15 min
def test_custom_expiry(self) -> None:
t = TokenResponse(
access_token="eyJ...",
refresh_token="eyR...",
expires_in=3600,
)
assert t.expires_in == 3600
def test_json_round_trip(self) -> None:
t = TokenResponse(
access_token="access123",
refresh_token="refresh456",
token_type="bearer",
expires_in=1800,
)
restored = TokenResponse.model_validate_json(t.model_dump_json())
assert restored == t