Initial extraction from monorepo

This commit is contained in:
Viktor Barzin 2026-05-07 17:06:11 +00:00
commit 5c7baa8acc
20 changed files with 1974 additions and 0 deletions

0
hmrc_sync/__init__.py Normal file
View file

36
hmrc_sync/__main__.py Normal file
View file

@ -0,0 +1,36 @@
import logging
import os
import subprocess
import sys
import click
import uvicorn
@click.group()
def cli() -> None:
logging.basicConfig(level=os.environ.get("LOG_LEVEL", "INFO"))
@cli.command()
def serve() -> None:
"""Run the FastAPI server (K8s entrypoint)."""
uvicorn.run("hmrc_sync.app:app", host="0.0.0.0", port=8080)
@cli.command()
@click.option("--tax-year", default="current", help="Tax year to fetch, e.g. 2024-25 or 'current'.")
def sync(tax_year: str) -> None:
"""One-shot sync of HMRC figures — used by the CronJob."""
raise click.ClickException("Sync stub — implement after HMRC prod approval lands")
@cli.command()
def migrate() -> None:
"""Run `alembic upgrade head`."""
result = subprocess.run(["alembic", "upgrade", "head"], check=False)
sys.exit(result.returncode)
if __name__ == "__main__":
cli()

129
hmrc_sync/app.py Normal file
View file

@ -0,0 +1,129 @@
"""FastAPI entrypoint for hmrc-sync.
Endpoints:
- GET /authorize redirect to HMRC OAuth, primes refresh_token
- GET /callback OAuth callback; exchange code, persist token
- POST /callback-metadata browser-side session attributes (fraud headers)
- POST /sync pull latest HMRC figures for a given tax year
- GET /healthz readiness + queue depth
"""
from __future__ import annotations
import logging
import os
import secrets
import urllib.parse
from contextlib import asynccontextmanager
from typing import Any
from fastapi import FastAPI, HTTPException, Request
from fastapi.responses import HTMLResponse, RedirectResponse
from prometheus_fastapi_instrumentator import Instrumentator
from hmrc_sync import oauth
from hmrc_sync.fraud_headers import SessionContext
log = logging.getLogger(__name__)
REQUIRED_ENV = [
"HMRC_PROD_CLIENT_ID",
"HMRC_PROD_CLIENT_SECRET",
"HMRC_PROD_REDIRECT_URI",
"DB_CONNECTION_STRING",
]
def _verify_env() -> None:
missing = [k for k in REQUIRED_ENV if not os.environ.get(k)]
if missing:
raise RuntimeError(f"Missing required env vars: {', '.join(missing)}")
@asynccontextmanager
async def lifespan(app: FastAPI): # type: ignore[no-untyped-def]
_verify_env()
app.state.session_context = SessionContext(
device_id=os.environ.get("HMRC_DEVICE_ID", ""),
public_ip=os.environ.get("HMRC_VENDOR_PUBLIC_IP", ""),
)
app.state.oauth_states = {} # anti-CSRF state → expires_at
yield
app = FastAPI(title="HMRC Sync", lifespan=lifespan)
Instrumentator().instrument(app).expose(app, endpoint="/metrics")
@app.get("/healthz")
async def healthz() -> dict[str, Any]:
return {"status": "ok"}
@app.get("/authorize")
async def authorize() -> RedirectResponse:
creds = oauth.load_creds_from_env()
state = secrets.token_urlsafe(24)
app.state.oauth_states[state] = True
params = urllib.parse.urlencode({
"response_type": "code",
"client_id": creds.client_id,
"scope": "read:self-assessment",
"redirect_uri": creds.redirect_uri,
"state": state,
})
return RedirectResponse(f"{oauth.PROD_BASE}/oauth/authorize?{params}")
@app.get("/callback", response_class=HTMLResponse)
async def callback(code: str, state: str) -> HTMLResponse:
if state not in app.state.oauth_states:
raise HTTPException(status_code=400, detail="unknown state (CSRF)")
del app.state.oauth_states[state]
creds = oauth.load_creds_from_env()
token = await oauth.exchange_code(creds, code)
oauth.persist_to_vault(token)
# Serve a 1-page form that POSTs browser attributes to /callback-metadata
# so we capture the per-session values HMRC wants in fraud headers.
return HTMLResponse(_metadata_capture_html())
@app.post("/callback-metadata")
async def callback_metadata(request: Request) -> dict[str, str]:
body = await request.json()
session: SessionContext = app.state.session_context
session.user_agent = str(body.get("user_agent", "") or "")
session.screen_width = int(body.get("screen_width", 0) or 0)
session.screen_height = int(body.get("screen_height", 0) or 0)
session.screen_colour_depth = int(body.get("screen_colour_depth", 0) or 0)
session.window_width = int(body.get("window_width", 0) or 0)
session.window_height = int(body.get("window_height", 0) or 0)
session.timezone_offset = int(body.get("timezone_offset", 0) or 0)
return {"status": "captured"}
@app.post("/sync")
async def sync(tax_year: str | None = None) -> dict[str, Any]:
"""Pull latest HMRC figures for `tax_year` (default: current fiscal year)."""
raise HTTPException(status_code=501, detail="Sync not yet implemented — awaiting HMRC prod approval")
def _metadata_capture_html() -> str:
return """<!doctype html>
<html><head><title>hmrc-sync capturing session</title></head><body>
<h2>Capturing session attributes for HMRC fraud headers...</h2>
<script>
fetch('/callback-metadata', {
method: 'POST',
headers: {'Content-Type': 'application/json'},
body: JSON.stringify({
user_agent: navigator.userAgent,
screen_width: screen.width,
screen_height: screen.height,
screen_colour_depth: screen.colorDepth,
window_width: window.innerWidth,
window_height: window.innerHeight,
timezone_offset: -new Date().getTimezoneOffset()
})
}).then(() => document.body.innerHTML = '<h2>Done. You can close this tab.</h2>');
</script>
</body></html>"""

