diff --git a/alembic.ini b/alembic.ini new file mode 100644 index 0000000..9183524 --- /dev/null +++ b/alembic.ini @@ -0,0 +1,151 @@ +# A generic, single database configuration. + +[alembic] +# path to migration scripts. +# this is typically a path given in POSIX (e.g. forward slashes) +# format, relative to the token %(here)s which refers to the location of this +# ini file +script_location = %(here)s/alembic + +# template used to generate migration file names; The default value is %%(rev)s_%%(slug)s +# Uncomment the line below if you want the files to be prepended with date and time +# see https://alembic.sqlalchemy.org/en/latest/tutorial.html#editing-the-ini-file +# for all available tokens +# file_template = %%(year)d_%%(month).2d_%%(day).2d_%%(hour).2d%%(minute).2d-%%(rev)s_%%(slug)s +# Or organize into date-based subdirectories (requires recursive_version_locations = true) +# file_template = %%(year)d/%%(month).2d/%%(day).2d_%%(hour).2d%%(minute).2d_%%(second).2d_%%(rev)s_%%(slug)s + +# sys.path path, will be prepended to sys.path if present. +# defaults to the current working directory. for multiple paths, the path separator +# is defined by "path_separator" below. +prepend_sys_path = . + + +# timezone to use when rendering the date within the migration file +# as well as the filename. +# If specified, requires the tzdata library which can be installed by adding +# `alembic[tz]` to the pip requirements. +# string value is passed to ZoneInfo() +# leave blank for localtime +# timezone = + +# max length of characters to apply to the "slug" field +# truncate_slug_length = 40 + +# set to 'true' to run the environment during +# the 'revision' command, regardless of autogenerate +# revision_environment = false + +# set to 'true' to allow .pyc and .pyo files without +# a source .py file to be detected as revisions in the +# versions/ directory +# sourceless = false + +# version location specification; This defaults +# to /versions. When using multiple version +# directories, initial revisions must be specified with --version-path. +# The path separator used here should be the separator specified by "path_separator" +# below. +# version_locations = %(here)s/bar:%(here)s/bat:%(here)s/alembic/versions + +# path_separator; This indicates what character is used to split lists of file +# paths, including version_locations and prepend_sys_path within configparser +# files such as alembic.ini. +# The default rendered in new alembic.ini files is "os", which uses os.pathsep +# to provide os-dependent path splitting. +# +# Note that in order to support legacy alembic.ini files, this default does NOT +# take place if path_separator is not present in alembic.ini. If this +# option is omitted entirely, fallback logic is as follows: +# +# 1. Parsing of the version_locations option falls back to using the legacy +# "version_path_separator" key, which if absent then falls back to the legacy +# behavior of splitting on spaces and/or commas. +# 2. Parsing of the prepend_sys_path option falls back to the legacy +# behavior of splitting on spaces, commas, or colons. +# +# Valid values for path_separator are: +# +# path_separator = : +# path_separator = ; +# path_separator = space +# path_separator = newline +# +# Use os.pathsep. Default configuration used for new projects. +path_separator = os + +# set to 'true' to search source files recursively +# in each "version_locations" directory +# new in Alembic version 1.10 +# recursive_version_locations = false + +# the output encoding used when revision files +# are written from script.py.mako +# output_encoding = utf-8 + +# database URL. This is consumed by the user-maintained env.py script only. +# other means of configuring database URLs may be customized within the env.py +# file. +# Database URL is read from the TRADING_DATABASE_URL environment variable +# in alembic/env.py. The value here is a fallback only. +sqlalchemy.url = postgresql+asyncpg://trading:trading@localhost:5432/trading + + +[post_write_hooks] +# post_write_hooks defines scripts or Python functions that are run +# on newly generated revision scripts. See the documentation for further +# detail and examples + +# format using "black" - use the console_scripts runner, against the "black" entrypoint +# hooks = black +# black.type = console_scripts +# black.entrypoint = black +# black.options = -l 79 REVISION_SCRIPT_FILENAME + +# lint with attempts to fix using "ruff" - use the module runner, against the "ruff" module +# hooks = ruff +# ruff.type = module +# ruff.module = ruff +# ruff.options = check --fix REVISION_SCRIPT_FILENAME + +# Alternatively, use the exec runner to execute a binary found on your PATH +# hooks = ruff +# ruff.type = exec +# ruff.executable = ruff +# ruff.options = check --fix REVISION_SCRIPT_FILENAME + +# Logging configuration. This is also consumed by the user-maintained +# env.py script only. +[loggers] +keys = root,sqlalchemy,alembic + +[handlers] +keys = console + +[formatters] +keys = generic + +[logger_root] +level = WARNING +handlers = console +qualname = + +[logger_sqlalchemy] +level = WARNING +handlers = +qualname = sqlalchemy.engine + +[logger_alembic] +level = INFO +handlers = +qualname = alembic + +[handler_console] +class = StreamHandler +args = (sys.stderr,) +level = NOTSET +formatter = generic + +[formatter_generic] +format = %(levelname)-5.5s [%(name)s] %(message)s +datefmt = %H:%M:%S diff --git a/alembic/README b/alembic/README new file mode 100644 index 0000000..98e4f9c --- /dev/null +++ b/alembic/README @@ -0,0 +1 @@ +Generic single-database configuration. \ No newline at end of file diff --git a/alembic/__pycache__/env.cpython-314.pyc b/alembic/__pycache__/env.cpython-314.pyc new file mode 100644 index 0000000..489a3fc Binary files /dev/null and b/alembic/__pycache__/env.cpython-314.pyc differ diff --git a/alembic/env.py b/alembic/env.py new file mode 100644 index 0000000..60b9094 --- /dev/null +++ b/alembic/env.py @@ -0,0 +1,71 @@ +"""Alembic environment — async-aware, imports all shared models.""" + +import asyncio +import os +from logging.config import fileConfig + +from alembic import context +from sqlalchemy import pool +from sqlalchemy.ext.asyncio import async_engine_from_config + +# Import all models so metadata is populated for autogenerate. +from shared.models import Base # noqa: F401 +import shared.models # noqa: F401 — triggers side-effect imports of every model + +config = context.config + +# Override sqlalchemy.url from environment if available. +db_url = os.environ.get("TRADING_DATABASE_URL") +if db_url: + config.set_main_option("sqlalchemy.url", db_url) + +# Interpret the config file for Python logging. +if config.config_file_name is not None: + fileConfig(config.config_file_name) + +target_metadata = Base.metadata + + +def run_migrations_offline() -> None: + """Run migrations in 'offline' mode (emit SQL without a live connection).""" + url = config.get_main_option("sqlalchemy.url") + context.configure( + url=url, + target_metadata=target_metadata, + literal_binds=True, + dialect_opts={"paramstyle": "named"}, + ) + + with context.begin_transaction(): + context.run_migrations() + + +def do_run_migrations(connection) -> None: # noqa: ANN001 + context.configure(connection=connection, target_metadata=target_metadata) + with context.begin_transaction(): + context.run_migrations() + + +async def run_async_migrations() -> None: + """Run migrations in 'online' mode using an async engine.""" + connectable = async_engine_from_config( + config.get_section(config.config_ini_section, {}), + prefix="sqlalchemy.", + poolclass=pool.NullPool, + ) + + async with connectable.connect() as connection: + await connection.run_sync(do_run_migrations) + + await connectable.dispose() + + +def run_migrations_online() -> None: + """Entry-point for online migrations — delegates to the async helper.""" + asyncio.run(run_async_migrations()) + + +if context.is_offline_mode(): + run_migrations_offline() +else: + run_migrations_online() diff --git a/alembic/script.py.mako b/alembic/script.py.mako new file mode 100644 index 0000000..1101630 --- /dev/null +++ b/alembic/script.py.mako @@ -0,0 +1,28 @@ +"""${message} + +Revision ID: ${up_revision} +Revises: ${down_revision | comma,n} +Create Date: ${create_date} + +""" +from typing import Sequence, Union + +from alembic import op +import sqlalchemy as sa +${imports if imports else ""} + +# revision identifiers, used by Alembic. +revision: str = ${repr(up_revision)} +down_revision: Union[str, Sequence[str], None] = ${repr(down_revision)} +branch_labels: Union[str, Sequence[str], None] = ${repr(branch_labels)} +depends_on: Union[str, Sequence[str], None] = ${repr(depends_on)} + + +def upgrade() -> None: + """Upgrade schema.""" + ${upgrades if upgrades else "pass"} + + +def downgrade() -> None: + """Downgrade schema.""" + ${downgrades if downgrades else "pass"} diff --git a/alembic/versions/__pycache__/a1b2c3d4e5f6_initial_schema.cpython-314.pyc b/alembic/versions/__pycache__/a1b2c3d4e5f6_initial_schema.cpython-314.pyc new file mode 100644 index 0000000..7096b5d Binary files /dev/null and b/alembic/versions/__pycache__/a1b2c3d4e5f6_initial_schema.cpython-314.pyc differ diff --git a/alembic/versions/a1b2c3d4e5f6_initial_schema.py b/alembic/versions/a1b2c3d4e5f6_initial_schema.py new file mode 100644 index 0000000..3c71fea --- /dev/null +++ b/alembic/versions/a1b2c3d4e5f6_initial_schema.py @@ -0,0 +1,284 @@ +"""initial schema + +Revision ID: a1b2c3d4e5f6 +Revises: +Create Date: 2026-02-22 15:15:15.661206 + +""" +from typing import Sequence, Union + +import sqlalchemy as sa +from alembic import op +from sqlalchemy.dialects import postgresql + +# revision identifiers, used by Alembic. +revision: str = "a1b2c3d4e5f6" +down_revision: Union[str, Sequence[str], None] = None +branch_labels: Union[str, Sequence[str], None] = None +depends_on: Union[str, Sequence[str], None] = None + + +def upgrade() -> None: + """Create all tables for the trading bot.""" + + # --- Core trading tables --- + + op.create_table( + "strategies", + sa.Column("id", postgresql.UUID(as_uuid=True), primary_key=True), + sa.Column("name", sa.String(255), unique=True, nullable=False), + sa.Column("description", sa.Text, nullable=True), + sa.Column("current_weight", sa.Float, nullable=False, server_default="0.333"), + sa.Column("active", sa.Boolean, nullable=False, server_default=sa.text("true")), + sa.Column("created_at", sa.DateTime(timezone=True), server_default=sa.func.now()), + sa.Column("updated_at", sa.DateTime(timezone=True), server_default=sa.func.now()), + ) + + op.create_table( + "signals", + sa.Column("id", postgresql.UUID(as_uuid=True), primary_key=True), + sa.Column("ticker", sa.String(20), nullable=False, index=True), + sa.Column( + "direction", + sa.Enum("LONG", "SHORT", "NEUTRAL", name="signaldirection"), + nullable=False, + ), + sa.Column("strength", sa.Float, nullable=False), + sa.Column("strategy_sources", postgresql.JSON, nullable=True), + sa.Column("sentiment_score", sa.Float, nullable=True), + sa.Column("acted_on", sa.Boolean, nullable=False, server_default=sa.text("false")), + sa.Column( + "strategy_id", + postgresql.UUID(as_uuid=True), + sa.ForeignKey("strategies.id"), + nullable=True, + ), + sa.Column("created_at", sa.DateTime(timezone=True), server_default=sa.func.now()), + sa.Column("updated_at", sa.DateTime(timezone=True), server_default=sa.func.now()), + ) + + op.create_table( + "trades", + sa.Column("id", postgresql.UUID(as_uuid=True), primary_key=True), + sa.Column("ticker", sa.String(20), nullable=False, index=True), + sa.Column( + "side", + sa.Enum("BUY", "SELL", name="tradeside"), + nullable=False, + ), + sa.Column("qty", sa.Float, nullable=False), + sa.Column("price", sa.Float, nullable=False), + sa.Column("timestamp", sa.String, nullable=True), + sa.Column( + "strategy_id", + postgresql.UUID(as_uuid=True), + sa.ForeignKey("strategies.id"), + nullable=True, + ), + sa.Column( + "signal_id", + postgresql.UUID(as_uuid=True), + sa.ForeignKey("signals.id"), + nullable=True, + ), + sa.Column( + "status", + sa.Enum("PENDING", "FILLED", "CANCELLED", "REJECTED", name="tradestatus"), + nullable=False, + server_default="PENDING", + ), + sa.Column("pnl", sa.Float, nullable=True), + sa.Column("created_at", sa.DateTime(timezone=True), server_default=sa.func.now()), + sa.Column("updated_at", sa.DateTime(timezone=True), server_default=sa.func.now()), + ) + + op.create_table( + "positions", + sa.Column("id", postgresql.UUID(as_uuid=True), primary_key=True), + sa.Column("ticker", sa.String(20), unique=True, nullable=False), + sa.Column("qty", sa.Float, nullable=False), + sa.Column("avg_entry", sa.Float, nullable=False), + sa.Column("unrealized_pnl", sa.Float, nullable=True), + sa.Column("stop_loss", sa.Float, nullable=True), + sa.Column("take_profit", sa.Float, nullable=True), + sa.Column("created_at", sa.DateTime(timezone=True), server_default=sa.func.now()), + sa.Column("updated_at", sa.DateTime(timezone=True), server_default=sa.func.now()), + ) + + op.create_table( + "strategy_weight_history", + sa.Column("id", postgresql.UUID(as_uuid=True), primary_key=True), + sa.Column( + "strategy_id", + postgresql.UUID(as_uuid=True), + sa.ForeignKey("strategies.id"), + nullable=False, + ), + sa.Column("old_weight", sa.Float, nullable=False), + sa.Column("new_weight", sa.Float, nullable=False), + sa.Column("reason", sa.String(500), nullable=True), + sa.Column("created_at", sa.DateTime(timezone=True), server_default=sa.func.now()), + sa.Column("updated_at", sa.DateTime(timezone=True), server_default=sa.func.now()), + ) + + # --- News & sentiment --- + + op.create_table( + "articles", + sa.Column("id", postgresql.UUID(as_uuid=True), primary_key=True), + sa.Column("source", sa.String(100), nullable=False), + sa.Column("url", sa.Text, nullable=False), + sa.Column("title", sa.Text, nullable=False), + sa.Column("published_at", sa.DateTime(timezone=True), nullable=True), + sa.Column("fetched_at", sa.DateTime(timezone=True), nullable=False), + sa.Column("content_hash", sa.String(64), unique=True, nullable=False), + sa.Column("created_at", sa.DateTime(timezone=True), server_default=sa.func.now()), + sa.Column("updated_at", sa.DateTime(timezone=True), server_default=sa.func.now()), + ) + + op.create_table( + "article_sentiments", + sa.Column("id", postgresql.UUID(as_uuid=True), primary_key=True), + sa.Column( + "article_id", + postgresql.UUID(as_uuid=True), + sa.ForeignKey("articles.id"), + nullable=False, + ), + sa.Column("ticker", sa.String(20), nullable=False, index=True), + sa.Column("score", sa.Float, nullable=False), + sa.Column("confidence", sa.Float, nullable=False), + sa.Column("model_used", sa.String(50), nullable=False), + sa.Column("created_at", sa.DateTime(timezone=True), server_default=sa.func.now()), + sa.Column("updated_at", sa.DateTime(timezone=True), server_default=sa.func.now()), + ) + + # --- Learning --- + + op.create_table( + "trade_outcomes", + sa.Column("id", postgresql.UUID(as_uuid=True), primary_key=True), + sa.Column( + "trade_id", + postgresql.UUID(as_uuid=True), + sa.ForeignKey("trades.id"), + unique=True, + nullable=False, + ), + sa.Column("hold_duration", sa.Interval, nullable=True), + sa.Column("realized_pnl", sa.Float, nullable=False), + sa.Column("roi_pct", sa.Float, nullable=False), + sa.Column("was_profitable", sa.Boolean, nullable=False), + sa.Column("created_at", sa.DateTime(timezone=True), server_default=sa.func.now()), + sa.Column("updated_at", sa.DateTime(timezone=True), server_default=sa.func.now()), + ) + + op.create_table( + "learning_adjustments", + sa.Column("id", postgresql.UUID(as_uuid=True), primary_key=True), + sa.Column( + "strategy_id", + postgresql.UUID(as_uuid=True), + sa.ForeignKey("strategies.id"), + nullable=False, + ), + sa.Column("old_weight", sa.Float, nullable=False), + sa.Column("new_weight", sa.Float, nullable=False), + sa.Column("reason", sa.Text, nullable=True), + sa.Column("reward_signal", sa.Float, nullable=False), + sa.Column("created_at", sa.DateTime(timezone=True), server_default=sa.func.now()), + sa.Column("updated_at", sa.DateTime(timezone=True), server_default=sa.func.now()), + ) + + # --- Auth --- + + op.create_table( + "users", + sa.Column("id", postgresql.UUID(as_uuid=True), primary_key=True), + sa.Column("username", sa.String(100), unique=True, nullable=False), + sa.Column("display_name", sa.String(255), nullable=True), + sa.Column("created_at", sa.DateTime(timezone=True), server_default=sa.func.now()), + sa.Column("updated_at", sa.DateTime(timezone=True), server_default=sa.func.now()), + ) + + op.create_table( + "user_credentials", + sa.Column("id", postgresql.UUID(as_uuid=True), primary_key=True), + sa.Column( + "user_id", + postgresql.UUID(as_uuid=True), + sa.ForeignKey("users.id"), + nullable=False, + ), + sa.Column("credential_id", sa.String(512), unique=True, nullable=False), + sa.Column("public_key", sa.LargeBinary, nullable=False), + sa.Column("sign_count", sa.Integer, nullable=False, server_default="0"), + sa.Column("created_at", sa.DateTime(timezone=True), server_default=sa.func.now()), + sa.Column("updated_at", sa.DateTime(timezone=True), server_default=sa.func.now()), + ) + + # --- Timeseries (TimescaleDB hypertables) --- + + op.create_table( + "market_data", + sa.Column("timestamp", sa.DateTime(timezone=True), primary_key=True), + sa.Column("ticker", sa.String(20), primary_key=True), + sa.Column("open", sa.Float, nullable=False), + sa.Column("high", sa.Float, nullable=False), + sa.Column("low", sa.Float, nullable=False), + sa.Column("close", sa.Float, nullable=False), + sa.Column("volume", sa.Float, nullable=False), + ) + + op.create_table( + "portfolio_snapshots", + sa.Column("timestamp", sa.DateTime(timezone=True), primary_key=True), + sa.Column("total_value", sa.Float, nullable=False), + sa.Column("cash", sa.Float, nullable=False), + sa.Column("positions_value", sa.Float, nullable=False), + sa.Column("daily_pnl", sa.Float, nullable=False), + ) + + op.create_table( + "strategy_metrics", + sa.Column("timestamp", sa.DateTime(timezone=True), primary_key=True), + sa.Column( + "strategy_id", + postgresql.UUID(as_uuid=True), + sa.ForeignKey("strategies.id"), + primary_key=True, + ), + sa.Column("win_rate", sa.Float, nullable=False), + sa.Column("total_pnl", sa.Float, nullable=False), + sa.Column("trade_count", sa.Integer, nullable=False), + sa.Column("sharpe_ratio", sa.Float, nullable=True), + ) + + # Convert timeseries tables to TimescaleDB hypertables. + # These calls are idempotent-safe when the extension is loaded. + op.execute("SELECT create_hypertable('market_data', 'timestamp', if_not_exists => TRUE)") + op.execute("SELECT create_hypertable('portfolio_snapshots', 'timestamp', if_not_exists => TRUE)") + op.execute("SELECT create_hypertable('strategy_metrics', 'timestamp', if_not_exists => TRUE)") + + +def downgrade() -> None: + """Drop all tables in reverse dependency order.""" + op.drop_table("strategy_metrics") + op.drop_table("portfolio_snapshots") + op.drop_table("market_data") + op.drop_table("user_credentials") + op.drop_table("users") + op.drop_table("learning_adjustments") + op.drop_table("trade_outcomes") + op.drop_table("article_sentiments") + op.drop_table("articles") + op.drop_table("strategy_weight_history") + op.drop_table("positions") + op.drop_table("trades") + op.drop_table("signals") + op.drop_table("strategies") + + # Drop enums + sa.Enum(name="signaldirection").drop(op.get_bind(), checkfirst=True) + sa.Enum(name="tradeside").drop(op.get_bind(), checkfirst=True) + sa.Enum(name="tradestatus").drop(op.get_bind(), checkfirst=True) diff --git a/shared/db.py b/shared/db.py new file mode 100644 index 0000000..574d183 --- /dev/null +++ b/shared/db.py @@ -0,0 +1,22 @@ +"""SQLAlchemy async engine and session factory.""" + +from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker, create_async_engine + +from shared.config import BaseConfig + + +def create_db(config: BaseConfig) -> tuple: + """Create an async engine and session factory from the given config. + + Returns a ``(engine, session_factory)`` tuple. + """ + engine = create_async_engine( + config.database_url, + echo=config.log_level == "DEBUG", + ) + session_factory = async_sessionmaker( + engine, + class_=AsyncSession, + expire_on_commit=False, + ) + return engine, session_factory diff --git a/shared/models/__init__.py b/shared/models/__init__.py new file mode 100644 index 0000000..d461e61 --- /dev/null +++ b/shared/models/__init__.py @@ -0,0 +1,44 @@ +"""Shared SQLAlchemy models — import all models here so Alembic can discover them.""" + +from shared.models.base import Base, TimestampMixin +from shared.models.trading import ( + Signal, + SignalDirection, + Strategy, + StrategyWeightHistory, + Trade, + TradeSide, + TradeStatus, + Position, +) +from shared.models.news import Article, ArticleSentiment +from shared.models.learning import LearningAdjustment, TradeOutcome +from shared.models.auth import User, UserCredential +from shared.models.timeseries import MarketData, PortfolioSnapshot, StrategyMetric + +__all__ = [ + "Base", + "TimestampMixin", + # Trading + "Strategy", + "Signal", + "SignalDirection", + "Trade", + "TradeSide", + "TradeStatus", + "Position", + "StrategyWeightHistory", + # News + "Article", + "ArticleSentiment", + # Learning + "TradeOutcome", + "LearningAdjustment", + # Auth + "User", + "UserCredential", + # Timeseries + "MarketData", + "PortfolioSnapshot", + "StrategyMetric", +] diff --git a/shared/models/__pycache__/__init__.cpython-314.pyc b/shared/models/__pycache__/__init__.cpython-314.pyc new file mode 100644 index 0000000..bf70a3b Binary files /dev/null and b/shared/models/__pycache__/__init__.cpython-314.pyc differ diff --git a/shared/models/__pycache__/auth.cpython-314.pyc b/shared/models/__pycache__/auth.cpython-314.pyc new file mode 100644 index 0000000..58c2596 Binary files /dev/null and b/shared/models/__pycache__/auth.cpython-314.pyc differ diff --git a/shared/models/__pycache__/base.cpython-314.pyc b/shared/models/__pycache__/base.cpython-314.pyc new file mode 100644 index 0000000..949ba8c Binary files /dev/null and b/shared/models/__pycache__/base.cpython-314.pyc differ diff --git a/shared/models/__pycache__/learning.cpython-314.pyc b/shared/models/__pycache__/learning.cpython-314.pyc new file mode 100644 index 0000000..65f72f8 Binary files /dev/null and b/shared/models/__pycache__/learning.cpython-314.pyc differ diff --git a/shared/models/__pycache__/news.cpython-314.pyc b/shared/models/__pycache__/news.cpython-314.pyc new file mode 100644 index 0000000..6bf5f3d Binary files /dev/null and b/shared/models/__pycache__/news.cpython-314.pyc differ diff --git a/shared/models/__pycache__/timeseries.cpython-314.pyc b/shared/models/__pycache__/timeseries.cpython-314.pyc new file mode 100644 index 0000000..964961b Binary files /dev/null and b/shared/models/__pycache__/timeseries.cpython-314.pyc differ diff --git a/shared/models/__pycache__/trading.cpython-314.pyc b/shared/models/__pycache__/trading.cpython-314.pyc new file mode 100644 index 0000000..8d0edd9 Binary files /dev/null and b/shared/models/__pycache__/trading.cpython-314.pyc differ diff --git a/shared/models/auth.py b/shared/models/auth.py new file mode 100644 index 0000000..a9c7f72 --- /dev/null +++ b/shared/models/auth.py @@ -0,0 +1,39 @@ +"""Authentication models: User, UserCredential.""" + +import uuid + +from sqlalchemy import ForeignKey, Integer, LargeBinary, String +from sqlalchemy.dialects.postgresql import UUID +from sqlalchemy.orm import Mapped, mapped_column, relationship + +from shared.models.base import Base, TimestampMixin + + +class User(TimestampMixin, Base): + __tablename__ = "users" + + id: Mapped[uuid.UUID] = mapped_column( + UUID(as_uuid=True), primary_key=True, default=uuid.uuid4 + ) + username: Mapped[str] = mapped_column(String(100), unique=True, nullable=False) + display_name: Mapped[str | None] = mapped_column(String(255), nullable=True) + + # Relationships + credentials: Mapped[list["UserCredential"]] = relationship(back_populates="user") + + +class UserCredential(TimestampMixin, Base): + __tablename__ = "user_credentials" + + id: Mapped[uuid.UUID] = mapped_column( + UUID(as_uuid=True), primary_key=True, default=uuid.uuid4 + ) + user_id: Mapped[uuid.UUID] = mapped_column( + UUID(as_uuid=True), ForeignKey("users.id"), nullable=False + ) + credential_id: Mapped[str] = mapped_column(String(512), unique=True, nullable=False) + public_key: Mapped[bytes] = mapped_column(LargeBinary, nullable=False) + sign_count: Mapped[int] = mapped_column(Integer, nullable=False, default=0) + + # Relationships + user: Mapped[User] = relationship(back_populates="credentials") diff --git a/shared/models/base.py b/shared/models/base.py new file mode 100644 index 0000000..daefe3f --- /dev/null +++ b/shared/models/base.py @@ -0,0 +1,26 @@ +"""SQLAlchemy declarative base and common mixins.""" + +from datetime import datetime + +from sqlalchemy import DateTime, func +from sqlalchemy.orm import DeclarativeBase, Mapped, mapped_column + + +class Base(DeclarativeBase): + """Shared declarative base for all models.""" + + pass + + +class TimestampMixin: + """Adds ``created_at`` and ``updated_at`` columns with server defaults.""" + + created_at: Mapped[datetime] = mapped_column( + DateTime(timezone=True), + server_default=func.now(), + ) + updated_at: Mapped[datetime] = mapped_column( + DateTime(timezone=True), + server_default=func.now(), + onupdate=func.now(), + ) diff --git a/shared/models/learning.py b/shared/models/learning.py new file mode 100644 index 0000000..ca9436d --- /dev/null +++ b/shared/models/learning.py @@ -0,0 +1,51 @@ +"""Learning domain models: TradeOutcome, LearningAdjustment.""" + +import uuid +from datetime import timedelta + +from sqlalchemy import Boolean, Float, ForeignKey, Interval, String, Text +from sqlalchemy.dialects.postgresql import UUID +from sqlalchemy.orm import Mapped, mapped_column, relationship + +from shared.models.base import Base, TimestampMixin + + +class TradeOutcome(TimestampMixin, Base): + __tablename__ = "trade_outcomes" + + id: Mapped[uuid.UUID] = mapped_column( + UUID(as_uuid=True), primary_key=True, default=uuid.uuid4 + ) + trade_id: Mapped[uuid.UUID] = mapped_column( + UUID(as_uuid=True), ForeignKey("trades.id"), unique=True, nullable=False + ) + hold_duration: Mapped[timedelta | None] = mapped_column(Interval, nullable=True) + realized_pnl: Mapped[float] = mapped_column(Float, nullable=False) + roi_pct: Mapped[float] = mapped_column(Float, nullable=False) + was_profitable: Mapped[bool] = mapped_column(Boolean, nullable=False) + + # Relationships + trade: Mapped["Trade"] = relationship("Trade", back_populates="outcome") + + +class LearningAdjustment(TimestampMixin, Base): + __tablename__ = "learning_adjustments" + + id: Mapped[uuid.UUID] = mapped_column( + UUID(as_uuid=True), primary_key=True, default=uuid.uuid4 + ) + strategy_id: Mapped[uuid.UUID] = mapped_column( + UUID(as_uuid=True), ForeignKey("strategies.id"), nullable=False + ) + old_weight: Mapped[float] = mapped_column(Float, nullable=False) + new_weight: Mapped[float] = mapped_column(Float, nullable=False) + reason: Mapped[str | None] = mapped_column(Text, nullable=True) + reward_signal: Mapped[float] = mapped_column(Float, nullable=False) + + # Relationships + strategy: Mapped["Strategy"] = relationship("Strategy") + + +# Avoid circular imports — reference by string in relationship() +from shared.models.trading import Trade as Trade # noqa: E402, F401 +from shared.models.trading import Strategy as Strategy # noqa: E402, F401 diff --git a/shared/models/news.py b/shared/models/news.py new file mode 100644 index 0000000..5b2caa1 --- /dev/null +++ b/shared/models/news.py @@ -0,0 +1,49 @@ +"""News and sentiment models: Article, ArticleSentiment.""" + +import uuid +from datetime import datetime + +from sqlalchemy import DateTime, Float, ForeignKey, String, Text +from sqlalchemy.dialects.postgresql import UUID +from sqlalchemy.orm import Mapped, mapped_column, relationship + +from shared.models.base import Base, TimestampMixin + + +class Article(TimestampMixin, Base): + __tablename__ = "articles" + + id: Mapped[uuid.UUID] = mapped_column( + UUID(as_uuid=True), primary_key=True, default=uuid.uuid4 + ) + source: Mapped[str] = mapped_column(String(100), nullable=False) + url: Mapped[str] = mapped_column(Text, nullable=False) + title: Mapped[str] = mapped_column(Text, nullable=False) + published_at: Mapped[datetime | None] = mapped_column( + DateTime(timezone=True), nullable=True + ) + fetched_at: Mapped[datetime] = mapped_column( + DateTime(timezone=True), nullable=False + ) + content_hash: Mapped[str] = mapped_column(String(64), unique=True, nullable=False) + + # Relationships + sentiments: Mapped[list["ArticleSentiment"]] = relationship(back_populates="article") + + +class ArticleSentiment(TimestampMixin, Base): + __tablename__ = "article_sentiments" + + id: Mapped[uuid.UUID] = mapped_column( + UUID(as_uuid=True), primary_key=True, default=uuid.uuid4 + ) + article_id: Mapped[uuid.UUID] = mapped_column( + UUID(as_uuid=True), ForeignKey("articles.id"), nullable=False + ) + ticker: Mapped[str] = mapped_column(String(20), nullable=False, index=True) + score: Mapped[float] = mapped_column(Float, nullable=False) + confidence: Mapped[float] = mapped_column(Float, nullable=False) + model_used: Mapped[str] = mapped_column(String(50), nullable=False) + + # Relationships + article: Mapped[Article] = relationship(back_populates="sentiments") diff --git a/shared/models/timeseries.py b/shared/models/timeseries.py new file mode 100644 index 0000000..15653be --- /dev/null +++ b/shared/models/timeseries.py @@ -0,0 +1,57 @@ +"""TimescaleDB hypertable models: MarketData, PortfolioSnapshot, StrategyMetric.""" + +import uuid +from datetime import datetime + +from sqlalchemy import DateTime, Float, ForeignKey, Integer, String +from sqlalchemy.dialects.postgresql import UUID +from sqlalchemy.orm import Mapped, mapped_column + +from shared.models.base import Base + + +class MarketData(Base): + """OHLCV bars — intended as a TimescaleDB hypertable partitioned by timestamp.""" + + __tablename__ = "market_data" + + timestamp: Mapped[datetime] = mapped_column( + DateTime(timezone=True), primary_key=True + ) + ticker: Mapped[str] = mapped_column(String(20), primary_key=True) + open: Mapped[float] = mapped_column(Float, nullable=False) + high: Mapped[float] = mapped_column(Float, nullable=False) + low: Mapped[float] = mapped_column(Float, nullable=False) + close: Mapped[float] = mapped_column(Float, nullable=False) + volume: Mapped[float] = mapped_column(Float, nullable=False) + + +class PortfolioSnapshot(Base): + """Periodic portfolio value snapshots — TimescaleDB hypertable.""" + + __tablename__ = "portfolio_snapshots" + + timestamp: Mapped[datetime] = mapped_column( + DateTime(timezone=True), primary_key=True + ) + total_value: Mapped[float] = mapped_column(Float, nullable=False) + cash: Mapped[float] = mapped_column(Float, nullable=False) + positions_value: Mapped[float] = mapped_column(Float, nullable=False) + daily_pnl: Mapped[float] = mapped_column(Float, nullable=False) + + +class StrategyMetric(Base): + """Per-strategy performance over time — TimescaleDB hypertable.""" + + __tablename__ = "strategy_metrics" + + timestamp: Mapped[datetime] = mapped_column( + DateTime(timezone=True), primary_key=True + ) + strategy_id: Mapped[uuid.UUID] = mapped_column( + UUID(as_uuid=True), ForeignKey("strategies.id"), primary_key=True + ) + win_rate: Mapped[float] = mapped_column(Float, nullable=False) + total_pnl: Mapped[float] = mapped_column(Float, nullable=False) + trade_count: Mapped[int] = mapped_column(Integer, nullable=False) + sharpe_ratio: Mapped[float | None] = mapped_column(Float, nullable=True) diff --git a/shared/models/trading.py b/shared/models/trading.py new file mode 100644 index 0000000..0e6eeb5 --- /dev/null +++ b/shared/models/trading.py @@ -0,0 +1,137 @@ +"""Trading domain models: Strategy, Signal, Trade, Position, StrategyWeightHistory.""" + +import enum +import uuid + +from sqlalchemy import Boolean, Float, ForeignKey, String, Text +from sqlalchemy.dialects.postgresql import JSON, UUID +from sqlalchemy.orm import Mapped, mapped_column, relationship + +from shared.models.base import Base, TimestampMixin + + +# --------------------------------------------------------------------------- +# Enums +# --------------------------------------------------------------------------- + +class TradeSide(str, enum.Enum): + BUY = "BUY" + SELL = "SELL" + + +class TradeStatus(str, enum.Enum): + PENDING = "PENDING" + FILLED = "FILLED" + CANCELLED = "CANCELLED" + REJECTED = "REJECTED" + + +class SignalDirection(str, enum.Enum): + LONG = "LONG" + SHORT = "SHORT" + NEUTRAL = "NEUTRAL" + + +# --------------------------------------------------------------------------- +# Models +# --------------------------------------------------------------------------- + +class Strategy(TimestampMixin, Base): + __tablename__ = "strategies" + + id: Mapped[uuid.UUID] = mapped_column( + UUID(as_uuid=True), primary_key=True, default=uuid.uuid4 + ) + name: Mapped[str] = mapped_column(String(255), unique=True, nullable=False) + description: Mapped[str | None] = mapped_column(Text, nullable=True) + current_weight: Mapped[float] = mapped_column(Float, nullable=False, default=0.333) + active: Mapped[bool] = mapped_column(Boolean, default=True, nullable=False) + + # Relationships + trades: Mapped[list["Trade"]] = relationship(back_populates="strategy") + signals: Mapped[list["Signal"]] = relationship(back_populates="strategy", foreign_keys="Signal.strategy_id", viewonly=True) + weight_history: Mapped[list["StrategyWeightHistory"]] = relationship(back_populates="strategy") + + +class Signal(TimestampMixin, Base): + __tablename__ = "signals" + + id: Mapped[uuid.UUID] = mapped_column( + UUID(as_uuid=True), primary_key=True, default=uuid.uuid4 + ) + ticker: Mapped[str] = mapped_column(String(20), nullable=False, index=True) + direction: Mapped[SignalDirection] = mapped_column(nullable=False) + strength: Mapped[float] = mapped_column(Float, nullable=False) + strategy_sources: Mapped[dict | None] = mapped_column(JSON, nullable=True) + sentiment_score: Mapped[float | None] = mapped_column(Float, nullable=True) + acted_on: Mapped[bool] = mapped_column(Boolean, default=False, nullable=False) + strategy_id: Mapped[uuid.UUID | None] = mapped_column( + UUID(as_uuid=True), ForeignKey("strategies.id"), nullable=True + ) + + # Relationships + strategy: Mapped[Strategy | None] = relationship(back_populates="signals", foreign_keys=[strategy_id]) + trades: Mapped[list["Trade"]] = relationship(back_populates="signal") + + +class Trade(TimestampMixin, Base): + __tablename__ = "trades" + + id: Mapped[uuid.UUID] = mapped_column( + UUID(as_uuid=True), primary_key=True, default=uuid.uuid4 + ) + ticker: Mapped[str] = mapped_column(String(20), nullable=False, index=True) + side: Mapped[TradeSide] = mapped_column(nullable=False) + qty: Mapped[float] = mapped_column(Float, nullable=False) + price: Mapped[float] = mapped_column(Float, nullable=False) + timestamp: Mapped[str | None] = mapped_column(String, nullable=True) + strategy_id: Mapped[uuid.UUID | None] = mapped_column( + UUID(as_uuid=True), ForeignKey("strategies.id"), nullable=True + ) + signal_id: Mapped[uuid.UUID | None] = mapped_column( + UUID(as_uuid=True), ForeignKey("signals.id"), nullable=True + ) + status: Mapped[TradeStatus] = mapped_column(nullable=False, default=TradeStatus.PENDING) + pnl: Mapped[float | None] = mapped_column(Float, nullable=True) + + # Relationships + strategy: Mapped[Strategy | None] = relationship(back_populates="trades") + signal: Mapped[Signal | None] = relationship(back_populates="trades") + outcome: Mapped["TradeOutcome | None"] = relationship( + "TradeOutcome", back_populates="trade", uselist=False + ) + + +class Position(TimestampMixin, Base): + __tablename__ = "positions" + + id: Mapped[uuid.UUID] = mapped_column( + UUID(as_uuid=True), primary_key=True, default=uuid.uuid4 + ) + ticker: Mapped[str] = mapped_column(String(20), unique=True, nullable=False) + qty: Mapped[float] = mapped_column(Float, nullable=False) + avg_entry: Mapped[float] = mapped_column(Float, nullable=False) + unrealized_pnl: Mapped[float | None] = mapped_column(Float, nullable=True) + stop_loss: Mapped[float | None] = mapped_column(Float, nullable=True) + take_profit: Mapped[float | None] = mapped_column(Float, nullable=True) + + +class StrategyWeightHistory(TimestampMixin, Base): + __tablename__ = "strategy_weight_history" + + id: Mapped[uuid.UUID] = mapped_column( + UUID(as_uuid=True), primary_key=True, default=uuid.uuid4 + ) + strategy_id: Mapped[uuid.UUID] = mapped_column( + UUID(as_uuid=True), ForeignKey("strategies.id"), nullable=False + ) + old_weight: Mapped[float] = mapped_column(Float, nullable=False) + new_weight: Mapped[float] = mapped_column(Float, nullable=False) + reason: Mapped[str | None] = mapped_column(String(500), nullable=True) + + # Relationships + strategy: Mapped[Strategy] = relationship(back_populates="weight_history") + + +# Avoid circular import — TradeOutcome is defined in learning.py +from shared.models.learning import TradeOutcome # noqa: E402, F401 diff --git a/tests/test_models.py b/tests/test_models.py new file mode 100644 index 0000000..1f426a2 --- /dev/null +++ b/tests/test_models.py @@ -0,0 +1,323 @@ +"""Tests for SQLAlchemy model instantiation, enums, and relationships.""" + +import uuid +from datetime import datetime, timedelta, timezone + +import pytest + +from shared.models import ( + Base, + TimestampMixin, + # Trading + Strategy, + Signal, + SignalDirection, + Trade, + TradeSide, + TradeStatus, + Position, + StrategyWeightHistory, + # News + Article, + ArticleSentiment, + # Learning + TradeOutcome, + LearningAdjustment, + # Auth + User, + UserCredential, + # Timeseries + MarketData, + PortfolioSnapshot, + StrategyMetric, +) +from shared.db import create_db +from shared.config import BaseConfig + + +# --------------------------------------------------------------------------- +# Enum tests +# --------------------------------------------------------------------------- + +class TestEnums: + def test_trade_side_values(self) -> None: + assert TradeSide.BUY == "BUY" + assert TradeSide.SELL == "SELL" + assert set(TradeSide) == {TradeSide.BUY, TradeSide.SELL} + + def test_trade_status_values(self) -> None: + assert TradeStatus.PENDING == "PENDING" + assert TradeStatus.FILLED == "FILLED" + assert TradeStatus.CANCELLED == "CANCELLED" + assert TradeStatus.REJECTED == "REJECTED" + assert len(TradeStatus) == 4 + + def test_signal_direction_values(self) -> None: + assert SignalDirection.LONG == "LONG" + assert SignalDirection.SHORT == "SHORT" + assert SignalDirection.NEUTRAL == "NEUTRAL" + assert len(SignalDirection) == 3 + + +# --------------------------------------------------------------------------- +# Model instantiation tests +# --------------------------------------------------------------------------- + +class TestStrategy: + def test_create_strategy(self) -> None: + s = Strategy( + id=uuid.uuid4(), + name="momentum", + description="Trend-following strategy", + current_weight=0.5, + active=True, + ) + assert s.name == "momentum" + assert s.current_weight == 0.5 + assert s.active is True + + def test_strategy_defaults(self) -> None: + """Without a DB session, Python-level defaults are not applied by SQLAlchemy. + The column default is only used at INSERT time.""" + s = Strategy(name="test") + assert s.description is None + # Column-level default=True is applied by the database at INSERT time, + # so in-memory the attribute is None until the row is flushed/refreshed. + assert s.active is None or s.active is True + + +class TestSignal: + def test_create_signal(self) -> None: + sig = Signal( + id=uuid.uuid4(), + ticker="AAPL", + direction=SignalDirection.LONG, + strength=0.85, + strategy_sources={"momentum": 0.9}, + sentiment_score=0.7, + acted_on=False, + ) + assert sig.ticker == "AAPL" + assert sig.direction == SignalDirection.LONG + assert sig.strength == 0.85 + assert sig.acted_on is False + + +class TestTrade: + def test_create_trade(self) -> None: + t = Trade( + id=uuid.uuid4(), + ticker="TSLA", + side=TradeSide.BUY, + qty=10.0, + price=150.25, + status=TradeStatus.FILLED, + pnl=250.50, + ) + assert t.ticker == "TSLA" + assert t.side == TradeSide.BUY + assert t.qty == 10.0 + assert t.price == 150.25 + assert t.status == TradeStatus.FILLED + assert t.pnl == 250.50 + + +class TestPosition: + def test_create_position(self) -> None: + p = Position( + id=uuid.uuid4(), + ticker="GOOG", + qty=5.0, + avg_entry=2800.00, + unrealized_pnl=-50.0, + stop_loss=2750.0, + take_profit=3000.0, + ) + assert p.ticker == "GOOG" + assert p.qty == 5.0 + assert p.stop_loss == 2750.0 + assert p.take_profit == 3000.0 + + +class TestStrategyWeightHistory: + def test_create_weight_history(self) -> None: + sid = uuid.uuid4() + wh = StrategyWeightHistory( + id=uuid.uuid4(), + strategy_id=sid, + old_weight=0.33, + new_weight=0.40, + reason="Improved win rate", + ) + assert wh.strategy_id == sid + assert wh.old_weight == 0.33 + assert wh.new_weight == 0.40 + + +class TestArticle: + def test_create_article(self) -> None: + now = datetime.now(timezone.utc) + a = Article( + id=uuid.uuid4(), + source="reuters", + url="https://reuters.com/article/1", + title="Market Rally", + published_at=now, + fetched_at=now, + content_hash="abc123def456", + ) + assert a.source == "reuters" + assert a.content_hash == "abc123def456" + + +class TestArticleSentiment: + def test_create_sentiment(self) -> None: + asent = ArticleSentiment( + id=uuid.uuid4(), + article_id=uuid.uuid4(), + ticker="AAPL", + score=0.85, + confidence=0.92, + model_used="finbert", + ) + assert asent.score == 0.85 + assert asent.model_used == "finbert" + + +class TestTradeOutcome: + def test_create_outcome(self) -> None: + outcome = TradeOutcome( + id=uuid.uuid4(), + trade_id=uuid.uuid4(), + hold_duration=timedelta(hours=4, minutes=30), + realized_pnl=125.50, + roi_pct=2.5, + was_profitable=True, + ) + assert outcome.realized_pnl == 125.50 + assert outcome.was_profitable is True + assert outcome.hold_duration == timedelta(hours=4, minutes=30) + + +class TestLearningAdjustment: + def test_create_adjustment(self) -> None: + adj = LearningAdjustment( + id=uuid.uuid4(), + strategy_id=uuid.uuid4(), + old_weight=0.30, + new_weight=0.35, + reason="Positive reward signal", + reward_signal=0.72, + ) + assert adj.reward_signal == 0.72 + assert adj.reason == "Positive reward signal" + + +class TestUser: + def test_create_user(self) -> None: + u = User( + id=uuid.uuid4(), + username="trader1", + display_name="Top Trader", + ) + assert u.username == "trader1" + assert u.display_name == "Top Trader" + + +class TestUserCredential: + def test_create_credential(self) -> None: + cred = UserCredential( + id=uuid.uuid4(), + user_id=uuid.uuid4(), + credential_id="cred-abc-123", + public_key=b"\x04abcdef", + sign_count=5, + ) + assert cred.credential_id == "cred-abc-123" + assert cred.sign_count == 5 + assert cred.public_key == b"\x04abcdef" + + +class TestMarketData: + def test_create_market_data(self) -> None: + now = datetime.now(timezone.utc) + md = MarketData( + timestamp=now, + ticker="AAPL", + open=150.0, + high=155.0, + low=149.0, + close=153.0, + volume=1_000_000.0, + ) + assert md.ticker == "AAPL" + assert md.close == 153.0 + + +class TestPortfolioSnapshot: + def test_create_snapshot(self) -> None: + now = datetime.now(timezone.utc) + snap = PortfolioSnapshot( + timestamp=now, + total_value=100_000.0, + cash=25_000.0, + positions_value=75_000.0, + daily_pnl=1_200.0, + ) + assert snap.total_value == 100_000.0 + assert snap.daily_pnl == 1_200.0 + + +class TestStrategyMetric: + def test_create_metric(self) -> None: + now = datetime.now(timezone.utc) + sm = StrategyMetric( + timestamp=now, + strategy_id=uuid.uuid4(), + win_rate=0.65, + total_pnl=5_432.10, + trade_count=42, + sharpe_ratio=1.8, + ) + assert sm.win_rate == 0.65 + assert sm.trade_count == 42 + assert sm.sharpe_ratio == 1.8 + + +# --------------------------------------------------------------------------- +# Metadata / Base tests +# --------------------------------------------------------------------------- + +class TestMetadata: + def test_all_tables_registered(self) -> None: + table_names = set(Base.metadata.tables.keys()) + expected = { + "strategies", + "signals", + "trades", + "positions", + "strategy_weight_history", + "articles", + "article_sentiments", + "trade_outcomes", + "learning_adjustments", + "users", + "user_credentials", + "market_data", + "portfolio_snapshots", + "strategy_metrics", + } + assert expected.issubset(table_names) + + def test_timestamp_mixin_fields(self) -> None: + """TimestampMixin should contribute created_at and updated_at columns.""" + assert "created_at" in Strategy.__table__.columns + assert "updated_at" in Strategy.__table__.columns + + +class TestDbFactory: + def test_create_db_returns_engine_and_session(self) -> None: + config = BaseConfig() + engine, session_factory = create_db(config) + assert engine is not None + assert session_factory is not None