feat: database models and alembic migrations — all tables per design

- shared/db.py: async engine + session factory
- shared/models/base.py: DeclarativeBase + TimestampMixin
- shared/models/trading.py: Strategy, Signal, Trade, Position, StrategyWeightHistory
- shared/models/news.py: Article, ArticleSentiment
- shared/models/learning.py: TradeOutcome, LearningAdjustment
- shared/models/auth.py: User, UserCredential
- shared/models/timeseries.py: MarketData, PortfolioSnapshot, StrategyMetric
- Alembic async env.py with initial migration including TimescaleDB hypertables
- 21 model tests covering enums, instantiation, metadata registration
This commit is contained in:
Viktor Barzin 2026-02-22 15:17:07 +00:00
parent ae5b3f89d1
commit 72cb1b6fe5
No known key found for this signature in database
GPG key ID: 0EB088298288D958
23 changed files with 1283 additions and 0 deletions

151
alembic.ini Normal file
View file

@ -0,0 +1,151 @@
# A generic, single database configuration.
[alembic]
# path to migration scripts.
# this is typically a path given in POSIX (e.g. forward slashes)
# format, relative to the token %(here)s which refers to the location of this
# ini file
script_location = %(here)s/alembic
# template used to generate migration file names; The default value is %%(rev)s_%%(slug)s
# Uncomment the line below if you want the files to be prepended with date and time
# see https://alembic.sqlalchemy.org/en/latest/tutorial.html#editing-the-ini-file
# for all available tokens
# file_template = %%(year)d_%%(month).2d_%%(day).2d_%%(hour).2d%%(minute).2d-%%(rev)s_%%(slug)s
# Or organize into date-based subdirectories (requires recursive_version_locations = true)
# file_template = %%(year)d/%%(month).2d/%%(day).2d_%%(hour).2d%%(minute).2d_%%(second).2d_%%(rev)s_%%(slug)s
# sys.path path, will be prepended to sys.path if present.
# defaults to the current working directory. for multiple paths, the path separator
# is defined by "path_separator" below.
prepend_sys_path = .
# timezone to use when rendering the date within the migration file
# as well as the filename.
# If specified, requires the tzdata library which can be installed by adding
# `alembic[tz]` to the pip requirements.
# string value is passed to ZoneInfo()
# leave blank for localtime
# timezone =
# max length of characters to apply to the "slug" field
# truncate_slug_length = 40
# set to 'true' to run the environment during
# the 'revision' command, regardless of autogenerate
# revision_environment = false
# set to 'true' to allow .pyc and .pyo files without
# a source .py file to be detected as revisions in the
# versions/ directory
# sourceless = false
# version location specification; This defaults
# to <script_location>/versions. When using multiple version
# directories, initial revisions must be specified with --version-path.
# The path separator used here should be the separator specified by "path_separator"
# below.
# version_locations = %(here)s/bar:%(here)s/bat:%(here)s/alembic/versions
# path_separator; This indicates what character is used to split lists of file
# paths, including version_locations and prepend_sys_path within configparser
# files such as alembic.ini.
# The default rendered in new alembic.ini files is "os", which uses os.pathsep
# to provide os-dependent path splitting.
#
# Note that in order to support legacy alembic.ini files, this default does NOT
# take place if path_separator is not present in alembic.ini. If this
# option is omitted entirely, fallback logic is as follows:
#
# 1. Parsing of the version_locations option falls back to using the legacy
# "version_path_separator" key, which if absent then falls back to the legacy
# behavior of splitting on spaces and/or commas.
# 2. Parsing of the prepend_sys_path option falls back to the legacy
# behavior of splitting on spaces, commas, or colons.
#
# Valid values for path_separator are:
#
# path_separator = :
# path_separator = ;
# path_separator = space
# path_separator = newline
#
# Use os.pathsep. Default configuration used for new projects.
path_separator = os
# set to 'true' to search source files recursively
# in each "version_locations" directory
# new in Alembic version 1.10
# recursive_version_locations = false
# the output encoding used when revision files
# are written from script.py.mako
# output_encoding = utf-8
# database URL. This is consumed by the user-maintained env.py script only.
# other means of configuring database URLs may be customized within the env.py
# file.
# Database URL is read from the TRADING_DATABASE_URL environment variable
# in alembic/env.py. The value here is a fallback only.
sqlalchemy.url = postgresql+asyncpg://trading:trading@localhost:5432/trading
[post_write_hooks]
# post_write_hooks defines scripts or Python functions that are run
# on newly generated revision scripts. See the documentation for further
# detail and examples
# format using "black" - use the console_scripts runner, against the "black" entrypoint
# hooks = black
# black.type = console_scripts
# black.entrypoint = black
# black.options = -l 79 REVISION_SCRIPT_FILENAME
# lint with attempts to fix using "ruff" - use the module runner, against the "ruff" module
# hooks = ruff
# ruff.type = module
# ruff.module = ruff
# ruff.options = check --fix REVISION_SCRIPT_FILENAME
# Alternatively, use the exec runner to execute a binary found on your PATH
# hooks = ruff
# ruff.type = exec
# ruff.executable = ruff
# ruff.options = check --fix REVISION_SCRIPT_FILENAME
# Logging configuration. This is also consumed by the user-maintained
# env.py script only.
[loggers]
keys = root,sqlalchemy,alembic
[handlers]
keys = console
[formatters]
keys = generic
[logger_root]
level = WARNING
handlers = console
qualname =
[logger_sqlalchemy]
level = WARNING
handlers =
qualname = sqlalchemy.engine
[logger_alembic]
level = INFO
handlers =
qualname = alembic
[handler_console]
class = StreamHandler
args = (sys.stderr,)
level = NOTSET
formatter = generic
[formatter_generic]
format = %(levelname)-5.5s [%(name)s] %(message)s
datefmt = %H:%M:%S

