3677e4aace
FW-01: job_subscriber.py now has on_disconnect callback (5-arg paho v2 signature), reconnect_delay_set(1,16) for exponential backoff, and with_retry-wrapped initial connect (5 attempts). paho loop_start() handles auto-reconnect internally. FW-05: publish_event.py signs payloads with HMAC-SHA256 using auth_token as key (replaces plaintext token in wire). mqtt_common.py adds verify_hmac() helper using hmac.compare_digest (timing-safe). job_subscriber.py validates incoming events via verify_hmac. PoC mode (auth_token=None) skips verification — backward compatible. Reviewed by agy-existing (PASS) and claude-existing (FAIL: on_disconnect 4-arg signature → fixed to 5-arg matching paho v2 CallbackAPIVersion).
567 lines
22 KiB
Python
567 lines
22 KiB
Python
"""Shared MQTT + registry helpers for the tmux-agent-orchestrate-delegate-job skill.
|
|
|
|
Single entry point for:
|
|
- broker configuration (env -> dataclass),
|
|
- paho client construction (auth + TLS + unique client id),
|
|
- monotonic per-job sequence counters,
|
|
- retry-with-exponential-backoff,
|
|
- atomic registry record load/update under an fcntl lock.
|
|
|
|
Requires paho-mqtt >= 2.0 (uses CallbackAPIVersion.VERSION2).
|
|
|
|
This module is the *only* place that talks to the broker config and to the
|
|
raw job record file, so PoC -> production migration touches just env/registry
|
|
values, never code (see references/mqtt-broker-setup.md).
|
|
"""
|
|
from __future__ import annotations
|
|
|
|
import functools
|
|
import hashlib
|
|
import hmac
|
|
import json
|
|
import logging
|
|
import os
|
|
import tempfile
|
|
import time
|
|
import uuid
|
|
from contextlib import contextmanager
|
|
from dataclasses import asdict, dataclass
|
|
from pathlib import Path
|
|
from typing import Any, Callable, Dict, Iterable, List, Optional
|
|
|
|
import paho.mqtt.client as mqtt
|
|
|
|
logger = logging.getLogger("delegate_job.mqtt_common")
|
|
|
|
# --------------------------------------------------------------------------
|
|
# Constants
|
|
# --------------------------------------------------------------------------
|
|
SCHEMA_VERSION = 1
|
|
DEFAULT_REGISTRY_DIR = ".hermes/jobs"
|
|
DEFAULT_TOPIC_ROOT = "python/mqtt/jobs"
|
|
LOCK_FILENAME = ".lock"
|
|
|
|
# Persistent audit-log layout: .hermes/delegate_job_logs/<job_id>/{meta,events,status}.
|
|
# This is a *separate* artifact from the registry: the registry is the live job
|
|
# record (mutated in place), the audit log is an append-only history that
|
|
# survives even if the registry dir is cleaned up.
|
|
META_FILENAME = "meta.json"
|
|
EVENTS_FILENAME = "events.ndjson"
|
|
STATUS_FILENAME = "status.json"
|
|
|
|
|
|
def _default_logs_dir() -> str:
|
|
"""Audit-log root. Overridable with ``DELEGATE_JOB_LOGS_DIR``; otherwise
|
|
``<cwd>/.hermes/delegate_job_logs`` — we keep audit logs next to the
|
|
live registry (``.hermes/jobs/``) so the two runtime artifacts sit
|
|
under the same parent dir and follow the same ``.gitignore`` rule.
|
|
The cwd of whichever process emits events (the bash wrapper and
|
|
scripts) is used as the anchor."""
|
|
env = os.environ.get("DELEGATE_JOB_LOGS_DIR")
|
|
if env and env.strip():
|
|
return env
|
|
return os.path.join(os.getcwd(), ".hermes", "delegate_job_logs")
|
|
|
|
|
|
LOGS_DIR = _default_logs_dir()
|
|
|
|
|
|
# --------------------------------------------------------------------------
|
|
# Broker configuration
|
|
# --------------------------------------------------------------------------
|
|
@dataclass
|
|
class BrokerConfig:
|
|
"""Resolved broker connection settings.
|
|
|
|
PoC defaults target the public HiveMQ broker. Production overrides arrive
|
|
either from environment variables or from a job record's ``broker.*`` block
|
|
(see ``broker_config_from_job``).
|
|
"""
|
|
|
|
host: str = "broker.hivemq.com"
|
|
port: int = 1883
|
|
tls: bool = False
|
|
username: Optional[str] = None
|
|
password: Optional[str] = None
|
|
client_id_prefix: str = "hermes"
|
|
# TLS material (only consulted when tls is True).
|
|
ca_certs: Optional[str] = None
|
|
certfile: Optional[str] = None
|
|
keyfile: Optional[str] = None
|
|
keepalive: int = 60
|
|
|
|
def to_dict(self) -> Dict[str, Any]:
|
|
return asdict(self)
|
|
|
|
def to_registry_block(self) -> Dict[str, Any]:
|
|
"""The subset that gets persisted into a job record's broker block."""
|
|
return {
|
|
"host": self.host,
|
|
"port": self.port,
|
|
"tls": self.tls,
|
|
"username": self.username,
|
|
"password": self.password,
|
|
}
|
|
|
|
|
|
def _env_bool(name: str, default: bool = False) -> bool:
|
|
raw = os.environ.get(name)
|
|
if raw is None:
|
|
return default
|
|
return raw.strip().lower() in ("1", "true", "yes", "on")
|
|
|
|
|
|
def _env_int(name: str, default: int) -> int:
|
|
raw = os.environ.get(name)
|
|
if raw is None or raw.strip() == "":
|
|
return default
|
|
try:
|
|
return int(raw)
|
|
except ValueError:
|
|
logger.warning("invalid int for %s=%r; using default %d", name, raw, default)
|
|
return default
|
|
|
|
|
|
def broker_config_from_env(overrides: Optional[Dict[str, Any]] = None) -> BrokerConfig:
|
|
"""Build a :class:`BrokerConfig` from environment variables.
|
|
|
|
Recognised vars (all optional, PoC defaults shown):
|
|
MQTT_BROKER (broker.hivemq.com), MQTT_PORT (1883), MQTT_TLS (0),
|
|
MQTT_USERNAME, MQTT_PASSWORD, MQTT_CLIENT_ID_PREFIX (hermes),
|
|
MQTT_CA_CERTS, MQTT_CERTFILE, MQTT_KEYFILE, MQTT_KEEPALIVE (60).
|
|
|
|
``overrides`` (e.g. a job record's broker block) wins over the env values
|
|
for any key it specifies with a non-None value.
|
|
"""
|
|
cfg = BrokerConfig(
|
|
host=os.environ.get("MQTT_BROKER", "broker.hivemq.com"),
|
|
port=_env_int("MQTT_PORT", 1883),
|
|
tls=_env_bool("MQTT_TLS", False),
|
|
username=os.environ.get("MQTT_USERNAME") or None,
|
|
password=os.environ.get("MQTT_PASSWORD") or None,
|
|
client_id_prefix=os.environ.get("MQTT_CLIENT_ID_PREFIX", "hermes"),
|
|
ca_certs=os.environ.get("MQTT_CA_CERTS") or None,
|
|
certfile=os.environ.get("MQTT_CERTFILE") or None,
|
|
keyfile=os.environ.get("MQTT_KEYFILE") or None,
|
|
keepalive=_env_int("MQTT_KEEPALIVE", 60),
|
|
)
|
|
if overrides:
|
|
for key, value in overrides.items():
|
|
if value is not None and hasattr(cfg, key):
|
|
setattr(cfg, key, value)
|
|
return cfg
|
|
|
|
|
|
def broker_config_from_job(job: Dict[str, Any]) -> BrokerConfig:
|
|
"""Resolve broker config for a job: env defaults, then the job's broker.*
|
|
block overrides. This lets ``publish_event.py`` connect from the registry
|
|
alone, while still honouring environment toggles (e.g. MQTT_TLS=1)."""
|
|
return broker_config_from_env(overrides=job.get("broker") or {})
|
|
|
|
|
|
def make_client(role: str, config: Optional[BrokerConfig] = None) -> mqtt.Client:
|
|
"""Return a configured paho ``Client`` (not yet connected).
|
|
|
|
The client id is ``f"{prefix}-{role}-{uuid8}"`` so concurrent publishers /
|
|
subscribers never collide on the broker. Auth and TLS are applied when the
|
|
config supplies them.
|
|
"""
|
|
config = config or broker_config_from_env()
|
|
client_id = f"{config.client_id_prefix}-{role}-{uuid.uuid4().hex[:8]}"
|
|
client = mqtt.Client(
|
|
callback_api_version=mqtt.CallbackAPIVersion.VERSION2,
|
|
client_id=client_id,
|
|
)
|
|
if config.username:
|
|
client.username_pw_set(config.username, config.password)
|
|
if config.tls:
|
|
# If ca_certs is None paho uses the system trust store (good enough for
|
|
# public CAs); a private CA bundle path is passed through unchanged.
|
|
client.tls_set(
|
|
ca_certs=config.ca_certs,
|
|
certfile=config.certfile,
|
|
keyfile=config.keyfile,
|
|
)
|
|
logger.debug("built client id=%s tls=%s host=%s", client_id, config.tls, config.host)
|
|
return client
|
|
|
|
|
|
def reason_code_value(rc: Any) -> int:
|
|
"""Normalise a paho v2 connect reason code to an int.
|
|
|
|
paho-mqtt 2.x hands callbacks a ``ReasonCode`` object (not an int); older
|
|
paths may pass a plain int. ``ReasonCode`` exposes ``.value``; 0 == success.
|
|
"""
|
|
return int(getattr(rc, "value", rc))
|
|
|
|
|
|
def verify_hmac(payload: dict, auth_token: Optional[str]) -> bool:
|
|
"""Verify HMAC-SHA256 signature. Returns True if valid or no token set."""
|
|
if not auth_token:
|
|
return True # PoC mode — no auth
|
|
sig = payload.get("data", {}).get("hmac_sig")
|
|
if not sig:
|
|
return False
|
|
sign_payload = {k: v for k, v in payload.items() if k != "data"}
|
|
sign_payload["data"] = {k: v for k, v in payload.get("data", {}).items() if k != "hmac_sig"}
|
|
msg = json.dumps(sign_payload, sort_keys=True, separators=(",", ":")).encode()
|
|
expected = hmac.new(auth_token.encode(), msg, hashlib.sha256).hexdigest()
|
|
return hmac.compare_digest(sig, expected)
|
|
|
|
|
|
def topic_prefix_for(job_id: str, root: str = DEFAULT_TOPIC_ROOT) -> str:
|
|
return f"{root}/{job_id}"
|
|
|
|
|
|
def events_topic_for(job_id: str, root: str = DEFAULT_TOPIC_ROOT) -> str:
|
|
return f"{topic_prefix_for(job_id, root)}/events"
|
|
|
|
|
|
# --------------------------------------------------------------------------
|
|
# Registry primitives (single source of truth for raw record I/O)
|
|
# --------------------------------------------------------------------------
|
|
def _job_path(job_id: str, registry_dir: str) -> Path:
|
|
return Path(registry_dir) / f"{job_id}.json"
|
|
|
|
|
|
def _lock_path(registry_dir: str) -> Path:
|
|
return Path(registry_dir) / LOCK_FILENAME
|
|
|
|
|
|
@contextmanager
|
|
def registry_lock(registry_dir: str):
|
|
"""Advisory exclusive lock over the whole registry dir via fcntl.
|
|
|
|
PoC-grade single-host concurrency control. Multiple tmux sessions / scripts
|
|
serialise their read-modify-write of job records through this lock so two
|
|
sessions never claim the same pending job. For multi-host delegation move
|
|
to SQLite WAL (see references/registry.md)."""
|
|
import fcntl # POSIX only; imported lazily so import works on Windows.
|
|
|
|
Path(registry_dir).mkdir(parents=True, exist_ok=True)
|
|
lock_file = _lock_path(registry_dir)
|
|
fh = open(lock_file, "a+")
|
|
try:
|
|
fcntl.flock(fh.fileno(), fcntl.LOCK_EX)
|
|
yield
|
|
finally:
|
|
try:
|
|
fcntl.flock(fh.fileno(), fcntl.LOCK_UN)
|
|
finally:
|
|
fh.close()
|
|
|
|
|
|
def load_job(job_id: str, registry_dir: str = DEFAULT_REGISTRY_DIR) -> Dict[str, Any]:
|
|
"""Load and parse a job record. Raises FileNotFoundError if absent."""
|
|
path = _job_path(job_id, registry_dir)
|
|
if not path.exists():
|
|
raise FileNotFoundError(f"job record not found: {path}")
|
|
with open(path, "r", encoding="utf-8") as fh:
|
|
return json.load(fh)
|
|
|
|
|
|
def _atomic_write_record(job_id: str, registry_dir: str, record: Dict[str, Any]) -> None:
|
|
"""Write a record atomically: temp file in the same dir + os.replace.
|
|
|
|
The rename is atomic on POSIX, so readers never observe a half-written
|
|
file. Callers MUST already hold ``registry_lock`` for read-modify-write
|
|
correctness."""
|
|
Path(registry_dir).mkdir(parents=True, exist_ok=True)
|
|
path = _job_path(job_id, registry_dir)
|
|
fd, tmp = tempfile.mkstemp(dir=str(path.parent), prefix=f".{job_id}.", suffix=".tmp")
|
|
try:
|
|
with os.fdopen(fd, "w", encoding="utf-8") as fh:
|
|
json.dump(record, fh, ensure_ascii=False, indent=2)
|
|
fh.write("\n")
|
|
fh.flush()
|
|
os.fsync(fh.fileno())
|
|
os.replace(tmp, path)
|
|
try:
|
|
os.chmod(path, 0o600)
|
|
except Exception:
|
|
pass
|
|
except BaseException:
|
|
if os.path.exists(tmp):
|
|
os.unlink(tmp)
|
|
raise
|
|
|
|
|
|
def update_job_status(job_id: str, registry_dir: str = DEFAULT_REGISTRY_DIR, **fields: Any) -> Dict[str, Any]:
|
|
"""Atomically merge ``fields`` into a job record under the registry lock.
|
|
|
|
Always refreshes ``updated_at``. Returns the new record. Raises
|
|
FileNotFoundError if the job does not exist.
|
|
|
|
This is the single chokepoint for status writes (both ``registry.update_status``
|
|
and ``publish_event.py``'s status sync route through here), so it also mirrors
|
|
any ``status`` change into the persistent audit log — best-effort, after the
|
|
registry lock is released so a slow/failed log write never blocks the record."""
|
|
with registry_lock(registry_dir):
|
|
record = load_job(job_id, registry_dir)
|
|
old_status = record.get("status")
|
|
record.update(fields)
|
|
record["updated_at"] = _utcnow()
|
|
_atomic_write_record(job_id, registry_dir, record)
|
|
if "status" in fields:
|
|
new_status = record.get("status")
|
|
update_logged_status(job_id, new_status, updated_at=record["updated_at"])
|
|
if old_status != new_status:
|
|
append_event(job_id, {
|
|
"event": "status_changed",
|
|
"from": old_status,
|
|
"to": new_status,
|
|
"timestamp": record["updated_at"],
|
|
})
|
|
return record
|
|
|
|
|
|
def next_seq(job_id: str, registry_dir: str = DEFAULT_REGISTRY_DIR) -> int:
|
|
"""Return the next monotonic sequence number for a job, persisted in the
|
|
record's ``last_seq`` field so it stays consistent across process restarts.
|
|
First call returns 1."""
|
|
with registry_lock(registry_dir):
|
|
record = load_job(job_id, registry_dir)
|
|
seq = int(record.get("last_seq", 0)) + 1
|
|
record["last_seq"] = seq
|
|
record["updated_at"] = _utcnow()
|
|
_atomic_write_record(job_id, registry_dir, record)
|
|
return seq
|
|
|
|
|
|
def _utcnow() -> str:
|
|
"""ISO-8601 UTC timestamp with trailing Z (payload `timestamp` field)."""
|
|
return time.strftime("%Y-%m-%dT%H:%M:%SZ", time.gmtime())
|
|
|
|
|
|
def _utcnow_precise() -> str:
|
|
"""ISO-8601 UTC timestamp with millisecond resolution. Used for the audit
|
|
log's ``logged_at`` so events sort cleanly even within the same second."""
|
|
now = time.time()
|
|
base = time.strftime("%Y-%m-%dT%H:%M:%S", time.gmtime(now))
|
|
return f"{base}.{int((now % 1) * 1000):03d}Z"
|
|
|
|
|
|
# --------------------------------------------------------------------------
|
|
# Persistent audit log (.hermes/delegate_job_logs/<job_id>/...)
|
|
#
|
|
# Every function here is idempotent, concurrency-safe, and *best-effort*: a
|
|
# logging failure is swallowed with a logger.warning and never propagated, so it
|
|
# can never break a publish, a subscribe, or a registry write. stdout is never
|
|
# touched (it is reserved for data output).
|
|
# --------------------------------------------------------------------------
|
|
def job_log_dir(job_id: str, logs_dir: Optional[str] = None) -> Path:
|
|
return Path(logs_dir or LOGS_DIR) / job_id
|
|
|
|
|
|
def job_log_path(job_id: str, kind: str, logs_dir: Optional[str] = None) -> Path:
|
|
"""Path to one audit-log file for a job. ``kind`` is a filename, e.g. the
|
|
module constants META_FILENAME / EVENTS_FILENAME / STATUS_FILENAME."""
|
|
return job_log_dir(job_id, logs_dir) / kind
|
|
|
|
|
|
@contextmanager
|
|
def _file_lock(fh):
|
|
"""Best-effort exclusive lock over a single open file via fcntl, so two
|
|
processes appending to events.ndjson never interleave a line. A no-op where
|
|
fcntl is unavailable (Windows); a short append is atomic enough there."""
|
|
try:
|
|
import fcntl
|
|
except ImportError: # pragma: no cover - non-POSIX
|
|
yield
|
|
return
|
|
fcntl.flock(fh.fileno(), fcntl.LOCK_EX)
|
|
try:
|
|
yield
|
|
finally:
|
|
fcntl.flock(fh.fileno(), fcntl.LOCK_UN)
|
|
|
|
|
|
def append_event(job_id: str, event_dict: Dict[str, Any], logs_dir: Optional[str] = None) -> None:
|
|
"""Append one event as a JSON line to ``<logs>/<job_id>/events.ndjson``.
|
|
|
|
Concurrency-safe (fcntl lock over the file) and best-effort. A millisecond
|
|
``logged_at`` is stamped when the caller did not supply one."""
|
|
try:
|
|
path = job_log_path(job_id, EVENTS_FILENAME, logs_dir)
|
|
path.parent.mkdir(parents=True, exist_ok=True)
|
|
record = dict(event_dict)
|
|
record.setdefault("logged_at", _utcnow_precise())
|
|
line = json.dumps(record, ensure_ascii=False) + "\n"
|
|
with open(path, "a", encoding="utf-8") as fh:
|
|
with _file_lock(fh):
|
|
fh.write(line)
|
|
fh.flush()
|
|
except Exception as exc: # pragma: no cover - best effort
|
|
logger.warning("append_event failed for job %s: %s", job_id, exc)
|
|
|
|
|
|
def update_logged_status(job_id: str, status: str, logs_dir: Optional[str] = None, **extras: Any) -> None:
|
|
"""Rewrite ``<logs>/<job_id>/status.json`` (current status for fast point
|
|
queries) atomically. Best-effort; merges any ``extras``."""
|
|
try:
|
|
path = job_log_path(job_id, STATUS_FILENAME, logs_dir)
|
|
path.parent.mkdir(parents=True, exist_ok=True)
|
|
record: Dict[str, Any] = {"job_id": job_id, "status": status, "updated_at": _utcnow()}
|
|
record.update(extras)
|
|
tmp = path.with_name(path.name + ".tmp")
|
|
with open(tmp, "w", encoding="utf-8") as fh:
|
|
json.dump(record, fh, ensure_ascii=False, indent=2)
|
|
fh.write("\n")
|
|
os.replace(tmp, path)
|
|
except Exception as exc: # pragma: no cover - best effort
|
|
logger.warning("update_logged_status failed for job %s: %s", job_id, exc)
|
|
|
|
|
|
def init_job_log(job_id: str, meta: Dict[str, Any], logs_dir: Optional[str] = None) -> None:
|
|
"""Seed the per-job audit-log dir: write meta.json, status.json, and a first
|
|
``registered`` line in events.ndjson. Idempotent (the ``registered`` line is
|
|
written only when events.ndjson does not yet exist) and best-effort."""
|
|
try:
|
|
d = job_log_dir(job_id, logs_dir)
|
|
d.mkdir(parents=True, exist_ok=True)
|
|
with open(d / META_FILENAME, "w", encoding="utf-8") as fh:
|
|
json.dump(meta, fh, ensure_ascii=False, indent=2)
|
|
fh.write("\n")
|
|
status = meta.get("status", "pending")
|
|
update_logged_status(
|
|
job_id, status, logs_dir=logs_dir,
|
|
created_at=meta.get("created_at"), prompt=meta.get("prompt"),
|
|
)
|
|
events_path = d / EVENTS_FILENAME
|
|
first_time = not events_path.exists()
|
|
events_path.touch(exist_ok=True)
|
|
if first_time:
|
|
append_event(job_id, {
|
|
"event": "registered",
|
|
"status": status,
|
|
"agent": meta.get("agent"),
|
|
"agent_session": meta.get("agent_session"),
|
|
"topic_prefix": meta.get("topic_prefix"),
|
|
"timestamp": meta.get("created_at"),
|
|
}, logs_dir=logs_dir)
|
|
except Exception as exc: # pragma: no cover - best effort
|
|
logger.warning("init_job_log failed for job %s: %s", job_id, exc)
|
|
|
|
|
|
def read_logged_meta(job_id: str, logs_dir: Optional[str] = None) -> Optional[Dict[str, Any]]:
|
|
"""Return a job's audit meta.json (registration snapshot), or None."""
|
|
try:
|
|
with open(job_log_path(job_id, META_FILENAME, logs_dir), "r", encoding="utf-8") as fh:
|
|
return json.load(fh)
|
|
except (OSError, json.JSONDecodeError):
|
|
return None
|
|
|
|
|
|
def read_logged_status(job_id: str, logs_dir: Optional[str] = None) -> Optional[Dict[str, Any]]:
|
|
"""Return a job's current status.json, or None. This is the fast point-query
|
|
file (current status only), separate from the registration-time meta.json."""
|
|
try:
|
|
with open(job_log_path(job_id, STATUS_FILENAME, logs_dir), "r", encoding="utf-8") as fh:
|
|
return json.load(fh)
|
|
except (OSError, json.JSONDecodeError):
|
|
return None
|
|
|
|
|
|
def iter_logged_events(job_id: str, logs_dir: Optional[str] = None):
|
|
"""Yield each parsed event from a job's events.ndjson in file (time) order.
|
|
Malformed lines are skipped with a warning."""
|
|
path = job_log_path(job_id, EVENTS_FILENAME, logs_dir)
|
|
if not path.exists():
|
|
return
|
|
with open(path, "r", encoding="utf-8") as fh:
|
|
for line in fh:
|
|
line = line.strip()
|
|
if not line:
|
|
continue
|
|
try:
|
|
yield json.loads(line)
|
|
except json.JSONDecodeError:
|
|
logger.warning("skipping malformed audit line in %s", path)
|
|
|
|
|
|
def list_logged_jobs(logs_dir: Optional[str] = None) -> List[Dict[str, Any]]:
|
|
"""Return one meta record per job directory under the logs root, oldest
|
|
first. Falls back to ``{"job_id": <dir>}`` when meta.json is missing."""
|
|
base = Path(logs_dir or LOGS_DIR)
|
|
out: List[Dict[str, Any]] = []
|
|
if not base.exists():
|
|
return out
|
|
for d in sorted(base.iterdir()):
|
|
if not d.is_dir():
|
|
continue
|
|
meta = read_logged_meta(d.name, logs_dir) or {"job_id": d.name}
|
|
# Overlay the live status.json so the summary reflects current state, not
|
|
# the registration-time snapshot frozen in meta.json.
|
|
status = read_logged_status(d.name, logs_dir)
|
|
if status:
|
|
meta = {**meta,
|
|
"status": status.get("status", meta.get("status")),
|
|
"updated_at": status.get("updated_at", meta.get("updated_at"))}
|
|
out.append(meta)
|
|
out.sort(key=lambda m: m.get("created_at") or "")
|
|
return out
|
|
|
|
|
|
# --------------------------------------------------------------------------
|
|
# Retry helper
|
|
# --------------------------------------------------------------------------
|
|
def with_retry(
|
|
fn: Optional[Callable] = None,
|
|
*,
|
|
attempts: int = 3,
|
|
base_delay: float = 0.5,
|
|
factor: float = 2.0,
|
|
max_delay: float = 8.0,
|
|
exceptions: Iterable[type] = (Exception,),
|
|
) -> Callable:
|
|
"""Retry ``fn`` with exponential backoff.
|
|
|
|
Usable two ways::
|
|
|
|
result = with_retry(do_publish, attempts=3)() # wrap-and-call
|
|
@with_retry(attempts=5, base_delay=1.0) # decorator
|
|
def do_publish(): ...
|
|
|
|
Re-raises the last exception once ``attempts`` is exhausted.
|
|
"""
|
|
exc_tuple = tuple(exceptions)
|
|
|
|
def decorate(func: Callable) -> Callable:
|
|
@functools.wraps(func)
|
|
def wrapper(*args: Any, **kwargs: Any) -> Any:
|
|
delay = base_delay
|
|
last_exc: Optional[BaseException] = None
|
|
for attempt in range(1, attempts + 1):
|
|
try:
|
|
return func(*args, **kwargs)
|
|
except exc_tuple as exc:
|
|
last_exc = exc
|
|
if attempt >= attempts:
|
|
break
|
|
logger.warning(
|
|
"attempt %d/%d failed: %s; retrying in %.1fs",
|
|
attempt, attempts, exc, delay,
|
|
)
|
|
time.sleep(delay)
|
|
delay = min(delay * factor, max_delay)
|
|
assert last_exc is not None
|
|
raise last_exc
|
|
|
|
return wrapper
|
|
|
|
if fn is not None:
|
|
return decorate(fn)
|
|
return decorate
|
|
|
|
|
|
def setup_logging(level: int = logging.WARNING) -> None:
|
|
"""Configure root logging to stderr. stdout is reserved for data output
|
|
(subscriber event lines, registry ids)."""
|
|
import sys
|
|
|
|
logging.basicConfig(
|
|
level=level,
|
|
stream=sys.stderr,
|
|
format="%(asctime)s %(levelname)s %(name)s: %(message)s",
|
|
)
|