278 lines
9.4 KiB
Python
278 lines
9.4 KiB
Python
|
|
import asyncio
|
||
|
|
import click
|
||
|
|
import logging
|
||
|
|
import time
|
||
|
|
from typing import List, Union, Optional
|
||
|
|
from kubernetes_asyncio import client, config
|
||
|
|
from kubernetes_asyncio.client.api_client import ApiClient
|
||
|
|
|
||
|
|
logging.basicConfig(level=logging.INFO, format="%(levelname)s: %(message)s")
|
||
|
|
logger = logging.getLogger(__name__)
|
||
|
|
|
||
|
|
|
||
|
|
async def wait_for_healthy(
|
||
|
|
api_instance: client.AppsV1Api,
|
||
|
|
resource_type: str,
|
||
|
|
namespace: str,
|
||
|
|
name: str,
|
||
|
|
target_replicas: int,
|
||
|
|
timeout: int = 300,
|
||
|
|
) -> None:
|
||
|
|
start_time = time.time()
|
||
|
|
logger.info(
|
||
|
|
f"Waiting for {resource_type} {name} to reach {target_replicas} replicas..."
|
||
|
|
)
|
||
|
|
|
||
|
|
while True:
|
||
|
|
if time.time() - start_time > timeout:
|
||
|
|
logger.error(f"❌ Timeout reached for {resource_type} {name}")
|
||
|
|
return
|
||
|
|
|
||
|
|
try:
|
||
|
|
if resource_type.lower() == "deployment":
|
||
|
|
res = await api_instance.read_namespaced_deployment_status(
|
||
|
|
name, namespace
|
||
|
|
)
|
||
|
|
ready = res.status.ready_replicas or 0
|
||
|
|
updated = res.status.updated_replicas or 0
|
||
|
|
if ready == target_replicas and updated == target_replicas:
|
||
|
|
break
|
||
|
|
else: # StatefulSet
|
||
|
|
res = await api_instance.read_namespaced_stateful_set_status(
|
||
|
|
name, namespace
|
||
|
|
)
|
||
|
|
ready = res.status.ready_replicas or 0
|
||
|
|
if ready == target_replicas:
|
||
|
|
break
|
||
|
|
|
||
|
|
except Exception as e:
|
||
|
|
logger.debug(f"Retrying status check for {name}: {e}")
|
||
|
|
|
||
|
|
await asyncio.sleep(5)
|
||
|
|
|
||
|
|
logger.info(f"✅ {resource_type} {name} is now healthy.")
|
||
|
|
|
||
|
|
|
||
|
|
async def wait_for_zero(
|
||
|
|
api: client.AppsV1Api, kind: str, ns: str, name: str, timeout: int
|
||
|
|
) -> tuple[str, str]:
|
||
|
|
start_time = asyncio.get_event_loop().time()
|
||
|
|
while (asyncio.get_event_loop().time() - start_time) < timeout:
|
||
|
|
try:
|
||
|
|
res = await (
|
||
|
|
api.read_namespaced_deployment_status(name, ns)
|
||
|
|
if kind.lower() == "deployment"
|
||
|
|
else api.read_namespaced_stateful_set_status(name, ns)
|
||
|
|
)
|
||
|
|
if (res.status.ready_replicas or 0) == 0:
|
||
|
|
return ns, name
|
||
|
|
except Exception:
|
||
|
|
return ns, name # Assume gone if error
|
||
|
|
await asyncio.sleep(3)
|
||
|
|
logger.error(f"Timeout: {kind} {ns}/{name} still has running pods.")
|
||
|
|
return ns, name
|
||
|
|
|
||
|
|
|
||
|
|
async def scale_resource(
|
||
|
|
api_instance: client.AppsV1Api,
|
||
|
|
resource_type: str,
|
||
|
|
namespace: str,
|
||
|
|
name: str,
|
||
|
|
replicas: int,
|
||
|
|
) -> None:
|
||
|
|
body = {"spec": {"replicas": replicas}}
|
||
|
|
try:
|
||
|
|
if resource_type.lower() == "deployment":
|
||
|
|
await api_instance.patch_namespaced_deployment_scale(name, namespace, body)
|
||
|
|
else:
|
||
|
|
await api_instance.patch_namespaced_stateful_set_scale(
|
||
|
|
name, namespace, body
|
||
|
|
)
|
||
|
|
except Exception as e:
|
||
|
|
logger.error(f"Failed to scale {resource_type} {name}: {e}")
|
||
|
|
|
||
|
|
|
||
|
|
async def run_stop_tier(
|
||
|
|
api_v1: client.AppsV1Api, label: str, output_file: str, timeout: int
|
||
|
|
) -> None:
|
||
|
|
"""Processes a single label tier: saves, scales to 0, and waits."""
|
||
|
|
excluded_ns = ["kube-system", "kube-public", "kube-node-lease"]
|
||
|
|
|
||
|
|
# 1. Discover
|
||
|
|
targets = [
|
||
|
|
("Deployment", api_v1.list_deployment_for_all_namespaces),
|
||
|
|
("StatefulSet", api_v1.list_stateful_set_for_all_namespaces),
|
||
|
|
]
|
||
|
|
|
||
|
|
tier_resources = []
|
||
|
|
for kind, list_func in targets:
|
||
|
|
resp = await list_func(label_selector=label)
|
||
|
|
tier_resources.extend(
|
||
|
|
[
|
||
|
|
(kind, item)
|
||
|
|
for item in resp.items
|
||
|
|
if item.metadata.namespace not in excluded_ns
|
||
|
|
]
|
||
|
|
)
|
||
|
|
|
||
|
|
if not tier_resources:
|
||
|
|
logger.warning(f"No resources found for label: {label}")
|
||
|
|
return
|
||
|
|
|
||
|
|
# 2. Save & Scale
|
||
|
|
active_jobs: set[tuple[str, str]] = set()
|
||
|
|
wait_tasks = []
|
||
|
|
|
||
|
|
# Append to file so we don't overwrite previous tiers
|
||
|
|
with open(output_file, "a") as f:
|
||
|
|
for kind, item in tier_resources:
|
||
|
|
ns, name = item.metadata.namespace, item.metadata.name
|
||
|
|
reps = item.spec.replicas or 0
|
||
|
|
f.write(f"{kind} {ns} {name} {reps}\n")
|
||
|
|
active_jobs.add((ns, name))
|
||
|
|
|
||
|
|
await scale_resource(api_v1, kind, ns, name, 0)
|
||
|
|
wait_tasks.append(wait_for_zero(api_v1, kind, ns, name, timeout))
|
||
|
|
|
||
|
|
# 3. Wait for this tier to finish before moving to next
|
||
|
|
logger.info(f"Tier [{label}]: Waiting for {len(active_jobs)} resources to stop...")
|
||
|
|
for coro in asyncio.as_completed(wait_tasks):
|
||
|
|
finished_ns, finished_name = await coro
|
||
|
|
active_jobs.discard((finished_ns, finished_name))
|
||
|
|
if active_jobs:
|
||
|
|
remaining_ns = sorted({ns for ns, name in active_jobs})
|
||
|
|
logger.info(
|
||
|
|
f"[{label}] Pending: {len(active_jobs)} | Namespaces: {', '.join(remaining_ns)}"
|
||
|
|
)
|
||
|
|
|
||
|
|
logger.info(f"✅ Tier [{label}] successfully shut down.")
|
||
|
|
|
||
|
|
|
||
|
|
@click.group()
|
||
|
|
def cli():
|
||
|
|
pass
|
||
|
|
|
||
|
|
|
||
|
|
@cli.command()
|
||
|
|
@click.argument("labels", nargs=-1, required=True)
|
||
|
|
@click.option("--output", "-o", default="resources.txt", help="Output state file")
|
||
|
|
@click.option("--timeout", "-t", default=3600)
|
||
|
|
def stop(labels: List[str], output: str, timeout: int):
|
||
|
|
"""Stop tiers sequentially. Usage: stop 'app=web' 'app=db'"""
|
||
|
|
|
||
|
|
async def main():
|
||
|
|
await config.load_kube_config()
|
||
|
|
# Clear/Create file at start
|
||
|
|
open(output, "w").close()
|
||
|
|
|
||
|
|
async with ApiClient() as api_client:
|
||
|
|
api_v1 = client.AppsV1Api(api_client)
|
||
|
|
for label in labels:
|
||
|
|
logger.info(f"🚀 Processing Shutdown Tier: {label}")
|
||
|
|
await run_stop_tier(api_v1, label, output, timeout)
|
||
|
|
logger.info("🏁 Sequence complete. Cluster is gracefully stopped.")
|
||
|
|
|
||
|
|
asyncio.run(main())
|
||
|
|
|
||
|
|
|
||
|
|
@cli.command()
|
||
|
|
@click.argument("labels", nargs=-1, required=True)
|
||
|
|
@click.option("--file", "-f", default="resources.txt")
|
||
|
|
@click.option("--timeout", "-t", default=3600, help="Seconds to wait per resource")
|
||
|
|
def start(labels: List[str], file: str, timeout: int):
|
||
|
|
asyncio.run(run_start_sequence(labels, file, timeout))
|
||
|
|
|
||
|
|
|
||
|
|
async def run_start_sequence(labels: List[str], file_path: str, timeout: int) -> None:
|
||
|
|
await config.load_kube_config()
|
||
|
|
|
||
|
|
async with ApiClient() as api_client:
|
||
|
|
apps_v1 = client.AppsV1Api(api_client)
|
||
|
|
|
||
|
|
# 1. Load the entire snapshot into memory for filtering
|
||
|
|
try:
|
||
|
|
with open(file_path, "r") as f:
|
||
|
|
# Format: Kind Namespace Name Replicas
|
||
|
|
snapshot_lines = [line.strip().split() for line in f if line.strip()]
|
||
|
|
except FileNotFoundError:
|
||
|
|
logger.error(f"Snapshot file {file_path} not found.")
|
||
|
|
return
|
||
|
|
|
||
|
|
# 2. Iterate through labels in the order provided
|
||
|
|
for label in labels:
|
||
|
|
logger.info(f"🚀 Starting Tier: {label}")
|
||
|
|
|
||
|
|
# Find resources in this tier by querying K8s for the label
|
||
|
|
# then matching against our snapshot file data
|
||
|
|
tier_resources = await get_resources_by_label(apps_v1, label)
|
||
|
|
|
||
|
|
# Cross-reference: Only start things that are in BOTH the K8s label query AND our file
|
||
|
|
# This ensures we restore them to the CORRECT previous replica count
|
||
|
|
to_restore = []
|
||
|
|
tier_keys = {(r["ns"], r["name"]) for r in tier_resources}
|
||
|
|
|
||
|
|
for kind, ns, name, reps in snapshot_lines:
|
||
|
|
if (ns, name) in tier_keys:
|
||
|
|
to_restore.append((kind, ns, name, int(reps)))
|
||
|
|
|
||
|
|
if not to_restore:
|
||
|
|
logger.warning(f"No resources found in snapshot for tier: {label}")
|
||
|
|
continue
|
||
|
|
|
||
|
|
# 3. Scale and Wait for this specific tier
|
||
|
|
await process_start_tier(apps_v1, to_restore, timeout, label)
|
||
|
|
|
||
|
|
logger.info("🏁 All tiers started successfully.")
|
||
|
|
|
||
|
|
|
||
|
|
async def get_resources_by_label(api: client.AppsV1Api, label: str) -> List[dict]:
|
||
|
|
"""Helper to find what currently exists in the cluster with this label."""
|
||
|
|
targets = [
|
||
|
|
api.list_deployment_for_all_namespaces,
|
||
|
|
api.list_stateful_set_for_all_namespaces,
|
||
|
|
]
|
||
|
|
found = []
|
||
|
|
for list_func in targets:
|
||
|
|
resp = await list_func(label_selector=label)
|
||
|
|
for item in resp.items:
|
||
|
|
found.append({"ns": item.metadata.namespace, "name": item.metadata.name})
|
||
|
|
return found
|
||
|
|
|
||
|
|
|
||
|
|
async def process_start_tier(
|
||
|
|
api: client.AppsV1Api, resources: list, timeout: int, label: str
|
||
|
|
):
|
||
|
|
active_jobs = set()
|
||
|
|
scale_tasks = []
|
||
|
|
wait_tasks = []
|
||
|
|
|
||
|
|
# Wrapper to track which job finishes
|
||
|
|
async def tracked_wait(kind, ns, name, target, t_out):
|
||
|
|
await wait_for_healthy(api, kind, ns, name, target, t_out)
|
||
|
|
return (ns, name)
|
||
|
|
|
||
|
|
for kind, ns, name, reps in resources:
|
||
|
|
active_jobs.add((ns, name))
|
||
|
|
scale_tasks.append(scale_resource(api, kind, ns, name, reps))
|
||
|
|
wait_tasks.append(tracked_wait(kind, ns, name, reps, timeout))
|
||
|
|
|
||
|
|
# Trigger all scales for this tier
|
||
|
|
await asyncio.gather(*scale_tasks)
|
||
|
|
|
||
|
|
# Monitor health
|
||
|
|
for coro in asyncio.as_completed(wait_tasks):
|
||
|
|
finished_ns, finished_name = await coro
|
||
|
|
active_jobs.discard((finished_ns, finished_name))
|
||
|
|
|
||
|
|
if active_jobs:
|
||
|
|
remaining_ns = sorted({ns for ns, name in active_jobs})
|
||
|
|
logger.info(
|
||
|
|
f"[{label}] Pending Health: {len(active_jobs)} | Namespaces: {', '.join(remaining_ns)}"
|
||
|
|
)
|
||
|
|
|
||
|
|
logger.info(f"✅ Tier [{label}] is healthy.")
|
||
|
|
|
||
|
|
|
||
|
|
if __name__ == "__main__":
|
||
|
|
cli()
|