1
alembic/README Normal file
View file

@ -0,0 +1 @@
Generic single-database configuration.

Binary file not shown.

71
alembic/env.py Normal file
View file

@ -0,0 +1,71 @@
"""Alembic environment — async-aware, imports all shared models."""
import asyncio
import os
from logging.config import fileConfig
from alembic import context
from sqlalchemy import pool
from sqlalchemy.ext.asyncio import async_engine_from_config
# Import all models so metadata is populated for autogenerate.
from shared.models import Base # noqa: F401
import shared.models # noqa: F401 — triggers side-effect imports of every model
config = context.config
# Override sqlalchemy.url from environment if available.
db_url = os.environ.get("TRADING_DATABASE_URL")
if db_url:
config.set_main_option("sqlalchemy.url", db_url)
# Interpret the config file for Python logging.
if config.config_file_name is not None:
fileConfig(config.config_file_name)
target_metadata = Base.metadata
def run_migrations_offline() -> None:
"""Run migrations in 'offline' mode (emit SQL without a live connection)."""
url = config.get_main_option("sqlalchemy.url")
context.configure(
url=url,
target_metadata=target_metadata,
literal_binds=True,
dialect_opts={"paramstyle": "named"},
)
with context.begin_transaction():
context.run_migrations()
def do_run_migrations(connection) -> None: # noqa: ANN001
context.configure(connection=connection, target_metadata=target_metadata)
with context.begin_transaction():
context.run_migrations()
async def run_async_migrations() -> None:
"""Run migrations in 'online' mode using an async engine."""
connectable = async_engine_from_config(
config.get_section(config.config_ini_section, {}),
prefix="sqlalchemy.",
poolclass=pool.NullPool,
)
async with connectable.connect() as connection:
await connection.run_sync(do_run_migrations)
await connectable.dispose()
def run_migrations_online() -> None:
"""Entry-point for online migrations — delegates to the async helper."""
asyncio.run(run_async_migrations())
if context.is_offline_mode():
run_migrations_offline()
else:
run_migrations_online()

28
alembic/script.py.mako Normal file
View file

@ -0,0 +1,28 @@
"""${message}
Revision ID: ${up_revision}
Revises: ${down_revision | comma,n}
Create Date: ${create_date}
"""
from typing import Sequence, Union
from alembic import op
import sqlalchemy as sa
${imports if imports else ""}
# revision identifiers, used by Alembic.
revision: str = ${repr(up_revision)}
down_revision: Union[str, Sequence[str], None] = ${repr(down_revision)}
branch_labels: Union[str, Sequence[str], None] = ${repr(branch_labels)}
depends_on: Union[str, Sequence[str], None] = ${repr(depends_on)}
def upgrade() -> None:
"""Upgrade schema."""
${upgrades if upgrades else "pass"}
def downgrade() -> None:
"""Downgrade schema."""
${downgrades if downgrades else "pass"}

View file

