"""Tests for Pydantic schemas — serialization round-trips and validation constraints.""" import uuid from datetime import datetime, timezone import pytest from pydantic import ValidationError from shared.schemas.trading import ( AccountInfo, MarketSnapshot, OrderRequest, OrderResult, OrderSide, OrderStatus, OrderType, PositionInfo, SentimentContext, SignalDirection, TradeExecution, TradeSignal, ) from shared.schemas.news import RawArticle, ScoredArticle from shared.schemas.learning import TradeOutcomeSchema, WeightAdjustment from shared.schemas.auth import LoginRequest, RegisterRequest, TokenResponse # --------------------------------------------------------------------------- # Trading schemas # --------------------------------------------------------------------------- class TestOrderRequest: def test_valid_market_order(self) -> None: o = OrderRequest(ticker="AAPL", side=OrderSide.BUY, qty=10.0) assert o.order_type == OrderType.MARKET assert o.limit_price is None def test_valid_limit_order(self) -> None: o = OrderRequest( ticker="TSLA", side=OrderSide.SELL, qty=5.0, order_type=OrderType.LIMIT, limit_price=250.50, ) assert o.limit_price == 250.50 def test_qty_must_be_positive(self) -> None: with pytest.raises(ValidationError): OrderRequest(ticker="AAPL", side=OrderSide.BUY, qty=0) def test_qty_must_not_be_negative(self) -> None: with pytest.raises(ValidationError): OrderRequest(ticker="AAPL", side=OrderSide.BUY, qty=-5) def test_serialization_round_trip(self) -> None: o = OrderRequest( ticker="GOOG", side=OrderSide.BUY, qty=3.0, order_type=OrderType.STOP, stop_price=100.0, ) data = o.model_dump() restored = OrderRequest.model_validate(data) assert restored == o def test_json_round_trip(self) -> None: o = OrderRequest(ticker="META", side=OrderSide.SELL, qty=1.5) json_str = o.model_dump_json() restored = OrderRequest.model_validate_json(json_str) assert restored == o class TestOrderResult: def test_valid_result(self) -> None: now = datetime.now(timezone.utc) r = OrderResult( order_id="ord-123", ticker="AAPL", side=OrderSide.BUY, qty=10.0, filled_price=150.25, status=OrderStatus.FILLED, timestamp=now, ) assert r.filled_price == 150.25 assert r.status == OrderStatus.FILLED def test_pending_no_fill(self) -> None: now = datetime.now(timezone.utc) r = OrderResult( order_id="ord-456", ticker="TSLA", side=OrderSide.SELL, qty=5.0, status=OrderStatus.PENDING, timestamp=now, ) assert r.filled_price is None class TestPositionInfo: def test_valid_position(self) -> None: p = PositionInfo( ticker="NVDA", qty=20.0, avg_entry=800.0, current_price=850.0, unrealized_pnl=1000.0, market_value=17000.0, ) assert p.market_value == 17000.0 def test_serialization_round_trip(self) -> None: p = PositionInfo( ticker="AMZN", qty=10.0, avg_entry=180.0, current_price=185.0, unrealized_pnl=50.0, market_value=1850.0, ) restored = PositionInfo.model_validate(p.model_dump()) assert restored == p class TestAccountInfo: def test_valid_account(self) -> None: a = AccountInfo( equity=100_000.0, cash=25_000.0, buying_power=50_000.0, portfolio_value=100_000.0, ) assert a.equity == 100_000.0 class TestTradeSignal: def test_valid_signal(self) -> None: now = datetime.now(timezone.utc) s = TradeSignal( ticker="AAPL", direction=SignalDirection.LONG, strength=0.85, strategy_sources=["momentum", "news_driven"], timestamp=now, ) assert s.strength == 0.85 assert s.sentiment_context is None def test_strength_must_be_in_range(self) -> None: now = datetime.now(timezone.utc) with pytest.raises(ValidationError): TradeSignal( ticker="AAPL", direction=SignalDirection.LONG, strength=1.5, strategy_sources=["momentum"], timestamp=now, ) def test_strength_lower_bound(self) -> None: now = datetime.now(timezone.utc) with pytest.raises(ValidationError): TradeSignal( ticker="AAPL", direction=SignalDirection.SHORT, strength=-0.1, strategy_sources=["mean_reversion"], timestamp=now, ) def test_json_round_trip(self) -> None: now = datetime.now(timezone.utc) s = TradeSignal( ticker="TSLA", direction=SignalDirection.SHORT, strength=0.6, strategy_sources=["mean_reversion"], sentiment_context={"avg_score": -0.4}, timestamp=now, ) restored = TradeSignal.model_validate_json(s.model_dump_json()) assert restored == s class TestTradeExecution: def test_valid_execution(self) -> None: now = datetime.now(timezone.utc) tid = uuid.uuid4() sid = uuid.uuid4() e = TradeExecution( trade_id=tid, ticker="AAPL", side=OrderSide.BUY, qty=10.0, price=150.0, status=OrderStatus.FILLED, signal_id=sid, strategy_id=uuid.uuid4(), timestamp=now, ) assert e.trade_id == tid assert e.signal_id == sid def test_optional_ids(self) -> None: now = datetime.now(timezone.utc) e = TradeExecution( trade_id=uuid.uuid4(), ticker="GOOG", side=OrderSide.SELL, qty=5.0, price=2800.0, status=OrderStatus.FILLED, timestamp=now, ) assert e.signal_id is None assert e.strategy_id is None class TestMarketSnapshot: def test_valid_snapshot(self) -> None: ms = MarketSnapshot( ticker="AAPL", current_price=150.0, open=148.0, high=152.0, low=147.0, close=150.0, volume=5_000_000.0, sma_20=149.5, rsi=55.0, ) assert ms.sma_50 is None assert ms.bars == [] def test_with_bars(self) -> None: ms = MarketSnapshot( ticker="TSLA", current_price=250.0, open=248.0, high=255.0, low=245.0, close=250.0, volume=10_000_000.0, bars=[ {"timestamp": "2026-01-01T10:00:00Z", "open": 248, "close": 250} ], ) assert len(ms.bars) == 1 class TestSentimentContext: def test_valid_context(self) -> None: sc = SentimentContext( ticker="AAPL", avg_score=0.65, article_count=5, recent_scores=[0.5, 0.7, 0.8], avg_confidence=0.85, ) assert sc.article_count == 5 def test_score_range_validation(self) -> None: with pytest.raises(ValidationError): SentimentContext( ticker="AAPL", avg_score=1.5, # Out of range article_count=1, avg_confidence=0.5, ) def test_negative_score_in_range(self) -> None: sc = SentimentContext( ticker="TSLA", avg_score=-0.8, article_count=3, avg_confidence=0.9, ) assert sc.avg_score == -0.8 def test_confidence_range_validation(self) -> None: with pytest.raises(ValidationError): SentimentContext( ticker="AAPL", avg_score=0.5, article_count=1, avg_confidence=1.5, # Out of range ) def test_article_count_non_negative(self) -> None: with pytest.raises(ValidationError): SentimentContext( ticker="AAPL", avg_score=0.0, article_count=-1, avg_confidence=0.5, ) # --------------------------------------------------------------------------- # News schemas # --------------------------------------------------------------------------- class TestRawArticle: def test_valid_article(self) -> None: now = datetime.now(timezone.utc) a = RawArticle( source="reuters", url="https://reuters.com/article/1", title="Market Rally Continues", content="Stocks rose sharply today...", published_at=now, fetched_at=now, content_hash="sha256abcdef1234567890", ) assert a.source == "reuters" def test_published_at_optional(self) -> None: now = datetime.now(timezone.utc) a = RawArticle( source="reddit", url="https://reddit.com/r/stocks/1", title="DD on TSLA", content="Here is my analysis...", fetched_at=now, content_hash="hash123", ) assert a.published_at is None def test_json_round_trip(self) -> None: now = datetime.now(timezone.utc) a = RawArticle( source="yahoo", url="https://finance.yahoo.com/1", title="Earnings Beat", content="Apple beat earnings...", published_at=now, fetched_at=now, content_hash="hash456", ) restored = RawArticle.model_validate_json(a.model_dump_json()) assert restored == a def test_required_fields(self) -> None: with pytest.raises(ValidationError): RawArticle(source="reuters") # type: ignore[call-arg] class TestScoredArticle: def test_valid_scored_article(self) -> None: now = datetime.now(timezone.utc) sa = ScoredArticle( source="reuters", url="https://reuters.com/article/1", title="Apple Earnings Beat", content="Apple reported...", published_at=now, fetched_at=now, content_hash="hash789", ticker="AAPL", sentiment_score=0.85, confidence=0.92, model_used="finbert", entities=["Apple Inc", "Tim Cook"], ) assert sa.sentiment_score == 0.85 assert sa.entities == ["Apple Inc", "Tim Cook"] def test_sentiment_score_range(self) -> None: now = datetime.now(timezone.utc) with pytest.raises(ValidationError): ScoredArticle( source="reuters", url="https://reuters.com/1", title="Test", content="Test content", fetched_at=now, content_hash="hash", ticker="AAPL", sentiment_score=1.5, # Out of range confidence=0.5, model_used="finbert", ) def test_confidence_range(self) -> None: now = datetime.now(timezone.utc) with pytest.raises(ValidationError): ScoredArticle( source="reuters", url="https://reuters.com/1", title="Test", content="Test content", fetched_at=now, content_hash="hash", ticker="AAPL", sentiment_score=0.5, confidence=-0.1, # Out of range model_used="finbert", ) def test_negative_sentiment(self) -> None: now = datetime.now(timezone.utc) sa = ScoredArticle( source="reddit", url="https://reddit.com/1", title="Bad news", content="Terrible quarter...", fetched_at=now, content_hash="hashN", ticker="TSLA", sentiment_score=-0.9, confidence=0.8, model_used="ollama", ) assert sa.sentiment_score == -0.9 def test_json_round_trip(self) -> None: now = datetime.now(timezone.utc) sa = ScoredArticle( source="yahoo", url="https://yahoo.com/1", title="Headline", content="Body text", fetched_at=now, content_hash="hashRT", ticker="NVDA", sentiment_score=0.3, confidence=0.7, model_used="finbert", entities=["NVIDIA"], ) restored = ScoredArticle.model_validate_json(sa.model_dump_json()) assert restored == sa # --------------------------------------------------------------------------- # Learning schemas # --------------------------------------------------------------------------- class TestTradeOutcomeSchema: def test_valid_outcome(self) -> None: o = TradeOutcomeSchema( trade_id=uuid.uuid4(), hold_duration_seconds=14400.0, realized_pnl=250.50, roi_pct=2.5, was_profitable=True, ) assert o.was_profitable is True assert o.hold_duration_seconds == 14400.0 def test_hold_duration_non_negative(self) -> None: with pytest.raises(ValidationError): TradeOutcomeSchema( trade_id=uuid.uuid4(), hold_duration_seconds=-1.0, realized_pnl=100.0, roi_pct=1.0, was_profitable=True, ) def test_losing_trade(self) -> None: o = TradeOutcomeSchema( trade_id=uuid.uuid4(), hold_duration_seconds=3600.0, realized_pnl=-150.0, roi_pct=-3.0, was_profitable=False, ) assert o.was_profitable is False assert o.realized_pnl == -150.0 def test_json_round_trip(self) -> None: o = TradeOutcomeSchema( trade_id=uuid.uuid4(), hold_duration_seconds=7200.0, realized_pnl=500.0, roi_pct=5.0, was_profitable=True, ) restored = TradeOutcomeSchema.model_validate_json(o.model_dump_json()) assert restored == o class TestWeightAdjustment: def test_valid_adjustment(self) -> None: now = datetime.now(timezone.utc) wa = WeightAdjustment( strategy_id=uuid.uuid4(), strategy_name="momentum", old_weight=0.33, new_weight=0.38, reason="Positive reward signal from recent trades", reward_signal=0.72, timestamp=now, ) assert wa.old_weight == 0.33 assert wa.new_weight == 0.38 def test_required_fields(self) -> None: with pytest.raises(ValidationError): WeightAdjustment( strategy_id=uuid.uuid4(), strategy_name="momentum", ) # type: ignore[call-arg] def test_json_round_trip(self) -> None: now = datetime.now(timezone.utc) wa = WeightAdjustment( strategy_id=uuid.uuid4(), strategy_name="mean_reversion", old_weight=0.30, new_weight=0.25, reason="Poor recent performance", reward_signal=-0.4, timestamp=now, ) restored = WeightAdjustment.model_validate_json(wa.model_dump_json()) assert restored == wa # --------------------------------------------------------------------------- # Auth schemas # --------------------------------------------------------------------------- class TestRegisterRequest: def test_valid_registration(self) -> None: r = RegisterRequest(username="trader1", display_name="Top Trader") assert r.username == "trader1" assert r.display_name == "Top Trader" def test_display_name_optional(self) -> None: r = RegisterRequest(username="trader2") assert r.display_name is None def test_username_required(self) -> None: with pytest.raises(ValidationError): RegisterRequest(username="") # min_length=1 def test_username_max_length(self) -> None: with pytest.raises(ValidationError): RegisterRequest(username="x" * 101) # max_length=100 class TestLoginRequest: def test_valid_login(self) -> None: l = LoginRequest(username="trader1") assert l.username == "trader1" def test_username_required(self) -> None: with pytest.raises(ValidationError): LoginRequest(username="") class TestTokenResponse: def test_valid_response(self) -> None: t = TokenResponse( access_token="eyJ...", refresh_token="eyR...", ) assert t.token_type == "bearer" assert t.expires_in == 900 # default 15 min def test_custom_expiry(self) -> None: t = TokenResponse( access_token="eyJ...", refresh_token="eyR...", expires_in=3600, ) assert t.expires_in == 3600 def test_json_round_trip(self) -> None: t = TokenResponse( access_token="access123", refresh_token="refresh456", token_type="bearer", expires_in=1800, ) restored = TokenResponse.model_validate_json(t.model_dump_json()) assert restored == t