feat: pydantic schemas for all service message types
- shared/schemas/trading.py: OrderRequest, OrderResult, PositionInfo, AccountInfo, TradeSignal, TradeExecution, MarketSnapshot, SentimentContext - shared/schemas/news.py: RawArticle, ScoredArticle - shared/schemas/learning.py: TradeOutcomeSchema, WeightAdjustment - shared/schemas/auth.py: RegisterRequest, LoginRequest, TokenResponse - 49 schema tests covering validation constraints, serialization round-trips, required fields, and range checks
This commit is contained in:
parent
72cb1b6fe5
commit
c8277e301e
11 changed files with 889 additions and 0 deletions
37
shared/schemas/__init__.py
Normal file
37
shared/schemas/__init__.py
Normal file
|
|
@ -0,0 +1,37 @@
|
||||||
|
"""Pydantic v2 schemas for all service message types."""
|
||||||
|
|
||||||
|
from shared.schemas.trading import (
|
||||||
|
AccountInfo,
|
||||||
|
MarketSnapshot,
|
||||||
|
OrderRequest,
|
||||||
|
OrderResult,
|
||||||
|
PositionInfo,
|
||||||
|
SentimentContext,
|
||||||
|
TradeExecution,
|
||||||
|
TradeSignal,
|
||||||
|
)
|
||||||
|
from shared.schemas.news import RawArticle, ScoredArticle
|
||||||
|
from shared.schemas.learning import TradeOutcomeSchema, WeightAdjustment
|
||||||
|
from shared.schemas.auth import LoginRequest, RegisterRequest, TokenResponse
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
# Trading
|
||||||
|
"OrderRequest",
|
||||||
|
"OrderResult",
|
||||||
|
"PositionInfo",
|
||||||
|
"AccountInfo",
|
||||||
|
"TradeSignal",
|
||||||
|
"TradeExecution",
|
||||||
|
"MarketSnapshot",
|
||||||
|
"SentimentContext",
|
||||||
|
# News
|
||||||
|
"RawArticle",
|
||||||
|
"ScoredArticle",
|
||||||
|
# Learning
|
||||||
|
"TradeOutcomeSchema",
|
||||||
|
"WeightAdjustment",
|
||||||
|
# Auth
|
||||||
|
"RegisterRequest",
|
||||||
|
"LoginRequest",
|
||||||
|
"TokenResponse",
|
||||||
|
]
|
||||||
BIN
shared/schemas/__pycache__/__init__.cpython-314.pyc
Normal file
BIN
shared/schemas/__pycache__/__init__.cpython-314.pyc
Normal file
Binary file not shown.
BIN
shared/schemas/__pycache__/auth.cpython-314.pyc
Normal file
BIN
shared/schemas/__pycache__/auth.cpython-314.pyc
Normal file
Binary file not shown.
BIN
shared/schemas/__pycache__/learning.cpython-314.pyc
Normal file
BIN
shared/schemas/__pycache__/learning.cpython-314.pyc
Normal file
Binary file not shown.
BIN
shared/schemas/__pycache__/news.cpython-314.pyc
Normal file
BIN
shared/schemas/__pycache__/news.cpython-314.pyc
Normal file
Binary file not shown.
BIN
shared/schemas/__pycache__/trading.cpython-314.pyc
Normal file
BIN
shared/schemas/__pycache__/trading.cpython-314.pyc
Normal file
Binary file not shown.
27
shared/schemas/auth.py
Normal file
27
shared/schemas/auth.py
Normal file
|
|
@ -0,0 +1,27 @@
|
||||||
|
"""Authentication Pydantic schemas for API request/response payloads."""
|
||||||
|
|
||||||
|
from pydantic import BaseModel, Field
|
||||||
|
|
||||||
|
|
||||||
|
class RegisterRequest(BaseModel):
|
||||||
|
"""Sent by the dashboard to begin passkey registration."""
|
||||||
|
|
||||||
|
username: str = Field(min_length=1, max_length=100)
|
||||||
|
display_name: str | None = None
|
||||||
|
|
||||||
|
|
||||||
|
class LoginRequest(BaseModel):
|
||||||
|
"""Sent by the dashboard to begin passkey authentication."""
|
||||||
|
|
||||||
|
username: str = Field(min_length=1, max_length=100)
|
||||||
|
|
||||||
|
|
||||||
|
class TokenResponse(BaseModel):
|
||||||
|
"""Returned after successful authentication."""
|
||||||
|
|
||||||
|
access_token: str
|
||||||
|
refresh_token: str
|
||||||
|
token_type: str = "bearer"
|
||||||
|
expires_in: int = Field(
|
||||||
|
description="Access token lifetime in seconds", default=900
|
||||||
|
)
|
||||||
32
shared/schemas/learning.py
Normal file
32
shared/schemas/learning.py
Normal file
|
|
@ -0,0 +1,32 @@
|
||||||
|
"""Learning domain Pydantic schemas."""
|
||||||
|
|
||||||
|
from datetime import datetime
|
||||||
|
from uuid import UUID
|
||||||
|
|
||||||
|
from pydantic import BaseModel, Field
|
||||||
|
|
||||||
|
|
||||||
|
class TradeOutcomeSchema(BaseModel):
|
||||||
|
"""Represents the evaluated outcome of a closed trade."""
|
||||||
|
|
||||||
|
trade_id: UUID
|
||||||
|
hold_duration_seconds: float = Field(ge=0)
|
||||||
|
realized_pnl: float
|
||||||
|
roi_pct: float
|
||||||
|
was_profitable: bool
|
||||||
|
|
||||||
|
model_config = {"from_attributes": True}
|
||||||
|
|
||||||
|
|
||||||
|
class WeightAdjustment(BaseModel):
|
||||||
|
"""Represents a strategy weight change made by the learning engine."""
|
||||||
|
|
||||||
|
strategy_id: UUID
|
||||||
|
strategy_name: str
|
||||||
|
old_weight: float
|
||||||
|
new_weight: float
|
||||||
|
reason: str
|
||||||
|
reward_signal: float
|
||||||
|
timestamp: datetime
|
||||||
|
|
||||||
|
model_config = {"from_attributes": True}
|
||||||
44
shared/schemas/news.py
Normal file
44
shared/schemas/news.py
Normal file
|
|
@ -0,0 +1,44 @@
|
||||||
|
"""News article Pydantic schemas for Redis Stream messages."""
|
||||||
|
|
||||||
|
from datetime import datetime
|
||||||
|
|
||||||
|
from pydantic import BaseModel, Field
|
||||||
|
|
||||||
|
|
||||||
|
class RawArticle(BaseModel):
|
||||||
|
"""Published to ``news:raw`` by the news fetcher."""
|
||||||
|
|
||||||
|
source: str
|
||||||
|
url: str
|
||||||
|
title: str
|
||||||
|
content: str
|
||||||
|
published_at: datetime | None = None
|
||||||
|
fetched_at: datetime
|
||||||
|
content_hash: str
|
||||||
|
|
||||||
|
model_config = {"from_attributes": True}
|
||||||
|
|
||||||
|
|
||||||
|
class ScoredArticle(BaseModel):
|
||||||
|
"""Published to ``news:scored`` by the sentiment analyzer.
|
||||||
|
|
||||||
|
Inherits all fields from RawArticle conceptually plus scoring metadata.
|
||||||
|
"""
|
||||||
|
|
||||||
|
# Original article fields
|
||||||
|
source: str
|
||||||
|
url: str
|
||||||
|
title: str
|
||||||
|
content: str
|
||||||
|
published_at: datetime | None = None
|
||||||
|
fetched_at: datetime
|
||||||
|
content_hash: str
|
||||||
|
|
||||||
|
# Scoring fields
|
||||||
|
ticker: str
|
||||||
|
sentiment_score: float = Field(ge=-1.0, le=1.0)
|
||||||
|
confidence: float = Field(ge=0.0, le=1.0)
|
||||||
|
model_used: str
|
||||||
|
entities: list[str] = Field(default_factory=list)
|
||||||
|
|
||||||
|
model_config = {"from_attributes": True}
|
||||||
163
shared/schemas/trading.py
Normal file
163
shared/schemas/trading.py
Normal file
|
|
@ -0,0 +1,163 @@
|
||||||
|
"""Trading-related Pydantic schemas for Redis Streams messages and API payloads."""
|
||||||
|
|
||||||
|
from datetime import datetime
|
||||||
|
from enum import Enum
|
||||||
|
from typing import Any
|
||||||
|
from uuid import UUID
|
||||||
|
|
||||||
|
from pydantic import BaseModel, Field
|
||||||
|
|
||||||
|
|
||||||
|
class OrderType(str, Enum):
|
||||||
|
MARKET = "market"
|
||||||
|
LIMIT = "limit"
|
||||||
|
STOP = "stop"
|
||||||
|
|
||||||
|
|
||||||
|
class OrderSide(str, Enum):
|
||||||
|
BUY = "BUY"
|
||||||
|
SELL = "SELL"
|
||||||
|
|
||||||
|
|
||||||
|
class OrderStatus(str, Enum):
|
||||||
|
PENDING = "PENDING"
|
||||||
|
FILLED = "FILLED"
|
||||||
|
CANCELLED = "CANCELLED"
|
||||||
|
REJECTED = "REJECTED"
|
||||||
|
|
||||||
|
|
||||||
|
class SignalDirection(str, Enum):
|
||||||
|
LONG = "LONG"
|
||||||
|
SHORT = "SHORT"
|
||||||
|
NEUTRAL = "NEUTRAL"
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# API request / response schemas
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
class OrderRequest(BaseModel):
|
||||||
|
"""Submitted by the trade executor or the API to place an order."""
|
||||||
|
|
||||||
|
ticker: str
|
||||||
|
side: OrderSide
|
||||||
|
qty: float = Field(gt=0)
|
||||||
|
order_type: OrderType = OrderType.MARKET
|
||||||
|
limit_price: float | None = None
|
||||||
|
stop_price: float | None = None
|
||||||
|
|
||||||
|
model_config = {"from_attributes": True}
|
||||||
|
|
||||||
|
|
||||||
|
class OrderResult(BaseModel):
|
||||||
|
"""Returned after order submission or status query."""
|
||||||
|
|
||||||
|
order_id: str
|
||||||
|
ticker: str
|
||||||
|
side: OrderSide
|
||||||
|
qty: float
|
||||||
|
filled_price: float | None = None
|
||||||
|
status: OrderStatus
|
||||||
|
timestamp: datetime
|
||||||
|
|
||||||
|
model_config = {"from_attributes": True}
|
||||||
|
|
||||||
|
|
||||||
|
class PositionInfo(BaseModel):
|
||||||
|
"""Current position state — used in API responses and portfolio views."""
|
||||||
|
|
||||||
|
ticker: str
|
||||||
|
qty: float
|
||||||
|
avg_entry: float
|
||||||
|
current_price: float
|
||||||
|
unrealized_pnl: float
|
||||||
|
market_value: float
|
||||||
|
|
||||||
|
model_config = {"from_attributes": True}
|
||||||
|
|
||||||
|
|
||||||
|
class AccountInfo(BaseModel):
|
||||||
|
"""Account-level summary from the brokerage."""
|
||||||
|
|
||||||
|
equity: float
|
||||||
|
cash: float
|
||||||
|
buying_power: float
|
||||||
|
portfolio_value: float
|
||||||
|
|
||||||
|
model_config = {"from_attributes": True}
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Redis Stream message schemas
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
class TradeSignal(BaseModel):
|
||||||
|
"""Published to ``signals:generated`` by the signal generator."""
|
||||||
|
|
||||||
|
ticker: str
|
||||||
|
direction: SignalDirection
|
||||||
|
strength: float = Field(ge=0.0, le=1.0)
|
||||||
|
strategy_sources: list[str]
|
||||||
|
sentiment_context: dict[str, Any] | None = None
|
||||||
|
timestamp: datetime
|
||||||
|
|
||||||
|
model_config = {"from_attributes": True}
|
||||||
|
|
||||||
|
|
||||||
|
class TradeExecution(BaseModel):
|
||||||
|
"""Published to ``trades:executed`` by the trade executor."""
|
||||||
|
|
||||||
|
trade_id: UUID
|
||||||
|
ticker: str
|
||||||
|
side: OrderSide
|
||||||
|
qty: float
|
||||||
|
price: float
|
||||||
|
status: OrderStatus
|
||||||
|
signal_id: UUID | None = None
|
||||||
|
strategy_id: UUID | None = None
|
||||||
|
timestamp: datetime
|
||||||
|
|
||||||
|
model_config = {"from_attributes": True}
|
||||||
|
|
||||||
|
|
||||||
|
class OHLCVBar(BaseModel):
|
||||||
|
"""Single OHLCV bar."""
|
||||||
|
|
||||||
|
timestamp: datetime
|
||||||
|
open: float
|
||||||
|
high: float
|
||||||
|
low: float
|
||||||
|
close: float
|
||||||
|
volume: float
|
||||||
|
|
||||||
|
|
||||||
|
class MarketSnapshot(BaseModel):
|
||||||
|
"""Snapshot of market data for a single ticker — used by strategies."""
|
||||||
|
|
||||||
|
ticker: str
|
||||||
|
current_price: float
|
||||||
|
open: float
|
||||||
|
high: float
|
||||||
|
low: float
|
||||||
|
close: float
|
||||||
|
volume: float
|
||||||
|
sma_20: float | None = None
|
||||||
|
sma_50: float | None = None
|
||||||
|
rsi: float | None = None
|
||||||
|
bars: list[dict[str, Any]] = Field(default_factory=list)
|
||||||
|
|
||||||
|
model_config = {"from_attributes": True}
|
||||||
|
|
||||||
|
|
||||||
|
class SentimentContext(BaseModel):
|
||||||
|
"""Aggregated sentiment for a ticker — passed to strategies."""
|
||||||
|
|
||||||
|
ticker: str
|
||||||
|
avg_score: float = Field(ge=-1.0, le=1.0)
|
||||||
|
article_count: int = Field(ge=0)
|
||||||
|
recent_scores: list[float] = Field(default_factory=list)
|
||||||
|
avg_confidence: float = Field(ge=0.0, le=1.0)
|
||||||
|
|
||||||
|
model_config = {"from_attributes": True}
|
||||||
586
tests/test_schemas.py
Normal file
586
tests/test_schemas.py
Normal file
|
|
@ -0,0 +1,586 @@
|
||||||
|
"""Tests for Pydantic schemas — serialization round-trips and validation constraints."""
|
||||||
|
|
||||||
|
import uuid
|
||||||
|
from datetime import datetime, timezone
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
from pydantic import ValidationError
|
||||||
|
|
||||||
|
from shared.schemas.trading import (
|
||||||
|
AccountInfo,
|
||||||
|
MarketSnapshot,
|
||||||
|
OrderRequest,
|
||||||
|
OrderResult,
|
||||||
|
OrderSide,
|
||||||
|
OrderStatus,
|
||||||
|
OrderType,
|
||||||
|
PositionInfo,
|
||||||
|
SentimentContext,
|
||||||
|
SignalDirection,
|
||||||
|
TradeExecution,
|
||||||
|
TradeSignal,
|
||||||
|
)
|
||||||
|
from shared.schemas.news import RawArticle, ScoredArticle
|
||||||
|
from shared.schemas.learning import TradeOutcomeSchema, WeightAdjustment
|
||||||
|
from shared.schemas.auth import LoginRequest, RegisterRequest, TokenResponse
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Trading schemas
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
class TestOrderRequest:
|
||||||
|
def test_valid_market_order(self) -> None:
|
||||||
|
o = OrderRequest(ticker="AAPL", side=OrderSide.BUY, qty=10.0)
|
||||||
|
assert o.order_type == OrderType.MARKET
|
||||||
|
assert o.limit_price is None
|
||||||
|
|
||||||
|
def test_valid_limit_order(self) -> None:
|
||||||
|
o = OrderRequest(
|
||||||
|
ticker="TSLA",
|
||||||
|
side=OrderSide.SELL,
|
||||||
|
qty=5.0,
|
||||||
|
order_type=OrderType.LIMIT,
|
||||||
|
limit_price=250.50,
|
||||||
|
)
|
||||||
|
assert o.limit_price == 250.50
|
||||||
|
|
||||||
|
def test_qty_must_be_positive(self) -> None:
|
||||||
|
with pytest.raises(ValidationError):
|
||||||
|
OrderRequest(ticker="AAPL", side=OrderSide.BUY, qty=0)
|
||||||
|
|
||||||
|
def test_qty_must_not_be_negative(self) -> None:
|
||||||
|
with pytest.raises(ValidationError):
|
||||||
|
OrderRequest(ticker="AAPL", side=OrderSide.BUY, qty=-5)
|
||||||
|
|
||||||
|
def test_serialization_round_trip(self) -> None:
|
||||||
|
o = OrderRequest(
|
||||||
|
ticker="GOOG",
|
||||||
|
side=OrderSide.BUY,
|
||||||
|
qty=3.0,
|
||||||
|
order_type=OrderType.STOP,
|
||||||
|
stop_price=100.0,
|
||||||
|
)
|
||||||
|
data = o.model_dump()
|
||||||
|
restored = OrderRequest.model_validate(data)
|
||||||
|
assert restored == o
|
||||||
|
|
||||||
|
def test_json_round_trip(self) -> None:
|
||||||
|
o = OrderRequest(ticker="META", side=OrderSide.SELL, qty=1.5)
|
||||||
|
json_str = o.model_dump_json()
|
||||||
|
restored = OrderRequest.model_validate_json(json_str)
|
||||||
|
assert restored == o
|
||||||
|
|
||||||
|
|
||||||
|
class TestOrderResult:
|
||||||
|
def test_valid_result(self) -> None:
|
||||||
|
now = datetime.now(timezone.utc)
|
||||||
|
r = OrderResult(
|
||||||
|
order_id="ord-123",
|
||||||
|
ticker="AAPL",
|
||||||
|
side=OrderSide.BUY,
|
||||||
|
qty=10.0,
|
||||||
|
filled_price=150.25,
|
||||||
|
status=OrderStatus.FILLED,
|
||||||
|
timestamp=now,
|
||||||
|
)
|
||||||
|
assert r.filled_price == 150.25
|
||||||
|
assert r.status == OrderStatus.FILLED
|
||||||
|
|
||||||
|
def test_pending_no_fill(self) -> None:
|
||||||
|
now = datetime.now(timezone.utc)
|
||||||
|
r = OrderResult(
|
||||||
|
order_id="ord-456",
|
||||||
|
ticker="TSLA",
|
||||||
|
side=OrderSide.SELL,
|
||||||
|
qty=5.0,
|
||||||
|
status=OrderStatus.PENDING,
|
||||||
|
timestamp=now,
|
||||||
|
)
|
||||||
|
assert r.filled_price is None
|
||||||
|
|
||||||
|
|
||||||
|
class TestPositionInfo:
|
||||||
|
def test_valid_position(self) -> None:
|
||||||
|
p = PositionInfo(
|
||||||
|
ticker="NVDA",
|
||||||
|
qty=20.0,
|
||||||
|
avg_entry=800.0,
|
||||||
|
current_price=850.0,
|
||||||
|
unrealized_pnl=1000.0,
|
||||||
|
market_value=17000.0,
|
||||||
|
)
|
||||||
|
assert p.market_value == 17000.0
|
||||||
|
|
||||||
|
def test_serialization_round_trip(self) -> None:
|
||||||
|
p = PositionInfo(
|
||||||
|
ticker="AMZN",
|
||||||
|
qty=10.0,
|
||||||
|
avg_entry=180.0,
|
||||||
|
current_price=185.0,
|
||||||
|
unrealized_pnl=50.0,
|
||||||
|
market_value=1850.0,
|
||||||
|
)
|
||||||
|
restored = PositionInfo.model_validate(p.model_dump())
|
||||||
|
assert restored == p
|
||||||
|
|
||||||
|
|
||||||
|
class TestAccountInfo:
|
||||||
|
def test_valid_account(self) -> None:
|
||||||
|
a = AccountInfo(
|
||||||
|
equity=100_000.0,
|
||||||
|
cash=25_000.0,
|
||||||
|
buying_power=50_000.0,
|
||||||
|
portfolio_value=100_000.0,
|
||||||
|
)
|
||||||
|
assert a.equity == 100_000.0
|
||||||
|
|
||||||
|
|
||||||
|
class TestTradeSignal:
|
||||||
|
def test_valid_signal(self) -> None:
|
||||||
|
now = datetime.now(timezone.utc)
|
||||||
|
s = TradeSignal(
|
||||||
|
ticker="AAPL",
|
||||||
|
direction=SignalDirection.LONG,
|
||||||
|
strength=0.85,
|
||||||
|
strategy_sources=["momentum", "news_driven"],
|
||||||
|
timestamp=now,
|
||||||
|
)
|
||||||
|
assert s.strength == 0.85
|
||||||
|
assert s.sentiment_context is None
|
||||||
|
|
||||||
|
def test_strength_must_be_in_range(self) -> None:
|
||||||
|
now = datetime.now(timezone.utc)
|
||||||
|
with pytest.raises(ValidationError):
|
||||||
|
TradeSignal(
|
||||||
|
ticker="AAPL",
|
||||||
|
direction=SignalDirection.LONG,
|
||||||
|
strength=1.5,
|
||||||
|
strategy_sources=["momentum"],
|
||||||
|
timestamp=now,
|
||||||
|
)
|
||||||
|
|
||||||
|
def test_strength_lower_bound(self) -> None:
|
||||||
|
now = datetime.now(timezone.utc)
|
||||||
|
with pytest.raises(ValidationError):
|
||||||
|
TradeSignal(
|
||||||
|
ticker="AAPL",
|
||||||
|
direction=SignalDirection.SHORT,
|
||||||
|
strength=-0.1,
|
||||||
|
strategy_sources=["mean_reversion"],
|
||||||
|
timestamp=now,
|
||||||
|
)
|
||||||
|
|
||||||
|
def test_json_round_trip(self) -> None:
|
||||||
|
now = datetime.now(timezone.utc)
|
||||||
|
s = TradeSignal(
|
||||||
|
ticker="TSLA",
|
||||||
|
direction=SignalDirection.SHORT,
|
||||||
|
strength=0.6,
|
||||||
|
strategy_sources=["mean_reversion"],
|
||||||
|
sentiment_context={"avg_score": -0.4},
|
||||||
|
timestamp=now,
|
||||||
|
)
|
||||||
|
restored = TradeSignal.model_validate_json(s.model_dump_json())
|
||||||
|
assert restored == s
|
||||||
|
|
||||||
|
|
||||||
|
class TestTradeExecution:
|
||||||
|
def test_valid_execution(self) -> None:
|
||||||
|
now = datetime.now(timezone.utc)
|
||||||
|
tid = uuid.uuid4()
|
||||||
|
sid = uuid.uuid4()
|
||||||
|
e = TradeExecution(
|
||||||
|
trade_id=tid,
|
||||||
|
ticker="AAPL",
|
||||||
|
side=OrderSide.BUY,
|
||||||
|
qty=10.0,
|
||||||
|
price=150.0,
|
||||||
|
status=OrderStatus.FILLED,
|
||||||
|
signal_id=sid,
|
||||||
|
strategy_id=uuid.uuid4(),
|
||||||
|
timestamp=now,
|
||||||
|
)
|
||||||
|
assert e.trade_id == tid
|
||||||
|
assert e.signal_id == sid
|
||||||
|
|
||||||
|
def test_optional_ids(self) -> None:
|
||||||
|
now = datetime.now(timezone.utc)
|
||||||
|
e = TradeExecution(
|
||||||
|
trade_id=uuid.uuid4(),
|
||||||
|
ticker="GOOG",
|
||||||
|
side=OrderSide.SELL,
|
||||||
|
qty=5.0,
|
||||||
|
price=2800.0,
|
||||||
|
status=OrderStatus.FILLED,
|
||||||
|
timestamp=now,
|
||||||
|
)
|
||||||
|
assert e.signal_id is None
|
||||||
|
assert e.strategy_id is None
|
||||||
|
|
||||||
|
|
||||||
|
class TestMarketSnapshot:
|
||||||
|
def test_valid_snapshot(self) -> None:
|
||||||
|
ms = MarketSnapshot(
|
||||||
|
ticker="AAPL",
|
||||||
|
current_price=150.0,
|
||||||
|
open=148.0,
|
||||||
|
high=152.0,
|
||||||
|
low=147.0,
|
||||||
|
close=150.0,
|
||||||
|
volume=5_000_000.0,
|
||||||
|
sma_20=149.5,
|
||||||
|
rsi=55.0,
|
||||||
|
)
|
||||||
|
assert ms.sma_50 is None
|
||||||
|
assert ms.bars == []
|
||||||
|
|
||||||
|
def test_with_bars(self) -> None:
|
||||||
|
ms = MarketSnapshot(
|
||||||
|
ticker="TSLA",
|
||||||
|
current_price=250.0,
|
||||||
|
open=248.0,
|
||||||
|
high=255.0,
|
||||||
|
low=245.0,
|
||||||
|
close=250.0,
|
||||||
|
volume=10_000_000.0,
|
||||||
|
bars=[
|
||||||
|
{"timestamp": "2026-01-01T10:00:00Z", "open": 248, "close": 250}
|
||||||
|
],
|
||||||
|
)
|
||||||
|
assert len(ms.bars) == 1
|
||||||
|
|
||||||
|
|
||||||
|
class TestSentimentContext:
|
||||||
|
def test_valid_context(self) -> None:
|
||||||
|
sc = SentimentContext(
|
||||||
|
ticker="AAPL",
|
||||||
|
avg_score=0.65,
|
||||||
|
article_count=5,
|
||||||
|
recent_scores=[0.5, 0.7, 0.8],
|
||||||
|
avg_confidence=0.85,
|
||||||
|
)
|
||||||
|
assert sc.article_count == 5
|
||||||
|
|
||||||
|
def test_score_range_validation(self) -> None:
|
||||||
|
with pytest.raises(ValidationError):
|
||||||
|
SentimentContext(
|
||||||
|
ticker="AAPL",
|
||||||
|
avg_score=1.5, # Out of range
|
||||||
|
article_count=1,
|
||||||
|
avg_confidence=0.5,
|
||||||
|
)
|
||||||
|
|
||||||
|
def test_negative_score_in_range(self) -> None:
|
||||||
|
sc = SentimentContext(
|
||||||
|
ticker="TSLA",
|
||||||
|
avg_score=-0.8,
|
||||||
|
article_count=3,
|
||||||
|
avg_confidence=0.9,
|
||||||
|
)
|
||||||
|
assert sc.avg_score == -0.8
|
||||||
|
|
||||||
|
def test_confidence_range_validation(self) -> None:
|
||||||
|
with pytest.raises(ValidationError):
|
||||||
|
SentimentContext(
|
||||||
|
ticker="AAPL",
|
||||||
|
avg_score=0.5,
|
||||||
|
article_count=1,
|
||||||
|
avg_confidence=1.5, # Out of range
|
||||||
|
)
|
||||||
|
|
||||||
|
def test_article_count_non_negative(self) -> None:
|
||||||
|
with pytest.raises(ValidationError):
|
||||||
|
SentimentContext(
|
||||||
|
ticker="AAPL",
|
||||||
|
avg_score=0.0,
|
||||||
|
article_count=-1,
|
||||||
|
avg_confidence=0.5,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# News schemas
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
class TestRawArticle:
|
||||||
|
def test_valid_article(self) -> None:
|
||||||
|
now = datetime.now(timezone.utc)
|
||||||
|
a = RawArticle(
|
||||||
|
source="reuters",
|
||||||
|
url="https://reuters.com/article/1",
|
||||||
|
title="Market Rally Continues",
|
||||||
|
content="Stocks rose sharply today...",
|
||||||
|
published_at=now,
|
||||||
|
fetched_at=now,
|
||||||
|
content_hash="sha256abcdef1234567890",
|
||||||
|
)
|
||||||
|
assert a.source == "reuters"
|
||||||
|
|
||||||
|
def test_published_at_optional(self) -> None:
|
||||||
|
now = datetime.now(timezone.utc)
|
||||||
|
a = RawArticle(
|
||||||
|
source="reddit",
|
||||||
|
url="https://reddit.com/r/stocks/1",
|
||||||
|
title="DD on TSLA",
|
||||||
|
content="Here is my analysis...",
|
||||||
|
fetched_at=now,
|
||||||
|
content_hash="hash123",
|
||||||
|
)
|
||||||
|
assert a.published_at is None
|
||||||
|
|
||||||
|
def test_json_round_trip(self) -> None:
|
||||||
|
now = datetime.now(timezone.utc)
|
||||||
|
a = RawArticle(
|
||||||
|
source="yahoo",
|
||||||
|
url="https://finance.yahoo.com/1",
|
||||||
|
title="Earnings Beat",
|
||||||
|
content="Apple beat earnings...",
|
||||||
|
published_at=now,
|
||||||
|
fetched_at=now,
|
||||||
|
content_hash="hash456",
|
||||||
|
)
|
||||||
|
restored = RawArticle.model_validate_json(a.model_dump_json())
|
||||||
|
assert restored == a
|
||||||
|
|
||||||
|
def test_required_fields(self) -> None:
|
||||||
|
with pytest.raises(ValidationError):
|
||||||
|
RawArticle(source="reuters") # type: ignore[call-arg]
|
||||||
|
|
||||||
|
|
||||||
|
class TestScoredArticle:
|
||||||
|
def test_valid_scored_article(self) -> None:
|
||||||
|
now = datetime.now(timezone.utc)
|
||||||
|
sa = ScoredArticle(
|
||||||
|
source="reuters",
|
||||||
|
url="https://reuters.com/article/1",
|
||||||
|
title="Apple Earnings Beat",
|
||||||
|
content="Apple reported...",
|
||||||
|
published_at=now,
|
||||||
|
fetched_at=now,
|
||||||
|
content_hash="hash789",
|
||||||
|
ticker="AAPL",
|
||||||
|
sentiment_score=0.85,
|
||||||
|
confidence=0.92,
|
||||||
|
model_used="finbert",
|
||||||
|
entities=["Apple Inc", "Tim Cook"],
|
||||||
|
)
|
||||||
|
assert sa.sentiment_score == 0.85
|
||||||
|
assert sa.entities == ["Apple Inc", "Tim Cook"]
|
||||||
|
|
||||||
|
def test_sentiment_score_range(self) -> None:
|
||||||
|
now = datetime.now(timezone.utc)
|
||||||
|
with pytest.raises(ValidationError):
|
||||||
|
ScoredArticle(
|
||||||
|
source="reuters",
|
||||||
|
url="https://reuters.com/1",
|
||||||
|
title="Test",
|
||||||
|
content="Test content",
|
||||||
|
fetched_at=now,
|
||||||
|
content_hash="hash",
|
||||||
|
ticker="AAPL",
|
||||||
|
sentiment_score=1.5, # Out of range
|
||||||
|
confidence=0.5,
|
||||||
|
model_used="finbert",
|
||||||
|
)
|
||||||
|
|
||||||
|
def test_confidence_range(self) -> None:
|
||||||
|
now = datetime.now(timezone.utc)
|
||||||
|
with pytest.raises(ValidationError):
|
||||||
|
ScoredArticle(
|
||||||
|
source="reuters",
|
||||||
|
url="https://reuters.com/1",
|
||||||
|
title="Test",
|
||||||
|
content="Test content",
|
||||||
|
fetched_at=now,
|
||||||
|
content_hash="hash",
|
||||||
|
ticker="AAPL",
|
||||||
|
sentiment_score=0.5,
|
||||||
|
confidence=-0.1, # Out of range
|
||||||
|
model_used="finbert",
|
||||||
|
)
|
||||||
|
|
||||||
|
def test_negative_sentiment(self) -> None:
|
||||||
|
now = datetime.now(timezone.utc)
|
||||||
|
sa = ScoredArticle(
|
||||||
|
source="reddit",
|
||||||
|
url="https://reddit.com/1",
|
||||||
|
title="Bad news",
|
||||||
|
content="Terrible quarter...",
|
||||||
|
fetched_at=now,
|
||||||
|
content_hash="hashN",
|
||||||
|
ticker="TSLA",
|
||||||
|
sentiment_score=-0.9,
|
||||||
|
confidence=0.8,
|
||||||
|
model_used="ollama",
|
||||||
|
)
|
||||||
|
assert sa.sentiment_score == -0.9
|
||||||
|
|
||||||
|
def test_json_round_trip(self) -> None:
|
||||||
|
now = datetime.now(timezone.utc)
|
||||||
|
sa = ScoredArticle(
|
||||||
|
source="yahoo",
|
||||||
|
url="https://yahoo.com/1",
|
||||||
|
title="Headline",
|
||||||
|
content="Body text",
|
||||||
|
fetched_at=now,
|
||||||
|
content_hash="hashRT",
|
||||||
|
ticker="NVDA",
|
||||||
|
sentiment_score=0.3,
|
||||||
|
confidence=0.7,
|
||||||
|
model_used="finbert",
|
||||||
|
entities=["NVIDIA"],
|
||||||
|
)
|
||||||
|
restored = ScoredArticle.model_validate_json(sa.model_dump_json())
|
||||||
|
assert restored == sa
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Learning schemas
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
class TestTradeOutcomeSchema:
|
||||||
|
def test_valid_outcome(self) -> None:
|
||||||
|
o = TradeOutcomeSchema(
|
||||||
|
trade_id=uuid.uuid4(),
|
||||||
|
hold_duration_seconds=14400.0,
|
||||||
|
realized_pnl=250.50,
|
||||||
|
roi_pct=2.5,
|
||||||
|
was_profitable=True,
|
||||||
|
)
|
||||||
|
assert o.was_profitable is True
|
||||||
|
assert o.hold_duration_seconds == 14400.0
|
||||||
|
|
||||||
|
def test_hold_duration_non_negative(self) -> None:
|
||||||
|
with pytest.raises(ValidationError):
|
||||||
|
TradeOutcomeSchema(
|
||||||
|
trade_id=uuid.uuid4(),
|
||||||
|
hold_duration_seconds=-1.0,
|
||||||
|
realized_pnl=100.0,
|
||||||
|
roi_pct=1.0,
|
||||||
|
was_profitable=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
def test_losing_trade(self) -> None:
|
||||||
|
o = TradeOutcomeSchema(
|
||||||
|
trade_id=uuid.uuid4(),
|
||||||
|
hold_duration_seconds=3600.0,
|
||||||
|
realized_pnl=-150.0,
|
||||||
|
roi_pct=-3.0,
|
||||||
|
was_profitable=False,
|
||||||
|
)
|
||||||
|
assert o.was_profitable is False
|
||||||
|
assert o.realized_pnl == -150.0
|
||||||
|
|
||||||
|
def test_json_round_trip(self) -> None:
|
||||||
|
o = TradeOutcomeSchema(
|
||||||
|
trade_id=uuid.uuid4(),
|
||||||
|
hold_duration_seconds=7200.0,
|
||||||
|
realized_pnl=500.0,
|
||||||
|
roi_pct=5.0,
|
||||||
|
was_profitable=True,
|
||||||
|
)
|
||||||
|
restored = TradeOutcomeSchema.model_validate_json(o.model_dump_json())
|
||||||
|
assert restored == o
|
||||||
|
|
||||||
|
|
||||||
|
class TestWeightAdjustment:
|
||||||
|
def test_valid_adjustment(self) -> None:
|
||||||
|
now = datetime.now(timezone.utc)
|
||||||
|
wa = WeightAdjustment(
|
||||||
|
strategy_id=uuid.uuid4(),
|
||||||
|
strategy_name="momentum",
|
||||||
|
old_weight=0.33,
|
||||||
|
new_weight=0.38,
|
||||||
|
reason="Positive reward signal from recent trades",
|
||||||
|
reward_signal=0.72,
|
||||||
|
timestamp=now,
|
||||||
|
)
|
||||||
|
assert wa.old_weight == 0.33
|
||||||
|
assert wa.new_weight == 0.38
|
||||||
|
|
||||||
|
def test_required_fields(self) -> None:
|
||||||
|
with pytest.raises(ValidationError):
|
||||||
|
WeightAdjustment(
|
||||||
|
strategy_id=uuid.uuid4(),
|
||||||
|
strategy_name="momentum",
|
||||||
|
) # type: ignore[call-arg]
|
||||||
|
|
||||||
|
def test_json_round_trip(self) -> None:
|
||||||
|
now = datetime.now(timezone.utc)
|
||||||
|
wa = WeightAdjustment(
|
||||||
|
strategy_id=uuid.uuid4(),
|
||||||
|
strategy_name="mean_reversion",
|
||||||
|
old_weight=0.30,
|
||||||
|
new_weight=0.25,
|
||||||
|
reason="Poor recent performance",
|
||||||
|
reward_signal=-0.4,
|
||||||
|
timestamp=now,
|
||||||
|
)
|
||||||
|
restored = WeightAdjustment.model_validate_json(wa.model_dump_json())
|
||||||
|
assert restored == wa
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Auth schemas
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
class TestRegisterRequest:
|
||||||
|
def test_valid_registration(self) -> None:
|
||||||
|
r = RegisterRequest(username="trader1", display_name="Top Trader")
|
||||||
|
assert r.username == "trader1"
|
||||||
|
assert r.display_name == "Top Trader"
|
||||||
|
|
||||||
|
def test_display_name_optional(self) -> None:
|
||||||
|
r = RegisterRequest(username="trader2")
|
||||||
|
assert r.display_name is None
|
||||||
|
|
||||||
|
def test_username_required(self) -> None:
|
||||||
|
with pytest.raises(ValidationError):
|
||||||
|
RegisterRequest(username="") # min_length=1
|
||||||
|
|
||||||
|
def test_username_max_length(self) -> None:
|
||||||
|
with pytest.raises(ValidationError):
|
||||||
|
RegisterRequest(username="x" * 101) # max_length=100
|
||||||
|
|
||||||
|
|
||||||
|
class TestLoginRequest:
|
||||||
|
def test_valid_login(self) -> None:
|
||||||
|
l = LoginRequest(username="trader1")
|
||||||
|
assert l.username == "trader1"
|
||||||
|
|
||||||
|
def test_username_required(self) -> None:
|
||||||
|
with pytest.raises(ValidationError):
|
||||||
|
LoginRequest(username="")
|
||||||
|
|
||||||
|
|
||||||
|
class TestTokenResponse:
|
||||||
|
def test_valid_response(self) -> None:
|
||||||
|
t = TokenResponse(
|
||||||
|
access_token="eyJ...",
|
||||||
|
refresh_token="eyR...",
|
||||||
|
)
|
||||||
|
assert t.token_type == "bearer"
|
||||||
|
assert t.expires_in == 900 # default 15 min
|
||||||
|
|
||||||
|
def test_custom_expiry(self) -> None:
|
||||||
|
t = TokenResponse(
|
||||||
|
access_token="eyJ...",
|
||||||
|
refresh_token="eyR...",
|
||||||
|
expires_in=3600,
|
||||||
|
)
|
||||||
|
assert t.expires_in == 3600
|
||||||
|
|
||||||
|
def test_json_round_trip(self) -> None:
|
||||||
|
t = TokenResponse(
|
||||||
|
access_token="access123",
|
||||||
|
refresh_token="refresh456",
|
||||||
|
token_type="bearer",
|
||||||
|
expires_in=1800,
|
||||||
|
)
|
||||||
|
restored = TokenResponse.model_validate_json(t.model_dump_json())
|
||||||
|
assert restored == t
|
||||||
Loading…
Add table
Add a link
Reference in a new issue