@ -0,0 +1,284 @@
"""initial schema
Revision ID: a1b2c3d4e5f6
Revises:
Create Date: 2026-02-22 15:15:15.661206
"""
from typing import Sequence, Union
import sqlalchemy as sa
from alembic import op
from sqlalchemy.dialects import postgresql
# revision identifiers, used by Alembic.
revision: str = "a1b2c3d4e5f6"
down_revision: Union[str, Sequence[str], None] = None
branch_labels: Union[str, Sequence[str], None] = None
depends_on: Union[str, Sequence[str], None] = None
def upgrade() -> None:
"""Create all tables for the trading bot."""
# --- Core trading tables ---
op.create_table(
"strategies",
sa.Column("id", postgresql.UUID(as_uuid=True), primary_key=True),
sa.Column("name", sa.String(255), unique=True, nullable=False),
sa.Column("description", sa.Text, nullable=True),
sa.Column("current_weight", sa.Float, nullable=False, server_default="0.333"),
sa.Column("active", sa.Boolean, nullable=False, server_default=sa.text("true")),
sa.Column("created_at", sa.DateTime(timezone=True), server_default=sa.func.now()),
sa.Column("updated_at", sa.DateTime(timezone=True), server_default=sa.func.now()),
)
op.create_table(
"signals",
sa.Column("id", postgresql.UUID(as_uuid=True), primary_key=True),
sa.Column("ticker", sa.String(20), nullable=False, index=True),
sa.Column(
"direction",
sa.Enum("LONG", "SHORT", "NEUTRAL", name="signaldirection"),
nullable=False,
),
sa.Column("strength", sa.Float, nullable=False),
sa.Column("strategy_sources", postgresql.JSON, nullable=True),
sa.Column("sentiment_score", sa.Float, nullable=True),
sa.Column("acted_on", sa.Boolean, nullable=False, server_default=sa.text("false")),
sa.Column(
"strategy_id",
postgresql.UUID(as_uuid=True),
sa.ForeignKey("strategies.id"),
nullable=True,
),
sa.Column("created_at", sa.DateTime(timezone=True), server_default=sa.func.now()),
sa.Column("updated_at", sa.DateTime(timezone=True), server_default=sa.func.now()),
)
op.create_table(
"trades",
sa.Column("id", postgresql.UUID(as_uuid=True), primary_key=True),
sa.Column("ticker", sa.String(20), nullable=False, index=True),
sa.Column(
"side",
sa.Enum("BUY", "SELL", name="tradeside"),
nullable=False,
),
sa.Column("qty", sa.Float, nullable=False),
sa.Column("price", sa.Float, nullable=False),
sa.Column("timestamp", sa.String, nullable=True),
sa.Column(
"strategy_id",
postgresql.UUID(as_uuid=True),
sa.ForeignKey("strategies.id"),
nullable=True,
),
sa.Column(
"signal_id",
postgresql.UUID(as_uuid=True),
sa.ForeignKey("signals.id"),
nullable=True,
),
sa.Column(
"status",
sa.Enum("PENDING", "FILLED", "CANCELLED", "REJECTED", name="tradestatus"),
nullable=False,
server_default="PENDING",
),
sa.Column("pnl", sa.Float, nullable=True),
sa.Column("created_at", sa.DateTime(timezone=True), server_default=sa.func.now()),
sa.Column("updated_at", sa.DateTime(timezone=True), server_default=sa.func.now()),
)
op.create_table(
"positions",
sa.Column("id", postgresql.UUID(as_uuid=True), primary_key=True),
sa.Column("ticker", sa.String(20), unique=True, nullable=False),
sa.Column("qty", sa.Float, nullable=False),
sa.Column("avg_entry", sa.Float, nullable=False),
sa.Column("unrealized_pnl", sa.Float, nullable=True),
sa.Column("stop_loss", sa.Float, nullable=True),
sa.Column("take_profit", sa.Float, nullable=True),
sa.Column("created_at", sa.DateTime(timezone=True), server_default=sa.func.now()),
sa.Column("updated_at", sa.DateTime(timezone=True), server_default=sa.func.now()),
)
op.create_table(
"strategy_weight_history",
sa.Column("id", postgresql.UUID(as_uuid=True), primary_key=True),
sa.Column(
"strategy_id",
postgresql.UUID(as_uuid=True),
sa.ForeignKey("strategies.id"),
nullable=False,
),
sa.Column("old_weight", sa.Float, nullable=False),
sa.Column("new_weight", sa.Float, nullable=False),
sa.Column("reason", sa.String(500), nullable=True),
sa.Column("created_at", sa.DateTime(timezone=True), server_default=sa.func.now()),
sa.Column("updated_at", sa.DateTime(timezone=True), server_default=sa.func.now()),
)
# --- News & sentiment ---
op.create_table(
"articles",
sa.Column("id", postgresql.UUID(as_uuid=True), primary_key=True),
sa.Column("source", sa.String(100), nullable=False),
sa.Column("url", sa.Text, nullable=False),
sa.Column("title", sa.Text, nullable=False),
sa.Column("published_at", sa.DateTime(timezone=True), nullable=True),
sa.Column("fetched_at", sa.DateTime(timezone=True), nullable=False),
sa.Column("content_hash", sa.String(64), unique=True, nullable=False),
sa.Column("created_at", sa.DateTime(timezone=True), server_default=sa.func.now()),
sa.Column("updated_at", sa.DateTime(timezone=True), server_default=sa.func.now()),
)
op.create_table(
"article_sentiments",
sa.Column("id", postgresql.UUID(as_uuid=True), primary_key=True),
sa.Column(
"article_id",
postgresql.UUID(as_uuid=True),
sa.ForeignKey("articles.id"),
nullable=False,
),
sa.Column("ticker", sa.String(20), nullable=False, index=True),
sa.Column("score", sa.Float, nullable=False),
sa.Column("confidence", sa.Float, nullable=False),
sa.Column("model_used", sa.String(50), nullable=False),
sa.Column("created_at", sa.DateTime(timezone=True), server_default=sa.func.now()),
sa.Column("updated_at", sa.DateTime(timezone=True), server_default=sa.func.now()),
)
# --- Learning ---
op.create_table(
"trade_outcomes",
sa.Column("id", postgresql.UUID(as_uuid=True), primary_key=True),
sa.Column(
"trade_id",
postgresql.UUID(as_uuid=True),
sa.ForeignKey("trades.id"),
unique=True,
nullable=False,
),
sa.Column("hold_duration", sa.Interval, nullable=True),
sa.Column("realized_pnl", sa.Float, nullable=False),
sa.Column("roi_pct", sa.Float, nullable=False),
sa.Column("was_profitable", sa.Boolean, nullable=False),
sa.Column("created_at", sa.DateTime(timezone=True), server_default=sa.func.now()),
sa.Column("updated_at", sa.DateTime(timezone=True), server_default=sa.func.now()),
)
op.create_table(
"learning_adjustments",
sa.Column("id", postgresql.UUID(as_uuid=True), primary_key=True),
sa.Column(
"strategy_id",
postgresql.UUID(as_uuid=True),
sa.ForeignKey("strategies.id"),
nullable=False,
),
sa.Column("old_weight", sa.Float, nullable=False),
sa.Column("new_weight", sa.Float, nullable=False),
sa.Column("reason", sa.Text, nullable=True),
sa.Column("reward_signal", sa.Float, nullable=False),
sa.Column("created_at", sa.DateTime(timezone=True), server_default=sa.func.now()),
sa.Column("updated_at", sa.DateTime(timezone=True), server_default=sa.func.now()),
)
# --- Auth ---
op.create_table(
"users",
sa.Column("id", postgresql.UUID(as_uuid=True), primary_key=True),
sa.Column("username", sa.String(100), unique=True, nullable=False),
sa.Column("display_name", sa.String(255), nullable=True),
sa.Column("created_at", sa.DateTime(timezone=True), server_default=sa.func.now()),
sa.Column("updated_at", sa.DateTime(timezone=True), server_default=sa.func.now()),
)
op.create_table(
"user_credentials",
sa.Column("id", postgresql.UUID(as_uuid=True), primary_key=True),
sa.Column(
"user_id",
postgresql.UUID(as_uuid=True),
sa.ForeignKey("users.id"),
nullable=False,
),
sa.Column("credential_id", sa.String(512), unique=True, nullable=False),
sa.Column("public_key", sa.LargeBinary, nullable=False),
sa.Column("sign_count", sa.Integer, nullable=False, server_default="0"),
sa.Column("created_at", sa.DateTime(timezone=True), server_default=sa.func.now()),
sa.Column("updated_at", sa.DateTime(timezone=True), server_default=sa.func.now()),
)
# --- Timeseries (TimescaleDB hypertables) ---
op.create_table(
"market_data",
sa.Column("timestamp", sa.DateTime(timezone=True), primary_key=True),
sa.Column("ticker", sa.String(20), primary_key=True),
sa.Column("open", sa.Float, nullable=False),
sa.Column("high", sa.Float, nullable=False),
sa.Column("low", sa.Float, nullable=False),
sa.Column("close", sa.Float, nullable=False),
sa.Column("volume", sa.Float, nullable=False),
)
op.create_table(
"portfolio_snapshots",
sa.Column("timestamp", sa.DateTime(timezone=True), primary_key=True),
sa.Column("total_value", sa.Float, nullable=False),
sa.Column("cash", sa.Float, nullable=False),
sa.Column("positions_value", sa.Float, nullable=False),
sa.Column("daily_pnl", sa.Float, nullable=False),
)
op.create_table(
"strategy_metrics",
sa.Column("timestamp", sa.DateTime(timezone=True), primary_key=True),
sa.Column(
"strategy_id",
postgresql.UUID(as_uuid=True),
sa.ForeignKey("strategies.id"),
primary_key=True,
),
sa.Column("win_rate", sa.Float, nullable=False),
sa.Column("total_pnl", sa.Float, nullable=False),
sa.Column("trade_count", sa.Integer, nullable=False),
sa.Column("sharpe_ratio", sa.Float, nullable=True),
)
# Convert timeseries tables to TimescaleDB hypertables.
# These calls are idempotent-safe when the extension is loaded.
op.execute("SELECT create_hypertable('market_data', 'timestamp', if_not_exists => TRUE)")
op.execute("SELECT create_hypertable('portfolio_snapshots', 'timestamp', if_not_exists => TRUE)")
op.execute("SELECT create_hypertable('strategy_metrics', 'timestamp', if_not_exists => TRUE)")
def downgrade() -> None:
"""Drop all tables in reverse dependency order."""
op.drop_table("strategy_metrics")
op.drop_table("portfolio_snapshots")
op.drop_table("market_data")
op.drop_table("user_credentials")
op.drop_table("users")
op.drop_table("learning_adjustments")
op.drop_table("trade_outcomes")
op.drop_table("article_sentiments")
op.drop_table("articles")
op.drop_table("strategy_weight_history")
op.drop_table("positions")
op.drop_table("trades")
op.drop_table("signals")
op.drop_table("strategies")
# Drop enums
sa.Enum(name="signaldirection").drop(op.get_bind(), checkfirst=True)
sa.Enum(name="tradeside").drop(op.get_bind(), checkfirst=True)
sa.Enum(name="tradestatus").drop(op.get_bind(), checkfirst=True)