82
hmrc_sync/client.py Normal file
View file

@ -0,0 +1,82 @@
"""HMRC Individual Tax API v1.1 wrapper.
One method per endpoint we consume. Every request attaches the full fraud-
prevention header set built by `fraud_headers.build_headers()`.
Individual Tax API v1.1 returns tax-paid + income-breakdown figures per
employment per tax year exactly the ground-truth data we reconcile
against the payslip-ingest monthly aggregate.
"""
from __future__ import annotations
import logging
import time
from dataclasses import dataclass
from typing import Any
import httpx
from hmrc_sync.fraud_headers import SessionContext, build_headers
log = logging.getLogger(__name__)
PROD_BASE = "https://api.service.hmrc.gov.uk"
INDIVIDUAL_TAX_VERSION = "application/vnd.hmrc.1.1+json"
@dataclass
class HmrcResponse:
status_code: int
body: dict[str, Any]
duration_ms: int
request_id: str | None
correlation_id: str | None
fraud_headers_sent: dict[str, str]
class HmrcClient:
def __init__(self,
access_token: str,
session: SessionContext,
connection_method: str = "BATCH_PROCESS_DIRECT",
base_url: str = PROD_BASE):
self._access_token = access_token
self._session = session
self._connection_method = connection_method
self._base_url = base_url.rstrip("/")
async def individual_tax_summary(self, utr: str, tax_year: str) -> HmrcResponse:
"""GET /individuals/tax/sa/{utr}/summary/{taxYear}
`utr` is the 10-digit Self Assessment reference; tax_year format
is `YYYY-YY` (e.g. `2024-25`).
"""
path = f"/individuals/tax/sa/{utr}/summary/{tax_year}"
return await self._get(path)
async def _get(self, path: str) -> HmrcResponse:
fraud = build_headers(self._session, self._connection_method)
headers = {
"Accept": INDIVIDUAL_TAX_VERSION,
"Authorization": f"Bearer {self._access_token}",
}
headers.update(fraud)
started = time.perf_counter()
async with httpx.AsyncClient(timeout=30.0) as client:
resp = await client.get(f"{self._base_url}{path}", headers=headers)
duration_ms = int((time.perf_counter() - started) * 1000)
body: dict[str, Any]
try:
body = resp.json() if resp.content else {}
except ValueError:
body = {"raw": resp.text[:2000]}
log.info("hmrc %s status=%s duration=%dms", path, resp.status_code, duration_ms)
return HmrcResponse(
status_code=resp.status_code,
body=body,
duration_ms=duration_ms,
request_id=resp.headers.get("x-request-id"),
correlation_id=resp.headers.get("x-correlation-id"),
fraud_headers_sent=fraud,
)

