diff --git a/shared/schemas/kevin.py b/shared/schemas/kevin.py new file mode 100644 index 0000000..f90179b --- /dev/null +++ b/shared/schemas/kevin.py @@ -0,0 +1,66 @@ +"""Pydantic schemas for the Kevin strategy. + +Used by KevinStrategy.evaluate_mention as input/output contracts and by +the live signal bridge + backtest engine to talk to the strategy. +""" + +from __future__ import annotations + +from decimal import Decimal +from enum import Enum + +from pydantic import BaseModel, ConfigDict, Field, model_validator + + +class KevinDecisionType(str, Enum): + OPEN_LONG = "open_long" # new position or top-up + CLOSE_LONG = "close_long" # exit existing long + NO_OP = "no_op" # filter says skip + + +class KevinDecision(BaseModel): + """A single trade decision emitted by KevinStrategy.evaluate_mention.""" + + model_config = ConfigDict(frozen=True) + + decision: KevinDecisionType + symbol: str = Field(min_length=1, max_length=16) + target_dollars: Decimal | None = None # required for OPEN_LONG + holding_days: int | None = None # required for OPEN_LONG + effective_conviction: Decimal | None = None # post-aggregation, 0-1 + rationale: str # one-line audit string + + @model_validator(mode="after") + def _open_long_requires_target_dollars(self) -> "KevinDecision": + if self.decision == KevinDecisionType.OPEN_LONG: + if self.target_dollars is None: + raise ValueError("OPEN_LONG requires target_dollars") + if self.holding_days is None: + raise ValueError("OPEN_LONG requires holding_days") + if self.target_dollars <= 0: + raise ValueError("target_dollars must be positive") + return self + + +class KevinAccountState(BaseModel): + """Snapshot of the account passed to KevinStrategy.evaluate_mention. + + The bridge populates this from live Alpaca account + Redis counters; the + backtest populates it from the simulated portfolio state. Same shape. + """ + + model_config = ConfigDict(frozen=True) + + equity_usd: Decimal + cash_usd: Decimal + held_positions: dict[str, Decimal] # symbol -> cost-basis $ + blocklisted_symbols: frozenset[str] | set[str] + daily_trade_count: int + daily_alloc_usd: Decimal + paused: bool + + def is_held(self, symbol: str) -> bool: + return symbol in self.held_positions and self.held_positions[symbol] > 0 + + def is_blocklisted(self, symbol: str) -> bool: + return symbol in self.blocklisted_symbols diff --git a/tests/shared/__init__.py b/tests/shared/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/shared/schemas/__init__.py b/tests/shared/schemas/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/shared/schemas/test_kevin.py b/tests/shared/schemas/test_kevin.py new file mode 100644 index 0000000..6183d3b --- /dev/null +++ b/tests/shared/schemas/test_kevin.py @@ -0,0 +1,59 @@ +"""Tests for KevinDecision + KevinAccountState pydantic schemas.""" + +from decimal import Decimal + +import pytest +from pydantic import ValidationError + +from shared.schemas.kevin import ( + KevinAccountState, + KevinDecision, + KevinDecisionType, +) + + +def test_kevin_decision_open_long_requires_target_dollars(): + d = KevinDecision( + decision=KevinDecisionType.OPEN_LONG, + symbol="NVDA", + target_dollars=Decimal("2000"), + holding_days=10, + effective_conviction=Decimal("0.75"), + rationale="conv 0.7 + 1 boost", + ) + assert d.symbol == "NVDA" + assert d.target_dollars == Decimal("2000") + + +def test_kevin_decision_close_long_does_not_require_target_dollars(): + d = KevinDecision( + decision=KevinDecisionType.CLOSE_LONG, + symbol="NVDA", + rationale="kevin reverse", + ) + assert d.target_dollars is None + + +def test_kevin_decision_open_long_rejects_missing_target_dollars(): + with pytest.raises(ValidationError, match="target_dollars"): + KevinDecision( + decision=KevinDecisionType.OPEN_LONG, + symbol="NVDA", + rationale="missing $", + ) + + +def test_kevin_account_state_held_symbols_lookup(): + state = KevinAccountState( + equity_usd=Decimal("100000"), + cash_usd=Decimal("80000"), + held_positions={"NVDA": Decimal("5000"), "INTC": Decimal("2000")}, + blocklisted_symbols={"WMT"}, + daily_trade_count=2, + daily_alloc_usd=Decimal("4000"), + paused=False, + ) + assert state.is_held("NVDA") + assert not state.is_held("AAPL") + assert state.is_blocklisted("WMT") + assert not state.is_blocklisted("NVDA")