22
shared/db.py Normal file
View file

@ -0,0 +1,22 @@
"""SQLAlchemy async engine and session factory."""
from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker, create_async_engine
from shared.config import BaseConfig
def create_db(config: BaseConfig) -> tuple:
"""Create an async engine and session factory from the given config.
Returns a ``(engine, session_factory)`` tuple.
"""
engine = create_async_engine(
config.database_url,
echo=config.log_level == "DEBUG",
)
session_factory = async_sessionmaker(
engine,
class_=AsyncSession,
expire_on_commit=False,
)
return engine, session_factory

44
shared/models/__init__.py Normal file
View file

@ -0,0 +1,44 @@
"""Shared SQLAlchemy models — import all models here so Alembic can discover them."""
from shared.models.base import Base, TimestampMixin
from shared.models.trading import (
Signal,
SignalDirection,
Strategy,
StrategyWeightHistory,
Trade,
TradeSide,
TradeStatus,
Position,
)
from shared.models.news import Article, ArticleSentiment
from shared.models.learning import LearningAdjustment, TradeOutcome
from shared.models.auth import User, UserCredential
from shared.models.timeseries import MarketData, PortfolioSnapshot, StrategyMetric
__all__ = [
"Base",
"TimestampMixin",
# Trading
"Strategy",
"Signal",
"SignalDirection",
"Trade",
"TradeSide",
"TradeStatus",
"Position",
"StrategyWeightHistory",
# News
"Article",
"ArticleSentiment",
# Learning
"TradeOutcome",
"LearningAdjustment",
# Auth
"User",
"UserCredential",
# Timeseries
"MarketData",
"PortfolioSnapshot",
"StrategyMetric",
]

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

39
shared/models/auth.py Normal file
View file

