feat(api): add per-task sse throttling and batching for robust updates

This commit is contained in:
che-pj
2025-08-30 09:32:44 +02:00
parent f800251de1
commit f9cf953de1
2 changed files with 85 additions and 17 deletions

View File

@@ -8,7 +8,7 @@ from typing import Set, Optional
import redis
import threading
from routes.utils.celery_config import REDIS_URL
from routes.utils.celery_config import REDIS_URL, get_config_params
from routes.utils.celery_tasks import (
get_task_info,
@@ -37,6 +37,11 @@ router = APIRouter()
class SSEBroadcaster:
def __init__(self):
self.clients: Set[asyncio.Queue] = set()
# Per-task throttling/batching/deduplication state
self._task_state = {} # task_id -> dict with last_sent, last_event, last_send_time, scheduled_handle
# Load configurable interval
config = get_config_params()
self.sse_update_interval = float(config.get("sseUpdateIntervalSeconds", 1))
async def add_client(self, queue: asyncio.Queue):
"""Add a new SSE client"""
@@ -49,43 +54,105 @@ class SSEBroadcaster:
logger.debug(f"SSE: Client disconnected (total: {len(self.clients)})")
async def broadcast_event(self, event_data: dict):
"""Broadcast an event to all connected clients"""
logger.debug(
f"SSE Broadcaster: Attempting to broadcast to {len(self.clients)} clients"
)
"""
Throttle, batch, and deduplicate SSE events per task.
Only emit at most 1 update/sec per task, aggregate within window, suppress redundant updates.
"""
if not self.clients:
logger.debug("SSE Broadcaster: No clients connected, skipping broadcast")
return
# Add global task counts right before broadcasting - this is the single source of truth
# Defensive: always work with a list of tasks
tasks = event_data.get("tasks", [])
if not isinstance(tasks, list):
tasks = [tasks]
# For each task, throttle/batch/dedupe
for task in tasks:
task_id = task.get("task_id")
if not task_id:
continue
now = time.time()
state = self._task_state.setdefault(task_id, {
"last_sent": None,
"last_event": None,
"last_send_time": 0,
"scheduled_handle": None,
})
# Deduplication: if event is identical to last sent, skip
if state["last_sent"] is not None and self._events_equal(state["last_sent"], task):
logger.debug(f"SSE: Deduped event for task {task_id}")
continue
# Throttling: if within interval, batch (store as last_event, schedule send)
elapsed = now - state["last_send_time"]
if elapsed < self.sse_update_interval:
state["last_event"] = task
if state["scheduled_handle"] is None:
delay = self.sse_update_interval - elapsed
loop = asyncio.get_event_loop()
state["scheduled_handle"] = loop.call_later(
delay, lambda: asyncio.create_task(self._send_batched_event(task_id))
)
continue
# Otherwise, send immediately
await self._send_event(task_id, task)
state["last_send_time"] = now
state["last_sent"] = task
state["last_event"] = None
if state["scheduled_handle"]:
state["scheduled_handle"].cancel()
state["scheduled_handle"] = None
async def _send_batched_event(self, task_id):
state = self._task_state.get(task_id)
if not state or not state["last_event"]:
return
await self._send_event(task_id, state["last_event"])
state["last_send_time"] = time.time()
state["last_sent"] = state["last_event"]
state["last_event"] = None
state["scheduled_handle"] = None
async def _send_event(self, task_id, task):
# Compose event_data for this task
event_data = {
"tasks": [task],
"current_timestamp": time.time(),
"change_type": "update",
}
enhanced_event_data = add_global_task_counts_to_event(event_data.copy())
event_json = json.dumps(enhanced_event_data)
sse_data = f"data: {event_json}\n\n"
logger.debug(
f"SSE Broadcaster: Broadcasting event: {enhanced_event_data.get('change_type', 'unknown')} with {enhanced_event_data.get('active_tasks', 0)} active tasks"
)
# Send to all clients, remove disconnected ones
disconnected = set()
sent_count = 0
for client_queue in self.clients.copy():
try:
await client_queue.put(sse_data)
sent_count += 1
logger.debug("SSE: Successfully sent to client queue")
except Exception as e:
logger.error(f"SSE: Failed to send to client: {e}")
disconnected.add(client_queue)
# Clean up disconnected clients
for client in disconnected:
self.clients.discard(client)
logger.debug(
f"SSE Broadcaster: Successfully sent to {sent_count} clients, removed {len(disconnected)} disconnected clients"
f"SSE Broadcaster: Sent throttled/batched event for task {task_id} to {sent_count} clients"
)
def _events_equal(self, a, b):
# Compare two task dicts for deduplication (ignore timestamps)
if not isinstance(a, dict) or not isinstance(b, dict):
return False
a_copy = dict(a)
b_copy = dict(b)
a_copy.pop("timestamp", None)
b_copy.pop("timestamp", None)
return a_copy == b_copy
# Global broadcaster instance
sse_broadcaster = SSEBroadcaster()

View File

@@ -52,6 +52,7 @@ DEFAULT_MAIN_CONFIG = {
"watch": {},
"realTimeMultiplier": 0,
"padNumberWidth": 3,
"sseUpdateIntervalSeconds": 1, # Configurable SSE update interval (default: 1s)
}
@@ -188,7 +189,7 @@ task_annotations = {
"rate_limit": f"{MAX_CONCURRENT_DL}/m",
},
"routes.utils.celery_tasks.trigger_sse_update_task": {
"rate_limit": "500/m", # Allow high rate for real-time SSE updates
"rate_limit": "60/m", # Throttle to 1 update/sec per task (matches SSE throttle)
"default_retry_delay": 1, # Quick retry for SSE updates
"max_retries": 1, # Limited retries for best-effort delivery
"ignore_result": True, # Don't store results for SSE tasks