infra/scripts/cluster_manager.py

278 lines
9.4 KiB
Python
Raw Normal View History

import asyncio
import click
import logging
import time
from typing import List, Union, Optional
from kubernetes_asyncio import client, config
from kubernetes_asyncio.client.api_client import ApiClient
logging.basicConfig(level=logging.INFO, format="%(levelname)s: %(message)s")
logger = logging.getLogger(__name__)
async def wait_for_healthy(
api_instance: client.AppsV1Api,
resource_type: str,
namespace: str,
name: str,
target_replicas: int,
timeout: int = 300,
) -> None:
start_time = time.time()
logger.info(
f"Waiting for {resource_type} {name} to reach {target_replicas} replicas..."
)
while True:
if time.time() - start_time > timeout:
logger.error(f"❌ Timeout reached for {resource_type} {name}")
return
try:
if resource_type.lower() == "deployment":
res = await api_instance.read_namespaced_deployment_status(
name, namespace
)
ready = res.status.ready_replicas or 0
updated = res.status.updated_replicas or 0
if ready == target_replicas and updated == target_replicas:
break
else: # StatefulSet
res = await api_instance.read_namespaced_stateful_set_status(
name, namespace
)
ready = res.status.ready_replicas or 0
if ready == target_replicas:
break
except Exception as e:
logger.debug(f"Retrying status check for {name}: {e}")
await asyncio.sleep(5)
logger.info(f"{resource_type} {name} is now healthy.")
async def wait_for_zero(
api: client.AppsV1Api, kind: str, ns: str, name: str, timeout: int
) -> tuple[str, str]:
start_time = asyncio.get_event_loop().time()
while (asyncio.get_event_loop().time() - start_time) < timeout:
try:
res = await (
api.read_namespaced_deployment_status(name, ns)
if kind.lower() == "deployment"
else api.read_namespaced_stateful_set_status(name, ns)
)
if (res.status.ready_replicas or 0) == 0:
return ns, name
except Exception:
return ns, name # Assume gone if error
await asyncio.sleep(3)
logger.error(f"Timeout: {kind} {ns}/{name} still has running pods.")
return ns, name
async def scale_resource(
api_instance: client.AppsV1Api,
resource_type: str,
namespace: str,
name: str,
replicas: int,
) -> None:
body = {"spec": {"replicas": replicas}}
try:
if resource_type.lower() == "deployment":
await api_instance.patch_namespaced_deployment_scale(name, namespace, body)
else:
await api_instance.patch_namespaced_stateful_set_scale(
name, namespace, body
)
except Exception as e:
logger.error(f"Failed to scale {resource_type} {name}: {e}")
async def run_stop_tier(
api_v1: client.AppsV1Api, label: str, output_file: str, timeout: int
) -> None:
"""Processes a single label tier: saves, scales to 0, and waits."""
excluded_ns = ["kube-system", "kube-public", "kube-node-lease"]
# 1. Discover
targets = [
("Deployment", api_v1.list_deployment_for_all_namespaces),
("StatefulSet", api_v1.list_stateful_set_for_all_namespaces),
]
tier_resources = []
for kind, list_func in targets:
resp = await list_func(label_selector=label)
tier_resources.extend(
[
(kind, item)
for item in resp.items
if item.metadata.namespace not in excluded_ns
]
)
if not tier_resources:
logger.warning(f"No resources found for label: {label}")
return
# 2. Save & Scale
active_jobs: set[tuple[str, str]] = set()
wait_tasks = []
# Append to file so we don't overwrite previous tiers
with open(output_file, "a") as f:
for kind, item in tier_resources:
ns, name = item.metadata.namespace, item.metadata.name
reps = item.spec.replicas or 0
f.write(f"{kind} {ns} {name} {reps}\n")
active_jobs.add((ns, name))
await scale_resource(api_v1, kind, ns, name, 0)
wait_tasks.append(wait_for_zero(api_v1, kind, ns, name, timeout))
# 3. Wait for this tier to finish before moving to next
logger.info(f"Tier [{label}]: Waiting for {len(active_jobs)} resources to stop...")
for coro in asyncio.as_completed(wait_tasks):
finished_ns, finished_name = await coro
active_jobs.discard((finished_ns, finished_name))
if active_jobs:
remaining_ns = sorted({ns for ns, name in active_jobs})
logger.info(
f"[{label}] Pending: {len(active_jobs)} | Namespaces: {', '.join(remaining_ns)}"
)
logger.info(f"✅ Tier [{label}] successfully shut down.")
@click.group()
def cli():
pass
@cli.command()
@click.argument("labels", nargs=-1, required=True)
@click.option("--output", "-o", default="resources.txt", help="Output state file")
@click.option("--timeout", "-t", default=3600)
def stop(labels: List[str], output: str, timeout: int):
"""Stop tiers sequentially. Usage: stop 'app=web' 'app=db'"""
async def main():
await config.load_kube_config()
# Clear/Create file at start
open(output, "w").close()
async with ApiClient() as api_client:
api_v1 = client.AppsV1Api(api_client)
for label in labels:
logger.info(f"🚀 Processing Shutdown Tier: {label}")
await run_stop_tier(api_v1, label, output, timeout)
logger.info("🏁 Sequence complete. Cluster is gracefully stopped.")
asyncio.run(main())
@cli.command()
@click.argument("labels", nargs=-1, required=True)
@click.option("--file", "-f", default="resources.txt")
@click.option("--timeout", "-t", default=3600, help="Seconds to wait per resource")
def start(labels: List[str], file: str, timeout: int):
asyncio.run(run_start_sequence(labels, file, timeout))
async def run_start_sequence(labels: List[str], file_path: str, timeout: int) -> None:
await config.load_kube_config()
async with ApiClient() as api_client:
apps_v1 = client.AppsV1Api(api_client)
# 1. Load the entire snapshot into memory for filtering
try:
with open(file_path, "r") as f:
# Format: Kind Namespace Name Replicas
snapshot_lines = [line.strip().split() for line in f if line.strip()]
except FileNotFoundError:
logger.error(f"Snapshot file {file_path} not found.")
return
# 2. Iterate through labels in the order provided
for label in labels:
logger.info(f"🚀 Starting Tier: {label}")
# Find resources in this tier by querying K8s for the label
# then matching against our snapshot file data
tier_resources = await get_resources_by_label(apps_v1, label)
# Cross-reference: Only start things that are in BOTH the K8s label query AND our file
# This ensures we restore them to the CORRECT previous replica count
to_restore = []
tier_keys = {(r["ns"], r["name"]) for r in tier_resources}
for kind, ns, name, reps in snapshot_lines:
if (ns, name) in tier_keys:
to_restore.append((kind, ns, name, int(reps)))
if not to_restore:
logger.warning(f"No resources found in snapshot for tier: {label}")
continue
# 3. Scale and Wait for this specific tier
await process_start_tier(apps_v1, to_restore, timeout, label)
logger.info("🏁 All tiers started successfully.")
async def get_resources_by_label(api: client.AppsV1Api, label: str) -> List[dict]:
"""Helper to find what currently exists in the cluster with this label."""
targets = [
api.list_deployment_for_all_namespaces,
api.list_stateful_set_for_all_namespaces,
]
found = []
for list_func in targets:
resp = await list_func(label_selector=label)
for item in resp.items:
found.append({"ns": item.metadata.namespace, "name": item.metadata.name})
return found
async def process_start_tier(
api: client.AppsV1Api, resources: list, timeout: int, label: str
):
active_jobs = set()
scale_tasks = []
wait_tasks = []
# Wrapper to track which job finishes
async def tracked_wait(kind, ns, name, target, t_out):
await wait_for_healthy(api, kind, ns, name, target, t_out)
return (ns, name)
for kind, ns, name, reps in resources:
active_jobs.add((ns, name))
scale_tasks.append(scale_resource(api, kind, ns, name, reps))
wait_tasks.append(tracked_wait(kind, ns, name, reps, timeout))
# Trigger all scales for this tier
await asyncio.gather(*scale_tasks)
# Monitor health
for coro in asyncio.as_completed(wait_tasks):
finished_ns, finished_name = await coro
active_jobs.discard((finished_ns, finished_name))
if active_jobs:
remaining_ns = sorted({ns for ns, name in active_jobs})
logger.info(
f"[{label}] Pending Health: {len(active_jobs)} | Namespaces: {', '.join(remaining_ns)}"
)
logger.info(f"✅ Tier [{label}] is healthy.")
if __name__ == "__main__":
cli()