@ -0,0 +1,39 @@
"""Authentication models: User, UserCredential."""
import uuid
from sqlalchemy import ForeignKey, Integer, LargeBinary, String
from sqlalchemy.dialects.postgresql import UUID
from sqlalchemy.orm import Mapped, mapped_column, relationship
from shared.models.base import Base, TimestampMixin
class User(TimestampMixin, Base):
__tablename__ = "users"
id: Mapped[uuid.UUID] = mapped_column(
UUID(as_uuid=True), primary_key=True, default=uuid.uuid4
)
username: Mapped[str] = mapped_column(String(100), unique=True, nullable=False)
display_name: Mapped[str | None] = mapped_column(String(255), nullable=True)
# Relationships
credentials: Mapped[list["UserCredential"]] = relationship(back_populates="user")
class UserCredential(TimestampMixin, Base):
__tablename__ = "user_credentials"
id: Mapped[uuid.UUID] = mapped_column(
UUID(as_uuid=True), primary_key=True, default=uuid.uuid4
)
user_id: Mapped[uuid.UUID] = mapped_column(
UUID(as_uuid=True), ForeignKey("users.id"), nullable=False
)
credential_id: Mapped[str] = mapped_column(String(512), unique=True, nullable=False)
public_key: Mapped[bytes] = mapped_column(LargeBinary, nullable=False)
sign_count: Mapped[int] = mapped_column(Integer, nullable=False, default=0)
# Relationships
user: Mapped[User] = relationship(back_populates="credentials")

26
shared/models/base.py Normal file
View file

@ -0,0 +1,26 @@
"""SQLAlchemy declarative base and common mixins."""
from datetime import datetime
from sqlalchemy import DateTime, func
from sqlalchemy.orm import DeclarativeBase, Mapped, mapped_column
class Base(DeclarativeBase):
"""Shared declarative base for all models."""
pass
class TimestampMixin:
"""Adds ``created_at`` and ``updated_at`` columns with server defaults."""
created_at: Mapped[datetime] = mapped_column(
DateTime(timezone=True),
server_default=func.now(),
)
updated_at: Mapped[datetime] = mapped_column(
DateTime(timezone=True),
server_default=func.now(),
onupdate=func.now(),
)

51
shared/models/learning.py Normal file
View file

@ -0,0 +1,51 @@
"""Learning domain models: TradeOutcome, LearningAdjustment."""
import uuid
from datetime import timedelta
from sqlalchemy import Boolean, Float, ForeignKey, Interval, String, Text
from sqlalchemy.dialects.postgresql import UUID
from sqlalchemy.orm import Mapped, mapped_column, relationship
from shared.models.base import Base, TimestampMixin
class TradeOutcome(TimestampMixin, Base):
__tablename__ = "trade_outcomes"
id: Mapped[uuid.UUID] = mapped_column(
UUID(as_uuid=True), primary_key=True, default=uuid.uuid4
)
trade_id: Mapped[uuid.UUID] = mapped_column(
UUID(as_uuid=True), ForeignKey("trades.id"), unique=True, nullable=False
)
hold_duration: Mapped[timedelta | None] = mapped_column(Interval, nullable=True)
realized_pnl: Mapped[float] = mapped_column(Float, nullable=False)
roi_pct: Mapped[float] = mapped_column(Float, nullable=False)
was_profitable: Mapped[bool] = mapped_column(Boolean, nullable=False)
# Relationships
trade: Mapped["Trade"] = relationship("Trade", back_populates="outcome")
class LearningAdjustment(TimestampMixin, Base):
__tablename__ = "learning_adjustments"
id: Mapped[uuid.UUID] = mapped_column(
UUID(as_uuid=True), primary_key=True, default=uuid.uuid4
)
strategy_id: Mapped[uuid.UUID] = mapped_column(
UUID(as_uuid=True), ForeignKey("strategies.id"), nullable=False
)
old_weight: Mapped[float] = mapped_column(Float, nullable=False)
new_weight: Mapped[float] = mapped_column(Float, nullable=False)
reason: Mapped[str | None] = mapped_column(Text, nullable=True)
reward_signal: Mapped[float] = mapped_column(Float, nullable=False)
# Relationships
strategy: Mapped["Strategy"] = relationship("Strategy")
# Avoid circular imports — reference by string in relationship()
from shared.models.trading import Trade as Trade # noqa: E402, F401
from shared.models.trading import Strategy as Strategy # noqa: E402, F401

49
shared/models/news.py Normal file
View file

@ -0,0 +1,49 @@
"""News and sentiment models: Article, ArticleSentiment."""
import uuid
from datetime import datetime
from sqlalchemy import DateTime, Float, ForeignKey, String, Text
from sqlalchemy.dialects.postgresql import UUID
from sqlalchemy.orm import Mapped, mapped_column, relationship
from shared.models.base import Base, TimestampMixin
class Article(TimestampMixin, Base):
__tablename__ = "articles"
id: Mapped[uuid.UUID] = mapped_column(
UUID(as_uuid=True), primary_key=True, default=uuid.uuid4
)
source: Mapped[str] = mapped_column(String(100), nullable=False)
url: Mapped[str] = mapped_column(Text, nullable=False)
title: Mapped[str] = mapped_column(Text, nullable=False)
published_at: Mapped[datetime | None] = mapped_column(
DateTime(timezone=True), nullable=True
)
fetched_at: Mapped[datetime] = mapped_column(
DateTime(timezone=True), nullable=False
)
content_hash: Mapped[str] = mapped_column(String(64), unique=True, nullable=False)
# Relationships
sentiments: Mapped[list["ArticleSentiment"]] = relationship(back_populates="article")
class ArticleSentiment(TimestampMixin, Base):
__tablename__ = "article_sentiments"
id: Mapped[uuid.UUID] = mapped_column(
UUID(as_uuid=True), primary_key=True, default=uuid.uuid4
)
article_id: Mapped[uuid.UUID] = mapped_column(
UUID(as_uuid=True), ForeignKey("articles.id"), nullable=False
)
ticker: Mapped[str] = mapped_column(String(20), nullable=False, index=True)
score: Mapped[float] = mapped_column(Float, nullable=False)
confidence: Mapped[float] = mapped_column(Float, nullable=False)
model_used: Mapped[str] = mapped_column(String(50), nullable=False)
# Relationships
article: Mapped[Article] = relationship(back_populates="sentiments")