70
hmrc_sync/db.py Normal file
View file

@ -0,0 +1,70 @@
import os
from datetime import datetime
from decimal import Decimal
from typing import Any
from sqlalchemy import JSON, TIMESTAMP, Integer, Numeric, String, text
from sqlalchemy.dialects.postgresql import JSONB
from sqlalchemy.ext.asyncio import AsyncEngine, async_sessionmaker, create_async_engine
from sqlalchemy.orm import DeclarativeBase, Mapped, mapped_column
SCHEMA_NAME = "hmrc_sync"
class Base(DeclarativeBase):
pass
JSON_TYPE = JSONB().with_variant(JSON(), "sqlite")
class TaxYearSnapshot(Base):
"""One row per (tax_year, employer_paye_ref, snapshot_date).
HMRC returns the `hmrc-held` view of annual PAYE/NI for a given
employment. Taking a daily snapshot lets us see HMRC's figures evolve
as late RTI filings land, and lets the dashboard always show the
latest value by snapshot_date.
"""
__tablename__ = "tax_year_snapshot"
__table_args__ = {"schema": SCHEMA_NAME} # noqa: RUF012
id: Mapped[int] = mapped_column(Integer, primary_key=True, autoincrement=True)
tax_year: Mapped[str] = mapped_column(String, nullable=False, index=True)
employer_paye_ref: Mapped[str] = mapped_column(String, nullable=False)
snapshot_date: Mapped[datetime] = mapped_column(TIMESTAMP(timezone=True), nullable=False)
gross_pay: Mapped[Decimal] = mapped_column(Numeric(12, 2), nullable=False)
income_tax: Mapped[Decimal] = mapped_column(Numeric(12, 2), nullable=False)
ni_contributions: Mapped[Decimal] = mapped_column(Numeric(12, 2), nullable=False)
source: Mapped[str] = mapped_column(String, nullable=False, server_default="hmrc-held")
raw_response: Mapped[dict[str, Any]] = mapped_column(JSON_TYPE, nullable=False)
fetched_at: Mapped[datetime] = mapped_column(TIMESTAMP(timezone=True),
nullable=False,
server_default=text("now()"))
class FetchLog(Base):
"""Audit trail of every HMRC API call — for fraud-header compliance review."""
__tablename__ = "fetch_log"
__table_args__ = {"schema": SCHEMA_NAME} # noqa: RUF012
id: Mapped[int] = mapped_column(Integer, primary_key=True, autoincrement=True)
endpoint: Mapped[str] = mapped_column(String, nullable=False)
status_code: Mapped[int] = mapped_column(Integer, nullable=False)
request_id: Mapped[str | None] = mapped_column(String, nullable=True)
correlation_id: Mapped[str | None] = mapped_column(String, nullable=True)
fraud_headers_sent: Mapped[dict[str, Any]] = mapped_column(JSON_TYPE, nullable=False)
response_snippet: Mapped[str | None] = mapped_column(String, nullable=True)
duration_ms: Mapped[int] = mapped_column(Integer, nullable=False)
fetched_at: Mapped[datetime] = mapped_column(TIMESTAMP(timezone=True),
nullable=False,
server_default=text("now()"))
def create_engine_from_env() -> AsyncEngine:
url = os.environ["DB_CONNECTION_STRING"]
return create_async_engine(url, pool_pre_ping=True)
def make_session_factory(engine: AsyncEngine) -> async_sessionmaker[Any]:
return async_sessionmaker(engine, expire_on_commit=False)

341
hmrc_sync/fraud_headers.py Normal file
View file

