Initial extraction from monorepo
This commit is contained in:
commit
5c7baa8acc
20 changed files with 1974 additions and 0 deletions
0
hmrc_sync/__init__.py
Normal file
0
hmrc_sync/__init__.py
Normal file
36
hmrc_sync/__main__.py
Normal file
36
hmrc_sync/__main__.py
Normal file
|
|
@ -0,0 +1,36 @@
|
|||
import logging
|
||||
import os
|
||||
import subprocess
|
||||
import sys
|
||||
|
||||
import click
|
||||
import uvicorn
|
||||
|
||||
|
||||
@click.group()
|
||||
def cli() -> None:
|
||||
logging.basicConfig(level=os.environ.get("LOG_LEVEL", "INFO"))
|
||||
|
||||
|
||||
@cli.command()
|
||||
def serve() -> None:
|
||||
"""Run the FastAPI server (K8s entrypoint)."""
|
||||
uvicorn.run("hmrc_sync.app:app", host="0.0.0.0", port=8080)
|
||||
|
||||
|
||||
@cli.command()
|
||||
@click.option("--tax-year", default="current", help="Tax year to fetch, e.g. 2024-25 or 'current'.")
|
||||
def sync(tax_year: str) -> None:
|
||||
"""One-shot sync of HMRC figures — used by the CronJob."""
|
||||
raise click.ClickException("Sync stub — implement after HMRC prod approval lands")
|
||||
|
||||
|
||||
@cli.command()
|
||||
def migrate() -> None:
|
||||
"""Run `alembic upgrade head`."""
|
||||
result = subprocess.run(["alembic", "upgrade", "head"], check=False)
|
||||
sys.exit(result.returncode)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
cli()
|
||||
129
hmrc_sync/app.py
Normal file
129
hmrc_sync/app.py
Normal file
|
|
@ -0,0 +1,129 @@
|
|||
"""FastAPI entrypoint for hmrc-sync.
|
||||
|
||||
Endpoints:
|
||||
- GET /authorize — redirect to HMRC OAuth, primes refresh_token
|
||||
- GET /callback — OAuth callback; exchange code, persist token
|
||||
- POST /callback-metadata — browser-side session attributes (fraud headers)
|
||||
- POST /sync — pull latest HMRC figures for a given tax year
|
||||
- GET /healthz — readiness + queue depth
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
import os
|
||||
import secrets
|
||||
import urllib.parse
|
||||
from contextlib import asynccontextmanager
|
||||
from typing import Any
|
||||
|
||||
from fastapi import FastAPI, HTTPException, Request
|
||||
from fastapi.responses import HTMLResponse, RedirectResponse
|
||||
from prometheus_fastapi_instrumentator import Instrumentator
|
||||
|
||||
from hmrc_sync import oauth
|
||||
from hmrc_sync.fraud_headers import SessionContext
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
|
||||
REQUIRED_ENV = [
|
||||
"HMRC_PROD_CLIENT_ID",
|
||||
"HMRC_PROD_CLIENT_SECRET",
|
||||
"HMRC_PROD_REDIRECT_URI",
|
||||
"DB_CONNECTION_STRING",
|
||||
]
|
||||
|
||||
|
||||
def _verify_env() -> None:
|
||||
missing = [k for k in REQUIRED_ENV if not os.environ.get(k)]
|
||||
if missing:
|
||||
raise RuntimeError(f"Missing required env vars: {', '.join(missing)}")
|
||||
|
||||
|
||||
@asynccontextmanager
|
||||
async def lifespan(app: FastAPI): # type: ignore[no-untyped-def]
|
||||
_verify_env()
|
||||
app.state.session_context = SessionContext(
|
||||
device_id=os.environ.get("HMRC_DEVICE_ID", ""),
|
||||
public_ip=os.environ.get("HMRC_VENDOR_PUBLIC_IP", ""),
|
||||
)
|
||||
app.state.oauth_states = {} # anti-CSRF state → expires_at
|
||||
yield
|
||||
|
||||
|
||||
app = FastAPI(title="HMRC Sync", lifespan=lifespan)
|
||||
Instrumentator().instrument(app).expose(app, endpoint="/metrics")
|
||||
|
||||
|
||||
@app.get("/healthz")
|
||||
async def healthz() -> dict[str, Any]:
|
||||
return {"status": "ok"}
|
||||
|
||||
|
||||
@app.get("/authorize")
|
||||
async def authorize() -> RedirectResponse:
|
||||
creds = oauth.load_creds_from_env()
|
||||
state = secrets.token_urlsafe(24)
|
||||
app.state.oauth_states[state] = True
|
||||
params = urllib.parse.urlencode({
|
||||
"response_type": "code",
|
||||
"client_id": creds.client_id,
|
||||
"scope": "read:self-assessment",
|
||||
"redirect_uri": creds.redirect_uri,
|
||||
"state": state,
|
||||
})
|
||||
return RedirectResponse(f"{oauth.PROD_BASE}/oauth/authorize?{params}")
|
||||
|
||||
|
||||
@app.get("/callback", response_class=HTMLResponse)
|
||||
async def callback(code: str, state: str) -> HTMLResponse:
|
||||
if state not in app.state.oauth_states:
|
||||
raise HTTPException(status_code=400, detail="unknown state (CSRF)")
|
||||
del app.state.oauth_states[state]
|
||||
creds = oauth.load_creds_from_env()
|
||||
token = await oauth.exchange_code(creds, code)
|
||||
oauth.persist_to_vault(token)
|
||||
# Serve a 1-page form that POSTs browser attributes to /callback-metadata
|
||||
# so we capture the per-session values HMRC wants in fraud headers.
|
||||
return HTMLResponse(_metadata_capture_html())
|
||||
|
||||
|
||||
@app.post("/callback-metadata")
|
||||
async def callback_metadata(request: Request) -> dict[str, str]:
|
||||
body = await request.json()
|
||||
session: SessionContext = app.state.session_context
|
||||
session.user_agent = str(body.get("user_agent", "") or "")
|
||||
session.screen_width = int(body.get("screen_width", 0) or 0)
|
||||
session.screen_height = int(body.get("screen_height", 0) or 0)
|
||||
session.screen_colour_depth = int(body.get("screen_colour_depth", 0) or 0)
|
||||
session.window_width = int(body.get("window_width", 0) or 0)
|
||||
session.window_height = int(body.get("window_height", 0) or 0)
|
||||
session.timezone_offset = int(body.get("timezone_offset", 0) or 0)
|
||||
return {"status": "captured"}
|
||||
|
||||
|
||||
@app.post("/sync")
|
||||
async def sync(tax_year: str | None = None) -> dict[str, Any]:
|
||||
"""Pull latest HMRC figures for `tax_year` (default: current fiscal year)."""
|
||||
raise HTTPException(status_code=501, detail="Sync not yet implemented — awaiting HMRC prod approval")
|
||||
|
||||
|
||||
def _metadata_capture_html() -> str:
|
||||
return """<!doctype html>
|
||||
<html><head><title>hmrc-sync — capturing session</title></head><body>
|
||||
<h2>Capturing session attributes for HMRC fraud headers...</h2>
|
||||
<script>
|
||||
fetch('/callback-metadata', {
|
||||
method: 'POST',
|
||||
headers: {'Content-Type': 'application/json'},
|
||||
body: JSON.stringify({
|
||||
user_agent: navigator.userAgent,
|
||||
screen_width: screen.width,
|
||||
screen_height: screen.height,
|
||||
screen_colour_depth: screen.colorDepth,
|
||||
window_width: window.innerWidth,
|
||||
window_height: window.innerHeight,
|
||||
timezone_offset: -new Date().getTimezoneOffset()
|
||||
})
|
||||
}).then(() => document.body.innerHTML = '<h2>Done. You can close this tab.</h2>');
|
||||
</script>
|
||||
</body></html>"""
|
||||
82
hmrc_sync/client.py
Normal file
82
hmrc_sync/client.py
Normal file
|
|
@ -0,0 +1,82 @@
|
|||
"""HMRC Individual Tax API v1.1 wrapper.
|
||||
|
||||
One method per endpoint we consume. Every request attaches the full fraud-
|
||||
prevention header set built by `fraud_headers.build_headers()`.
|
||||
|
||||
Individual Tax API v1.1 returns tax-paid + income-breakdown figures per
|
||||
employment per tax year — exactly the ground-truth data we reconcile
|
||||
against the payslip-ingest monthly aggregate.
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
import time
|
||||
from dataclasses import dataclass
|
||||
from typing import Any
|
||||
|
||||
import httpx
|
||||
|
||||
from hmrc_sync.fraud_headers import SessionContext, build_headers
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
|
||||
PROD_BASE = "https://api.service.hmrc.gov.uk"
|
||||
INDIVIDUAL_TAX_VERSION = "application/vnd.hmrc.1.1+json"
|
||||
|
||||
|
||||
@dataclass
|
||||
class HmrcResponse:
|
||||
status_code: int
|
||||
body: dict[str, Any]
|
||||
duration_ms: int
|
||||
request_id: str | None
|
||||
correlation_id: str | None
|
||||
fraud_headers_sent: dict[str, str]
|
||||
|
||||
|
||||
class HmrcClient:
|
||||
|
||||
def __init__(self,
|
||||
access_token: str,
|
||||
session: SessionContext,
|
||||
connection_method: str = "BATCH_PROCESS_DIRECT",
|
||||
base_url: str = PROD_BASE):
|
||||
self._access_token = access_token
|
||||
self._session = session
|
||||
self._connection_method = connection_method
|
||||
self._base_url = base_url.rstrip("/")
|
||||
|
||||
async def individual_tax_summary(self, utr: str, tax_year: str) -> HmrcResponse:
|
||||
"""GET /individuals/tax/sa/{utr}/summary/{taxYear}
|
||||
|
||||
`utr` is the 10-digit Self Assessment reference; tax_year format
|
||||
is `YYYY-YY` (e.g. `2024-25`).
|
||||
"""
|
||||
path = f"/individuals/tax/sa/{utr}/summary/{tax_year}"
|
||||
return await self._get(path)
|
||||
|
||||
async def _get(self, path: str) -> HmrcResponse:
|
||||
fraud = build_headers(self._session, self._connection_method)
|
||||
headers = {
|
||||
"Accept": INDIVIDUAL_TAX_VERSION,
|
||||
"Authorization": f"Bearer {self._access_token}",
|
||||
}
|
||||
headers.update(fraud)
|
||||
started = time.perf_counter()
|
||||
async with httpx.AsyncClient(timeout=30.0) as client:
|
||||
resp = await client.get(f"{self._base_url}{path}", headers=headers)
|
||||
duration_ms = int((time.perf_counter() - started) * 1000)
|
||||
body: dict[str, Any]
|
||||
try:
|
||||
body = resp.json() if resp.content else {}
|
||||
except ValueError:
|
||||
body = {"raw": resp.text[:2000]}
|
||||
log.info("hmrc %s status=%s duration=%dms", path, resp.status_code, duration_ms)
|
||||
return HmrcResponse(
|
||||
status_code=resp.status_code,
|
||||
body=body,
|
||||
duration_ms=duration_ms,
|
||||
request_id=resp.headers.get("x-request-id"),
|
||||
correlation_id=resp.headers.get("x-correlation-id"),
|
||||
fraud_headers_sent=fraud,
|
||||
)
|
||||
70
hmrc_sync/db.py
Normal file
70
hmrc_sync/db.py
Normal file
|
|
@ -0,0 +1,70 @@
|
|||
import os
|
||||
from datetime import datetime
|
||||
from decimal import Decimal
|
||||
from typing import Any
|
||||
|
||||
from sqlalchemy import JSON, TIMESTAMP, Integer, Numeric, String, text
|
||||
from sqlalchemy.dialects.postgresql import JSONB
|
||||
from sqlalchemy.ext.asyncio import AsyncEngine, async_sessionmaker, create_async_engine
|
||||
from sqlalchemy.orm import DeclarativeBase, Mapped, mapped_column
|
||||
|
||||
SCHEMA_NAME = "hmrc_sync"
|
||||
|
||||
|
||||
class Base(DeclarativeBase):
|
||||
pass
|
||||
|
||||
|
||||
JSON_TYPE = JSONB().with_variant(JSON(), "sqlite")
|
||||
|
||||
|
||||
class TaxYearSnapshot(Base):
|
||||
"""One row per (tax_year, employer_paye_ref, snapshot_date).
|
||||
|
||||
HMRC returns the `hmrc-held` view of annual PAYE/NI for a given
|
||||
employment. Taking a daily snapshot lets us see HMRC's figures evolve
|
||||
as late RTI filings land, and lets the dashboard always show the
|
||||
latest value by snapshot_date.
|
||||
"""
|
||||
__tablename__ = "tax_year_snapshot"
|
||||
__table_args__ = {"schema": SCHEMA_NAME} # noqa: RUF012
|
||||
|
||||
id: Mapped[int] = mapped_column(Integer, primary_key=True, autoincrement=True)
|
||||
tax_year: Mapped[str] = mapped_column(String, nullable=False, index=True)
|
||||
employer_paye_ref: Mapped[str] = mapped_column(String, nullable=False)
|
||||
snapshot_date: Mapped[datetime] = mapped_column(TIMESTAMP(timezone=True), nullable=False)
|
||||
gross_pay: Mapped[Decimal] = mapped_column(Numeric(12, 2), nullable=False)
|
||||
income_tax: Mapped[Decimal] = mapped_column(Numeric(12, 2), nullable=False)
|
||||
ni_contributions: Mapped[Decimal] = mapped_column(Numeric(12, 2), nullable=False)
|
||||
source: Mapped[str] = mapped_column(String, nullable=False, server_default="hmrc-held")
|
||||
raw_response: Mapped[dict[str, Any]] = mapped_column(JSON_TYPE, nullable=False)
|
||||
fetched_at: Mapped[datetime] = mapped_column(TIMESTAMP(timezone=True),
|
||||
nullable=False,
|
||||
server_default=text("now()"))
|
||||
|
||||
|
||||
class FetchLog(Base):
|
||||
"""Audit trail of every HMRC API call — for fraud-header compliance review."""
|
||||
__tablename__ = "fetch_log"
|
||||
__table_args__ = {"schema": SCHEMA_NAME} # noqa: RUF012
|
||||
|
||||
id: Mapped[int] = mapped_column(Integer, primary_key=True, autoincrement=True)
|
||||
endpoint: Mapped[str] = mapped_column(String, nullable=False)
|
||||
status_code: Mapped[int] = mapped_column(Integer, nullable=False)
|
||||
request_id: Mapped[str | None] = mapped_column(String, nullable=True)
|
||||
correlation_id: Mapped[str | None] = mapped_column(String, nullable=True)
|
||||
fraud_headers_sent: Mapped[dict[str, Any]] = mapped_column(JSON_TYPE, nullable=False)
|
||||
response_snippet: Mapped[str | None] = mapped_column(String, nullable=True)
|
||||
duration_ms: Mapped[int] = mapped_column(Integer, nullable=False)
|
||||
fetched_at: Mapped[datetime] = mapped_column(TIMESTAMP(timezone=True),
|
||||
nullable=False,
|
||||
server_default=text("now()"))
|
||||
|
||||
|
||||
def create_engine_from_env() -> AsyncEngine:
|
||||
url = os.environ["DB_CONNECTION_STRING"]
|
||||
return create_async_engine(url, pool_pre_ping=True)
|
||||
|
||||
|
||||
def make_session_factory(engine: AsyncEngine) -> async_sessionmaker[Any]:
|
||||
return async_sessionmaker(engine, expire_on_commit=False)
|
||||
341
hmrc_sync/fraud_headers.py
Normal file
341
hmrc_sync/fraud_headers.py
Normal file
|
|
@ -0,0 +1,341 @@
|
|||
"""Build HMRC MTD fraud-prevention headers (Gov-Client-* / Gov-Vendor-*).
|
||||
|
||||
HMRC's BATCH_PROCESS_DIRECT connection method (what our CronJob uses)
|
||||
mandates 11 headers on every MTD API call; WEB_APP_VIA_SERVER adds a
|
||||
handful of browser-derived fields. Shipping without these risks fines
|
||||
and API-access revocation per the HMRC fraud-prevention guide.
|
||||
|
||||
Layout:
|
||||
|
||||
- **Static** — vendor-constant across runs (product name/version,
|
||||
hashed license id).
|
||||
- **Runtime** — collected at module load from the pod's own network
|
||||
stack + OS: MAC addresses, local IPs, OS family/version, device model.
|
||||
- **Per-request** — built at call time (timestamps, request ids).
|
||||
- **Per-session** — captured from the browser on `/callback-metadata`
|
||||
(screen dimensions, public IP, MFA timestamp). Only WEB_APP_VIA_SERVER.
|
||||
|
||||
The public entry point is `build_headers(session, connection_method)`.
|
||||
Run `tests/test_fraud_headers.py::test_headers_pass_hmrc_validator`
|
||||
with `HMRC_VALIDATOR=1` to verify against the HMRC sandbox validator.
|
||||
|
||||
Spec references:
|
||||
https://developer.service.hmrc.gov.uk/guides/fraud-prevention/
|
||||
https://developer.service.hmrc.gov.uk/guides/fraud-prevention/connection-method/batch-process-direct/
|
||||
https://developer.service.hmrc.gov.uk/api-documentation/docs/api/service/txm-fph-validator-api/1.0
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
import getpass
|
||||
import hashlib
|
||||
import logging
|
||||
import os
|
||||
import platform
|
||||
import secrets
|
||||
import socket
|
||||
import time
|
||||
import urllib.parse
|
||||
import uuid
|
||||
from dataclasses import dataclass, field
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
|
||||
VENDOR_PRODUCT_NAME = "hmrc-sync"
|
||||
VENDOR_PRODUCT_VERSION = "0.1.0"
|
||||
# Self-assigned for a personal single-user tool. HMRC permits arbitrary
|
||||
# vendor strings; the header value is then SHA-256-hashed per spec
|
||||
# (`Gov-Vendor-License-IDs: <name>=<hashed-value>`).
|
||||
VENDOR_LICENSE_ID = os.environ.get("HMRC_VENDOR_LICENSE_ID",
|
||||
"hmrc-sync-private-single-user")
|
||||
VENDOR_PUBLIC_IP = os.environ.get("HMRC_VENDOR_PUBLIC_IP", "")
|
||||
|
||||
# Valid HMRC connection-method enum values.
|
||||
CONNECTION_METHOD_BATCH = "BATCH_PROCESS_DIRECT"
|
||||
CONNECTION_METHOD_WEB_APP = "WEB_APP_VIA_SERVER"
|
||||
CONNECTION_METHOD_MFA = "AUTH_USING_MFA"
|
||||
|
||||
_NET_CLASS = Path("/sys/class/net")
|
||||
_EMPTY_MAC = "00:00:00:00:00:00"
|
||||
|
||||
|
||||
@dataclass
|
||||
class SessionContext:
|
||||
"""Browser-side attributes captured on the `/callback-metadata` POST.
|
||||
|
||||
Only relevant for WEB_APP_VIA_SERVER flows (browser-initiated OAuth
|
||||
+ server-side API calls). BATCH_PROCESS_DIRECT flows derive their
|
||||
context from `RuntimeContext` (see below) without touching these.
|
||||
"""
|
||||
user_agent: str = ""
|
||||
screen_width: int = 0
|
||||
screen_height: int = 0
|
||||
screen_colour_depth: int = 0
|
||||
window_width: int = 0
|
||||
window_height: int = 0
|
||||
timezone_offset: int = 0
|
||||
device_id: str = ""
|
||||
mfa_timestamp: str = ""
|
||||
public_ip: str = ""
|
||||
public_port: int = 0
|
||||
|
||||
|
||||
@dataclass
|
||||
class RuntimeContext:
|
||||
"""Pod-side environment values required on every API call.
|
||||
|
||||
Collected once at module load (cheap — all local syscalls). If any
|
||||
field is empty, the header emitter falls back to safe defaults so
|
||||
the call never goes out with an empty mandatory header.
|
||||
"""
|
||||
mac_addresses: list[str] = field(default_factory=list)
|
||||
local_ips: list[str] = field(default_factory=list)
|
||||
os_family: str = ""
|
||||
os_version: str = ""
|
||||
device_manufacturer: str = "Kubernetes"
|
||||
device_model: str = ""
|
||||
os_user: str = ""
|
||||
|
||||
|
||||
def _collect_mac_addresses() -> list[str]:
|
||||
"""Read every non-loopback interface MAC from `/sys/class/net/*/address`.
|
||||
|
||||
Colons are kept raw; `_format_mac_list` percent-encodes on output per spec.
|
||||
"""
|
||||
out: list[str] = []
|
||||
if not _NET_CLASS.exists():
|
||||
return out
|
||||
for iface in sorted(_NET_CLASS.iterdir()):
|
||||
if iface.name == "lo":
|
||||
continue
|
||||
addr_file = iface / "address"
|
||||
if not addr_file.exists():
|
||||
continue
|
||||
try:
|
||||
mac = addr_file.read_text().strip()
|
||||
except OSError:
|
||||
continue
|
||||
if mac and mac != _EMPTY_MAC:
|
||||
out.append(mac)
|
||||
return out
|
||||
|
||||
|
||||
def _collect_local_ips() -> list[str]:
|
||||
"""Every IP bound to this host — IPv4 + IPv6, loopback excluded."""
|
||||
ips: set[str] = set()
|
||||
try:
|
||||
hostname = socket.gethostname()
|
||||
for family, _, _, _, sockaddr in socket.getaddrinfo(hostname, None):
|
||||
raw = sockaddr[0]
|
||||
if not isinstance(raw, str):
|
||||
continue
|
||||
if family == socket.AF_INET and not raw.startswith("127."):
|
||||
ips.add(raw)
|
||||
elif family == socket.AF_INET6 and not raw.startswith("::1"):
|
||||
ips.add(raw.split("%")[0]) # strip zone id
|
||||
except (socket.gaierror, OSError):
|
||||
pass
|
||||
# Also grab the primary outbound IP — `getaddrinfo(hostname)` can miss
|
||||
# it inside containers whose hostname has no DNS entry.
|
||||
try:
|
||||
with socket.socket(socket.AF_INET, socket.SOCK_DGRAM) as s:
|
||||
s.connect(("10.255.255.255", 1))
|
||||
ips.add(s.getsockname()[0])
|
||||
except OSError:
|
||||
pass
|
||||
return sorted(ips)
|
||||
|
||||
|
||||
def _detect_runtime_context() -> RuntimeContext:
|
||||
uname = platform.uname()
|
||||
return RuntimeContext(
|
||||
mac_addresses=_collect_mac_addresses(),
|
||||
local_ips=_collect_local_ips(),
|
||||
os_family=uname.system or "Linux",
|
||||
os_version=uname.release or "unknown",
|
||||
device_manufacturer="Kubernetes",
|
||||
device_model=uname.node or socket.gethostname() or "pod",
|
||||
os_user=_safe_getuser(),
|
||||
)
|
||||
|
||||
|
||||
def _safe_getuser() -> str:
|
||||
try:
|
||||
return getpass.getuser()
|
||||
except (KeyError, OSError):
|
||||
return os.environ.get("USER", "unknown")
|
||||
|
||||
|
||||
RUNTIME_CONTEXT: RuntimeContext = _detect_runtime_context()
|
||||
|
||||
|
||||
def build_headers(session: SessionContext | None = None,
|
||||
connection_method: str = CONNECTION_METHOD_BATCH,
|
||||
runtime: RuntimeContext | None = None) -> dict[str, str]:
|
||||
"""Return the full header dict to attach to every HMRC API call.
|
||||
|
||||
Defaults to BATCH_PROCESS_DIRECT — the mode the CronJob uses. Pass
|
||||
a populated `SessionContext` + `connection_method=WEB_APP_VIA_SERVER`
|
||||
for browser-initiated flows; the browser-only fields layer on top.
|
||||
"""
|
||||
session = session or SessionContext()
|
||||
rt = runtime or RUNTIME_CONTEXT
|
||||
headers: dict[str, str] = {}
|
||||
headers.update(_static_headers())
|
||||
headers.update(_per_request_headers())
|
||||
headers.update(_mandatory_runtime_headers(rt, session, connection_method))
|
||||
if connection_method == CONNECTION_METHOD_WEB_APP:
|
||||
headers.update(_web_app_session_headers(session))
|
||||
if connection_method == CONNECTION_METHOD_MFA and session.mfa_timestamp:
|
||||
headers["Gov-Client-MFA-Timestamp"] = session.mfa_timestamp
|
||||
return headers
|
||||
|
||||
|
||||
def _static_headers() -> dict[str, str]:
|
||||
"""Vendor-constant identity headers that apply to every connection method.
|
||||
|
||||
Product-Name is percent-encoded per spec; License-IDs value is SHA-256-
|
||||
hashed per spec; Version is a key-value pair of `<software-name>=<version>`.
|
||||
|
||||
Gov-Vendor-Public-IP and Gov-Vendor-Forwarded are NOT emitted here — the
|
||||
HMRC validator rejects them for BATCH_PROCESS_DIRECT (where no vendor
|
||||
server sits between the client and the HMRC API). They're added in
|
||||
`_web_app_session_headers` for the WEB_APP_VIA_SERVER path only.
|
||||
"""
|
||||
license_hash = hashlib.sha256(VENDOR_LICENSE_ID.encode()).hexdigest()
|
||||
return {
|
||||
"Gov-Vendor-Product-Name": _pct(VENDOR_PRODUCT_NAME),
|
||||
"Gov-Vendor-Version": f"{VENDOR_PRODUCT_NAME}={VENDOR_PRODUCT_VERSION}",
|
||||
"Gov-Vendor-License-IDs": f"{VENDOR_PRODUCT_NAME}={license_hash}",
|
||||
}
|
||||
|
||||
|
||||
def _per_request_headers() -> dict[str, str]:
|
||||
"""Per-call trace + timestamp headers. Local-IPs-Timestamp uses HMRC's
|
||||
exact format `yyyy-MM-ddThh:mm:ss.sssZ` — always UTC, always millis."""
|
||||
now_ms = int(time.time() * 1000)
|
||||
iso_ms = time.strftime("%Y-%m-%dT%H:%M:%S", time.gmtime(now_ms / 1000))
|
||||
now_iso = f"{iso_ms}.{now_ms % 1000:03d}Z"
|
||||
return {
|
||||
"Gov-Client-Timezone": "UTC+00:00",
|
||||
"Gov-Client-Local-IPs-Timestamp": now_iso,
|
||||
"x-correlation-id": str(uuid.uuid4()),
|
||||
"x-request-id": secrets.token_hex(16),
|
||||
}
|
||||
|
||||
|
||||
def _mandatory_runtime_headers(rt: RuntimeContext, session: SessionContext,
|
||||
connection_method: str) -> dict[str, str]:
|
||||
"""The 8 headers mandatory for BATCH_PROCESS_DIRECT that come from the
|
||||
host — Connection-Method, Device-ID, User-IDs, User-Agent, Local-IPs,
|
||||
MAC-Addresses (+ Timezone and Local-IPs-Timestamp live in
|
||||
`_per_request_headers`)."""
|
||||
return {
|
||||
"Gov-Client-Connection-Method": connection_method,
|
||||
"Gov-Client-Device-ID": session.device_id or _fallback_device_id(),
|
||||
"Gov-Client-User-IDs": _user_ids(rt, session),
|
||||
"Gov-Client-User-Agent": _user_agent(rt, session),
|
||||
"Gov-Client-Local-IPs": _format_ip_list(rt.local_ips),
|
||||
"Gov-Client-MAC-Addresses": _format_mac_list(rt.mac_addresses),
|
||||
}
|
||||
|
||||
|
||||
def _web_app_session_headers(session: SessionContext) -> dict[str, str]:
|
||||
"""WEB_APP_VIA_SERVER-only headers — browser context + vendor hop trail.
|
||||
|
||||
Gov-Vendor-Public-IP and Gov-Vendor-Forwarded describe the vendor server
|
||||
that sits between the user's browser and HMRC — only meaningful for
|
||||
WEB_APP_VIA_SERVER. BATCH_PROCESS_DIRECT must omit them (validator
|
||||
rejects them there).
|
||||
"""
|
||||
out: dict[str, str] = {}
|
||||
if session.screen_width and session.screen_height:
|
||||
out["Gov-Client-Screens"] = (
|
||||
f"width={session.screen_width}&height={session.screen_height}"
|
||||
f"&scaling-factor=1&colour-depth={session.screen_colour_depth}")
|
||||
if session.window_width and session.window_height:
|
||||
out["Gov-Client-Window-Size"] = (f"width={session.window_width}&"
|
||||
f"height={session.window_height}")
|
||||
if session.public_ip:
|
||||
out["Gov-Client-Public-IP"] = session.public_ip
|
||||
if session.public_port:
|
||||
out["Gov-Client-Public-Port"] = str(session.public_port)
|
||||
vendor_ip = VENDOR_PUBLIC_IP or (RUNTIME_CONTEXT.local_ips[0] if RUNTIME_CONTEXT.local_ips
|
||||
else "")
|
||||
if vendor_ip:
|
||||
out["Gov-Vendor-Public-IP"] = vendor_ip
|
||||
out["Gov-Vendor-Forwarded"] = f"by={vendor_ip}&for={vendor_ip}"
|
||||
return out
|
||||
|
||||
|
||||
def _user_ids(rt: RuntimeContext, session: SessionContext) -> str:
|
||||
"""Per spec: `os=<device-user>&<app>=<app-user>`. The `os=` field is
|
||||
mandatory; we additionally tag our application with the OAuth user
|
||||
so HMRC can correlate activity in a breach investigation.
|
||||
"""
|
||||
os_user = rt.os_user or "unknown"
|
||||
pairs = [f"os={_pct(os_user)}"]
|
||||
oauth_user = os.environ.get("HMRC_OAUTH_USER", "viktor")
|
||||
pairs.append(f"hmrc-sync={_pct(oauth_user)}")
|
||||
_ = session # reserved for future per-session identity extension
|
||||
return "&".join(pairs)
|
||||
|
||||
|
||||
def _user_agent(rt: RuntimeContext, session: SessionContext) -> str:
|
||||
"""Per spec: `os-family=…&os-version=…&device-manufacturer=…&device-model=…`.
|
||||
|
||||
For WEB_APP_VIA_SERVER with a captured browser UA, the browser string
|
||||
is encoded under `device-model` with the rest of the fields defaulting
|
||||
to our pod's values — HMRC's validator accepts this hybrid shape.
|
||||
"""
|
||||
model = session.user_agent or rt.device_model or "pod"
|
||||
pairs = [
|
||||
f"os-family={_pct(rt.os_family)}",
|
||||
f"os-version={_pct(rt.os_version)}",
|
||||
f"device-manufacturer={_pct(rt.device_manufacturer)}",
|
||||
f"device-model={_pct(model)}",
|
||||
]
|
||||
return "&".join(pairs)
|
||||
|
||||
|
||||
def _format_ip_list(ips: list[str]) -> str:
|
||||
"""IPv6 addresses percent-encoded; IPv4 passes through. Joined with ','.
|
||||
|
||||
HMRC's validator accepts an empty header only if the request truly
|
||||
has no IPs; on a live pod we always have at least one — if the list
|
||||
comes back empty we fall back to the loopback so the header is
|
||||
syntactically valid.
|
||||
"""
|
||||
if not ips:
|
||||
return "127.0.0.1"
|
||||
out = []
|
||||
for ip in ips:
|
||||
out.append(_pct(ip) if ":" in ip else ip)
|
||||
return ",".join(out)
|
||||
|
||||
|
||||
def _format_mac_list(macs: list[str]) -> str:
|
||||
"""Each MAC percent-encoded (colons → %3A), comma-joined.
|
||||
|
||||
Empty list → single dummy MAC so we never ship a blank header;
|
||||
HMRC's validator treats blank as a violation.
|
||||
"""
|
||||
if not macs:
|
||||
return _pct("02:00:00:00:00:00")
|
||||
return ",".join(_pct(m) for m in macs)
|
||||
|
||||
|
||||
def _fallback_device_id() -> str:
|
||||
"""Deterministic UUID derived from hostname when no Vault-backed
|
||||
Device-ID is seeded. Stable across restarts on the same node."""
|
||||
return str(uuid.uuid5(uuid.NAMESPACE_DNS, f"hmrc-sync-{platform.node()}"))
|
||||
|
||||
|
||||
def _pct(s: str) -> str:
|
||||
return urllib.parse.quote(s, safe="")
|
||||
|
||||
|
||||
def as_validator_payload(headers: dict[str, str]) -> dict[str, Any]:
|
||||
"""Reshape headers for the HMRC fraud-header validator API body."""
|
||||
return {"headers": [{"name": k, "value": v} for k, v in headers.items()]}
|
||||
125
hmrc_sync/oauth.py
Normal file
125
hmrc_sync/oauth.py
Normal file
|
|
@ -0,0 +1,125 @@
|
|||
"""HMRC OAuth token persistence — Vault-backed refresh-token store.
|
||||
|
||||
The refresh_token is long-lived (HMRC grants 18 months). We keep it in
|
||||
Vault at `secret/viktor/hmrc_refresh_token` and let ESO sync it to a K8s
|
||||
Secret the pod mounts as an env var. On every refresh, we write the new
|
||||
token back to Vault so the next pod restart picks it up.
|
||||
|
||||
Writing back requires Vault write access — the pod uses a short-lived
|
||||
K8s-auth Vault token with a narrow policy that only allows writing
|
||||
`secret/viktor/hmrc_refresh_token`.
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
import os
|
||||
from dataclasses import dataclass
|
||||
|
||||
import httpx
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
|
||||
VAULT_KEY = "secret/viktor/hmrc_refresh_token"
|
||||
PROD_BASE = "https://api.service.hmrc.gov.uk"
|
||||
TOKEN_PATH = "/oauth/token"
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class OAuthCreds:
|
||||
client_id: str
|
||||
client_secret: str
|
||||
redirect_uri: str
|
||||
|
||||
|
||||
@dataclass
|
||||
class TokenBundle:
|
||||
access_token: str
|
||||
refresh_token: str
|
||||
expires_in: int
|
||||
scope: str
|
||||
|
||||
@classmethod
|
||||
def from_json(cls, data: dict[str, object]) -> TokenBundle:
|
||||
return cls(
|
||||
access_token=str(data["access_token"]),
|
||||
refresh_token=str(data["refresh_token"]),
|
||||
expires_in=int(data["expires_in"]), # type: ignore[arg-type]
|
||||
scope=str(data.get("scope", "")),
|
||||
)
|
||||
|
||||
|
||||
def load_creds_from_env() -> OAuthCreds:
|
||||
return OAuthCreds(
|
||||
client_id=os.environ["HMRC_PROD_CLIENT_ID"],
|
||||
client_secret=os.environ["HMRC_PROD_CLIENT_SECRET"],
|
||||
redirect_uri=os.environ["HMRC_PROD_REDIRECT_URI"],
|
||||
)
|
||||
|
||||
|
||||
async def exchange_code(creds: OAuthCreds, code: str) -> TokenBundle:
|
||||
"""Swap a fresh authorization_code for an access+refresh token pair."""
|
||||
async with httpx.AsyncClient(timeout=30.0) as client:
|
||||
resp = await client.post(
|
||||
f"{PROD_BASE}{TOKEN_PATH}",
|
||||
data={
|
||||
"grant_type": "authorization_code",
|
||||
"client_id": creds.client_id,
|
||||
"client_secret": creds.client_secret,
|
||||
"redirect_uri": creds.redirect_uri,
|
||||
"code": code,
|
||||
},
|
||||
headers={"Accept": "application/vnd.hmrc.1.0+json"},
|
||||
)
|
||||
resp.raise_for_status()
|
||||
return TokenBundle.from_json(resp.json())
|
||||
|
||||
|
||||
async def refresh(creds: OAuthCreds, refresh_token: str) -> TokenBundle:
|
||||
"""Exchange an old refresh_token for a fresh access+refresh pair.
|
||||
|
||||
HMRC rotates the refresh_token on every refresh — the old one becomes
|
||||
invalid immediately after this call returns. Persist the new one to
|
||||
Vault atomically; a failure between the refresh and the Vault write
|
||||
leaves us stranded.
|
||||
"""
|
||||
async with httpx.AsyncClient(timeout=30.0) as client:
|
||||
resp = await client.post(
|
||||
f"{PROD_BASE}{TOKEN_PATH}",
|
||||
data={
|
||||
"grant_type": "refresh_token",
|
||||
"client_id": creds.client_id,
|
||||
"client_secret": creds.client_secret,
|
||||
"refresh_token": refresh_token,
|
||||
},
|
||||
headers={"Accept": "application/vnd.hmrc.1.0+json"},
|
||||
)
|
||||
resp.raise_for_status()
|
||||
return TokenBundle.from_json(resp.json())
|
||||
|
||||
|
||||
def persist_to_vault(token: TokenBundle) -> None:
|
||||
"""Write the new refresh_token back to Vault.
|
||||
|
||||
Uses the hvac client with K8s-auth — the pod's service-account token
|
||||
at /var/run/secrets/kubernetes.io/serviceaccount/token logs into
|
||||
Vault's kubernetes auth method and receives a short-lived Vault token
|
||||
with write access to `secret/viktor/hmrc_refresh_token` only.
|
||||
"""
|
||||
import hvac
|
||||
|
||||
addr = os.environ.get("VAULT_ADDR", "https://vault.viktorbarzin.me")
|
||||
role = os.environ.get("VAULT_K8S_ROLE", "hmrc-sync")
|
||||
jwt_path = "/var/run/secrets/kubernetes.io/serviceaccount/token"
|
||||
with open(jwt_path, encoding="utf-8") as fh:
|
||||
jwt = fh.read()
|
||||
client = hvac.Client(url=addr)
|
||||
client.auth.kubernetes.login(role=role, jwt=jwt)
|
||||
client.secrets.kv.v2.create_or_update_secret(
|
||||
path="viktor/hmrc_refresh_token",
|
||||
secret={
|
||||
"refresh_token": token.refresh_token,
|
||||
"expires_in": token.expires_in,
|
||||
"scope": token.scope,
|
||||
},
|
||||
)
|
||||
log.info("Rotated HMRC refresh_token persisted to Vault")
|
||||
Loading…
Add table
Add a link
Reference in a new issue