View file

@ -0,0 +1,57 @@
"""TimescaleDB hypertable models: MarketData, PortfolioSnapshot, StrategyMetric."""
import uuid
from datetime import datetime
from sqlalchemy import DateTime, Float, ForeignKey, Integer, String
from sqlalchemy.dialects.postgresql import UUID
from sqlalchemy.orm import Mapped, mapped_column
from shared.models.base import Base
class MarketData(Base):
"""OHLCV bars — intended as a TimescaleDB hypertable partitioned by timestamp."""
__tablename__ = "market_data"
timestamp: Mapped[datetime] = mapped_column(
DateTime(timezone=True), primary_key=True
)
ticker: Mapped[str] = mapped_column(String(20), primary_key=True)
open: Mapped[float] = mapped_column(Float, nullable=False)
high: Mapped[float] = mapped_column(Float, nullable=False)
low: Mapped[float] = mapped_column(Float, nullable=False)
close: Mapped[float] = mapped_column(Float, nullable=False)
volume: Mapped[float] = mapped_column(Float, nullable=False)
class PortfolioSnapshot(Base):
"""Periodic portfolio value snapshots — TimescaleDB hypertable."""
__tablename__ = "portfolio_snapshots"
timestamp: Mapped[datetime] = mapped_column(
DateTime(timezone=True), primary_key=True
)
total_value: Mapped[float] = mapped_column(Float, nullable=False)
cash: Mapped[float] = mapped_column(Float, nullable=False)
positions_value: Mapped[float] = mapped_column(Float, nullable=False)
daily_pnl: Mapped[float] = mapped_column(Float, nullable=False)
class StrategyMetric(Base):
"""Per-strategy performance over time — TimescaleDB hypertable."""
__tablename__ = "strategy_metrics"
timestamp: Mapped[datetime] = mapped_column(
DateTime(timezone=True), primary_key=True
)
strategy_id: Mapped[uuid.UUID] = mapped_column(
UUID(as_uuid=True), ForeignKey("strategies.id"), primary_key=True
)
win_rate: Mapped[float] = mapped_column(Float, nullable=False)
total_pnl: Mapped[float] = mapped_column(Float, nullable=False)
trade_count: Mapped[int] = mapped_column(Integer, nullable=False)
sharpe_ratio: Mapped[float | None] = mapped_column(Float, nullable=True)

137
shared/models/trading.py Normal file
View file

@ -0,0 +1,137 @@
"""Trading domain models: Strategy, Signal, Trade, Position, StrategyWeightHistory."""
import enum
import uuid
from sqlalchemy import Boolean, Float, ForeignKey, String, Text
from sqlalchemy.dialects.postgresql import JSON, UUID
from sqlalchemy.orm import Mapped, mapped_column, relationship
from shared.models.base import Base, TimestampMixin
# ---------------------------------------------------------------------------
# Enums
# ---------------------------------------------------------------------------
class TradeSide(str, enum.Enum):
BUY = "BUY"
SELL = "SELL"
class TradeStatus(str, enum.Enum):
PENDING = "PENDING"
FILLED = "FILLED"
CANCELLED = "CANCELLED"
REJECTED = "REJECTED"
class SignalDirection(str, enum.Enum):
LONG = "LONG"
SHORT = "SHORT"
NEUTRAL = "NEUTRAL"
# ---------------------------------------------------------------------------
# Models
# ---------------------------------------------------------------------------
class Strategy(TimestampMixin, Base):
__tablename__ = "strategies"
id: Mapped[uuid.UUID] = mapped_column(
UUID(as_uuid=True), primary_key=True, default=uuid.uuid4
)
name: Mapped[str] = mapped_column(String(255), unique=True, nullable=False)
description: Mapped[str | None] = mapped_column(Text, nullable=True)
current_weight: Mapped[float] = mapped_column(Float, nullable=False, default=0.333)
active: Mapped[bool] = mapped_column(Boolean, default=True, nullable=False)
# Relationships
trades: Mapped[list["Trade"]] = relationship(back_populates="strategy")
signals: Mapped[list["Signal"]] = relationship(back_populates="strategy", foreign_keys="Signal.strategy_id", viewonly=True)
weight_history: Mapped[list["StrategyWeightHistory"]] = relationship(back_populates="strategy")
class Signal(TimestampMixin, Base):
__tablename__ = "signals"
id: Mapped[uuid.UUID] = mapped_column(
UUID(as_uuid=True), primary_key=True, default=uuid.uuid4
)
ticker: Mapped[str] = mapped_column(String(20), nullable=False, index=True)
direction: Mapped[SignalDirection] = mapped_column(nullable=False)
strength: Mapped[float] = mapped_column(Float, nullable=False)
strategy_sources: Mapped[dict | None] = mapped_column(JSON, nullable=True)
sentiment_score: Mapped[float | None] = mapped_column(Float, nullable=True)
acted_on: Mapped[bool] = mapped_column(Boolean, default=False, nullable=False)
strategy_id: Mapped[uuid.UUID | None] = mapped_column(
UUID(as_uuid=True), ForeignKey("strategies.id"), nullable=True
)
# Relationships
strategy: Mapped[Strategy | None] = relationship(back_populates="signals", foreign_keys=[strategy_id])
trades: Mapped[list["Trade"]] = relationship(back_populates="signal")
class Trade(TimestampMixin, Base):
__tablename__ = "trades"
id: Mapped[uuid.UUID] = mapped_column(
UUID(as_uuid=True), primary_key=True, default=uuid.uuid4
)
ticker: Mapped[str] = mapped_column(String(20), nullable=False, index=True)
side: Mapped[TradeSide] = mapped_column(nullable=False)
qty: Mapped[float] = mapped_column(Float, nullable=False)
price: Mapped[float] = mapped_column(Float, nullable=False)
timestamp: Mapped[str | None] = mapped_column(String, nullable=True)
strategy_id: Mapped[uuid.UUID | None] = mapped_column(
UUID(as_uuid=True), ForeignKey("strategies.id"), nullable=True
)
signal_id: Mapped[uuid.UUID | None] = mapped_column(
UUID(as_uuid=True), ForeignKey("signals.id"), nullable=True
)
status: Mapped[TradeStatus] = mapped_column(nullable=False, default=TradeStatus.PENDING)
pnl: Mapped[float | None] = mapped_column(Float, nullable=True)
# Relationships
strategy: Mapped[Strategy | None] = relationship(back_populates="trades")
signal: Mapped[Signal | None] = relationship(back_populates="trades")
outcome: Mapped["TradeOutcome | None"] = relationship(
"TradeOutcome", back_populates="trade", uselist=False
)
class Position(TimestampMixin, Base):
__tablename__ = "positions"
id: Mapped[uuid.UUID] = mapped_column(
UUID(as_uuid=True), primary_key=True, default=uuid.uuid4
)
ticker: Mapped[str] = mapped_column(String(20), unique=True, nullable=False)
qty: Mapped[float] = mapped_column(Float, nullable=False)
avg_entry: Mapped[float] = mapped_column(Float, nullable=False)
unrealized_pnl: Mapped[float | None] = mapped_column(Float, nullable=True)
stop_loss: Mapped[float | None] = mapped_column(Float, nullable=True)
take_profit: Mapped[float | None] = mapped_column(Float, nullable=True)
class StrategyWeightHistory(TimestampMixin, Base):
__tablename__ = "strategy_weight_history"
id: Mapped[uuid.UUID] = mapped_column(
UUID(as_uuid=True), primary_key=True, default=uuid.uuid4
)
strategy_id: Mapped[uuid.UUID] = mapped_column(
UUID(as_uuid=True), ForeignKey("strategies.id"), nullable=False
)
old_weight: Mapped[float] = mapped_column(Float, nullable=False)
new_weight: Mapped[float] = mapped_column(Float, nullable=False)
reason: Mapped[str | None] = mapped_column(String(500), nullable=True)
# Relationships
strategy: Mapped[Strategy] = relationship(back_populates="weight_history")
# Avoid circular import — TradeOutcome is defined in learning.py
from shared.models.learning import TradeOutcome # noqa: E402, F401