@ -0,0 +1,341 @@
"""Build HMRC MTD fraud-prevention headers (Gov-Client-* / Gov-Vendor-*).
HMRC's BATCH_PROCESS_DIRECT connection method (what our CronJob uses)
mandates 11 headers on every MTD API call; WEB_APP_VIA_SERVER adds a
handful of browser-derived fields. Shipping without these risks fines
and API-access revocation per the HMRC fraud-prevention guide.
Layout:
- **Static** vendor-constant across runs (product name/version,
hashed license id).
- **Runtime** collected at module load from the pod's own network
stack + OS: MAC addresses, local IPs, OS family/version, device model.
- **Per-request** built at call time (timestamps, request ids).
- **Per-session** captured from the browser on `/callback-metadata`
(screen dimensions, public IP, MFA timestamp). Only WEB_APP_VIA_SERVER.
The public entry point is `build_headers(session, connection_method)`.
Run `tests/test_fraud_headers.py::test_headers_pass_hmrc_validator`
with `HMRC_VALIDATOR=1` to verify against the HMRC sandbox validator.
Spec references:
https://developer.service.hmrc.gov.uk/guides/fraud-prevention/
https://developer.service.hmrc.gov.uk/guides/fraud-prevention/connection-method/batch-process-direct/
https://developer.service.hmrc.gov.uk/api-documentation/docs/api/service/txm-fph-validator-api/1.0
"""
from __future__ import annotations
import getpass
import hashlib
import logging
import os
import platform
import secrets
import socket
import time
import urllib.parse
import uuid
from dataclasses import dataclass, field
from pathlib import Path
from typing import Any
log = logging.getLogger(__name__)
VENDOR_PRODUCT_NAME = "hmrc-sync"
VENDOR_PRODUCT_VERSION = "0.1.0"
# Self-assigned for a personal single-user tool. HMRC permits arbitrary
# vendor strings; the header value is then SHA-256-hashed per spec
# (`Gov-Vendor-License-IDs: <name>=<hashed-value>`).
VENDOR_LICENSE_ID = os.environ.get("HMRC_VENDOR_LICENSE_ID",
"hmrc-sync-private-single-user")
VENDOR_PUBLIC_IP = os.environ.get("HMRC_VENDOR_PUBLIC_IP", "")
# Valid HMRC connection-method enum values.
CONNECTION_METHOD_BATCH = "BATCH_PROCESS_DIRECT"
CONNECTION_METHOD_WEB_APP = "WEB_APP_VIA_SERVER"
CONNECTION_METHOD_MFA = "AUTH_USING_MFA"
_NET_CLASS = Path("/sys/class/net")
_EMPTY_MAC = "00:00:00:00:00:00"
@dataclass
class SessionContext:
"""Browser-side attributes captured on the `/callback-metadata` POST.
Only relevant for WEB_APP_VIA_SERVER flows (browser-initiated OAuth
+ server-side API calls). BATCH_PROCESS_DIRECT flows derive their
context from `RuntimeContext` (see below) without touching these.
"""
user_agent: str = ""
screen_width: int = 0
screen_height: int = 0
screen_colour_depth: int = 0
window_width: int = 0
window_height: int = 0
timezone_offset: int = 0
device_id: str = ""
mfa_timestamp: str = ""
public_ip: str = ""
public_port: int = 0
@dataclass
class RuntimeContext:
"""Pod-side environment values required on every API call.
Collected once at module load (cheap all local syscalls). If any
field is empty, the header emitter falls back to safe defaults so
the call never goes out with an empty mandatory header.
"""
mac_addresses: list[str] = field(default_factory=list)
local_ips: list[str] = field(default_factory=list)
os_family: str = ""
os_version: str = ""
device_manufacturer: str = "Kubernetes"
device_model: str = ""
os_user: str = ""
def _collect_mac_addresses() -> list[str]:
"""Read every non-loopback interface MAC from `/sys/class/net/*/address`.
Colons are kept raw; `_format_mac_list` percent-encodes on output per spec.
"""
out: list[str] = []
if not _NET_CLASS.exists():
return out
for iface in sorted(_NET_CLASS.iterdir()):
if iface.name == "lo":
continue
addr_file = iface / "address"
if not addr_file.exists():
continue
try:
mac = addr_file.read_text().strip()
except OSError:
continue
if mac and mac != _EMPTY_MAC:
out.append(mac)
return out
def _collect_local_ips() -> list[str]:
"""Every IP bound to this host — IPv4 + IPv6, loopback excluded."""
ips: set[str] = set()
try:
hostname = socket.gethostname()
for family, _, _, _, sockaddr in socket.getaddrinfo(hostname, None):
raw = sockaddr[0]
if not isinstance(raw, str):
continue
if family == socket.AF_INET and not raw.startswith("127."):
ips.add(raw)
elif family == socket.AF_INET6 and not raw.startswith("::1"):
ips.add(raw.split("%")[0]) # strip zone id
except (socket.gaierror, OSError):
pass
# Also grab the primary outbound IP — `getaddrinfo(hostname)` can miss
# it inside containers whose hostname has no DNS entry.
try:
with socket.socket(socket.AF_INET, socket.SOCK_DGRAM) as s:
s.connect(("10.255.255.255", 1))
ips.add(s.getsockname()[0])
except OSError:
pass
return sorted(ips)
def _detect_runtime_context() -> RuntimeContext:
uname = platform.uname()
return RuntimeContext(
mac_addresses=_collect_mac_addresses(),
local_ips=_collect_local_ips(),
os_family=uname.system or "Linux",
os_version=uname.release or "unknown",
device_manufacturer="Kubernetes",
device_model=uname.node or socket.gethostname() or "pod",
os_user=_safe_getuser(),
)
def _safe_getuser() -> str:
try:
return getpass.getuser()
except (KeyError, OSError):
return os.environ.get("USER", "unknown")
RUNTIME_CONTEXT: RuntimeContext = _detect_runtime_context()
def build_headers(session: SessionContext | None = None,
connection_method: str = CONNECTION_METHOD_BATCH,
runtime: RuntimeContext | None = None) -> dict[str, str]:
"""Return the full header dict to attach to every HMRC API call.
Defaults to BATCH_PROCESS_DIRECT the mode the CronJob uses. Pass
a populated `SessionContext` + `connection_method=WEB_APP_VIA_SERVER`
for browser-initiated flows; the browser-only fields layer on top.
"""
session = session or SessionContext()
rt = runtime or RUNTIME_CONTEXT
headers: dict[str, str] = {}
headers.update(_static_headers())
headers.update(_per_request_headers())
headers.update(_mandatory_runtime_headers(rt, session, connection_method))
if connection_method == CONNECTION_METHOD_WEB_APP:
headers.update(_web_app_session_headers(session))
if connection_method == CONNECTION_METHOD_MFA and session.mfa_timestamp:
headers["Gov-Client-MFA-Timestamp"] = session.mfa_timestamp
return headers
def _static_headers() -> dict[str, str]:
"""Vendor-constant identity headers that apply to every connection method.
Product-Name is percent-encoded per spec; License-IDs value is SHA-256-
hashed per spec; Version is a key-value pair of `<software-name>=<version>`.
Gov-Vendor-Public-IP and Gov-Vendor-Forwarded are NOT emitted here the
HMRC validator rejects them for BATCH_PROCESS_DIRECT (where no vendor
server sits between the client and the HMRC API). They're added in
`_web_app_session_headers` for the WEB_APP_VIA_SERVER path only.
"""
license_hash = hashlib.sha256(VENDOR_LICENSE_ID.encode()).hexdigest()
return {
"Gov-Vendor-Product-Name": _pct(VENDOR_PRODUCT_NAME),
"Gov-Vendor-Version": f"{VENDOR_PRODUCT_NAME}={VENDOR_PRODUCT_VERSION}",
"Gov-Vendor-License-IDs": f"{VENDOR_PRODUCT_NAME}={license_hash}",
}
def _per_request_headers() -> dict[str, str]:
"""Per-call trace + timestamp headers. Local-IPs-Timestamp uses HMRC's
exact format `yyyy-MM-ddThh:mm:ss.sssZ` always UTC, always millis."""
now_ms = int(time.time() * 1000)
iso_ms = time.strftime("%Y-%m-%dT%H:%M:%S", time.gmtime(now_ms / 1000))
now_iso = f"{iso_ms}.{now_ms % 1000:03d}Z"
return {
"Gov-Client-Timezone": "UTC+00:00",
"Gov-Client-Local-IPs-Timestamp": now_iso,
"x-correlation-id": str(uuid.uuid4()),
"x-request-id": secrets.token_hex(16),
}
def _mandatory_runtime_headers(rt: RuntimeContext, session: SessionContext,
connection_method: str) -> dict[str, str]:
"""The 8 headers mandatory for BATCH_PROCESS_DIRECT that come from the
host Connection-Method, Device-ID, User-IDs, User-Agent, Local-IPs,
MAC-Addresses (+ Timezone and Local-IPs-Timestamp live in
`_per_request_headers`)."""
return {
"Gov-Client-Connection-Method": connection_method,
"Gov-Client-Device-ID": session.device_id or _fallback_device_id(),
"Gov-Client-User-IDs": _user_ids(rt, session),
"Gov-Client-User-Agent": _user_agent(rt, session),
"Gov-Client-Local-IPs": _format_ip_list(rt.local_ips),
"Gov-Client-MAC-Addresses": _format_mac_list(rt.mac_addresses),
}
def _web_app_session_headers(session: SessionContext) -> dict[str, str]:
"""WEB_APP_VIA_SERVER-only headers — browser context + vendor hop trail.
Gov-Vendor-Public-IP and Gov-Vendor-Forwarded describe the vendor server
that sits between the user's browser and HMRC — only meaningful for
WEB_APP_VIA_SERVER. BATCH_PROCESS_DIRECT must omit them (validator
rejects them there).
"""
out: dict[str, str] = {}
if session.screen_width and session.screen_height:
out["Gov-Client-Screens"] = (
f"width={session.screen_width}&height={session.screen_height}"
f"&scaling-factor=1&colour-depth={session.screen_colour_depth}")
if session.window_width and session.window_height:
out["Gov-Client-Window-Size"] = (f"width={session.window_width}&"
f"height={session.window_height}")
if session.public_ip:
out["Gov-Client-Public-IP"] = session.public_ip
if session.public_port:
out["Gov-Client-Public-Port"] = str(session.public_port)
vendor_ip = VENDOR_PUBLIC_IP or (RUNTIME_CONTEXT.local_ips[0] if RUNTIME_CONTEXT.local_ips
else "")
if vendor_ip:
out["Gov-Vendor-Public-IP"] = vendor_ip
out["Gov-Vendor-Forwarded"] = f"by={vendor_ip}&for={vendor_ip}"
return out
def _user_ids(rt: RuntimeContext, session: SessionContext) -> str:
"""Per spec: `os=<device-user>&<app>=<app-user>`. The `os=` field is
mandatory; we additionally tag our application with the OAuth user
so HMRC can correlate activity in a breach investigation.
"""
os_user = rt.os_user or "unknown"
pairs = [f"os={_pct(os_user)}"]
oauth_user = os.environ.get("HMRC_OAUTH_USER", "viktor")
pairs.append(f"hmrc-sync={_pct(oauth_user)}")
_ = session # reserved for future per-session identity extension
return "&".join(pairs)
def _user_agent(rt: RuntimeContext, session: SessionContext) -> str:
"""Per spec: `os-family=…&os-version=…&device-manufacturer=…&device-model=…`.
For WEB_APP_VIA_SERVER with a captured browser UA, the browser string
is encoded under `device-model` with the rest of the fields defaulting
to our pod's values — HMRC's validator accepts this hybrid shape.
"""
model = session.user_agent or rt.device_model or "pod"
pairs = [
f"os-family={_pct(rt.os_family)}",
f"os-version={_pct(rt.os_version)}",
f"device-manufacturer={_pct(rt.device_manufacturer)}",
f"device-model={_pct(model)}",
]
return "&".join(pairs)
def _format_ip_list(ips: list[str]) -> str:
"""IPv6 addresses percent-encoded; IPv4 passes through. Joined with ','.
HMRC's validator accepts an empty header only if the request truly
has no IPs; on a live pod we always have at least one if the list
comes back empty we fall back to the loopback so the header is
syntactically valid.
"""
if not ips:
return "127.0.0.1"
out = []
for ip in ips:
out.append(_pct(ip) if ":" in ip else ip)
return ",".join(out)
def _format_mac_list(macs: list[str]) -> str:
"""Each MAC percent-encoded (colons → %3A), comma-joined.
Empty list single dummy MAC so we never ship a blank header;
HMRC's validator treats blank as a violation.
"""
if not macs:
return _pct("02:00:00:00:00:00")
return ",".join(_pct(m) for m in macs)
def _fallback_device_id() -> str:
"""Deterministic UUID derived from hostname when no Vault-backed
Device-ID is seeded. Stable across restarts on the same node."""
return str(uuid.uuid5(uuid.NAMESPACE_DNS, f"hmrc-sync-{platform.node()}"))
def _pct(s: str) -> str:
return urllib.parse.quote(s, safe="")
def as_validator_payload(headers: dict[str, str]) -> dict[str, Any]:
"""Reshape headers for the HMRC fraud-header validator API body."""
return {"headers": [{"name": k, "value": v} for k, v in headers.items()]}

