feat: pydantic schemas for all service message types
- shared/schemas/trading.py: OrderRequest, OrderResult, PositionInfo, AccountInfo, TradeSignal, TradeExecution, MarketSnapshot, SentimentContext - shared/schemas/news.py: RawArticle, ScoredArticle - shared/schemas/learning.py: TradeOutcomeSchema, WeightAdjustment - shared/schemas/auth.py: RegisterRequest, LoginRequest, TokenResponse - 49 schema tests covering validation constraints, serialization round-trips, required fields, and range checks
This commit is contained in:
parent
72cb1b6fe5
commit
c8277e301e
11 changed files with 889 additions and 0 deletions
37
shared/schemas/__init__.py
Normal file
37
shared/schemas/__init__.py
Normal file
|
|
@ -0,0 +1,37 @@
|
|||
"""Pydantic v2 schemas for all service message types."""
|
||||
|
||||
from shared.schemas.trading import (
|
||||
AccountInfo,
|
||||
MarketSnapshot,
|
||||
OrderRequest,
|
||||
OrderResult,
|
||||
PositionInfo,
|
||||
SentimentContext,
|
||||
TradeExecution,
|
||||
TradeSignal,
|
||||
)
|
||||
from shared.schemas.news import RawArticle, ScoredArticle
|
||||
from shared.schemas.learning import TradeOutcomeSchema, WeightAdjustment
|
||||
from shared.schemas.auth import LoginRequest, RegisterRequest, TokenResponse
|
||||
|
||||
__all__ = [
|
||||
# Trading
|
||||
"OrderRequest",
|
||||
"OrderResult",
|
||||
"PositionInfo",
|
||||
"AccountInfo",
|
||||
"TradeSignal",
|
||||
"TradeExecution",
|
||||
"MarketSnapshot",
|
||||
"SentimentContext",
|
||||
# News
|
||||
"RawArticle",
|
||||
"ScoredArticle",
|
||||
# Learning
|
||||
"TradeOutcomeSchema",
|
||||
"WeightAdjustment",
|
||||
# Auth
|
||||
"RegisterRequest",
|
||||
"LoginRequest",
|
||||
"TokenResponse",
|
||||
]
|
||||
BIN
shared/schemas/__pycache__/__init__.cpython-314.pyc
Normal file
BIN
shared/schemas/__pycache__/__init__.cpython-314.pyc
Normal file
Binary file not shown.
BIN
shared/schemas/__pycache__/auth.cpython-314.pyc
Normal file
BIN
shared/schemas/__pycache__/auth.cpython-314.pyc
Normal file
Binary file not shown.
BIN
shared/schemas/__pycache__/learning.cpython-314.pyc
Normal file
BIN
shared/schemas/__pycache__/learning.cpython-314.pyc
Normal file
Binary file not shown.
BIN
shared/schemas/__pycache__/news.cpython-314.pyc
Normal file
BIN
shared/schemas/__pycache__/news.cpython-314.pyc
Normal file
Binary file not shown.
BIN
shared/schemas/__pycache__/trading.cpython-314.pyc
Normal file
BIN
shared/schemas/__pycache__/trading.cpython-314.pyc
Normal file
Binary file not shown.
27
shared/schemas/auth.py
Normal file
27
shared/schemas/auth.py
Normal file
|
|
@ -0,0 +1,27 @@
|
|||
"""Authentication Pydantic schemas for API request/response payloads."""
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
|
||||
class RegisterRequest(BaseModel):
|
||||
"""Sent by the dashboard to begin passkey registration."""
|
||||
|
||||
username: str = Field(min_length=1, max_length=100)
|
||||
display_name: str | None = None
|
||||
|
||||
|
||||
class LoginRequest(BaseModel):
|
||||
"""Sent by the dashboard to begin passkey authentication."""
|
||||
|
||||
username: str = Field(min_length=1, max_length=100)
|
||||
|
||||
|
||||
class TokenResponse(BaseModel):
|
||||
"""Returned after successful authentication."""
|
||||
|
||||
access_token: str
|
||||
refresh_token: str
|
||||
token_type: str = "bearer"
|
||||
expires_in: int = Field(
|
||||
description="Access token lifetime in seconds", default=900
|
||||
)
|
||||
32
shared/schemas/learning.py
Normal file
32
shared/schemas/learning.py
Normal file
|
|
@ -0,0 +1,32 @@
|
|||
"""Learning domain Pydantic schemas."""
|
||||
|
||||
from datetime import datetime
|
||||
from uuid import UUID
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
|
||||
class TradeOutcomeSchema(BaseModel):
|
||||
"""Represents the evaluated outcome of a closed trade."""
|
||||
|
||||
trade_id: UUID
|
||||
hold_duration_seconds: float = Field(ge=0)
|
||||
realized_pnl: float
|
||||
roi_pct: float
|
||||
was_profitable: bool
|
||||
|
||||
model_config = {"from_attributes": True}
|
||||
|
||||
|
||||
class WeightAdjustment(BaseModel):
|
||||
"""Represents a strategy weight change made by the learning engine."""
|
||||
|
||||
strategy_id: UUID
|
||||
strategy_name: str
|
||||
old_weight: float
|
||||
new_weight: float
|
||||
reason: str
|
||||
reward_signal: float
|
||||
timestamp: datetime
|
||||
|
||||
model_config = {"from_attributes": True}
|
||||
44
shared/schemas/news.py
Normal file
44
shared/schemas/news.py
Normal file
|
|
@ -0,0 +1,44 @@
|
|||
"""News article Pydantic schemas for Redis Stream messages."""
|
||||
|
||||
from datetime import datetime
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
|
||||
class RawArticle(BaseModel):
|
||||
"""Published to ``news:raw`` by the news fetcher."""
|
||||
|
||||
source: str
|
||||
url: str
|
||||
title: str
|
||||
content: str
|
||||
published_at: datetime | None = None
|
||||
fetched_at: datetime
|
||||
content_hash: str
|
||||
|
||||
model_config = {"from_attributes": True}
|
||||
|
||||
|
||||
class ScoredArticle(BaseModel):
|
||||
"""Published to ``news:scored`` by the sentiment analyzer.
|
||||
|
||||
Inherits all fields from RawArticle conceptually plus scoring metadata.
|
||||
"""
|
||||
|
||||
# Original article fields
|
||||
source: str
|
||||
url: str
|
||||
title: str
|
||||
content: str
|
||||
published_at: datetime | None = None
|
||||
fetched_at: datetime
|
||||
content_hash: str
|
||||
|
||||
# Scoring fields
|
||||
ticker: str
|
||||
sentiment_score: float = Field(ge=-1.0, le=1.0)
|
||||
confidence: float = Field(ge=0.0, le=1.0)
|
||||
model_used: str
|
||||
entities: list[str] = Field(default_factory=list)
|
||||
|
||||
model_config = {"from_attributes": True}
|
||||
163
shared/schemas/trading.py
Normal file
163
shared/schemas/trading.py
Normal file
|
|
@ -0,0 +1,163 @@
|
|||
"""Trading-related Pydantic schemas for Redis Streams messages and API payloads."""
|
||||
|
||||
from datetime import datetime
|
||||
from enum import Enum
|
||||
from typing import Any
|
||||
from uuid import UUID
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
|
||||
class OrderType(str, Enum):
|
||||
MARKET = "market"
|
||||
LIMIT = "limit"
|
||||
STOP = "stop"
|
||||
|
||||
|
||||
class OrderSide(str, Enum):
|
||||
BUY = "BUY"
|
||||
SELL = "SELL"
|
||||
|
||||
|
||||
class OrderStatus(str, Enum):
|
||||
PENDING = "PENDING"
|
||||
FILLED = "FILLED"
|
||||
CANCELLED = "CANCELLED"
|
||||
REJECTED = "REJECTED"
|
||||
|
||||
|
||||
class SignalDirection(str, Enum):
|
||||
LONG = "LONG"
|
||||
SHORT = "SHORT"
|
||||
NEUTRAL = "NEUTRAL"
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# API request / response schemas
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class OrderRequest(BaseModel):
|
||||
"""Submitted by the trade executor or the API to place an order."""
|
||||
|
||||
ticker: str
|
||||
side: OrderSide
|
||||
qty: float = Field(gt=0)
|
||||
order_type: OrderType = OrderType.MARKET
|
||||
limit_price: float | None = None
|
||||
stop_price: float | None = None
|
||||
|
||||
model_config = {"from_attributes": True}
|
||||
|
||||
|
||||
class OrderResult(BaseModel):
|
||||
"""Returned after order submission or status query."""
|
||||
|
||||
order_id: str
|
||||
ticker: str
|
||||
side: OrderSide
|
||||
qty: float
|
||||
filled_price: float | None = None
|
||||
status: OrderStatus
|
||||
timestamp: datetime
|
||||
|
||||
model_config = {"from_attributes": True}
|
||||
|
||||
|
||||
class PositionInfo(BaseModel):
|
||||
"""Current position state — used in API responses and portfolio views."""
|
||||
|
||||
ticker: str
|
||||
qty: float
|
||||
avg_entry: float
|
||||
current_price: float
|
||||
unrealized_pnl: float
|
||||
market_value: float
|
||||
|
||||
model_config = {"from_attributes": True}
|
||||
|
||||
|
||||
class AccountInfo(BaseModel):
|
||||
"""Account-level summary from the brokerage."""
|
||||
|
||||
equity: float
|
||||
cash: float
|
||||
buying_power: float
|
||||
portfolio_value: float
|
||||
|
||||
model_config = {"from_attributes": True}
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Redis Stream message schemas
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TradeSignal(BaseModel):
|
||||
"""Published to ``signals:generated`` by the signal generator."""
|
||||
|
||||
ticker: str
|
||||
direction: SignalDirection
|
||||
strength: float = Field(ge=0.0, le=1.0)
|
||||
strategy_sources: list[str]
|
||||
sentiment_context: dict[str, Any] | None = None
|
||||
timestamp: datetime
|
||||
|
||||
model_config = {"from_attributes": True}
|
||||
|
||||
|
||||
class TradeExecution(BaseModel):
|
||||
"""Published to ``trades:executed`` by the trade executor."""
|
||||
|
||||
trade_id: UUID
|
||||
ticker: str
|
||||
side: OrderSide
|
||||
qty: float
|
||||
price: float
|
||||
status: OrderStatus
|
||||
signal_id: UUID | None = None
|
||||
strategy_id: UUID | None = None
|
||||
timestamp: datetime
|
||||
|
||||
model_config = {"from_attributes": True}
|
||||
|
||||
|
||||
class OHLCVBar(BaseModel):
|
||||
"""Single OHLCV bar."""
|
||||
|
||||
timestamp: datetime
|
||||
open: float
|
||||
high: float
|
||||
low: float
|
||||
close: float
|
||||
volume: float
|
||||
|
||||
|
||||
class MarketSnapshot(BaseModel):
|
||||
"""Snapshot of market data for a single ticker — used by strategies."""
|
||||
|
||||
ticker: str
|
||||
current_price: float
|
||||
open: float
|
||||
high: float
|
||||
low: float
|
||||
close: float
|
||||
volume: float
|
||||
sma_20: float | None = None
|
||||
sma_50: float | None = None
|
||||
rsi: float | None = None
|
||||
bars: list[dict[str, Any]] = Field(default_factory=list)
|
||||
|
||||
model_config = {"from_attributes": True}
|
||||
|
||||
|
||||
class SentimentContext(BaseModel):
|
||||
"""Aggregated sentiment for a ticker — passed to strategies."""
|
||||
|
||||
ticker: str
|
||||
avg_score: float = Field(ge=-1.0, le=1.0)
|
||||
article_count: int = Field(ge=0)
|
||||
recent_scores: list[float] = Field(default_factory=list)
|
||||
avg_confidence: float = Field(ge=0.0, le=1.0)
|
||||
|
||||
model_config = {"from_attributes": True}
|
||||
586
tests/test_schemas.py
Normal file
586
tests/test_schemas.py
Normal file
|
|
@ -0,0 +1,586 @@
|
|||
"""Tests for Pydantic schemas — serialization round-trips and validation constraints."""
|
||||
|
||||
import uuid
|
||||
from datetime import datetime, timezone
|
||||
|
||||
import pytest
|
||||
from pydantic import ValidationError
|
||||
|
||||
from shared.schemas.trading import (
|
||||
AccountInfo,
|
||||
MarketSnapshot,
|
||||
OrderRequest,
|
||||
OrderResult,
|
||||
OrderSide,
|
||||
OrderStatus,
|
||||
OrderType,
|
||||
PositionInfo,
|
||||
SentimentContext,
|
||||
SignalDirection,
|
||||
TradeExecution,
|
||||
TradeSignal,
|
||||
)
|
||||
from shared.schemas.news import RawArticle, ScoredArticle
|
||||
from shared.schemas.learning import TradeOutcomeSchema, WeightAdjustment
|
||||
from shared.schemas.auth import LoginRequest, RegisterRequest, TokenResponse
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Trading schemas
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestOrderRequest:
|
||||
def test_valid_market_order(self) -> None:
|
||||
o = OrderRequest(ticker="AAPL", side=OrderSide.BUY, qty=10.0)
|
||||
assert o.order_type == OrderType.MARKET
|
||||
assert o.limit_price is None
|
||||
|
||||
def test_valid_limit_order(self) -> None:
|
||||
o = OrderRequest(
|
||||
ticker="TSLA",
|
||||
side=OrderSide.SELL,
|
||||
qty=5.0,
|
||||
order_type=OrderType.LIMIT,
|
||||
limit_price=250.50,
|
||||
)
|
||||
assert o.limit_price == 250.50
|
||||
|
||||
def test_qty_must_be_positive(self) -> None:
|
||||
with pytest.raises(ValidationError):
|
||||
OrderRequest(ticker="AAPL", side=OrderSide.BUY, qty=0)
|
||||
|
||||
def test_qty_must_not_be_negative(self) -> None:
|
||||
with pytest.raises(ValidationError):
|
||||
OrderRequest(ticker="AAPL", side=OrderSide.BUY, qty=-5)
|
||||
|
||||
def test_serialization_round_trip(self) -> None:
|
||||
o = OrderRequest(
|
||||
ticker="GOOG",
|
||||
side=OrderSide.BUY,
|
||||
qty=3.0,
|
||||
order_type=OrderType.STOP,
|
||||
stop_price=100.0,
|
||||
)
|
||||
data = o.model_dump()
|
||||
restored = OrderRequest.model_validate(data)
|
||||
assert restored == o
|
||||
|
||||
def test_json_round_trip(self) -> None:
|
||||
o = OrderRequest(ticker="META", side=OrderSide.SELL, qty=1.5)
|
||||
json_str = o.model_dump_json()
|
||||
restored = OrderRequest.model_validate_json(json_str)
|
||||
assert restored == o
|
||||
|
||||
|
||||
class TestOrderResult:
|
||||
def test_valid_result(self) -> None:
|
||||
now = datetime.now(timezone.utc)
|
||||
r = OrderResult(
|
||||
order_id="ord-123",
|
||||
ticker="AAPL",
|
||||
side=OrderSide.BUY,
|
||||
qty=10.0,
|
||||
filled_price=150.25,
|
||||
status=OrderStatus.FILLED,
|
||||
timestamp=now,
|
||||
)
|
||||
assert r.filled_price == 150.25
|
||||
assert r.status == OrderStatus.FILLED
|
||||
|
||||
def test_pending_no_fill(self) -> None:
|
||||
now = datetime.now(timezone.utc)
|
||||
r = OrderResult(
|
||||
order_id="ord-456",
|
||||
ticker="TSLA",
|
||||
side=OrderSide.SELL,
|
||||
qty=5.0,
|
||||
status=OrderStatus.PENDING,
|
||||
timestamp=now,
|
||||
)
|
||||
assert r.filled_price is None
|
||||
|
||||
|
||||
class TestPositionInfo:
|
||||
def test_valid_position(self) -> None:
|
||||
p = PositionInfo(
|
||||
ticker="NVDA",
|
||||
qty=20.0,
|
||||
avg_entry=800.0,
|
||||
current_price=850.0,
|
||||
unrealized_pnl=1000.0,
|
||||
market_value=17000.0,
|
||||
)
|
||||
assert p.market_value == 17000.0
|
||||
|
||||
def test_serialization_round_trip(self) -> None:
|
||||
p = PositionInfo(
|
||||
ticker="AMZN",
|
||||
qty=10.0,
|
||||
avg_entry=180.0,
|
||||
current_price=185.0,
|
||||
unrealized_pnl=50.0,
|
||||
market_value=1850.0,
|
||||
)
|
||||
restored = PositionInfo.model_validate(p.model_dump())
|
||||
assert restored == p
|
||||
|
||||
|
||||
class TestAccountInfo:
|
||||
def test_valid_account(self) -> None:
|
||||
a = AccountInfo(
|
||||
equity=100_000.0,
|
||||
cash=25_000.0,
|
||||
buying_power=50_000.0,
|
||||
portfolio_value=100_000.0,
|
||||
)
|
||||
assert a.equity == 100_000.0
|
||||
|
||||
|
||||
class TestTradeSignal:
|
||||
def test_valid_signal(self) -> None:
|
||||
now = datetime.now(timezone.utc)
|
||||
s = TradeSignal(
|
||||
ticker="AAPL",
|
||||
direction=SignalDirection.LONG,
|
||||
strength=0.85,
|
||||
strategy_sources=["momentum", "news_driven"],
|
||||
timestamp=now,
|
||||
)
|
||||
assert s.strength == 0.85
|
||||
assert s.sentiment_context is None
|
||||
|
||||
def test_strength_must_be_in_range(self) -> None:
|
||||
now = datetime.now(timezone.utc)
|
||||
with pytest.raises(ValidationError):
|
||||
TradeSignal(
|
||||
ticker="AAPL",
|
||||
direction=SignalDirection.LONG,
|
||||
strength=1.5,
|
||||
strategy_sources=["momentum"],
|
||||
timestamp=now,
|
||||
)
|
||||
|
||||
def test_strength_lower_bound(self) -> None:
|
||||
now = datetime.now(timezone.utc)
|
||||
with pytest.raises(ValidationError):
|
||||
TradeSignal(
|
||||
ticker="AAPL",
|
||||
direction=SignalDirection.SHORT,
|
||||
strength=-0.1,
|
||||
strategy_sources=["mean_reversion"],
|
||||
timestamp=now,
|
||||
)
|
||||
|
||||
def test_json_round_trip(self) -> None:
|
||||
now = datetime.now(timezone.utc)
|
||||
s = TradeSignal(
|
||||
ticker="TSLA",
|
||||
direction=SignalDirection.SHORT,
|
||||
strength=0.6,
|
||||
strategy_sources=["mean_reversion"],
|
||||
sentiment_context={"avg_score": -0.4},
|
||||
timestamp=now,
|
||||
)
|
||||
restored = TradeSignal.model_validate_json(s.model_dump_json())
|
||||
assert restored == s
|
||||
|
||||
|
||||
class TestTradeExecution:
|
||||
def test_valid_execution(self) -> None:
|
||||
now = datetime.now(timezone.utc)
|
||||
tid = uuid.uuid4()
|
||||
sid = uuid.uuid4()
|
||||
e = TradeExecution(
|
||||
trade_id=tid,
|
||||
ticker="AAPL",
|
||||
side=OrderSide.BUY,
|
||||
qty=10.0,
|
||||
price=150.0,
|
||||
status=OrderStatus.FILLED,
|
||||
signal_id=sid,
|
||||
strategy_id=uuid.uuid4(),
|
||||
timestamp=now,
|
||||
)
|
||||
assert e.trade_id == tid
|
||||
assert e.signal_id == sid
|
||||
|
||||
def test_optional_ids(self) -> None:
|
||||
now = datetime.now(timezone.utc)
|
||||
e = TradeExecution(
|
||||
trade_id=uuid.uuid4(),
|
||||
ticker="GOOG",
|
||||
side=OrderSide.SELL,
|
||||
qty=5.0,
|
||||
price=2800.0,
|
||||
status=OrderStatus.FILLED,
|
||||
timestamp=now,
|
||||
)
|
||||
assert e.signal_id is None
|
||||
assert e.strategy_id is None
|
||||
|
||||
|
||||
class TestMarketSnapshot:
|
||||
def test_valid_snapshot(self) -> None:
|
||||
ms = MarketSnapshot(
|
||||
ticker="AAPL",
|
||||
current_price=150.0,
|
||||
open=148.0,
|
||||
high=152.0,
|
||||
low=147.0,
|
||||
close=150.0,
|
||||
volume=5_000_000.0,
|
||||
sma_20=149.5,
|
||||
rsi=55.0,
|
||||
)
|
||||
assert ms.sma_50 is None
|
||||
assert ms.bars == []
|
||||
|
||||
def test_with_bars(self) -> None:
|
||||
ms = MarketSnapshot(
|
||||
ticker="TSLA",
|
||||
current_price=250.0,
|
||||
open=248.0,
|
||||
high=255.0,
|
||||
low=245.0,
|
||||
close=250.0,
|
||||
volume=10_000_000.0,
|
||||
bars=[
|
||||
{"timestamp": "2026-01-01T10:00:00Z", "open": 248, "close": 250}
|
||||
],
|
||||
)
|
||||
assert len(ms.bars) == 1
|
||||
|
||||
|
||||
class TestSentimentContext:
|
||||
def test_valid_context(self) -> None:
|
||||
sc = SentimentContext(
|
||||
ticker="AAPL",
|
||||
avg_score=0.65,
|
||||
article_count=5,
|
||||
recent_scores=[0.5, 0.7, 0.8],
|
||||
avg_confidence=0.85,
|
||||
)
|
||||
assert sc.article_count == 5
|
||||
|
||||
def test_score_range_validation(self) -> None:
|
||||
with pytest.raises(ValidationError):
|
||||
SentimentContext(
|
||||
ticker="AAPL",
|
||||
avg_score=1.5, # Out of range
|
||||
article_count=1,
|
||||
avg_confidence=0.5,
|
||||
)
|
||||
|
||||
def test_negative_score_in_range(self) -> None:
|
||||
sc = SentimentContext(
|
||||
ticker="TSLA",
|
||||
avg_score=-0.8,
|
||||
article_count=3,
|
||||
avg_confidence=0.9,
|
||||
)
|
||||
assert sc.avg_score == -0.8
|
||||
|
||||
def test_confidence_range_validation(self) -> None:
|
||||
with pytest.raises(ValidationError):
|
||||
SentimentContext(
|
||||
ticker="AAPL",
|
||||
avg_score=0.5,
|
||||
article_count=1,
|
||||
avg_confidence=1.5, # Out of range
|
||||
)
|
||||
|
||||
def test_article_count_non_negative(self) -> None:
|
||||
with pytest.raises(ValidationError):
|
||||
SentimentContext(
|
||||
ticker="AAPL",
|
||||
avg_score=0.0,
|
||||
article_count=-1,
|
||||
avg_confidence=0.5,
|
||||
)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# News schemas
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestRawArticle:
|
||||
def test_valid_article(self) -> None:
|
||||
now = datetime.now(timezone.utc)
|
||||
a = RawArticle(
|
||||
source="reuters",
|
||||
url="https://reuters.com/article/1",
|
||||
title="Market Rally Continues",
|
||||
content="Stocks rose sharply today...",
|
||||
published_at=now,
|
||||
fetched_at=now,
|
||||
content_hash="sha256abcdef1234567890",
|
||||
)
|
||||
assert a.source == "reuters"
|
||||
|
||||
def test_published_at_optional(self) -> None:
|
||||
now = datetime.now(timezone.utc)
|
||||
a = RawArticle(
|
||||
source="reddit",
|
||||
url="https://reddit.com/r/stocks/1",
|
||||
title="DD on TSLA",
|
||||
content="Here is my analysis...",
|
||||
fetched_at=now,
|
||||
content_hash="hash123",
|
||||
)
|
||||
assert a.published_at is None
|
||||
|
||||
def test_json_round_trip(self) -> None:
|
||||
now = datetime.now(timezone.utc)
|
||||
a = RawArticle(
|
||||
source="yahoo",
|
||||
url="https://finance.yahoo.com/1",
|
||||
title="Earnings Beat",
|
||||
content="Apple beat earnings...",
|
||||
published_at=now,
|
||||
fetched_at=now,
|
||||
content_hash="hash456",
|
||||
)
|
||||
restored = RawArticle.model_validate_json(a.model_dump_json())
|
||||
assert restored == a
|
||||
|
||||
def test_required_fields(self) -> None:
|
||||
with pytest.raises(ValidationError):
|
||||
RawArticle(source="reuters") # type: ignore[call-arg]
|
||||
|
||||
|
||||
class TestScoredArticle:
|
||||
def test_valid_scored_article(self) -> None:
|
||||
now = datetime.now(timezone.utc)
|
||||
sa = ScoredArticle(
|
||||
source="reuters",
|
||||
url="https://reuters.com/article/1",
|
||||
title="Apple Earnings Beat",
|
||||
content="Apple reported...",
|
||||
published_at=now,
|
||||
fetched_at=now,
|
||||
content_hash="hash789",
|
||||
ticker="AAPL",
|
||||
sentiment_score=0.85,
|
||||
confidence=0.92,
|
||||
model_used="finbert",
|
||||
entities=["Apple Inc", "Tim Cook"],
|
||||
)
|
||||
assert sa.sentiment_score == 0.85
|
||||
assert sa.entities == ["Apple Inc", "Tim Cook"]
|
||||
|
||||
def test_sentiment_score_range(self) -> None:
|
||||
now = datetime.now(timezone.utc)
|
||||
with pytest.raises(ValidationError):
|
||||
ScoredArticle(
|
||||
source="reuters",
|
||||
url="https://reuters.com/1",
|
||||
title="Test",
|
||||
content="Test content",
|
||||
fetched_at=now,
|
||||
content_hash="hash",
|
||||
ticker="AAPL",
|
||||
sentiment_score=1.5, # Out of range
|
||||
confidence=0.5,
|
||||
model_used="finbert",
|
||||
)
|
||||
|
||||
def test_confidence_range(self) -> None:
|
||||
now = datetime.now(timezone.utc)
|
||||
with pytest.raises(ValidationError):
|
||||
ScoredArticle(
|
||||
source="reuters",
|
||||
url="https://reuters.com/1",
|
||||
title="Test",
|
||||
content="Test content",
|
||||
fetched_at=now,
|
||||
content_hash="hash",
|
||||
ticker="AAPL",
|
||||
sentiment_score=0.5,
|
||||
confidence=-0.1, # Out of range
|
||||
model_used="finbert",
|
||||
)
|
||||
|
||||
def test_negative_sentiment(self) -> None:
|
||||
now = datetime.now(timezone.utc)
|
||||
sa = ScoredArticle(
|
||||
source="reddit",
|
||||
url="https://reddit.com/1",
|
||||
title="Bad news",
|
||||
content="Terrible quarter...",
|
||||
fetched_at=now,
|
||||
content_hash="hashN",
|
||||
ticker="TSLA",
|
||||
sentiment_score=-0.9,
|
||||
confidence=0.8,
|
||||
model_used="ollama",
|
||||
)
|
||||
assert sa.sentiment_score == -0.9
|
||||
|
||||
def test_json_round_trip(self) -> None:
|
||||
now = datetime.now(timezone.utc)
|
||||
sa = ScoredArticle(
|
||||
source="yahoo",
|
||||
url="https://yahoo.com/1",
|
||||
title="Headline",
|
||||
content="Body text",
|
||||
fetched_at=now,
|
||||
content_hash="hashRT",
|
||||
ticker="NVDA",
|
||||
sentiment_score=0.3,
|
||||
confidence=0.7,
|
||||
model_used="finbert",
|
||||
entities=["NVIDIA"],
|
||||
)
|
||||
restored = ScoredArticle.model_validate_json(sa.model_dump_json())
|
||||
assert restored == sa
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Learning schemas
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestTradeOutcomeSchema:
|
||||
def test_valid_outcome(self) -> None:
|
||||
o = TradeOutcomeSchema(
|
||||
trade_id=uuid.uuid4(),
|
||||
hold_duration_seconds=14400.0,
|
||||
realized_pnl=250.50,
|
||||
roi_pct=2.5,
|
||||
was_profitable=True,
|
||||
)
|
||||
assert o.was_profitable is True
|
||||
assert o.hold_duration_seconds == 14400.0
|
||||
|
||||
def test_hold_duration_non_negative(self) -> None:
|
||||
with pytest.raises(ValidationError):
|
||||
TradeOutcomeSchema(
|
||||
trade_id=uuid.uuid4(),
|
||||
hold_duration_seconds=-1.0,
|
||||
realized_pnl=100.0,
|
||||
roi_pct=1.0,
|
||||
was_profitable=True,
|
||||
)
|
||||
|
||||
def test_losing_trade(self) -> None:
|
||||
o = TradeOutcomeSchema(
|
||||
trade_id=uuid.uuid4(),
|
||||
hold_duration_seconds=3600.0,
|
||||
realized_pnl=-150.0,
|
||||
roi_pct=-3.0,
|
||||
was_profitable=False,
|
||||
)
|
||||
assert o.was_profitable is False
|
||||
assert o.realized_pnl == -150.0
|
||||
|
||||
def test_json_round_trip(self) -> None:
|
||||
o = TradeOutcomeSchema(
|
||||
trade_id=uuid.uuid4(),
|
||||
hold_duration_seconds=7200.0,
|
||||
realized_pnl=500.0,
|
||||
roi_pct=5.0,
|
||||
was_profitable=True,
|
||||
)
|
||||
restored = TradeOutcomeSchema.model_validate_json(o.model_dump_json())
|
||||
assert restored == o
|
||||
|
||||
|
||||
class TestWeightAdjustment:
|
||||
def test_valid_adjustment(self) -> None:
|
||||
now = datetime.now(timezone.utc)
|
||||
wa = WeightAdjustment(
|
||||
strategy_id=uuid.uuid4(),
|
||||
strategy_name="momentum",
|
||||
old_weight=0.33,
|
||||
new_weight=0.38,
|
||||
reason="Positive reward signal from recent trades",
|
||||
reward_signal=0.72,
|
||||
timestamp=now,
|
||||
)
|
||||
assert wa.old_weight == 0.33
|
||||
assert wa.new_weight == 0.38
|
||||
|
||||
def test_required_fields(self) -> None:
|
||||
with pytest.raises(ValidationError):
|
||||
WeightAdjustment(
|
||||
strategy_id=uuid.uuid4(),
|
||||
strategy_name="momentum",
|
||||
) # type: ignore[call-arg]
|
||||
|
||||
def test_json_round_trip(self) -> None:
|
||||
now = datetime.now(timezone.utc)
|
||||
wa = WeightAdjustment(
|
||||
strategy_id=uuid.uuid4(),
|
||||
strategy_name="mean_reversion",
|
||||
old_weight=0.30,
|
||||
new_weight=0.25,
|
||||
reason="Poor recent performance",
|
||||
reward_signal=-0.4,
|
||||
timestamp=now,
|
||||
)
|
||||
restored = WeightAdjustment.model_validate_json(wa.model_dump_json())
|
||||
assert restored == wa
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Auth schemas
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestRegisterRequest:
|
||||
def test_valid_registration(self) -> None:
|
||||
r = RegisterRequest(username="trader1", display_name="Top Trader")
|
||||
assert r.username == "trader1"
|
||||
assert r.display_name == "Top Trader"
|
||||
|
||||
def test_display_name_optional(self) -> None:
|
||||
r = RegisterRequest(username="trader2")
|
||||
assert r.display_name is None
|
||||
|
||||
def test_username_required(self) -> None:
|
||||
with pytest.raises(ValidationError):
|
||||
RegisterRequest(username="") # min_length=1
|
||||
|
||||
def test_username_max_length(self) -> None:
|
||||
with pytest.raises(ValidationError):
|
||||
RegisterRequest(username="x" * 101) # max_length=100
|
||||
|
||||
|
||||
class TestLoginRequest:
|
||||
def test_valid_login(self) -> None:
|
||||
l = LoginRequest(username="trader1")
|
||||
assert l.username == "trader1"
|
||||
|
||||
def test_username_required(self) -> None:
|
||||
with pytest.raises(ValidationError):
|
||||
LoginRequest(username="")
|
||||
|
||||
|
||||
class TestTokenResponse:
|
||||
def test_valid_response(self) -> None:
|
||||
t = TokenResponse(
|
||||
access_token="eyJ...",
|
||||
refresh_token="eyR...",
|
||||
)
|
||||
assert t.token_type == "bearer"
|
||||
assert t.expires_in == 900 # default 15 min
|
||||
|
||||
def test_custom_expiry(self) -> None:
|
||||
t = TokenResponse(
|
||||
access_token="eyJ...",
|
||||
refresh_token="eyR...",
|
||||
expires_in=3600,
|
||||
)
|
||||
assert t.expires_in == 3600
|
||||
|
||||
def test_json_round_trip(self) -> None:
|
||||
t = TokenResponse(
|
||||
access_token="access123",
|
||||
refresh_token="refresh456",
|
||||
token_type="bearer",
|
||||
expires_in=1800,
|
||||
)
|
||||
restored = TokenResponse.model_validate_json(t.model_dump_json())
|
||||
assert restored == t
|
||||
Loading…
Add table
Add a link
Reference in a new issue