diff --git a/shared/schemas/__init__.py b/shared/schemas/__init__.py new file mode 100644 index 0000000..b6583bf --- /dev/null +++ b/shared/schemas/__init__.py @@ -0,0 +1,37 @@ +"""Pydantic v2 schemas for all service message types.""" + +from shared.schemas.trading import ( + AccountInfo, + MarketSnapshot, + OrderRequest, + OrderResult, + PositionInfo, + SentimentContext, + TradeExecution, + TradeSignal, +) +from shared.schemas.news import RawArticle, ScoredArticle +from shared.schemas.learning import TradeOutcomeSchema, WeightAdjustment +from shared.schemas.auth import LoginRequest, RegisterRequest, TokenResponse + +__all__ = [ + # Trading + "OrderRequest", + "OrderResult", + "PositionInfo", + "AccountInfo", + "TradeSignal", + "TradeExecution", + "MarketSnapshot", + "SentimentContext", + # News + "RawArticle", + "ScoredArticle", + # Learning + "TradeOutcomeSchema", + "WeightAdjustment", + # Auth + "RegisterRequest", + "LoginRequest", + "TokenResponse", +] diff --git a/shared/schemas/__pycache__/__init__.cpython-314.pyc b/shared/schemas/__pycache__/__init__.cpython-314.pyc new file mode 100644 index 0000000..4e710f8 Binary files /dev/null and b/shared/schemas/__pycache__/__init__.cpython-314.pyc differ diff --git a/shared/schemas/__pycache__/auth.cpython-314.pyc b/shared/schemas/__pycache__/auth.cpython-314.pyc new file mode 100644 index 0000000..f9d9336 Binary files /dev/null and b/shared/schemas/__pycache__/auth.cpython-314.pyc differ diff --git a/shared/schemas/__pycache__/learning.cpython-314.pyc b/shared/schemas/__pycache__/learning.cpython-314.pyc new file mode 100644 index 0000000..ba29268 Binary files /dev/null and b/shared/schemas/__pycache__/learning.cpython-314.pyc differ diff --git a/shared/schemas/__pycache__/news.cpython-314.pyc b/shared/schemas/__pycache__/news.cpython-314.pyc new file mode 100644 index 0000000..148487b Binary files /dev/null and b/shared/schemas/__pycache__/news.cpython-314.pyc differ diff --git a/shared/schemas/__pycache__/trading.cpython-314.pyc b/shared/schemas/__pycache__/trading.cpython-314.pyc new file mode 100644 index 0000000..030c4e0 Binary files /dev/null and b/shared/schemas/__pycache__/trading.cpython-314.pyc differ diff --git a/shared/schemas/auth.py b/shared/schemas/auth.py new file mode 100644 index 0000000..2dadacd --- /dev/null +++ b/shared/schemas/auth.py @@ -0,0 +1,27 @@ +"""Authentication Pydantic schemas for API request/response payloads.""" + +from pydantic import BaseModel, Field + + +class RegisterRequest(BaseModel): + """Sent by the dashboard to begin passkey registration.""" + + username: str = Field(min_length=1, max_length=100) + display_name: str | None = None + + +class LoginRequest(BaseModel): + """Sent by the dashboard to begin passkey authentication.""" + + username: str = Field(min_length=1, max_length=100) + + +class TokenResponse(BaseModel): + """Returned after successful authentication.""" + + access_token: str + refresh_token: str + token_type: str = "bearer" + expires_in: int = Field( + description="Access token lifetime in seconds", default=900 + ) diff --git a/shared/schemas/learning.py b/shared/schemas/learning.py new file mode 100644 index 0000000..cb8bd41 --- /dev/null +++ b/shared/schemas/learning.py @@ -0,0 +1,32 @@ +"""Learning domain Pydantic schemas.""" + +from datetime import datetime +from uuid import UUID + +from pydantic import BaseModel, Field + + +class TradeOutcomeSchema(BaseModel): + """Represents the evaluated outcome of a closed trade.""" + + trade_id: UUID + hold_duration_seconds: float = Field(ge=0) + realized_pnl: float + roi_pct: float + was_profitable: bool + + model_config = {"from_attributes": True} + + +class WeightAdjustment(BaseModel): + """Represents a strategy weight change made by the learning engine.""" + + strategy_id: UUID + strategy_name: str + old_weight: float + new_weight: float + reason: str + reward_signal: float + timestamp: datetime + + model_config = {"from_attributes": True} diff --git a/shared/schemas/news.py b/shared/schemas/news.py new file mode 100644 index 0000000..e23adb1 --- /dev/null +++ b/shared/schemas/news.py @@ -0,0 +1,44 @@ +"""News article Pydantic schemas for Redis Stream messages.""" + +from datetime import datetime + +from pydantic import BaseModel, Field + + +class RawArticle(BaseModel): + """Published to ``news:raw`` by the news fetcher.""" + + source: str + url: str + title: str + content: str + published_at: datetime | None = None + fetched_at: datetime + content_hash: str + + model_config = {"from_attributes": True} + + +class ScoredArticle(BaseModel): + """Published to ``news:scored`` by the sentiment analyzer. + + Inherits all fields from RawArticle conceptually plus scoring metadata. + """ + + # Original article fields + source: str + url: str + title: str + content: str + published_at: datetime | None = None + fetched_at: datetime + content_hash: str + + # Scoring fields + ticker: str + sentiment_score: float = Field(ge=-1.0, le=1.0) + confidence: float = Field(ge=0.0, le=1.0) + model_used: str + entities: list[str] = Field(default_factory=list) + + model_config = {"from_attributes": True} diff --git a/shared/schemas/trading.py b/shared/schemas/trading.py new file mode 100644 index 0000000..0a0046b --- /dev/null +++ b/shared/schemas/trading.py @@ -0,0 +1,163 @@ +"""Trading-related Pydantic schemas for Redis Streams messages and API payloads.""" + +from datetime import datetime +from enum import Enum +from typing import Any +from uuid import UUID + +from pydantic import BaseModel, Field + + +class OrderType(str, Enum): + MARKET = "market" + LIMIT = "limit" + STOP = "stop" + + +class OrderSide(str, Enum): + BUY = "BUY" + SELL = "SELL" + + +class OrderStatus(str, Enum): + PENDING = "PENDING" + FILLED = "FILLED" + CANCELLED = "CANCELLED" + REJECTED = "REJECTED" + + +class SignalDirection(str, Enum): + LONG = "LONG" + SHORT = "SHORT" + NEUTRAL = "NEUTRAL" + + +# --------------------------------------------------------------------------- +# API request / response schemas +# --------------------------------------------------------------------------- + + +class OrderRequest(BaseModel): + """Submitted by the trade executor or the API to place an order.""" + + ticker: str + side: OrderSide + qty: float = Field(gt=0) + order_type: OrderType = OrderType.MARKET + limit_price: float | None = None + stop_price: float | None = None + + model_config = {"from_attributes": True} + + +class OrderResult(BaseModel): + """Returned after order submission or status query.""" + + order_id: str + ticker: str + side: OrderSide + qty: float + filled_price: float | None = None + status: OrderStatus + timestamp: datetime + + model_config = {"from_attributes": True} + + +class PositionInfo(BaseModel): + """Current position state — used in API responses and portfolio views.""" + + ticker: str + qty: float + avg_entry: float + current_price: float + unrealized_pnl: float + market_value: float + + model_config = {"from_attributes": True} + + +class AccountInfo(BaseModel): + """Account-level summary from the brokerage.""" + + equity: float + cash: float + buying_power: float + portfolio_value: float + + model_config = {"from_attributes": True} + + +# --------------------------------------------------------------------------- +# Redis Stream message schemas +# --------------------------------------------------------------------------- + + +class TradeSignal(BaseModel): + """Published to ``signals:generated`` by the signal generator.""" + + ticker: str + direction: SignalDirection + strength: float = Field(ge=0.0, le=1.0) + strategy_sources: list[str] + sentiment_context: dict[str, Any] | None = None + timestamp: datetime + + model_config = {"from_attributes": True} + + +class TradeExecution(BaseModel): + """Published to ``trades:executed`` by the trade executor.""" + + trade_id: UUID + ticker: str + side: OrderSide + qty: float + price: float + status: OrderStatus + signal_id: UUID | None = None + strategy_id: UUID | None = None + timestamp: datetime + + model_config = {"from_attributes": True} + + +class OHLCVBar(BaseModel): + """Single OHLCV bar.""" + + timestamp: datetime + open: float + high: float + low: float + close: float + volume: float + + +class MarketSnapshot(BaseModel): + """Snapshot of market data for a single ticker — used by strategies.""" + + ticker: str + current_price: float + open: float + high: float + low: float + close: float + volume: float + sma_20: float | None = None + sma_50: float | None = None + rsi: float | None = None + bars: list[dict[str, Any]] = Field(default_factory=list) + + model_config = {"from_attributes": True} + + +class SentimentContext(BaseModel): + """Aggregated sentiment for a ticker — passed to strategies.""" + + ticker: str + avg_score: float = Field(ge=-1.0, le=1.0) + article_count: int = Field(ge=0) + recent_scores: list[float] = Field(default_factory=list) + avg_confidence: float = Field(ge=0.0, le=1.0) + + model_config = {"from_attributes": True} diff --git a/tests/test_schemas.py b/tests/test_schemas.py new file mode 100644 index 0000000..d44d248 --- /dev/null +++ b/tests/test_schemas.py @@ -0,0 +1,586 @@ +"""Tests for Pydantic schemas — serialization round-trips and validation constraints.""" + +import uuid +from datetime import datetime, timezone + +import pytest +from pydantic import ValidationError + +from shared.schemas.trading import ( + AccountInfo, + MarketSnapshot, + OrderRequest, + OrderResult, + OrderSide, + OrderStatus, + OrderType, + PositionInfo, + SentimentContext, + SignalDirection, + TradeExecution, + TradeSignal, +) +from shared.schemas.news import RawArticle, ScoredArticle +from shared.schemas.learning import TradeOutcomeSchema, WeightAdjustment +from shared.schemas.auth import LoginRequest, RegisterRequest, TokenResponse + + +# --------------------------------------------------------------------------- +# Trading schemas +# --------------------------------------------------------------------------- + + +class TestOrderRequest: + def test_valid_market_order(self) -> None: + o = OrderRequest(ticker="AAPL", side=OrderSide.BUY, qty=10.0) + assert o.order_type == OrderType.MARKET + assert o.limit_price is None + + def test_valid_limit_order(self) -> None: + o = OrderRequest( + ticker="TSLA", + side=OrderSide.SELL, + qty=5.0, + order_type=OrderType.LIMIT, + limit_price=250.50, + ) + assert o.limit_price == 250.50 + + def test_qty_must_be_positive(self) -> None: + with pytest.raises(ValidationError): + OrderRequest(ticker="AAPL", side=OrderSide.BUY, qty=0) + + def test_qty_must_not_be_negative(self) -> None: + with pytest.raises(ValidationError): + OrderRequest(ticker="AAPL", side=OrderSide.BUY, qty=-5) + + def test_serialization_round_trip(self) -> None: + o = OrderRequest( + ticker="GOOG", + side=OrderSide.BUY, + qty=3.0, + order_type=OrderType.STOP, + stop_price=100.0, + ) + data = o.model_dump() + restored = OrderRequest.model_validate(data) + assert restored == o + + def test_json_round_trip(self) -> None: + o = OrderRequest(ticker="META", side=OrderSide.SELL, qty=1.5) + json_str = o.model_dump_json() + restored = OrderRequest.model_validate_json(json_str) + assert restored == o + + +class TestOrderResult: + def test_valid_result(self) -> None: + now = datetime.now(timezone.utc) + r = OrderResult( + order_id="ord-123", + ticker="AAPL", + side=OrderSide.BUY, + qty=10.0, + filled_price=150.25, + status=OrderStatus.FILLED, + timestamp=now, + ) + assert r.filled_price == 150.25 + assert r.status == OrderStatus.FILLED + + def test_pending_no_fill(self) -> None: + now = datetime.now(timezone.utc) + r = OrderResult( + order_id="ord-456", + ticker="TSLA", + side=OrderSide.SELL, + qty=5.0, + status=OrderStatus.PENDING, + timestamp=now, + ) + assert r.filled_price is None + + +class TestPositionInfo: + def test_valid_position(self) -> None: + p = PositionInfo( + ticker="NVDA", + qty=20.0, + avg_entry=800.0, + current_price=850.0, + unrealized_pnl=1000.0, + market_value=17000.0, + ) + assert p.market_value == 17000.0 + + def test_serialization_round_trip(self) -> None: + p = PositionInfo( + ticker="AMZN", + qty=10.0, + avg_entry=180.0, + current_price=185.0, + unrealized_pnl=50.0, + market_value=1850.0, + ) + restored = PositionInfo.model_validate(p.model_dump()) + assert restored == p + + +class TestAccountInfo: + def test_valid_account(self) -> None: + a = AccountInfo( + equity=100_000.0, + cash=25_000.0, + buying_power=50_000.0, + portfolio_value=100_000.0, + ) + assert a.equity == 100_000.0 + + +class TestTradeSignal: + def test_valid_signal(self) -> None: + now = datetime.now(timezone.utc) + s = TradeSignal( + ticker="AAPL", + direction=SignalDirection.LONG, + strength=0.85, + strategy_sources=["momentum", "news_driven"], + timestamp=now, + ) + assert s.strength == 0.85 + assert s.sentiment_context is None + + def test_strength_must_be_in_range(self) -> None: + now = datetime.now(timezone.utc) + with pytest.raises(ValidationError): + TradeSignal( + ticker="AAPL", + direction=SignalDirection.LONG, + strength=1.5, + strategy_sources=["momentum"], + timestamp=now, + ) + + def test_strength_lower_bound(self) -> None: + now = datetime.now(timezone.utc) + with pytest.raises(ValidationError): + TradeSignal( + ticker="AAPL", + direction=SignalDirection.SHORT, + strength=-0.1, + strategy_sources=["mean_reversion"], + timestamp=now, + ) + + def test_json_round_trip(self) -> None: + now = datetime.now(timezone.utc) + s = TradeSignal( + ticker="TSLA", + direction=SignalDirection.SHORT, + strength=0.6, + strategy_sources=["mean_reversion"], + sentiment_context={"avg_score": -0.4}, + timestamp=now, + ) + restored = TradeSignal.model_validate_json(s.model_dump_json()) + assert restored == s + + +class TestTradeExecution: + def test_valid_execution(self) -> None: + now = datetime.now(timezone.utc) + tid = uuid.uuid4() + sid = uuid.uuid4() + e = TradeExecution( + trade_id=tid, + ticker="AAPL", + side=OrderSide.BUY, + qty=10.0, + price=150.0, + status=OrderStatus.FILLED, + signal_id=sid, + strategy_id=uuid.uuid4(), + timestamp=now, + ) + assert e.trade_id == tid + assert e.signal_id == sid + + def test_optional_ids(self) -> None: + now = datetime.now(timezone.utc) + e = TradeExecution( + trade_id=uuid.uuid4(), + ticker="GOOG", + side=OrderSide.SELL, + qty=5.0, + price=2800.0, + status=OrderStatus.FILLED, + timestamp=now, + ) + assert e.signal_id is None + assert e.strategy_id is None + + +class TestMarketSnapshot: + def test_valid_snapshot(self) -> None: + ms = MarketSnapshot( + ticker="AAPL", + current_price=150.0, + open=148.0, + high=152.0, + low=147.0, + close=150.0, + volume=5_000_000.0, + sma_20=149.5, + rsi=55.0, + ) + assert ms.sma_50 is None + assert ms.bars == [] + + def test_with_bars(self) -> None: + ms = MarketSnapshot( + ticker="TSLA", + current_price=250.0, + open=248.0, + high=255.0, + low=245.0, + close=250.0, + volume=10_000_000.0, + bars=[ + {"timestamp": "2026-01-01T10:00:00Z", "open": 248, "close": 250} + ], + ) + assert len(ms.bars) == 1 + + +class TestSentimentContext: + def test_valid_context(self) -> None: + sc = SentimentContext( + ticker="AAPL", + avg_score=0.65, + article_count=5, + recent_scores=[0.5, 0.7, 0.8], + avg_confidence=0.85, + ) + assert sc.article_count == 5 + + def test_score_range_validation(self) -> None: + with pytest.raises(ValidationError): + SentimentContext( + ticker="AAPL", + avg_score=1.5, # Out of range + article_count=1, + avg_confidence=0.5, + ) + + def test_negative_score_in_range(self) -> None: + sc = SentimentContext( + ticker="TSLA", + avg_score=-0.8, + article_count=3, + avg_confidence=0.9, + ) + assert sc.avg_score == -0.8 + + def test_confidence_range_validation(self) -> None: + with pytest.raises(ValidationError): + SentimentContext( + ticker="AAPL", + avg_score=0.5, + article_count=1, + avg_confidence=1.5, # Out of range + ) + + def test_article_count_non_negative(self) -> None: + with pytest.raises(ValidationError): + SentimentContext( + ticker="AAPL", + avg_score=0.0, + article_count=-1, + avg_confidence=0.5, + ) + + +# --------------------------------------------------------------------------- +# News schemas +# --------------------------------------------------------------------------- + + +class TestRawArticle: + def test_valid_article(self) -> None: + now = datetime.now(timezone.utc) + a = RawArticle( + source="reuters", + url="https://reuters.com/article/1", + title="Market Rally Continues", + content="Stocks rose sharply today...", + published_at=now, + fetched_at=now, + content_hash="sha256abcdef1234567890", + ) + assert a.source == "reuters" + + def test_published_at_optional(self) -> None: + now = datetime.now(timezone.utc) + a = RawArticle( + source="reddit", + url="https://reddit.com/r/stocks/1", + title="DD on TSLA", + content="Here is my analysis...", + fetched_at=now, + content_hash="hash123", + ) + assert a.published_at is None + + def test_json_round_trip(self) -> None: + now = datetime.now(timezone.utc) + a = RawArticle( + source="yahoo", + url="https://finance.yahoo.com/1", + title="Earnings Beat", + content="Apple beat earnings...", + published_at=now, + fetched_at=now, + content_hash="hash456", + ) + restored = RawArticle.model_validate_json(a.model_dump_json()) + assert restored == a + + def test_required_fields(self) -> None: + with pytest.raises(ValidationError): + RawArticle(source="reuters") # type: ignore[call-arg] + + +class TestScoredArticle: + def test_valid_scored_article(self) -> None: + now = datetime.now(timezone.utc) + sa = ScoredArticle( + source="reuters", + url="https://reuters.com/article/1", + title="Apple Earnings Beat", + content="Apple reported...", + published_at=now, + fetched_at=now, + content_hash="hash789", + ticker="AAPL", + sentiment_score=0.85, + confidence=0.92, + model_used="finbert", + entities=["Apple Inc", "Tim Cook"], + ) + assert sa.sentiment_score == 0.85 + assert sa.entities == ["Apple Inc", "Tim Cook"] + + def test_sentiment_score_range(self) -> None: + now = datetime.now(timezone.utc) + with pytest.raises(ValidationError): + ScoredArticle( + source="reuters", + url="https://reuters.com/1", + title="Test", + content="Test content", + fetched_at=now, + content_hash="hash", + ticker="AAPL", + sentiment_score=1.5, # Out of range + confidence=0.5, + model_used="finbert", + ) + + def test_confidence_range(self) -> None: + now = datetime.now(timezone.utc) + with pytest.raises(ValidationError): + ScoredArticle( + source="reuters", + url="https://reuters.com/1", + title="Test", + content="Test content", + fetched_at=now, + content_hash="hash", + ticker="AAPL", + sentiment_score=0.5, + confidence=-0.1, # Out of range + model_used="finbert", + ) + + def test_negative_sentiment(self) -> None: + now = datetime.now(timezone.utc) + sa = ScoredArticle( + source="reddit", + url="https://reddit.com/1", + title="Bad news", + content="Terrible quarter...", + fetched_at=now, + content_hash="hashN", + ticker="TSLA", + sentiment_score=-0.9, + confidence=0.8, + model_used="ollama", + ) + assert sa.sentiment_score == -0.9 + + def test_json_round_trip(self) -> None: + now = datetime.now(timezone.utc) + sa = ScoredArticle( + source="yahoo", + url="https://yahoo.com/1", + title="Headline", + content="Body text", + fetched_at=now, + content_hash="hashRT", + ticker="NVDA", + sentiment_score=0.3, + confidence=0.7, + model_used="finbert", + entities=["NVIDIA"], + ) + restored = ScoredArticle.model_validate_json(sa.model_dump_json()) + assert restored == sa + + +# --------------------------------------------------------------------------- +# Learning schemas +# --------------------------------------------------------------------------- + + +class TestTradeOutcomeSchema: + def test_valid_outcome(self) -> None: + o = TradeOutcomeSchema( + trade_id=uuid.uuid4(), + hold_duration_seconds=14400.0, + realized_pnl=250.50, + roi_pct=2.5, + was_profitable=True, + ) + assert o.was_profitable is True + assert o.hold_duration_seconds == 14400.0 + + def test_hold_duration_non_negative(self) -> None: + with pytest.raises(ValidationError): + TradeOutcomeSchema( + trade_id=uuid.uuid4(), + hold_duration_seconds=-1.0, + realized_pnl=100.0, + roi_pct=1.0, + was_profitable=True, + ) + + def test_losing_trade(self) -> None: + o = TradeOutcomeSchema( + trade_id=uuid.uuid4(), + hold_duration_seconds=3600.0, + realized_pnl=-150.0, + roi_pct=-3.0, + was_profitable=False, + ) + assert o.was_profitable is False + assert o.realized_pnl == -150.0 + + def test_json_round_trip(self) -> None: + o = TradeOutcomeSchema( + trade_id=uuid.uuid4(), + hold_duration_seconds=7200.0, + realized_pnl=500.0, + roi_pct=5.0, + was_profitable=True, + ) + restored = TradeOutcomeSchema.model_validate_json(o.model_dump_json()) + assert restored == o + + +class TestWeightAdjustment: + def test_valid_adjustment(self) -> None: + now = datetime.now(timezone.utc) + wa = WeightAdjustment( + strategy_id=uuid.uuid4(), + strategy_name="momentum", + old_weight=0.33, + new_weight=0.38, + reason="Positive reward signal from recent trades", + reward_signal=0.72, + timestamp=now, + ) + assert wa.old_weight == 0.33 + assert wa.new_weight == 0.38 + + def test_required_fields(self) -> None: + with pytest.raises(ValidationError): + WeightAdjustment( + strategy_id=uuid.uuid4(), + strategy_name="momentum", + ) # type: ignore[call-arg] + + def test_json_round_trip(self) -> None: + now = datetime.now(timezone.utc) + wa = WeightAdjustment( + strategy_id=uuid.uuid4(), + strategy_name="mean_reversion", + old_weight=0.30, + new_weight=0.25, + reason="Poor recent performance", + reward_signal=-0.4, + timestamp=now, + ) + restored = WeightAdjustment.model_validate_json(wa.model_dump_json()) + assert restored == wa + + +# --------------------------------------------------------------------------- +# Auth schemas +# --------------------------------------------------------------------------- + + +class TestRegisterRequest: + def test_valid_registration(self) -> None: + r = RegisterRequest(username="trader1", display_name="Top Trader") + assert r.username == "trader1" + assert r.display_name == "Top Trader" + + def test_display_name_optional(self) -> None: + r = RegisterRequest(username="trader2") + assert r.display_name is None + + def test_username_required(self) -> None: + with pytest.raises(ValidationError): + RegisterRequest(username="") # min_length=1 + + def test_username_max_length(self) -> None: + with pytest.raises(ValidationError): + RegisterRequest(username="x" * 101) # max_length=100 + + +class TestLoginRequest: + def test_valid_login(self) -> None: + l = LoginRequest(username="trader1") + assert l.username == "trader1" + + def test_username_required(self) -> None: + with pytest.raises(ValidationError): + LoginRequest(username="") + + +class TestTokenResponse: + def test_valid_response(self) -> None: + t = TokenResponse( + access_token="eyJ...", + refresh_token="eyR...", + ) + assert t.token_type == "bearer" + assert t.expires_in == 900 # default 15 min + + def test_custom_expiry(self) -> None: + t = TokenResponse( + access_token="eyJ...", + refresh_token="eyR...", + expires_in=3600, + ) + assert t.expires_in == 3600 + + def test_json_round_trip(self) -> None: + t = TokenResponse( + access_token="access123", + refresh_token="refresh456", + token_type="bearer", + expires_in=1800, + ) + restored = TokenResponse.model_validate_json(t.model_dump_json()) + assert restored == t