feat: database models and alembic migrations — all tables per design
- shared/db.py: async engine + session factory - shared/models/base.py: DeclarativeBase + TimestampMixin - shared/models/trading.py: Strategy, Signal, Trade, Position, StrategyWeightHistory - shared/models/news.py: Article, ArticleSentiment - shared/models/learning.py: TradeOutcome, LearningAdjustment - shared/models/auth.py: User, UserCredential - shared/models/timeseries.py: MarketData, PortfolioSnapshot, StrategyMetric - Alembic async env.py with initial migration including TimescaleDB hypertables - 21 model tests covering enums, instantiation, metadata registration
This commit is contained in:
parent
ae5b3f89d1
commit
72cb1b6fe5
23 changed files with 1283 additions and 0 deletions
151
alembic.ini
Normal file
151
alembic.ini
Normal file
|
|
@ -0,0 +1,151 @@
|
|||
# A generic, single database configuration.
|
||||
|
||||
[alembic]
|
||||
# path to migration scripts.
|
||||
# this is typically a path given in POSIX (e.g. forward slashes)
|
||||
# format, relative to the token %(here)s which refers to the location of this
|
||||
# ini file
|
||||
script_location = %(here)s/alembic
|
||||
|
||||
# template used to generate migration file names; The default value is %%(rev)s_%%(slug)s
|
||||
# Uncomment the line below if you want the files to be prepended with date and time
|
||||
# see https://alembic.sqlalchemy.org/en/latest/tutorial.html#editing-the-ini-file
|
||||
# for all available tokens
|
||||
# file_template = %%(year)d_%%(month).2d_%%(day).2d_%%(hour).2d%%(minute).2d-%%(rev)s_%%(slug)s
|
||||
# Or organize into date-based subdirectories (requires recursive_version_locations = true)
|
||||
# file_template = %%(year)d/%%(month).2d/%%(day).2d_%%(hour).2d%%(minute).2d_%%(second).2d_%%(rev)s_%%(slug)s
|
||||
|
||||
# sys.path path, will be prepended to sys.path if present.
|
||||
# defaults to the current working directory. for multiple paths, the path separator
|
||||
# is defined by "path_separator" below.
|
||||
prepend_sys_path = .
|
||||
|
||||
|
||||
# timezone to use when rendering the date within the migration file
|
||||
# as well as the filename.
|
||||
# If specified, requires the tzdata library which can be installed by adding
|
||||
# `alembic[tz]` to the pip requirements.
|
||||
# string value is passed to ZoneInfo()
|
||||
# leave blank for localtime
|
||||
# timezone =
|
||||
|
||||
# max length of characters to apply to the "slug" field
|
||||
# truncate_slug_length = 40
|
||||
|
||||
# set to 'true' to run the environment during
|
||||
# the 'revision' command, regardless of autogenerate
|
||||
# revision_environment = false
|
||||
|
||||
# set to 'true' to allow .pyc and .pyo files without
|
||||
# a source .py file to be detected as revisions in the
|
||||
# versions/ directory
|
||||
# sourceless = false
|
||||
|
||||
# version location specification; This defaults
|
||||
# to <script_location>/versions. When using multiple version
|
||||
# directories, initial revisions must be specified with --version-path.
|
||||
# The path separator used here should be the separator specified by "path_separator"
|
||||
# below.
|
||||
# version_locations = %(here)s/bar:%(here)s/bat:%(here)s/alembic/versions
|
||||
|
||||
# path_separator; This indicates what character is used to split lists of file
|
||||
# paths, including version_locations and prepend_sys_path within configparser
|
||||
# files such as alembic.ini.
|
||||
# The default rendered in new alembic.ini files is "os", which uses os.pathsep
|
||||
# to provide os-dependent path splitting.
|
||||
#
|
||||
# Note that in order to support legacy alembic.ini files, this default does NOT
|
||||
# take place if path_separator is not present in alembic.ini. If this
|
||||
# option is omitted entirely, fallback logic is as follows:
|
||||
#
|
||||
# 1. Parsing of the version_locations option falls back to using the legacy
|
||||
# "version_path_separator" key, which if absent then falls back to the legacy
|
||||
# behavior of splitting on spaces and/or commas.
|
||||
# 2. Parsing of the prepend_sys_path option falls back to the legacy
|
||||
# behavior of splitting on spaces, commas, or colons.
|
||||
#
|
||||
# Valid values for path_separator are:
|
||||
#
|
||||
# path_separator = :
|
||||
# path_separator = ;
|
||||
# path_separator = space
|
||||
# path_separator = newline
|
||||
#
|
||||
# Use os.pathsep. Default configuration used for new projects.
|
||||
path_separator = os
|
||||
|
||||
# set to 'true' to search source files recursively
|
||||
# in each "version_locations" directory
|
||||
# new in Alembic version 1.10
|
||||
# recursive_version_locations = false
|
||||
|
||||
# the output encoding used when revision files
|
||||
# are written from script.py.mako
|
||||
# output_encoding = utf-8
|
||||
|
||||
# database URL. This is consumed by the user-maintained env.py script only.
|
||||
# other means of configuring database URLs may be customized within the env.py
|
||||
# file.
|
||||
# Database URL is read from the TRADING_DATABASE_URL environment variable
|
||||
# in alembic/env.py. The value here is a fallback only.
|
||||
sqlalchemy.url = postgresql+asyncpg://trading:trading@localhost:5432/trading
|
||||
|
||||
|
||||
[post_write_hooks]
|
||||
# post_write_hooks defines scripts or Python functions that are run
|
||||
# on newly generated revision scripts. See the documentation for further
|
||||
# detail and examples
|
||||
|
||||
# format using "black" - use the console_scripts runner, against the "black" entrypoint
|
||||
# hooks = black
|
||||
# black.type = console_scripts
|
||||
# black.entrypoint = black
|
||||
# black.options = -l 79 REVISION_SCRIPT_FILENAME
|
||||
|
||||
# lint with attempts to fix using "ruff" - use the module runner, against the "ruff" module
|
||||
# hooks = ruff
|
||||
# ruff.type = module
|
||||
# ruff.module = ruff
|
||||
# ruff.options = check --fix REVISION_SCRIPT_FILENAME
|
||||
|
||||
# Alternatively, use the exec runner to execute a binary found on your PATH
|
||||
# hooks = ruff
|
||||
# ruff.type = exec
|
||||
# ruff.executable = ruff
|
||||
# ruff.options = check --fix REVISION_SCRIPT_FILENAME
|
||||
|
||||
# Logging configuration. This is also consumed by the user-maintained
|
||||
# env.py script only.
|
||||
[loggers]
|
||||
keys = root,sqlalchemy,alembic
|
||||
|
||||
[handlers]
|
||||
keys = console
|
||||
|
||||
[formatters]
|
||||
keys = generic
|
||||
|
||||
[logger_root]
|
||||
level = WARNING
|
||||
handlers = console
|
||||
qualname =
|
||||
|
||||
[logger_sqlalchemy]
|
||||
level = WARNING
|
||||
handlers =
|
||||
qualname = sqlalchemy.engine
|
||||
|
||||
[logger_alembic]
|
||||
level = INFO
|
||||
handlers =
|
||||
qualname = alembic
|
||||
|
||||
[handler_console]
|
||||
class = StreamHandler
|
||||
args = (sys.stderr,)
|
||||
level = NOTSET
|
||||
formatter = generic
|
||||
|
||||
[formatter_generic]
|
||||
format = %(levelname)-5.5s [%(name)s] %(message)s
|
||||
datefmt = %H:%M:%S
|
||||
1
alembic/README
Normal file
1
alembic/README
Normal file
|
|
@ -0,0 +1 @@
|
|||
Generic single-database configuration.
|
||||
BIN
alembic/__pycache__/env.cpython-314.pyc
Normal file
BIN
alembic/__pycache__/env.cpython-314.pyc
Normal file
Binary file not shown.
71
alembic/env.py
Normal file
71
alembic/env.py
Normal file
|
|
@ -0,0 +1,71 @@
|
|||
"""Alembic environment — async-aware, imports all shared models."""
|
||||
|
||||
import asyncio
|
||||
import os
|
||||
from logging.config import fileConfig
|
||||
|
||||
from alembic import context
|
||||
from sqlalchemy import pool
|
||||
from sqlalchemy.ext.asyncio import async_engine_from_config
|
||||
|
||||
# Import all models so metadata is populated for autogenerate.
|
||||
from shared.models import Base # noqa: F401
|
||||
import shared.models # noqa: F401 — triggers side-effect imports of every model
|
||||
|
||||
config = context.config
|
||||
|
||||
# Override sqlalchemy.url from environment if available.
|
||||
db_url = os.environ.get("TRADING_DATABASE_URL")
|
||||
if db_url:
|
||||
config.set_main_option("sqlalchemy.url", db_url)
|
||||
|
||||
# Interpret the config file for Python logging.
|
||||
if config.config_file_name is not None:
|
||||
fileConfig(config.config_file_name)
|
||||
|
||||
target_metadata = Base.metadata
|
||||
|
||||
|
||||
def run_migrations_offline() -> None:
|
||||
"""Run migrations in 'offline' mode (emit SQL without a live connection)."""
|
||||
url = config.get_main_option("sqlalchemy.url")
|
||||
context.configure(
|
||||
url=url,
|
||||
target_metadata=target_metadata,
|
||||
literal_binds=True,
|
||||
dialect_opts={"paramstyle": "named"},
|
||||
)
|
||||
|
||||
with context.begin_transaction():
|
||||
context.run_migrations()
|
||||
|
||||
|
||||
def do_run_migrations(connection) -> None: # noqa: ANN001
|
||||
context.configure(connection=connection, target_metadata=target_metadata)
|
||||
with context.begin_transaction():
|
||||
context.run_migrations()
|
||||
|
||||
|
||||
async def run_async_migrations() -> None:
|
||||
"""Run migrations in 'online' mode using an async engine."""
|
||||
connectable = async_engine_from_config(
|
||||
config.get_section(config.config_ini_section, {}),
|
||||
prefix="sqlalchemy.",
|
||||
poolclass=pool.NullPool,
|
||||
)
|
||||
|
||||
async with connectable.connect() as connection:
|
||||
await connection.run_sync(do_run_migrations)
|
||||
|
||||
await connectable.dispose()
|
||||
|
||||
|
||||
def run_migrations_online() -> None:
|
||||
"""Entry-point for online migrations — delegates to the async helper."""
|
||||
asyncio.run(run_async_migrations())
|
||||
|
||||
|
||||
if context.is_offline_mode():
|
||||
run_migrations_offline()
|
||||
else:
|
||||
run_migrations_online()
|
||||
28
alembic/script.py.mako
Normal file
28
alembic/script.py.mako
Normal file
|
|
@ -0,0 +1,28 @@
|
|||
"""${message}
|
||||
|
||||
Revision ID: ${up_revision}
|
||||
Revises: ${down_revision | comma,n}
|
||||
Create Date: ${create_date}
|
||||
|
||||
"""
|
||||
from typing import Sequence, Union
|
||||
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
${imports if imports else ""}
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision: str = ${repr(up_revision)}
|
||||
down_revision: Union[str, Sequence[str], None] = ${repr(down_revision)}
|
||||
branch_labels: Union[str, Sequence[str], None] = ${repr(branch_labels)}
|
||||
depends_on: Union[str, Sequence[str], None] = ${repr(depends_on)}
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
"""Upgrade schema."""
|
||||
${upgrades if upgrades else "pass"}
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
"""Downgrade schema."""
|
||||
${downgrades if downgrades else "pass"}
|
||||
Binary file not shown.
284
alembic/versions/a1b2c3d4e5f6_initial_schema.py
Normal file
284
alembic/versions/a1b2c3d4e5f6_initial_schema.py
Normal file
|
|
@ -0,0 +1,284 @@
|
|||
"""initial schema
|
||||
|
||||
Revision ID: a1b2c3d4e5f6
|
||||
Revises:
|
||||
Create Date: 2026-02-22 15:15:15.661206
|
||||
|
||||
"""
|
||||
from typing import Sequence, Union
|
||||
|
||||
import sqlalchemy as sa
|
||||
from alembic import op
|
||||
from sqlalchemy.dialects import postgresql
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision: str = "a1b2c3d4e5f6"
|
||||
down_revision: Union[str, Sequence[str], None] = None
|
||||
branch_labels: Union[str, Sequence[str], None] = None
|
||||
depends_on: Union[str, Sequence[str], None] = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
"""Create all tables for the trading bot."""
|
||||
|
||||
# --- Core trading tables ---
|
||||
|
||||
op.create_table(
|
||||
"strategies",
|
||||
sa.Column("id", postgresql.UUID(as_uuid=True), primary_key=True),
|
||||
sa.Column("name", sa.String(255), unique=True, nullable=False),
|
||||
sa.Column("description", sa.Text, nullable=True),
|
||||
sa.Column("current_weight", sa.Float, nullable=False, server_default="0.333"),
|
||||
sa.Column("active", sa.Boolean, nullable=False, server_default=sa.text("true")),
|
||||
sa.Column("created_at", sa.DateTime(timezone=True), server_default=sa.func.now()),
|
||||
sa.Column("updated_at", sa.DateTime(timezone=True), server_default=sa.func.now()),
|
||||
)
|
||||
|
||||
op.create_table(
|
||||
"signals",
|
||||
sa.Column("id", postgresql.UUID(as_uuid=True), primary_key=True),
|
||||
sa.Column("ticker", sa.String(20), nullable=False, index=True),
|
||||
sa.Column(
|
||||
"direction",
|
||||
sa.Enum("LONG", "SHORT", "NEUTRAL", name="signaldirection"),
|
||||
nullable=False,
|
||||
),
|
||||
sa.Column("strength", sa.Float, nullable=False),
|
||||
sa.Column("strategy_sources", postgresql.JSON, nullable=True),
|
||||
sa.Column("sentiment_score", sa.Float, nullable=True),
|
||||
sa.Column("acted_on", sa.Boolean, nullable=False, server_default=sa.text("false")),
|
||||
sa.Column(
|
||||
"strategy_id",
|
||||
postgresql.UUID(as_uuid=True),
|
||||
sa.ForeignKey("strategies.id"),
|
||||
nullable=True,
|
||||
),
|
||||
sa.Column("created_at", sa.DateTime(timezone=True), server_default=sa.func.now()),
|
||||
sa.Column("updated_at", sa.DateTime(timezone=True), server_default=sa.func.now()),
|
||||
)
|
||||
|
||||
op.create_table(
|
||||
"trades",
|
||||
sa.Column("id", postgresql.UUID(as_uuid=True), primary_key=True),
|
||||
sa.Column("ticker", sa.String(20), nullable=False, index=True),
|
||||
sa.Column(
|
||||
"side",
|
||||
sa.Enum("BUY", "SELL", name="tradeside"),
|
||||
nullable=False,
|
||||
),
|
||||
sa.Column("qty", sa.Float, nullable=False),
|
||||
sa.Column("price", sa.Float, nullable=False),
|
||||
sa.Column("timestamp", sa.String, nullable=True),
|
||||
sa.Column(
|
||||
"strategy_id",
|
||||
postgresql.UUID(as_uuid=True),
|
||||
sa.ForeignKey("strategies.id"),
|
||||
nullable=True,
|
||||
),
|
||||
sa.Column(
|
||||
"signal_id",
|
||||
postgresql.UUID(as_uuid=True),
|
||||
sa.ForeignKey("signals.id"),
|
||||
nullable=True,
|
||||
),
|
||||
sa.Column(
|
||||
"status",
|
||||
sa.Enum("PENDING", "FILLED", "CANCELLED", "REJECTED", name="tradestatus"),
|
||||
nullable=False,
|
||||
server_default="PENDING",
|
||||
),
|
||||
sa.Column("pnl", sa.Float, nullable=True),
|
||||
sa.Column("created_at", sa.DateTime(timezone=True), server_default=sa.func.now()),
|
||||
sa.Column("updated_at", sa.DateTime(timezone=True), server_default=sa.func.now()),
|
||||
)
|
||||
|
||||
op.create_table(
|
||||
"positions",
|
||||
sa.Column("id", postgresql.UUID(as_uuid=True), primary_key=True),
|
||||
sa.Column("ticker", sa.String(20), unique=True, nullable=False),
|
||||
sa.Column("qty", sa.Float, nullable=False),
|
||||
sa.Column("avg_entry", sa.Float, nullable=False),
|
||||
sa.Column("unrealized_pnl", sa.Float, nullable=True),
|
||||
sa.Column("stop_loss", sa.Float, nullable=True),
|
||||
sa.Column("take_profit", sa.Float, nullable=True),
|
||||
sa.Column("created_at", sa.DateTime(timezone=True), server_default=sa.func.now()),
|
||||
sa.Column("updated_at", sa.DateTime(timezone=True), server_default=sa.func.now()),
|
||||
)
|
||||
|
||||
op.create_table(
|
||||
"strategy_weight_history",
|
||||
sa.Column("id", postgresql.UUID(as_uuid=True), primary_key=True),
|
||||
sa.Column(
|
||||
"strategy_id",
|
||||
postgresql.UUID(as_uuid=True),
|
||||
sa.ForeignKey("strategies.id"),
|
||||
nullable=False,
|
||||
),
|
||||
sa.Column("old_weight", sa.Float, nullable=False),
|
||||
sa.Column("new_weight", sa.Float, nullable=False),
|
||||
sa.Column("reason", sa.String(500), nullable=True),
|
||||
sa.Column("created_at", sa.DateTime(timezone=True), server_default=sa.func.now()),
|
||||
sa.Column("updated_at", sa.DateTime(timezone=True), server_default=sa.func.now()),
|
||||
)
|
||||
|
||||
# --- News & sentiment ---
|
||||
|
||||
op.create_table(
|
||||
"articles",
|
||||
sa.Column("id", postgresql.UUID(as_uuid=True), primary_key=True),
|
||||
sa.Column("source", sa.String(100), nullable=False),
|
||||
sa.Column("url", sa.Text, nullable=False),
|
||||
sa.Column("title", sa.Text, nullable=False),
|
||||
sa.Column("published_at", sa.DateTime(timezone=True), nullable=True),
|
||||
sa.Column("fetched_at", sa.DateTime(timezone=True), nullable=False),
|
||||
sa.Column("content_hash", sa.String(64), unique=True, nullable=False),
|
||||
sa.Column("created_at", sa.DateTime(timezone=True), server_default=sa.func.now()),
|
||||
sa.Column("updated_at", sa.DateTime(timezone=True), server_default=sa.func.now()),
|
||||
)
|
||||
|
||||
op.create_table(
|
||||
"article_sentiments",
|
||||
sa.Column("id", postgresql.UUID(as_uuid=True), primary_key=True),
|
||||
sa.Column(
|
||||
"article_id",
|
||||
postgresql.UUID(as_uuid=True),
|
||||
sa.ForeignKey("articles.id"),
|
||||
nullable=False,
|
||||
),
|
||||
sa.Column("ticker", sa.String(20), nullable=False, index=True),
|
||||
sa.Column("score", sa.Float, nullable=False),
|
||||
sa.Column("confidence", sa.Float, nullable=False),
|
||||
sa.Column("model_used", sa.String(50), nullable=False),
|
||||
sa.Column("created_at", sa.DateTime(timezone=True), server_default=sa.func.now()),
|
||||
sa.Column("updated_at", sa.DateTime(timezone=True), server_default=sa.func.now()),
|
||||
)
|
||||
|
||||
# --- Learning ---
|
||||
|
||||
op.create_table(
|
||||
"trade_outcomes",
|
||||
sa.Column("id", postgresql.UUID(as_uuid=True), primary_key=True),
|
||||
sa.Column(
|
||||
"trade_id",
|
||||
postgresql.UUID(as_uuid=True),
|
||||
sa.ForeignKey("trades.id"),
|
||||
unique=True,
|
||||
nullable=False,
|
||||
),
|
||||
sa.Column("hold_duration", sa.Interval, nullable=True),
|
||||
sa.Column("realized_pnl", sa.Float, nullable=False),
|
||||
sa.Column("roi_pct", sa.Float, nullable=False),
|
||||
sa.Column("was_profitable", sa.Boolean, nullable=False),
|
||||
sa.Column("created_at", sa.DateTime(timezone=True), server_default=sa.func.now()),
|
||||
sa.Column("updated_at", sa.DateTime(timezone=True), server_default=sa.func.now()),
|
||||
)
|
||||
|
||||
op.create_table(
|
||||
"learning_adjustments",
|
||||
sa.Column("id", postgresql.UUID(as_uuid=True), primary_key=True),
|
||||
sa.Column(
|
||||
"strategy_id",
|
||||
postgresql.UUID(as_uuid=True),
|
||||
sa.ForeignKey("strategies.id"),
|
||||
nullable=False,
|
||||
),
|
||||
sa.Column("old_weight", sa.Float, nullable=False),
|
||||
sa.Column("new_weight", sa.Float, nullable=False),
|
||||
sa.Column("reason", sa.Text, nullable=True),
|
||||
sa.Column("reward_signal", sa.Float, nullable=False),
|
||||
sa.Column("created_at", sa.DateTime(timezone=True), server_default=sa.func.now()),
|
||||
sa.Column("updated_at", sa.DateTime(timezone=True), server_default=sa.func.now()),
|
||||
)
|
||||
|
||||
# --- Auth ---
|
||||
|
||||
op.create_table(
|
||||
"users",
|
||||
sa.Column("id", postgresql.UUID(as_uuid=True), primary_key=True),
|
||||
sa.Column("username", sa.String(100), unique=True, nullable=False),
|
||||
sa.Column("display_name", sa.String(255), nullable=True),
|
||||
sa.Column("created_at", sa.DateTime(timezone=True), server_default=sa.func.now()),
|
||||
sa.Column("updated_at", sa.DateTime(timezone=True), server_default=sa.func.now()),
|
||||
)
|
||||
|
||||
op.create_table(
|
||||
"user_credentials",
|
||||
sa.Column("id", postgresql.UUID(as_uuid=True), primary_key=True),
|
||||
sa.Column(
|
||||
"user_id",
|
||||
postgresql.UUID(as_uuid=True),
|
||||
sa.ForeignKey("users.id"),
|
||||
nullable=False,
|
||||
),
|
||||
sa.Column("credential_id", sa.String(512), unique=True, nullable=False),
|
||||
sa.Column("public_key", sa.LargeBinary, nullable=False),
|
||||
sa.Column("sign_count", sa.Integer, nullable=False, server_default="0"),
|
||||
sa.Column("created_at", sa.DateTime(timezone=True), server_default=sa.func.now()),
|
||||
sa.Column("updated_at", sa.DateTime(timezone=True), server_default=sa.func.now()),
|
||||
)
|
||||
|
||||
# --- Timeseries (TimescaleDB hypertables) ---
|
||||
|
||||
op.create_table(
|
||||
"market_data",
|
||||
sa.Column("timestamp", sa.DateTime(timezone=True), primary_key=True),
|
||||
sa.Column("ticker", sa.String(20), primary_key=True),
|
||||
sa.Column("open", sa.Float, nullable=False),
|
||||
sa.Column("high", sa.Float, nullable=False),
|
||||
sa.Column("low", sa.Float, nullable=False),
|
||||
sa.Column("close", sa.Float, nullable=False),
|
||||
sa.Column("volume", sa.Float, nullable=False),
|
||||
)
|
||||
|
||||
op.create_table(
|
||||
"portfolio_snapshots",
|
||||
sa.Column("timestamp", sa.DateTime(timezone=True), primary_key=True),
|
||||
sa.Column("total_value", sa.Float, nullable=False),
|
||||
sa.Column("cash", sa.Float, nullable=False),
|
||||
sa.Column("positions_value", sa.Float, nullable=False),
|
||||
sa.Column("daily_pnl", sa.Float, nullable=False),
|
||||
)
|
||||
|
||||
op.create_table(
|
||||
"strategy_metrics",
|
||||
sa.Column("timestamp", sa.DateTime(timezone=True), primary_key=True),
|
||||
sa.Column(
|
||||
"strategy_id",
|
||||
postgresql.UUID(as_uuid=True),
|
||||
sa.ForeignKey("strategies.id"),
|
||||
primary_key=True,
|
||||
),
|
||||
sa.Column("win_rate", sa.Float, nullable=False),
|
||||
sa.Column("total_pnl", sa.Float, nullable=False),
|
||||
sa.Column("trade_count", sa.Integer, nullable=False),
|
||||
sa.Column("sharpe_ratio", sa.Float, nullable=True),
|
||||
)
|
||||
|
||||
# Convert timeseries tables to TimescaleDB hypertables.
|
||||
# These calls are idempotent-safe when the extension is loaded.
|
||||
op.execute("SELECT create_hypertable('market_data', 'timestamp', if_not_exists => TRUE)")
|
||||
op.execute("SELECT create_hypertable('portfolio_snapshots', 'timestamp', if_not_exists => TRUE)")
|
||||
op.execute("SELECT create_hypertable('strategy_metrics', 'timestamp', if_not_exists => TRUE)")
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
"""Drop all tables in reverse dependency order."""
|
||||
op.drop_table("strategy_metrics")
|
||||
op.drop_table("portfolio_snapshots")
|
||||
op.drop_table("market_data")
|
||||
op.drop_table("user_credentials")
|
||||
op.drop_table("users")
|
||||
op.drop_table("learning_adjustments")
|
||||
op.drop_table("trade_outcomes")
|
||||
op.drop_table("article_sentiments")
|
||||
op.drop_table("articles")
|
||||
op.drop_table("strategy_weight_history")
|
||||
op.drop_table("positions")
|
||||
op.drop_table("trades")
|
||||
op.drop_table("signals")
|
||||
op.drop_table("strategies")
|
||||
|
||||
# Drop enums
|
||||
sa.Enum(name="signaldirection").drop(op.get_bind(), checkfirst=True)
|
||||
sa.Enum(name="tradeside").drop(op.get_bind(), checkfirst=True)
|
||||
sa.Enum(name="tradestatus").drop(op.get_bind(), checkfirst=True)
|
||||
22
shared/db.py
Normal file
22
shared/db.py
Normal file
|
|
@ -0,0 +1,22 @@
|
|||
"""SQLAlchemy async engine and session factory."""
|
||||
|
||||
from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker, create_async_engine
|
||||
|
||||
from shared.config import BaseConfig
|
||||
|
||||
|
||||
def create_db(config: BaseConfig) -> tuple:
|
||||
"""Create an async engine and session factory from the given config.
|
||||
|
||||
Returns a ``(engine, session_factory)`` tuple.
|
||||
"""
|
||||
engine = create_async_engine(
|
||||
config.database_url,
|
||||
echo=config.log_level == "DEBUG",
|
||||
)
|
||||
session_factory = async_sessionmaker(
|
||||
engine,
|
||||
class_=AsyncSession,
|
||||
expire_on_commit=False,
|
||||
)
|
||||
return engine, session_factory
|
||||
44
shared/models/__init__.py
Normal file
44
shared/models/__init__.py
Normal file
|
|
@ -0,0 +1,44 @@
|
|||
"""Shared SQLAlchemy models — import all models here so Alembic can discover them."""
|
||||
|
||||
from shared.models.base import Base, TimestampMixin
|
||||
from shared.models.trading import (
|
||||
Signal,
|
||||
SignalDirection,
|
||||
Strategy,
|
||||
StrategyWeightHistory,
|
||||
Trade,
|
||||
TradeSide,
|
||||
TradeStatus,
|
||||
Position,
|
||||
)
|
||||
from shared.models.news import Article, ArticleSentiment
|
||||
from shared.models.learning import LearningAdjustment, TradeOutcome
|
||||
from shared.models.auth import User, UserCredential
|
||||
from shared.models.timeseries import MarketData, PortfolioSnapshot, StrategyMetric
|
||||
|
||||
__all__ = [
|
||||
"Base",
|
||||
"TimestampMixin",
|
||||
# Trading
|
||||
"Strategy",
|
||||
"Signal",
|
||||
"SignalDirection",
|
||||
"Trade",
|
||||
"TradeSide",
|
||||
"TradeStatus",
|
||||
"Position",
|
||||
"StrategyWeightHistory",
|
||||
# News
|
||||
"Article",
|
||||
"ArticleSentiment",
|
||||
# Learning
|
||||
"TradeOutcome",
|
||||
"LearningAdjustment",
|
||||
# Auth
|
||||
"User",
|
||||
"UserCredential",
|
||||
# Timeseries
|
||||
"MarketData",
|
||||
"PortfolioSnapshot",
|
||||
"StrategyMetric",
|
||||
]
|
||||
BIN
shared/models/__pycache__/__init__.cpython-314.pyc
Normal file
BIN
shared/models/__pycache__/__init__.cpython-314.pyc
Normal file
Binary file not shown.
BIN
shared/models/__pycache__/auth.cpython-314.pyc
Normal file
BIN
shared/models/__pycache__/auth.cpython-314.pyc
Normal file
Binary file not shown.
BIN
shared/models/__pycache__/base.cpython-314.pyc
Normal file
BIN
shared/models/__pycache__/base.cpython-314.pyc
Normal file
Binary file not shown.
BIN
shared/models/__pycache__/learning.cpython-314.pyc
Normal file
BIN
shared/models/__pycache__/learning.cpython-314.pyc
Normal file
Binary file not shown.
BIN
shared/models/__pycache__/news.cpython-314.pyc
Normal file
BIN
shared/models/__pycache__/news.cpython-314.pyc
Normal file
Binary file not shown.
BIN
shared/models/__pycache__/timeseries.cpython-314.pyc
Normal file
BIN
shared/models/__pycache__/timeseries.cpython-314.pyc
Normal file
Binary file not shown.
BIN
shared/models/__pycache__/trading.cpython-314.pyc
Normal file
BIN
shared/models/__pycache__/trading.cpython-314.pyc
Normal file
Binary file not shown.
39
shared/models/auth.py
Normal file
39
shared/models/auth.py
Normal file
|
|
@ -0,0 +1,39 @@
|
|||
"""Authentication models: User, UserCredential."""
|
||||
|
||||
import uuid
|
||||
|
||||
from sqlalchemy import ForeignKey, Integer, LargeBinary, String
|
||||
from sqlalchemy.dialects.postgresql import UUID
|
||||
from sqlalchemy.orm import Mapped, mapped_column, relationship
|
||||
|
||||
from shared.models.base import Base, TimestampMixin
|
||||
|
||||
|
||||
class User(TimestampMixin, Base):
|
||||
__tablename__ = "users"
|
||||
|
||||
id: Mapped[uuid.UUID] = mapped_column(
|
||||
UUID(as_uuid=True), primary_key=True, default=uuid.uuid4
|
||||
)
|
||||
username: Mapped[str] = mapped_column(String(100), unique=True, nullable=False)
|
||||
display_name: Mapped[str | None] = mapped_column(String(255), nullable=True)
|
||||
|
||||
# Relationships
|
||||
credentials: Mapped[list["UserCredential"]] = relationship(back_populates="user")
|
||||
|
||||
|
||||
class UserCredential(TimestampMixin, Base):
|
||||
__tablename__ = "user_credentials"
|
||||
|
||||
id: Mapped[uuid.UUID] = mapped_column(
|
||||
UUID(as_uuid=True), primary_key=True, default=uuid.uuid4
|
||||
)
|
||||
user_id: Mapped[uuid.UUID] = mapped_column(
|
||||
UUID(as_uuid=True), ForeignKey("users.id"), nullable=False
|
||||
)
|
||||
credential_id: Mapped[str] = mapped_column(String(512), unique=True, nullable=False)
|
||||
public_key: Mapped[bytes] = mapped_column(LargeBinary, nullable=False)
|
||||
sign_count: Mapped[int] = mapped_column(Integer, nullable=False, default=0)
|
||||
|
||||
# Relationships
|
||||
user: Mapped[User] = relationship(back_populates="credentials")
|
||||
26
shared/models/base.py
Normal file
26
shared/models/base.py
Normal file
|
|
@ -0,0 +1,26 @@
|
|||
"""SQLAlchemy declarative base and common mixins."""
|
||||
|
||||
from datetime import datetime
|
||||
|
||||
from sqlalchemy import DateTime, func
|
||||
from sqlalchemy.orm import DeclarativeBase, Mapped, mapped_column
|
||||
|
||||
|
||||
class Base(DeclarativeBase):
|
||||
"""Shared declarative base for all models."""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
class TimestampMixin:
|
||||
"""Adds ``created_at`` and ``updated_at`` columns with server defaults."""
|
||||
|
||||
created_at: Mapped[datetime] = mapped_column(
|
||||
DateTime(timezone=True),
|
||||
server_default=func.now(),
|
||||
)
|
||||
updated_at: Mapped[datetime] = mapped_column(
|
||||
DateTime(timezone=True),
|
||||
server_default=func.now(),
|
||||
onupdate=func.now(),
|
||||
)
|
||||
51
shared/models/learning.py
Normal file
51
shared/models/learning.py
Normal file
|
|
@ -0,0 +1,51 @@
|
|||
"""Learning domain models: TradeOutcome, LearningAdjustment."""
|
||||
|
||||
import uuid
|
||||
from datetime import timedelta
|
||||
|
||||
from sqlalchemy import Boolean, Float, ForeignKey, Interval, String, Text
|
||||
from sqlalchemy.dialects.postgresql import UUID
|
||||
from sqlalchemy.orm import Mapped, mapped_column, relationship
|
||||
|
||||
from shared.models.base import Base, TimestampMixin
|
||||
|
||||
|
||||
class TradeOutcome(TimestampMixin, Base):
|
||||
__tablename__ = "trade_outcomes"
|
||||
|
||||
id: Mapped[uuid.UUID] = mapped_column(
|
||||
UUID(as_uuid=True), primary_key=True, default=uuid.uuid4
|
||||
)
|
||||
trade_id: Mapped[uuid.UUID] = mapped_column(
|
||||
UUID(as_uuid=True), ForeignKey("trades.id"), unique=True, nullable=False
|
||||
)
|
||||
hold_duration: Mapped[timedelta | None] = mapped_column(Interval, nullable=True)
|
||||
realized_pnl: Mapped[float] = mapped_column(Float, nullable=False)
|
||||
roi_pct: Mapped[float] = mapped_column(Float, nullable=False)
|
||||
was_profitable: Mapped[bool] = mapped_column(Boolean, nullable=False)
|
||||
|
||||
# Relationships
|
||||
trade: Mapped["Trade"] = relationship("Trade", back_populates="outcome")
|
||||
|
||||
|
||||
class LearningAdjustment(TimestampMixin, Base):
|
||||
__tablename__ = "learning_adjustments"
|
||||
|
||||
id: Mapped[uuid.UUID] = mapped_column(
|
||||
UUID(as_uuid=True), primary_key=True, default=uuid.uuid4
|
||||
)
|
||||
strategy_id: Mapped[uuid.UUID] = mapped_column(
|
||||
UUID(as_uuid=True), ForeignKey("strategies.id"), nullable=False
|
||||
)
|
||||
old_weight: Mapped[float] = mapped_column(Float, nullable=False)
|
||||
new_weight: Mapped[float] = mapped_column(Float, nullable=False)
|
||||
reason: Mapped[str | None] = mapped_column(Text, nullable=True)
|
||||
reward_signal: Mapped[float] = mapped_column(Float, nullable=False)
|
||||
|
||||
# Relationships
|
||||
strategy: Mapped["Strategy"] = relationship("Strategy")
|
||||
|
||||
|
||||
# Avoid circular imports — reference by string in relationship()
|
||||
from shared.models.trading import Trade as Trade # noqa: E402, F401
|
||||
from shared.models.trading import Strategy as Strategy # noqa: E402, F401
|
||||
49
shared/models/news.py
Normal file
49
shared/models/news.py
Normal file
|
|
@ -0,0 +1,49 @@
|
|||
"""News and sentiment models: Article, ArticleSentiment."""
|
||||
|
||||
import uuid
|
||||
from datetime import datetime
|
||||
|
||||
from sqlalchemy import DateTime, Float, ForeignKey, String, Text
|
||||
from sqlalchemy.dialects.postgresql import UUID
|
||||
from sqlalchemy.orm import Mapped, mapped_column, relationship
|
||||
|
||||
from shared.models.base import Base, TimestampMixin
|
||||
|
||||
|
||||
class Article(TimestampMixin, Base):
|
||||
__tablename__ = "articles"
|
||||
|
||||
id: Mapped[uuid.UUID] = mapped_column(
|
||||
UUID(as_uuid=True), primary_key=True, default=uuid.uuid4
|
||||
)
|
||||
source: Mapped[str] = mapped_column(String(100), nullable=False)
|
||||
url: Mapped[str] = mapped_column(Text, nullable=False)
|
||||
title: Mapped[str] = mapped_column(Text, nullable=False)
|
||||
published_at: Mapped[datetime | None] = mapped_column(
|
||||
DateTime(timezone=True), nullable=True
|
||||
)
|
||||
fetched_at: Mapped[datetime] = mapped_column(
|
||||
DateTime(timezone=True), nullable=False
|
||||
)
|
||||
content_hash: Mapped[str] = mapped_column(String(64), unique=True, nullable=False)
|
||||
|
||||
# Relationships
|
||||
sentiments: Mapped[list["ArticleSentiment"]] = relationship(back_populates="article")
|
||||
|
||||
|
||||
class ArticleSentiment(TimestampMixin, Base):
|
||||
__tablename__ = "article_sentiments"
|
||||
|
||||
id: Mapped[uuid.UUID] = mapped_column(
|
||||
UUID(as_uuid=True), primary_key=True, default=uuid.uuid4
|
||||
)
|
||||
article_id: Mapped[uuid.UUID] = mapped_column(
|
||||
UUID(as_uuid=True), ForeignKey("articles.id"), nullable=False
|
||||
)
|
||||
ticker: Mapped[str] = mapped_column(String(20), nullable=False, index=True)
|
||||
score: Mapped[float] = mapped_column(Float, nullable=False)
|
||||
confidence: Mapped[float] = mapped_column(Float, nullable=False)
|
||||
model_used: Mapped[str] = mapped_column(String(50), nullable=False)
|
||||
|
||||
# Relationships
|
||||
article: Mapped[Article] = relationship(back_populates="sentiments")
|
||||
57
shared/models/timeseries.py
Normal file
57
shared/models/timeseries.py
Normal file
|
|
@ -0,0 +1,57 @@
|
|||
"""TimescaleDB hypertable models: MarketData, PortfolioSnapshot, StrategyMetric."""
|
||||
|
||||
import uuid
|
||||
from datetime import datetime
|
||||
|
||||
from sqlalchemy import DateTime, Float, ForeignKey, Integer, String
|
||||
from sqlalchemy.dialects.postgresql import UUID
|
||||
from sqlalchemy.orm import Mapped, mapped_column
|
||||
|
||||
from shared.models.base import Base
|
||||
|
||||
|
||||
class MarketData(Base):
|
||||
"""OHLCV bars — intended as a TimescaleDB hypertable partitioned by timestamp."""
|
||||
|
||||
__tablename__ = "market_data"
|
||||
|
||||
timestamp: Mapped[datetime] = mapped_column(
|
||||
DateTime(timezone=True), primary_key=True
|
||||
)
|
||||
ticker: Mapped[str] = mapped_column(String(20), primary_key=True)
|
||||
open: Mapped[float] = mapped_column(Float, nullable=False)
|
||||
high: Mapped[float] = mapped_column(Float, nullable=False)
|
||||
low: Mapped[float] = mapped_column(Float, nullable=False)
|
||||
close: Mapped[float] = mapped_column(Float, nullable=False)
|
||||
volume: Mapped[float] = mapped_column(Float, nullable=False)
|
||||
|
||||
|
||||
class PortfolioSnapshot(Base):
|
||||
"""Periodic portfolio value snapshots — TimescaleDB hypertable."""
|
||||
|
||||
__tablename__ = "portfolio_snapshots"
|
||||
|
||||
timestamp: Mapped[datetime] = mapped_column(
|
||||
DateTime(timezone=True), primary_key=True
|
||||
)
|
||||
total_value: Mapped[float] = mapped_column(Float, nullable=False)
|
||||
cash: Mapped[float] = mapped_column(Float, nullable=False)
|
||||
positions_value: Mapped[float] = mapped_column(Float, nullable=False)
|
||||
daily_pnl: Mapped[float] = mapped_column(Float, nullable=False)
|
||||
|
||||
|
||||
class StrategyMetric(Base):
|
||||
"""Per-strategy performance over time — TimescaleDB hypertable."""
|
||||
|
||||
__tablename__ = "strategy_metrics"
|
||||
|
||||
timestamp: Mapped[datetime] = mapped_column(
|
||||
DateTime(timezone=True), primary_key=True
|
||||
)
|
||||
strategy_id: Mapped[uuid.UUID] = mapped_column(
|
||||
UUID(as_uuid=True), ForeignKey("strategies.id"), primary_key=True
|
||||
)
|
||||
win_rate: Mapped[float] = mapped_column(Float, nullable=False)
|
||||
total_pnl: Mapped[float] = mapped_column(Float, nullable=False)
|
||||
trade_count: Mapped[int] = mapped_column(Integer, nullable=False)
|
||||
sharpe_ratio: Mapped[float | None] = mapped_column(Float, nullable=True)
|
||||
137
shared/models/trading.py
Normal file
137
shared/models/trading.py
Normal file
|
|
@ -0,0 +1,137 @@
|
|||
"""Trading domain models: Strategy, Signal, Trade, Position, StrategyWeightHistory."""
|
||||
|
||||
import enum
|
||||
import uuid
|
||||
|
||||
from sqlalchemy import Boolean, Float, ForeignKey, String, Text
|
||||
from sqlalchemy.dialects.postgresql import JSON, UUID
|
||||
from sqlalchemy.orm import Mapped, mapped_column, relationship
|
||||
|
||||
from shared.models.base import Base, TimestampMixin
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Enums
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TradeSide(str, enum.Enum):
|
||||
BUY = "BUY"
|
||||
SELL = "SELL"
|
||||
|
||||
|
||||
class TradeStatus(str, enum.Enum):
|
||||
PENDING = "PENDING"
|
||||
FILLED = "FILLED"
|
||||
CANCELLED = "CANCELLED"
|
||||
REJECTED = "REJECTED"
|
||||
|
||||
|
||||
class SignalDirection(str, enum.Enum):
|
||||
LONG = "LONG"
|
||||
SHORT = "SHORT"
|
||||
NEUTRAL = "NEUTRAL"
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Models
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class Strategy(TimestampMixin, Base):
|
||||
__tablename__ = "strategies"
|
||||
|
||||
id: Mapped[uuid.UUID] = mapped_column(
|
||||
UUID(as_uuid=True), primary_key=True, default=uuid.uuid4
|
||||
)
|
||||
name: Mapped[str] = mapped_column(String(255), unique=True, nullable=False)
|
||||
description: Mapped[str | None] = mapped_column(Text, nullable=True)
|
||||
current_weight: Mapped[float] = mapped_column(Float, nullable=False, default=0.333)
|
||||
active: Mapped[bool] = mapped_column(Boolean, default=True, nullable=False)
|
||||
|
||||
# Relationships
|
||||
trades: Mapped[list["Trade"]] = relationship(back_populates="strategy")
|
||||
signals: Mapped[list["Signal"]] = relationship(back_populates="strategy", foreign_keys="Signal.strategy_id", viewonly=True)
|
||||
weight_history: Mapped[list["StrategyWeightHistory"]] = relationship(back_populates="strategy")
|
||||
|
||||
|
||||
class Signal(TimestampMixin, Base):
|
||||
__tablename__ = "signals"
|
||||
|
||||
id: Mapped[uuid.UUID] = mapped_column(
|
||||
UUID(as_uuid=True), primary_key=True, default=uuid.uuid4
|
||||
)
|
||||
ticker: Mapped[str] = mapped_column(String(20), nullable=False, index=True)
|
||||
direction: Mapped[SignalDirection] = mapped_column(nullable=False)
|
||||
strength: Mapped[float] = mapped_column(Float, nullable=False)
|
||||
strategy_sources: Mapped[dict | None] = mapped_column(JSON, nullable=True)
|
||||
sentiment_score: Mapped[float | None] = mapped_column(Float, nullable=True)
|
||||
acted_on: Mapped[bool] = mapped_column(Boolean, default=False, nullable=False)
|
||||
strategy_id: Mapped[uuid.UUID | None] = mapped_column(
|
||||
UUID(as_uuid=True), ForeignKey("strategies.id"), nullable=True
|
||||
)
|
||||
|
||||
# Relationships
|
||||
strategy: Mapped[Strategy | None] = relationship(back_populates="signals", foreign_keys=[strategy_id])
|
||||
trades: Mapped[list["Trade"]] = relationship(back_populates="signal")
|
||||
|
||||
|
||||
class Trade(TimestampMixin, Base):
|
||||
__tablename__ = "trades"
|
||||
|
||||
id: Mapped[uuid.UUID] = mapped_column(
|
||||
UUID(as_uuid=True), primary_key=True, default=uuid.uuid4
|
||||
)
|
||||
ticker: Mapped[str] = mapped_column(String(20), nullable=False, index=True)
|
||||
side: Mapped[TradeSide] = mapped_column(nullable=False)
|
||||
qty: Mapped[float] = mapped_column(Float, nullable=False)
|
||||
price: Mapped[float] = mapped_column(Float, nullable=False)
|
||||
timestamp: Mapped[str | None] = mapped_column(String, nullable=True)
|
||||
strategy_id: Mapped[uuid.UUID | None] = mapped_column(
|
||||
UUID(as_uuid=True), ForeignKey("strategies.id"), nullable=True
|
||||
)
|
||||
signal_id: Mapped[uuid.UUID | None] = mapped_column(
|
||||
UUID(as_uuid=True), ForeignKey("signals.id"), nullable=True
|
||||
)
|
||||
status: Mapped[TradeStatus] = mapped_column(nullable=False, default=TradeStatus.PENDING)
|
||||
pnl: Mapped[float | None] = mapped_column(Float, nullable=True)
|
||||
|
||||
# Relationships
|
||||
strategy: Mapped[Strategy | None] = relationship(back_populates="trades")
|
||||
signal: Mapped[Signal | None] = relationship(back_populates="trades")
|
||||
outcome: Mapped["TradeOutcome | None"] = relationship(
|
||||
"TradeOutcome", back_populates="trade", uselist=False
|
||||
)
|
||||
|
||||
|
||||
class Position(TimestampMixin, Base):
|
||||
__tablename__ = "positions"
|
||||
|
||||
id: Mapped[uuid.UUID] = mapped_column(
|
||||
UUID(as_uuid=True), primary_key=True, default=uuid.uuid4
|
||||
)
|
||||
ticker: Mapped[str] = mapped_column(String(20), unique=True, nullable=False)
|
||||
qty: Mapped[float] = mapped_column(Float, nullable=False)
|
||||
avg_entry: Mapped[float] = mapped_column(Float, nullable=False)
|
||||
unrealized_pnl: Mapped[float | None] = mapped_column(Float, nullable=True)
|
||||
stop_loss: Mapped[float | None] = mapped_column(Float, nullable=True)
|
||||
take_profit: Mapped[float | None] = mapped_column(Float, nullable=True)
|
||||
|
||||
|
||||
class StrategyWeightHistory(TimestampMixin, Base):
|
||||
__tablename__ = "strategy_weight_history"
|
||||
|
||||
id: Mapped[uuid.UUID] = mapped_column(
|
||||
UUID(as_uuid=True), primary_key=True, default=uuid.uuid4
|
||||
)
|
||||
strategy_id: Mapped[uuid.UUID] = mapped_column(
|
||||
UUID(as_uuid=True), ForeignKey("strategies.id"), nullable=False
|
||||
)
|
||||
old_weight: Mapped[float] = mapped_column(Float, nullable=False)
|
||||
new_weight: Mapped[float] = mapped_column(Float, nullable=False)
|
||||
reason: Mapped[str | None] = mapped_column(String(500), nullable=True)
|
||||
|
||||
# Relationships
|
||||
strategy: Mapped[Strategy] = relationship(back_populates="weight_history")
|
||||
|
||||
|
||||
# Avoid circular import — TradeOutcome is defined in learning.py
|
||||
from shared.models.learning import TradeOutcome # noqa: E402, F401
|
||||
323
tests/test_models.py
Normal file
323
tests/test_models.py
Normal file
|
|
@ -0,0 +1,323 @@
|
|||
"""Tests for SQLAlchemy model instantiation, enums, and relationships."""
|
||||
|
||||
import uuid
|
||||
from datetime import datetime, timedelta, timezone
|
||||
|
||||
import pytest
|
||||
|
||||
from shared.models import (
|
||||
Base,
|
||||
TimestampMixin,
|
||||
# Trading
|
||||
Strategy,
|
||||
Signal,
|
||||
SignalDirection,
|
||||
Trade,
|
||||
TradeSide,
|
||||
TradeStatus,
|
||||
Position,
|
||||
StrategyWeightHistory,
|
||||
# News
|
||||
Article,
|
||||
ArticleSentiment,
|
||||
# Learning
|
||||
TradeOutcome,
|
||||
LearningAdjustment,
|
||||
# Auth
|
||||
User,
|
||||
UserCredential,
|
||||
# Timeseries
|
||||
MarketData,
|
||||
PortfolioSnapshot,
|
||||
StrategyMetric,
|
||||
)
|
||||
from shared.db import create_db
|
||||
from shared.config import BaseConfig
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Enum tests
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestEnums:
|
||||
def test_trade_side_values(self) -> None:
|
||||
assert TradeSide.BUY == "BUY"
|
||||
assert TradeSide.SELL == "SELL"
|
||||
assert set(TradeSide) == {TradeSide.BUY, TradeSide.SELL}
|
||||
|
||||
def test_trade_status_values(self) -> None:
|
||||
assert TradeStatus.PENDING == "PENDING"
|
||||
assert TradeStatus.FILLED == "FILLED"
|
||||
assert TradeStatus.CANCELLED == "CANCELLED"
|
||||
assert TradeStatus.REJECTED == "REJECTED"
|
||||
assert len(TradeStatus) == 4
|
||||
|
||||
def test_signal_direction_values(self) -> None:
|
||||
assert SignalDirection.LONG == "LONG"
|
||||
assert SignalDirection.SHORT == "SHORT"
|
||||
assert SignalDirection.NEUTRAL == "NEUTRAL"
|
||||
assert len(SignalDirection) == 3
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Model instantiation tests
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestStrategy:
|
||||
def test_create_strategy(self) -> None:
|
||||
s = Strategy(
|
||||
id=uuid.uuid4(),
|
||||
name="momentum",
|
||||
description="Trend-following strategy",
|
||||
current_weight=0.5,
|
||||
active=True,
|
||||
)
|
||||
assert s.name == "momentum"
|
||||
assert s.current_weight == 0.5
|
||||
assert s.active is True
|
||||
|
||||
def test_strategy_defaults(self) -> None:
|
||||
"""Without a DB session, Python-level defaults are not applied by SQLAlchemy.
|
||||
The column default is only used at INSERT time."""
|
||||
s = Strategy(name="test")
|
||||
assert s.description is None
|
||||
# Column-level default=True is applied by the database at INSERT time,
|
||||
# so in-memory the attribute is None until the row is flushed/refreshed.
|
||||
assert s.active is None or s.active is True
|
||||
|
||||
|
||||
class TestSignal:
|
||||
def test_create_signal(self) -> None:
|
||||
sig = Signal(
|
||||
id=uuid.uuid4(),
|
||||
ticker="AAPL",
|
||||
direction=SignalDirection.LONG,
|
||||
strength=0.85,
|
||||
strategy_sources={"momentum": 0.9},
|
||||
sentiment_score=0.7,
|
||||
acted_on=False,
|
||||
)
|
||||
assert sig.ticker == "AAPL"
|
||||
assert sig.direction == SignalDirection.LONG
|
||||
assert sig.strength == 0.85
|
||||
assert sig.acted_on is False
|
||||
|
||||
|
||||
class TestTrade:
|
||||
def test_create_trade(self) -> None:
|
||||
t = Trade(
|
||||
id=uuid.uuid4(),
|
||||
ticker="TSLA",
|
||||
side=TradeSide.BUY,
|
||||
qty=10.0,
|
||||
price=150.25,
|
||||
status=TradeStatus.FILLED,
|
||||
pnl=250.50,
|
||||
)
|
||||
assert t.ticker == "TSLA"
|
||||
assert t.side == TradeSide.BUY
|
||||
assert t.qty == 10.0
|
||||
assert t.price == 150.25
|
||||
assert t.status == TradeStatus.FILLED
|
||||
assert t.pnl == 250.50
|
||||
|
||||
|
||||
class TestPosition:
|
||||
def test_create_position(self) -> None:
|
||||
p = Position(
|
||||
id=uuid.uuid4(),
|
||||
ticker="GOOG",
|
||||
qty=5.0,
|
||||
avg_entry=2800.00,
|
||||
unrealized_pnl=-50.0,
|
||||
stop_loss=2750.0,
|
||||
take_profit=3000.0,
|
||||
)
|
||||
assert p.ticker == "GOOG"
|
||||
assert p.qty == 5.0
|
||||
assert p.stop_loss == 2750.0
|
||||
assert p.take_profit == 3000.0
|
||||
|
||||
|
||||
class TestStrategyWeightHistory:
|
||||
def test_create_weight_history(self) -> None:
|
||||
sid = uuid.uuid4()
|
||||
wh = StrategyWeightHistory(
|
||||
id=uuid.uuid4(),
|
||||
strategy_id=sid,
|
||||
old_weight=0.33,
|
||||
new_weight=0.40,
|
||||
reason="Improved win rate",
|
||||
)
|
||||
assert wh.strategy_id == sid
|
||||
assert wh.old_weight == 0.33
|
||||
assert wh.new_weight == 0.40
|
||||
|
||||
|
||||
class TestArticle:
|
||||
def test_create_article(self) -> None:
|
||||
now = datetime.now(timezone.utc)
|
||||
a = Article(
|
||||
id=uuid.uuid4(),
|
||||
source="reuters",
|
||||
url="https://reuters.com/article/1",
|
||||
title="Market Rally",
|
||||
published_at=now,
|
||||
fetched_at=now,
|
||||
content_hash="abc123def456",
|
||||
)
|
||||
assert a.source == "reuters"
|
||||
assert a.content_hash == "abc123def456"
|
||||
|
||||
|
||||
class TestArticleSentiment:
|
||||
def test_create_sentiment(self) -> None:
|
||||
asent = ArticleSentiment(
|
||||
id=uuid.uuid4(),
|
||||
article_id=uuid.uuid4(),
|
||||
ticker="AAPL",
|
||||
score=0.85,
|
||||
confidence=0.92,
|
||||
model_used="finbert",
|
||||
)
|
||||
assert asent.score == 0.85
|
||||
assert asent.model_used == "finbert"
|
||||
|
||||
|
||||
class TestTradeOutcome:
|
||||
def test_create_outcome(self) -> None:
|
||||
outcome = TradeOutcome(
|
||||
id=uuid.uuid4(),
|
||||
trade_id=uuid.uuid4(),
|
||||
hold_duration=timedelta(hours=4, minutes=30),
|
||||
realized_pnl=125.50,
|
||||
roi_pct=2.5,
|
||||
was_profitable=True,
|
||||
)
|
||||
assert outcome.realized_pnl == 125.50
|
||||
assert outcome.was_profitable is True
|
||||
assert outcome.hold_duration == timedelta(hours=4, minutes=30)
|
||||
|
||||
|
||||
class TestLearningAdjustment:
|
||||
def test_create_adjustment(self) -> None:
|
||||
adj = LearningAdjustment(
|
||||
id=uuid.uuid4(),
|
||||
strategy_id=uuid.uuid4(),
|
||||
old_weight=0.30,
|
||||
new_weight=0.35,
|
||||
reason="Positive reward signal",
|
||||
reward_signal=0.72,
|
||||
)
|
||||
assert adj.reward_signal == 0.72
|
||||
assert adj.reason == "Positive reward signal"
|
||||
|
||||
|
||||
class TestUser:
|
||||
def test_create_user(self) -> None:
|
||||
u = User(
|
||||
id=uuid.uuid4(),
|
||||
username="trader1",
|
||||
display_name="Top Trader",
|
||||
)
|
||||
assert u.username == "trader1"
|
||||
assert u.display_name == "Top Trader"
|
||||
|
||||
|
||||
class TestUserCredential:
|
||||
def test_create_credential(self) -> None:
|
||||
cred = UserCredential(
|
||||
id=uuid.uuid4(),
|
||||
user_id=uuid.uuid4(),
|
||||
credential_id="cred-abc-123",
|
||||
public_key=b"\x04abcdef",
|
||||
sign_count=5,
|
||||
)
|
||||
assert cred.credential_id == "cred-abc-123"
|
||||
assert cred.sign_count == 5
|
||||
assert cred.public_key == b"\x04abcdef"
|
||||
|
||||
|
||||
class TestMarketData:
|
||||
def test_create_market_data(self) -> None:
|
||||
now = datetime.now(timezone.utc)
|
||||
md = MarketData(
|
||||
timestamp=now,
|
||||
ticker="AAPL",
|
||||
open=150.0,
|
||||
high=155.0,
|
||||
low=149.0,
|
||||
close=153.0,
|
||||
volume=1_000_000.0,
|
||||
)
|
||||
assert md.ticker == "AAPL"
|
||||
assert md.close == 153.0
|
||||
|
||||
|
||||
class TestPortfolioSnapshot:
|
||||
def test_create_snapshot(self) -> None:
|
||||
now = datetime.now(timezone.utc)
|
||||
snap = PortfolioSnapshot(
|
||||
timestamp=now,
|
||||
total_value=100_000.0,
|
||||
cash=25_000.0,
|
||||
positions_value=75_000.0,
|
||||
daily_pnl=1_200.0,
|
||||
)
|
||||
assert snap.total_value == 100_000.0
|
||||
assert snap.daily_pnl == 1_200.0
|
||||
|
||||
|
||||
class TestStrategyMetric:
|
||||
def test_create_metric(self) -> None:
|
||||
now = datetime.now(timezone.utc)
|
||||
sm = StrategyMetric(
|
||||
timestamp=now,
|
||||
strategy_id=uuid.uuid4(),
|
||||
win_rate=0.65,
|
||||
total_pnl=5_432.10,
|
||||
trade_count=42,
|
||||
sharpe_ratio=1.8,
|
||||
)
|
||||
assert sm.win_rate == 0.65
|
||||
assert sm.trade_count == 42
|
||||
assert sm.sharpe_ratio == 1.8
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Metadata / Base tests
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestMetadata:
|
||||
def test_all_tables_registered(self) -> None:
|
||||
table_names = set(Base.metadata.tables.keys())
|
||||
expected = {
|
||||
"strategies",
|
||||
"signals",
|
||||
"trades",
|
||||
"positions",
|
||||
"strategy_weight_history",
|
||||
"articles",
|
||||
"article_sentiments",
|
||||
"trade_outcomes",
|
||||
"learning_adjustments",
|
||||
"users",
|
||||
"user_credentials",
|
||||
"market_data",
|
||||
"portfolio_snapshots",
|
||||
"strategy_metrics",
|
||||
}
|
||||
assert expected.issubset(table_names)
|
||||
|
||||
def test_timestamp_mixin_fields(self) -> None:
|
||||
"""TimestampMixin should contribute created_at and updated_at columns."""
|
||||
assert "created_at" in Strategy.__table__.columns
|
||||
assert "updated_at" in Strategy.__table__.columns
|
||||
|
||||
|
||||
class TestDbFactory:
|
||||
def test_create_db_returns_engine_and_session(self) -> None:
|
||||
config = BaseConfig()
|
||||
engine, session_factory = create_db(config)
|
||||
assert engine is not None
|
||||
assert session_factory is not None
|
||||
Loading…
Add table
Add a link
Reference in a new issue