323
tests/test_models.py Normal file
View file

@ -0,0 +1,323 @@
"""Tests for SQLAlchemy model instantiation, enums, and relationships."""
import uuid
from datetime import datetime, timedelta, timezone
import pytest
from shared.models import (
Base,
TimestampMixin,
# Trading
Strategy,
Signal,
SignalDirection,
Trade,
TradeSide,
TradeStatus,
Position,
StrategyWeightHistory,
# News
Article,
ArticleSentiment,
# Learning
TradeOutcome,
LearningAdjustment,
# Auth
User,
UserCredential,
# Timeseries
MarketData,
PortfolioSnapshot,
StrategyMetric,
)
from shared.db import create_db
from shared.config import BaseConfig
# ---------------------------------------------------------------------------
# Enum tests
# ---------------------------------------------------------------------------
class TestEnums:
def test_trade_side_values(self) -> None:
assert TradeSide.BUY == "BUY"
assert TradeSide.SELL == "SELL"
assert set(TradeSide) == {TradeSide.BUY, TradeSide.SELL}
def test_trade_status_values(self) -> None:
assert TradeStatus.PENDING == "PENDING"
assert TradeStatus.FILLED == "FILLED"
assert TradeStatus.CANCELLED == "CANCELLED"
assert TradeStatus.REJECTED == "REJECTED"
assert len(TradeStatus) == 4
def test_signal_direction_values(self) -> None:
assert SignalDirection.LONG == "LONG"
assert SignalDirection.SHORT == "SHORT"
assert SignalDirection.NEUTRAL == "NEUTRAL"
assert len(SignalDirection) == 3
# ---------------------------------------------------------------------------
# Model instantiation tests
# ---------------------------------------------------------------------------
class TestStrategy:
def test_create_strategy(self) -> None:
s = Strategy(
id=uuid.uuid4(),
name="momentum",
description="Trend-following strategy",
current_weight=0.5,
active=True,
)
assert s.name == "momentum"
assert s.current_weight == 0.5
assert s.active is True
def test_strategy_defaults(self) -> None:
"""Without a DB session, Python-level defaults are not applied by SQLAlchemy.
The column default is only used at INSERT time."""
s = Strategy(name="test")
assert s.description is None
# Column-level default=True is applied by the database at INSERT time,
# so in-memory the attribute is None until the row is flushed/refreshed.
assert s.active is None or s.active is True
class TestSignal:
def test_create_signal(self) -> None:
sig = Signal(
id=uuid.uuid4(),
ticker="AAPL",
direction=SignalDirection.LONG,
strength=0.85,
strategy_sources={"momentum": 0.9},
sentiment_score=0.7,
acted_on=False,
)
assert sig.ticker == "AAPL"
assert sig.direction == SignalDirection.LONG
assert sig.strength == 0.85
assert sig.acted_on is False
class TestTrade:
def test_create_trade(self) -> None:
t = Trade(
id=uuid.uuid4(),
ticker="TSLA",
side=TradeSide.BUY,
qty=10.0,
price=150.25,
status=TradeStatus.FILLED,
pnl=250.50,
)
assert t.ticker == "TSLA"
assert t.side == TradeSide.BUY
assert t.qty == 10.0
assert t.price == 150.25
assert t.status == TradeStatus.FILLED
assert t.pnl == 250.50
class TestPosition:
def test_create_position(self) -> None:
p = Position(
id=uuid.uuid4(),
ticker="GOOG",
qty=5.0,
avg_entry=2800.00,
unrealized_pnl=-50.0,
stop_loss=2750.0,
take_profit=3000.0,
)
assert p.ticker == "GOOG"
assert p.qty == 5.0
assert p.stop_loss == 2750.0
assert p.take_profit == 3000.0
class TestStrategyWeightHistory:
def test_create_weight_history(self) -> None:
sid = uuid.uuid4()
wh = StrategyWeightHistory(
id=uuid.uuid4(),
strategy_id=sid,
old_weight=0.33,
new_weight=0.40,
reason="Improved win rate",
)
assert wh.strategy_id == sid
assert wh.old_weight == 0.33
assert wh.new_weight == 0.40
class TestArticle:
def test_create_article(self) -> None:
now = datetime.now(timezone.utc)
a = Article(
id=uuid.uuid4(),
source="reuters",
url="https://reuters.com/article/1",
title="Market Rally",
published_at=now,
fetched_at=now,
content_hash="abc123def456",
)
assert a.source == "reuters"
assert a.content_hash == "abc123def456"
class TestArticleSentiment:
def test_create_sentiment(self) -> None:
asent = ArticleSentiment(
id=uuid.uuid4(),
article_id=uuid.uuid4(),
ticker="AAPL",
score=0.85,
confidence=0.92,
model_used="finbert",
)
assert asent.score == 0.85
assert asent.model_used == "finbert"
class TestTradeOutcome:
def test_create_outcome(self) -> None:
outcome = TradeOutcome(
id=uuid.uuid4(),
trade_id=uuid.uuid4(),
hold_duration=timedelta(hours=4, minutes=30),
realized_pnl=125.50,
roi_pct=2.5,
was_profitable=True,
)
assert outcome.realized_pnl == 125.50
assert outcome.was_profitable is True
assert outcome.hold_duration == timedelta(hours=4, minutes=30)
class TestLearningAdjustment:
def test_create_adjustment(self) -> None:
adj = LearningAdjustment(
id=uuid.uuid4(),
strategy_id=uuid.uuid4(),
old_weight=0.30,
new_weight=0.35,
reason="Positive reward signal",
reward_signal=0.72,
)
assert adj.reward_signal == 0.72
assert adj.reason == "Positive reward signal"
class TestUser:
def test_create_user(self) -> None:
u = User(
id=uuid.uuid4(),
username="trader1",
display_name="Top Trader",
)
assert u.username == "trader1"
assert u.display_name == "Top Trader"
class TestUserCredential:
def test_create_credential(self) -> None:
cred = UserCredential(
id=uuid.uuid4(),
user_id=uuid.uuid4(),
credential_id="cred-abc-123",
public_key=b"\x04abcdef",
sign_count=5,
)
assert cred.credential_id == "cred-abc-123"
assert cred.sign_count == 5
assert cred.public_key == b"\x04abcdef"
class TestMarketData:
def test_create_market_data(self) -> None:
now = datetime.now(timezone.utc)
md = MarketData(
timestamp=now,
ticker="AAPL",
open=150.0,
high=155.0,
low=149.0,
close=153.0,
volume=1_000_000.0,
)
assert md.ticker == "AAPL"
assert md.close == 153.0
class TestPortfolioSnapshot:
def test_create_snapshot(self) -> None:
now = datetime.now(timezone.utc)
snap = PortfolioSnapshot(
timestamp=now,
total_value=100_000.0,
cash=25_000.0,
positions_value=75_000.0,
daily_pnl=1_200.0,
)
assert snap.total_value == 100_000.0
assert snap.daily_pnl == 1_200.0
class TestStrategyMetric:
def test_create_metric(self) -> None:
now = datetime.now(timezone.utc)
sm = StrategyMetric(
timestamp=now,
strategy_id=uuid.uuid4(),
win_rate=0.65,
total_pnl=5_432.10,
trade_count=42,
sharpe_ratio=1.8,
)
assert sm.win_rate == 0.65
assert sm.trade_count == 42
assert sm.sharpe_ratio == 1.8
# ---------------------------------------------------------------------------
# Metadata / Base tests
# ---------------------------------------------------------------------------
class TestMetadata:
def test_all_tables_registered(self) -> None:
table_names = set(Base.metadata.tables.keys())
expected = {
"strategies",
"signals",
"trades",
"positions",
"strategy_weight_history",
"articles",
"article_sentiments",
"trade_outcomes",
"learning_adjustments",
"users",
"user_credentials",
"market_data",
"portfolio_snapshots",
"strategy_metrics",
}
assert expected.issubset(table_names)
def test_timestamp_mixin_fields(self) -> None:
"""TimestampMixin should contribute created_at and updated_at columns."""
assert "created_at" in Strategy.__table__.columns
assert "updated_at" in Strategy.__table__.columns
class TestDbFactory:
def test_create_db_returns_engine_and_session(self) -> None:
config = BaseConfig()
engine, session_factory = create_db(config)
assert engine is not None
assert session_factory is not None