125
hmrc_sync/oauth.py Normal file
View file

@ -0,0 +1,125 @@
"""HMRC OAuth token persistence — Vault-backed refresh-token store.
The refresh_token is long-lived (HMRC grants 18 months). We keep it in
Vault at `secret/viktor/hmrc_refresh_token` and let ESO sync it to a K8s
Secret the pod mounts as an env var. On every refresh, we write the new
token back to Vault so the next pod restart picks it up.
Writing back requires Vault write access the pod uses a short-lived
K8s-auth Vault token with a narrow policy that only allows writing
`secret/viktor/hmrc_refresh_token`.
"""
from __future__ import annotations
import logging
import os
from dataclasses import dataclass
import httpx
log = logging.getLogger(__name__)
VAULT_KEY = "secret/viktor/hmrc_refresh_token"
PROD_BASE = "https://api.service.hmrc.gov.uk"
TOKEN_PATH = "/oauth/token"
@dataclass(frozen=True)
class OAuthCreds:
client_id: str
client_secret: str
redirect_uri: str
@dataclass
class TokenBundle:
access_token: str
refresh_token: str
expires_in: int
scope: str
@classmethod
def from_json(cls, data: dict[str, object]) -> TokenBundle:
return cls(
access_token=str(data["access_token"]),
refresh_token=str(data["refresh_token"]),
expires_in=int(data["expires_in"]), # type: ignore[arg-type]
scope=str(data.get("scope", "")),
)
def load_creds_from_env() -> OAuthCreds:
return OAuthCreds(
client_id=os.environ["HMRC_PROD_CLIENT_ID"],
client_secret=os.environ["HMRC_PROD_CLIENT_SECRET"],
redirect_uri=os.environ["HMRC_PROD_REDIRECT_URI"],
)
async def exchange_code(creds: OAuthCreds, code: str) -> TokenBundle:
"""Swap a fresh authorization_code for an access+refresh token pair."""
async with httpx.AsyncClient(timeout=30.0) as client:
resp = await client.post(
f"{PROD_BASE}{TOKEN_PATH}",
data={
"grant_type": "authorization_code",
"client_id": creds.client_id,
"client_secret": creds.client_secret,
"redirect_uri": creds.redirect_uri,
"code": code,
},
headers={"Accept": "application/vnd.hmrc.1.0+json"},
)
resp.raise_for_status()
return TokenBundle.from_json(resp.json())
async def refresh(creds: OAuthCreds, refresh_token: str) -> TokenBundle:
"""Exchange an old refresh_token for a fresh access+refresh pair.
HMRC rotates the refresh_token on every refresh the old one becomes
invalid immediately after this call returns. Persist the new one to
Vault atomically; a failure between the refresh and the Vault write
leaves us stranded.
"""
async with httpx.AsyncClient(timeout=30.0) as client:
resp = await client.post(
f"{PROD_BASE}{TOKEN_PATH}",
data={
"grant_type": "refresh_token",
"client_id": creds.client_id,
"client_secret": creds.client_secret,
"refresh_token": refresh_token,
},
headers={"Accept": "application/vnd.hmrc.1.0+json"},
)
resp.raise_for_status()
return TokenBundle.from_json(resp.json())
def persist_to_vault(token: TokenBundle) -> None:
"""Write the new refresh_token back to Vault.
Uses the hvac client with K8s-auth the pod's service-account token
at /var/run/secrets/kubernetes.io/serviceaccount/token logs into
Vault's kubernetes auth method and receives a short-lived Vault token
with write access to `secret/viktor/hmrc_refresh_token` only.
"""
import hvac
addr = os.environ.get("VAULT_ADDR", "https://vault.viktorbarzin.me")
role = os.environ.get("VAULT_K8S_ROLE", "hmrc-sync")
jwt_path = "/var/run/secrets/kubernetes.io/serviceaccount/token"
with open(jwt_path, encoding="utf-8") as fh:
jwt = fh.read()
client = hvac.Client(url=addr)
client.auth.kubernetes.login(role=role, jwt=jwt)
client.secrets.kv.v2.create_or_update_secret(
path="viktor/hmrc_refresh_token",
secret={
"refresh_token": token.refresh_token,
"expires_in": token.expires_in,
"scope": token.scope,
},
)
log.info("Rotated HMRC refresh_token persisted to Vault")