From f9cf953de142dc4de52fad48ffb49b2ca9ca1688 Mon Sep 17 00:00:00 2001 From: che-pj Date: Sat, 30 Aug 2025 09:32:44 +0200 Subject: [PATCH] feat(api): add per-task sse throttling and batching for robust updates --- routes/system/progress.py | 99 +++++++++++++++++++++++++++++------ routes/utils/celery_config.py | 3 +- 2 files changed, 85 insertions(+), 17 deletions(-) diff --git a/routes/system/progress.py b/routes/system/progress.py index f5d5d78..e2a6cc0 100755 --- a/routes/system/progress.py +++ b/routes/system/progress.py @@ -8,7 +8,7 @@ from typing import Set, Optional import redis import threading -from routes.utils.celery_config import REDIS_URL +from routes.utils.celery_config import REDIS_URL, get_config_params from routes.utils.celery_tasks import ( get_task_info, @@ -37,6 +37,11 @@ router = APIRouter() class SSEBroadcaster: def __init__(self): self.clients: Set[asyncio.Queue] = set() + # Per-task throttling/batching/deduplication state + self._task_state = {} # task_id -> dict with last_sent, last_event, last_send_time, scheduled_handle + # Load configurable interval + config = get_config_params() + self.sse_update_interval = float(config.get("sseUpdateIntervalSeconds", 1)) async def add_client(self, queue: asyncio.Queue): """Add a new SSE client""" @@ -49,43 +54,105 @@ class SSEBroadcaster: logger.debug(f"SSE: Client disconnected (total: {len(self.clients)})") async def broadcast_event(self, event_data: dict): - """Broadcast an event to all connected clients""" - logger.debug( - f"SSE Broadcaster: Attempting to broadcast to {len(self.clients)} clients" - ) - + """ + Throttle, batch, and deduplicate SSE events per task. + Only emit at most 1 update/sec per task, aggregate within window, suppress redundant updates. + """ if not self.clients: logger.debug("SSE Broadcaster: No clients connected, skipping broadcast") return - # Add global task counts right before broadcasting - this is the single source of truth + # Defensive: always work with a list of tasks + tasks = event_data.get("tasks", []) + if not isinstance(tasks, list): + tasks = [tasks] + + # For each task, throttle/batch/dedupe + for task in tasks: + task_id = task.get("task_id") + if not task_id: + continue + + now = time.time() + state = self._task_state.setdefault(task_id, { + "last_sent": None, + "last_event": None, + "last_send_time": 0, + "scheduled_handle": None, + }) + + # Deduplication: if event is identical to last sent, skip + if state["last_sent"] is not None and self._events_equal(state["last_sent"], task): + logger.debug(f"SSE: Deduped event for task {task_id}") + continue + + # Throttling: if within interval, batch (store as last_event, schedule send) + elapsed = now - state["last_send_time"] + if elapsed < self.sse_update_interval: + state["last_event"] = task + if state["scheduled_handle"] is None: + delay = self.sse_update_interval - elapsed + loop = asyncio.get_event_loop() + state["scheduled_handle"] = loop.call_later( + delay, lambda: asyncio.create_task(self._send_batched_event(task_id)) + ) + continue + + # Otherwise, send immediately + await self._send_event(task_id, task) + state["last_send_time"] = now + state["last_sent"] = task + state["last_event"] = None + if state["scheduled_handle"]: + state["scheduled_handle"].cancel() + state["scheduled_handle"] = None + + async def _send_batched_event(self, task_id): + state = self._task_state.get(task_id) + if not state or not state["last_event"]: + return + await self._send_event(task_id, state["last_event"]) + state["last_send_time"] = time.time() + state["last_sent"] = state["last_event"] + state["last_event"] = None + state["scheduled_handle"] = None + + async def _send_event(self, task_id, task): + # Compose event_data for this task + event_data = { + "tasks": [task], + "current_timestamp": time.time(), + "change_type": "update", + } enhanced_event_data = add_global_task_counts_to_event(event_data.copy()) event_json = json.dumps(enhanced_event_data) sse_data = f"data: {event_json}\n\n" - logger.debug( - f"SSE Broadcaster: Broadcasting event: {enhanced_event_data.get('change_type', 'unknown')} with {enhanced_event_data.get('active_tasks', 0)} active tasks" - ) - - # Send to all clients, remove disconnected ones disconnected = set() sent_count = 0 for client_queue in self.clients.copy(): try: await client_queue.put(sse_data) sent_count += 1 - logger.debug("SSE: Successfully sent to client queue") except Exception as e: logger.error(f"SSE: Failed to send to client: {e}") disconnected.add(client_queue) - - # Clean up disconnected clients for client in disconnected: self.clients.discard(client) logger.debug( - f"SSE Broadcaster: Successfully sent to {sent_count} clients, removed {len(disconnected)} disconnected clients" + f"SSE Broadcaster: Sent throttled/batched event for task {task_id} to {sent_count} clients" ) + def _events_equal(self, a, b): + # Compare two task dicts for deduplication (ignore timestamps) + if not isinstance(a, dict) or not isinstance(b, dict): + return False + a_copy = dict(a) + b_copy = dict(b) + a_copy.pop("timestamp", None) + b_copy.pop("timestamp", None) + return a_copy == b_copy + # Global broadcaster instance sse_broadcaster = SSEBroadcaster() diff --git a/routes/utils/celery_config.py b/routes/utils/celery_config.py index 83814fd..e97c24f 100644 --- a/routes/utils/celery_config.py +++ b/routes/utils/celery_config.py @@ -52,6 +52,7 @@ DEFAULT_MAIN_CONFIG = { "watch": {}, "realTimeMultiplier": 0, "padNumberWidth": 3, + "sseUpdateIntervalSeconds": 1, # Configurable SSE update interval (default: 1s) } @@ -188,7 +189,7 @@ task_annotations = { "rate_limit": f"{MAX_CONCURRENT_DL}/m", }, "routes.utils.celery_tasks.trigger_sse_update_task": { - "rate_limit": "500/m", # Allow high rate for real-time SSE updates + "rate_limit": "60/m", # Throttle to 1 update/sec per task (matches SSE throttle) "default_retry_delay": 1, # Quick retry for SSE updates "max_retries": 1, # Limited retries for best-effort delivery "ignore_result": True, # Don't store results for SSE tasks