From 999da3048a6dba83501cee518330229dc2c8e95b Mon Sep 17 00:00:00 2001 From: Xoconoch Date: Sat, 2 Aug 2025 18:59:00 -0600 Subject: [PATCH] This may be the cleanest this shit of a project has ever been --- routes/prgs.py | 748 +++++++++++--------- routes/utils/celery_config.py | 23 + routes/utils/celery_manager.py | 5 +- routes/utils/celery_tasks.py | 76 +- spotizerr-ui/src/components/Queue.tsx | 99 +-- spotizerr-ui/src/contexts/QueueProvider.tsx | 49 +- spotizerr-ui/src/contexts/queue-context.ts | 21 +- 7 files changed, 621 insertions(+), 400 deletions(-) diff --git a/routes/prgs.py b/routes/prgs.py index 9ac9c1a..bb9c601 100755 --- a/routes/prgs.py +++ b/routes/prgs.py @@ -4,6 +4,7 @@ import logging import time import json import asyncio +from typing import Dict, Set from routes.utils.celery_tasks import ( get_task_info, @@ -19,6 +20,206 @@ logger = logging.getLogger(__name__) router = APIRouter() +# Global SSE Event Broadcaster +class SSEBroadcaster: + def __init__(self): + self.clients: Set[asyncio.Queue] = set() + + async def add_client(self, queue: asyncio.Queue): + """Add a new SSE client""" + self.clients.add(queue) + logger.info(f"SSE: Client connected (total: {len(self.clients)})") + + async def remove_client(self, queue: asyncio.Queue): + """Remove an SSE client""" + self.clients.discard(queue) + logger.info(f"SSE: Client disconnected (total: {len(self.clients)})") + + async def broadcast_event(self, event_data: dict): + """Broadcast an event to all connected clients""" + logger.debug(f"SSE Broadcaster: Attempting to broadcast to {len(self.clients)} clients") + + if not self.clients: + logger.debug("SSE Broadcaster: No clients connected, skipping broadcast") + return + + # Add global task counts right before broadcasting - this is the single source of truth + enhanced_event_data = add_global_task_counts_to_event(event_data.copy()) + event_json = json.dumps(enhanced_event_data) + sse_data = f"data: {event_json}\n\n" + + logger.debug(f"SSE Broadcaster: Broadcasting event: {enhanced_event_data.get('change_type', 'unknown')} with {enhanced_event_data.get('active_tasks', 0)} active tasks") + + # Send to all clients, remove disconnected ones + disconnected = set() + sent_count = 0 + for client_queue in self.clients.copy(): + try: + await client_queue.put(sse_data) + sent_count += 1 + logger.debug(f"SSE: Successfully sent to client queue") + except Exception as e: + logger.error(f"SSE: Failed to send to client: {e}") + disconnected.add(client_queue) + + # Clean up disconnected clients + for client in disconnected: + self.clients.discard(client) + + logger.info(f"SSE Broadcaster: Successfully sent to {sent_count} clients, removed {len(disconnected)} disconnected clients") + +# Global broadcaster instance +sse_broadcaster = SSEBroadcaster() + +# Redis subscriber for cross-process SSE events +import redis +import threading +from routes.utils.celery_config import REDIS_URL + +# Redis client for SSE pub/sub +sse_redis_client = redis.Redis.from_url(REDIS_URL) + +def start_sse_redis_subscriber(): + """Start Redis subscriber to listen for SSE events from Celery workers""" + def redis_subscriber_thread(): + try: + pubsub = sse_redis_client.pubsub() + pubsub.subscribe("sse_events") + logger.info("SSE Redis Subscriber: Started listening for events") + + for message in pubsub.listen(): + if message['type'] == 'message': + try: + event_data = json.loads(message['data'].decode('utf-8')) + event_type = event_data.get('event_type', 'unknown') + task_id = event_data.get('task_id', 'unknown') + + logger.debug(f"SSE Redis Subscriber: Received {event_type} for task {task_id}") + + # Handle different event types + if event_type == 'progress_update': + # Transform callback data into task format expected by frontend + loop = asyncio.new_event_loop() + asyncio.set_event_loop(loop) + try: + broadcast_data = loop.run_until_complete(transform_callback_to_task_format(task_id, event_data)) + if broadcast_data: + loop.run_until_complete(sse_broadcaster.broadcast_event(broadcast_data)) + logger.debug(f"SSE Redis Subscriber: Broadcasted callback to {len(sse_broadcaster.clients)} clients") + finally: + loop.close() + elif event_type == 'summary_update': + # Task summary update - use existing trigger_sse_update logic + loop = asyncio.new_event_loop() + asyncio.set_event_loop(loop) + try: + loop.run_until_complete(trigger_sse_update(task_id, event_data.get('reason', 'update'))) + logger.debug(f"SSE Redis Subscriber: Processed summary update for {task_id}") + finally: + loop.close() + else: + # Unknown event type - broadcast as-is + loop = asyncio.new_event_loop() + asyncio.set_event_loop(loop) + try: + loop.run_until_complete(sse_broadcaster.broadcast_event(event_data)) + logger.debug(f"SSE Redis Subscriber: Broadcasted {event_type} to {len(sse_broadcaster.clients)} clients") + finally: + loop.close() + + except Exception as e: + logger.error(f"SSE Redis Subscriber: Error processing message: {e}", exc_info=True) + + except Exception as e: + logger.error(f"SSE Redis Subscriber: Fatal error: {e}", exc_info=True) + + # Start Redis subscriber in background thread + thread = threading.Thread(target=redis_subscriber_thread, daemon=True) + thread.start() + logger.info("SSE Redis Subscriber: Background thread started") + +async def transform_callback_to_task_format(task_id: str, event_data: dict) -> dict: + """Transform callback event data into the task format expected by frontend""" + try: + # Import here to avoid circular imports + from routes.utils.celery_tasks import get_task_info, get_all_tasks + + # Get task info to build complete task object + task_info = get_task_info(task_id) + if not task_info: + logger.warning(f"SSE Transform: No task info found for {task_id}") + return None + + # Extract callback data + callback_data = event_data.get('callback_data', {}) + + # Build task object in the format expected by frontend + task_object = { + "task_id": task_id, + "original_url": f"http://localhost:7171/api/{task_info.get('download_type', 'track')}/download/{task_info.get('url', '').split('/')[-1] if task_info.get('url') else ''}", + "last_line": callback_data, # This is what frontend expects for callback data + "timestamp": event_data.get('timestamp', time.time()), + "download_type": task_info.get('download_type', 'track'), + "type": task_info.get('type', task_info.get('download_type', 'track')), + "name": task_info.get('name', 'Unknown'), + "artist": task_info.get('artist', ''), + "created_at": task_info.get('created_at'), + } + + # Build minimal event data - global counts will be added at broadcast time + return { + "change_type": "update", # Use "update" so it gets processed by existing frontend logic + "tasks": [task_object], # Frontend expects tasks array + "current_timestamp": time.time(), + "updated_count": 1, + "since_timestamp": time.time(), + "trigger_reason": "callback_update" + } + + except Exception as e: + logger.error(f"SSE Transform: Error transforming callback for task {task_id}: {e}", exc_info=True) + return None + +# Start the Redis subscriber when module loads +start_sse_redis_subscriber() + +async def trigger_sse_update(task_id: str, reason: str = "task_update"): + """Trigger an immediate SSE update for a specific task""" + try: + current_time = time.time() + + # Find the specific task that changed + task_info = get_task_info(task_id) + if not task_info: + logger.warning(f"SSE: Task {task_id} not found for update") + return + + last_status = get_last_task_status(task_id) + + # Create a dummy request for the _build_task_response function + from fastapi import Request + class DummyRequest: + def __init__(self): + self.base_url = "http://localhost:7171" + + dummy_request = DummyRequest() + task_response = _build_task_response(task_info, last_status, task_id, current_time, dummy_request) + + # Create minimal event data - global counts will be added at broadcast time + event_data = { + "tasks": [task_response], + "current_timestamp": current_time, + "since_timestamp": current_time, + "change_type": "realtime", + "trigger_reason": reason + } + + await sse_broadcaster.broadcast_event(event_data) + logger.debug(f"SSE: Broadcast update for task {task_id} (reason: {reason})") + + except Exception as e: + logger.error(f"SSE: Failed to trigger update for task {task_id}: {e}") + # Define active task states using ProgressState constants ACTIVE_TASK_STATES = { ProgressState.INITIALIZING, # "initializing" - task is starting up @@ -29,6 +230,8 @@ ACTIVE_TASK_STATES = { ProgressState.REAL_TIME, # "real_time" - real-time download progress ProgressState.RETRYING, # "retrying" - task is retrying after error "real-time", # "real-time" - real-time download progress (hyphenated version) + ProgressState.QUEUED, # "queued" - task is queued and waiting + "pending", # "pending" - legacy queued status } # Define terminal task states that should be included when recently completed @@ -43,6 +246,7 @@ TERMINAL_TASK_STATES = { def get_task_status_from_last_status(last_status): """ Extract the task status from last_status, checking both possible locations. + Uses improved priority logic to handle real-time downloads correctly. Args: last_status: The last status dict from get_last_task_status() @@ -53,13 +257,20 @@ def get_task_status_from_last_status(last_status): if not last_status: return "unknown" - # Check for status in nested status_info (for real-time downloads) + # For real-time downloads, prioritize status_info.status as it contains the actual progress state status_info = last_status.get("status_info", {}) if isinstance(status_info, dict) and "status" in status_info: - return status_info["status"] + status_info_status = status_info["status"] + # If status_info contains an active status, use it regardless of top-level status + if status_info_status in ACTIVE_TASK_STATES: + return status_info_status - # Fall back to top-level status (for other task types) - return last_status.get("status", "unknown") + # Fall back to top-level status + top_level_status = last_status.get("status", "unknown") + + # If both exist but neither is active, prefer the more recent one (usually top-level) + # For active states, we already handled status_info above + return top_level_status def is_task_active(task_status): @@ -72,9 +283,97 @@ def is_task_active(task_status): Returns: bool: True if the task is active, False otherwise """ + if not task_status or task_status == "unknown": + return False return task_status in ACTIVE_TASK_STATES +def get_global_task_counts(): + """ + Get comprehensive task counts for ALL tasks in Redis. + This is called right before sending SSE events to ensure accurate counts. + + Returns: + dict: Task counts by status + """ + task_counts = { + "active": 0, + "queued": 0, + "completed": 0, + "error": 0, + "cancelled": 0, + "retrying": 0, + "skipped": 0 + } + + try: + # Get ALL tasks from Redis - this is the source of truth + all_tasks = get_all_tasks() + + for task_summary in all_tasks: + task_id = task_summary.get("task_id") + if not task_id: + continue + + task_info = get_task_info(task_id) + if not task_info: + continue + + last_status = get_last_task_status(task_id) + task_status = get_task_status_from_last_status(last_status) + is_active_task = is_task_active(task_status) + + # Categorize tasks by status using ProgressState constants + if task_status == ProgressState.RETRYING: + task_counts["retrying"] += 1 + elif task_status in {ProgressState.QUEUED, "pending"}: + task_counts["queued"] += 1 + elif task_status in {ProgressState.COMPLETE, ProgressState.DONE}: + task_counts["completed"] += 1 + elif task_status == ProgressState.ERROR: + task_counts["error"] += 1 + elif task_status == ProgressState.CANCELLED: + task_counts["cancelled"] += 1 + elif task_status == ProgressState.SKIPPED: + task_counts["skipped"] += 1 + elif is_active_task: + task_counts["active"] += 1 + + logger.debug(f"Global task counts: {task_counts} (total: {len(all_tasks)} tasks)") + + except Exception as e: + logger.error(f"Error getting global task counts: {e}", exc_info=True) + + return task_counts + + +def add_global_task_counts_to_event(event_data): + """ + Add global task counts to any SSE event data right before broadcasting. + This ensures all SSE events have accurate, up-to-date counts of ALL tasks. + + Args: + event_data: The event data dictionary to be sent via SSE + + Returns: + dict: Enhanced event data with global task counts + """ + try: + # Get fresh counts of ALL tasks right before sending + global_task_counts = get_global_task_counts() + + # Add/update the counts in the event data + event_data["task_counts"] = global_task_counts + event_data["active_tasks"] = global_task_counts["active"] + event_data["all_tasks_count"] = sum(global_task_counts.values()) + + return event_data + + except Exception as e: + logger.error(f"Error adding global task counts to SSE event: {e}", exc_info=True) + return event_data + + def _build_error_callback_object(last_status): """ Constructs a structured error callback object based on the last status of a task. @@ -200,20 +499,14 @@ async def get_paginated_tasks(page=1, limit=20, active_only=False, request: Requ """ try: all_tasks = get_all_tasks() + + # Get global task counts + task_counts = get_global_task_counts() + active_tasks = [] other_tasks = [] - # Task categorization counters - task_counts = { - "active": 0, - "queued": 0, - "completed": 0, - "error": 0, - "cancelled": 0, - "retrying": 0, - "skipped": 0 - } - + # Process tasks for pagination and response building for task_summary in all_tasks: task_id = task_summary.get("task_id") if not task_id: @@ -227,22 +520,6 @@ async def get_paginated_tasks(page=1, limit=20, active_only=False, request: Requ task_status = get_task_status_from_last_status(last_status) is_active_task = is_task_active(task_status) - # Categorize tasks by status using ProgressState constants - if task_status == ProgressState.RETRYING: - task_counts["retrying"] += 1 - elif task_status in {ProgressState.QUEUED, "pending"}: # Keep "pending" for backward compatibility - task_counts["queued"] += 1 - elif task_status in {ProgressState.COMPLETE, ProgressState.DONE}: - task_counts["completed"] += 1 - elif task_status == ProgressState.ERROR: - task_counts["error"] += 1 - elif task_status == ProgressState.CANCELLED: - task_counts["cancelled"] += 1 - elif task_status == ProgressState.SKIPPED: - task_counts["skipped"] += 1 - elif is_active_task: - task_counts["active"] += 1 - task_response = _build_task_response(task_info, last_status, task_id, time.time(), request) if is_active_task: @@ -465,21 +742,15 @@ async def get_task_updates(request: Request): # Get all tasks all_tasks = get_all_tasks() - updated_tasks = [] - active_tasks = [] current_time = time.time() - # Task categorization counters - task_counts = { - "active": 0, - "queued": 0, - "completed": 0, - "error": 0, - "cancelled": 0, - "retrying": 0, - "skipped": 0 - } + # Get global task counts + task_counts = get_global_task_counts() + updated_tasks = [] + active_tasks = [] + + # Process tasks for filtering and response building for task_summary in all_tasks: task_id = task_summary.get("task_id") if not task_id: @@ -490,29 +761,11 @@ async def get_task_updates(request: Request): continue last_status = get_last_task_status(task_id) - - # Check if task has been updated since the given timestamp - task_timestamp = last_status.get("timestamp") if last_status else task_info.get("created_at", 0) - - # Determine task status and categorize task_status = get_task_status_from_last_status(last_status) is_active_task = is_task_active(task_status) - # Categorize tasks by status using ProgressState constants - if task_status == ProgressState.RETRYING: - task_counts["retrying"] += 1 - elif task_status in {ProgressState.QUEUED, "pending"}: # Keep "pending" for backward compatibility - task_counts["queued"] += 1 - elif task_status in {ProgressState.COMPLETE, ProgressState.DONE}: - task_counts["completed"] += 1 - elif task_status == ProgressState.ERROR: - task_counts["error"] += 1 - elif task_status == ProgressState.CANCELLED: - task_counts["cancelled"] += 1 - elif task_status == ProgressState.SKIPPED: - task_counts["skipped"] += 1 - elif is_active_task: - task_counts["active"] += 1 + # Check if task has been updated since the given timestamp + task_timestamp = last_status.get("timestamp") if last_status else task_info.get("created_at", 0) # Always include active tasks in updates, apply filtering to others # Also include recently completed/terminal tasks to ensure "done" status gets sent @@ -657,9 +910,7 @@ async def delete_task(task_id: str): async def stream_task_updates(request: Request): """ Stream real-time task updates via Server-Sent Events (SSE). - - This endpoint provides continuous updates for task status changes without polling. - Clients can connect and receive instant notifications when tasks update. + Now uses event-driven architecture for true real-time updates. Query parameters: active_only (bool): If true, only stream active tasks (downloading, processing, etc.) @@ -672,254 +923,110 @@ async def stream_task_updates(request: Request): active_only = request.query_params.get('active_only', '').lower() == 'true' async def event_generator(): - # Track last known state of each task to detect actual changes - last_task_states = {} # task_id -> {"status": str, "timestamp": float, "status_count": int} - last_update_timestamp = time.time() - last_heartbeat = time.time() - heartbeat_interval = 10.0 # Reduced from 30s to 10s for faster connection monitoring - burst_mode_until = 0 # Timestamp until which we stay in burst mode + # Create a queue for this client + client_queue = asyncio.Queue() try: + # Register this client with the broadcaster + logger.info(f"SSE Stream: New client connecting...") + await sse_broadcaster.add_client(client_queue) + logger.info(f"SSE Stream: Client registered successfully, total clients: {len(sse_broadcaster.clients)}") + # Send initial data immediately upon connection - initial_data = await generate_task_update_event(last_update_timestamp, active_only, request) + initial_data = await generate_task_update_event(time.time(), active_only, request) yield initial_data - # Initialize task states from initial data - try: - initial_json = json.loads(initial_data.replace("data: ", "").strip()) - for task in initial_json.get("tasks", []): - task_id = task.get("task_id") - if task_id: - last_task_states[task_id] = { - "status": get_task_status_from_last_status(task.get("last_line")), - "timestamp": task.get("timestamp", last_update_timestamp), - "status_count": task.get("status_count", 0) + # Also send any active tasks as callback-style events to newly connected clients + all_tasks = get_all_tasks() + for task_summary in all_tasks: + task_id = task_summary.get("task_id") + if not task_id: + continue + + task_info = get_task_info(task_id) + if not task_info: + continue + + last_status = get_last_task_status(task_id) + task_status = get_task_status_from_last_status(last_status) + + # Send recent callback data for active or recently completed tasks + if is_task_active(task_status) or (last_status and last_status.get("timestamp", 0) > time.time() - 30): + if last_status and "raw_callback" in last_status: + callback_event = { + "task_id": task_id, + "callback_data": last_status["raw_callback"], + "timestamp": last_status.get("timestamp", time.time()), + "change_type": "callback", + "event_type": "progress_update", + "replay": True # Mark as replay for client } - except: - pass # Continue if initial state parsing fails + event_json = json.dumps(callback_event) + yield f"data: {event_json}\n\n" + logger.info(f"SSE Stream: Sent replay callback for task {task_id}") - last_update_timestamp = time.time() + # Send periodic heartbeats and listen for real-time events + last_heartbeat = time.time() + heartbeat_interval = 30.0 - # Optimized monitoring loop - only send when changes detected while True: try: - current_time = time.time() - - # Get all tasks and detect actual changes - all_tasks = get_all_tasks() - updated_tasks = [] - active_tasks = [] - - # Task categorization counters - task_counts = { - "active": 0, - "queued": 0, - "completed": 0, - "error": 0, - "cancelled": 0, - "retrying": 0, - "skipped": 0 - } - - has_actual_changes = False - current_task_ids = set() - - for task_summary in all_tasks: - task_id = task_summary.get("task_id") - if not task_id: - continue - - current_task_ids.add(task_id) - task_info = get_task_info(task_id) - if not task_info: - continue - - last_status = get_last_task_status(task_id) - task_timestamp = last_status.get("timestamp") if last_status else task_info.get("created_at", 0) - task_status = get_task_status_from_last_status(last_status) - is_active_task = is_task_active(task_status) - status_count = len(get_task_status(task_id)) - - # Categorize tasks by status - if task_status == ProgressState.RETRYING: - task_counts["retrying"] += 1 - elif task_status in {ProgressState.QUEUED, "pending"}: - task_counts["queued"] += 1 - elif task_status in {ProgressState.COMPLETE, ProgressState.DONE}: - task_counts["completed"] += 1 - elif task_status == ProgressState.ERROR: - task_counts["error"] += 1 - elif task_status == ProgressState.CANCELLED: - task_counts["cancelled"] += 1 - elif task_status == ProgressState.SKIPPED: - task_counts["skipped"] += 1 - elif is_active_task: - task_counts["active"] += 1 - - # Check if this task has actually changed - previous_state = last_task_states.get(task_id) - - # Determine if task has meaningful changes - task_changed = False - is_new_task = previous_state is None - just_became_terminal = False - - if is_new_task: - # Include new tasks if they're active OR if they're recently terminal - # (avoid sending old completed/cancelled tasks on connection) - if not (task_status in TERMINAL_TASK_STATES): - task_changed = True - # Trigger burst mode for new active tasks to catch rapid completions - burst_mode_until = current_time + 10.0 # 10 seconds of frequent polling - logger.debug(f"SSE: New active task detected: {task_id} - entering burst mode") - else: - # Check if terminal task is recent (completed within last 30 seconds) - is_recently_terminal = (current_time - task_timestamp) <= 30.0 - if is_recently_terminal: - task_changed = True - logger.info(f"SSE: New recently terminal task detected: {task_id} (status: {task_status}, age: {current_time - task_timestamp:.1f}s)") - else: - logger.debug(f"SSE: Skipping old terminal task: {task_id} (status: {task_status}, age: {current_time - task_timestamp:.1f}s)") - else: - # Check for status changes - status_changed = previous_state["status"] != task_status - # Check for new status updates (more detailed progress) - status_count_changed = previous_state["status_count"] != status_count - # Check for significant timestamp changes (new activity) - significant_timestamp_change = task_timestamp > previous_state["timestamp"] + # Wait for either an event or timeout for heartbeat + try: + event_data = await asyncio.wait_for(client_queue.get(), timeout=heartbeat_interval) + # Send the real-time event + yield event_data + last_heartbeat = time.time() + except asyncio.TimeoutError: + # Send heartbeat if no events for a while + current_time = time.time() + if current_time - last_heartbeat >= heartbeat_interval: + # Generate current task counts for heartbeat + all_tasks = get_all_tasks() + task_counts = {"active": 0, "queued": 0, "completed": 0, "error": 0, "cancelled": 0, "retrying": 0, "skipped": 0} - if status_changed: - task_changed = True - # Check if this is a transition TO terminal state - was_terminal = previous_state["status"] in TERMINAL_TASK_STATES - is_now_terminal = task_status in TERMINAL_TASK_STATES - just_became_terminal = not was_terminal and is_now_terminal + for task_summary in all_tasks: + task_id = task_summary.get("task_id") + if not task_id: + continue + task_info = get_task_info(task_id) + if not task_info: + continue + last_status = get_last_task_status(task_id) + task_status = get_task_status_from_last_status(last_status) - # Extend burst mode on significant status changes - if not is_now_terminal: - burst_mode_until = max(burst_mode_until, current_time + 5.0) # 5 more seconds - - logger.debug(f"SSE: Status changed for {task_id}: {previous_state['status']} -> {task_status}") - if just_became_terminal: - logger.debug(f"SSE: Task {task_id} just became terminal") - elif status_count_changed and significant_timestamp_change and not (task_status in TERMINAL_TASK_STATES): - # Only track progress updates for non-terminal tasks - task_changed = True - logger.debug(f"SSE: Progress update for {task_id}: status_count {previous_state['status_count']} -> {status_count}") - - # Include task if it changed and meets criteria - should_include = False - if task_changed: - # For terminal state tasks, only include if they just became terminal - if task_status in TERMINAL_TASK_STATES: - if just_became_terminal: - should_include = True - has_actual_changes = True - logger.debug(f"SSE: Including terminal task {task_id} (just transitioned)") - # Note: we don't include new terminal tasks (handled above) - else: - # Non-terminal tasks are always included when they change - should_include = True - has_actual_changes = True - elif is_active_task and not active_only: - # For non-active_only streams, include active tasks periodically for frontend state sync - # But only if significant time has passed since last update - if current_time - last_update_timestamp > 10.0: # Every 10 seconds max - should_include = True - - if should_include: - # Update our tracked state - last_task_states[task_id] = { - "status": task_status, - "timestamp": task_timestamp, - "status_count": status_count + if task_status == ProgressState.RETRYING: + task_counts["retrying"] += 1 + elif task_status in {ProgressState.QUEUED, "pending"}: + task_counts["queued"] += 1 + elif task_status in {ProgressState.COMPLETE, ProgressState.DONE}: + task_counts["completed"] += 1 + elif task_status == ProgressState.ERROR: + task_counts["error"] += 1 + elif task_status == ProgressState.CANCELLED: + task_counts["cancelled"] += 1 + elif task_status == ProgressState.SKIPPED: + task_counts["skipped"] += 1 + elif is_task_active(task_status): + task_counts["active"] += 1 + + heartbeat_data = { + "current_timestamp": current_time, + "total_tasks": task_counts["active"] + task_counts["retrying"], + "task_counts": task_counts, + "change_type": "heartbeat" } - # Build response - task_response = _build_task_response(task_info, last_status, task_id, current_time, request) + event_json = json.dumps(heartbeat_data) + yield f"data: {event_json}\n\n" + last_heartbeat = current_time - if is_active_task: - active_tasks.append(task_response) - else: - updated_tasks.append(task_response) - - # Clean up states for tasks that no longer exist - removed_tasks = set(last_task_states.keys()) - current_task_ids - for removed_task_id in removed_tasks: - del last_task_states[removed_task_id] - has_actual_changes = True - logger.debug(f"SSE: Task removed: {removed_task_id}") - - # Send update only if there are actual changes - if has_actual_changes: - all_returned_tasks = active_tasks + updated_tasks - - # Sort by priority (active first, then by creation time) - all_returned_tasks.sort(key=lambda x: ( - 0 if x.get("task_id") in [t["task_id"] for t in active_tasks] else 1, - -x.get("created_at", 0) - )) - - update_data = { - "tasks": all_returned_tasks, - "current_timestamp": current_time, - "total_tasks": task_counts["active"] + task_counts["retrying"], - "all_tasks_count": len(all_tasks), - "task_counts": task_counts, - "active_tasks": len(active_tasks), - "updated_count": len(updated_tasks), - "since_timestamp": last_update_timestamp, - "change_type": "update" - } - - # Send SSE event with update data - event_data = json.dumps(update_data) - yield f"data: {event_data}\n\n" - - # Log details about what was sent - task_statuses = [f"{task.get('task_id', 'unknown')}:{get_task_status_from_last_status(task.get('last_line'))}" for task in all_returned_tasks] - logger.info(f"SSE: Sent {len(active_tasks)} active + {len(updated_tasks)} updated tasks: {task_statuses}") - - last_update_timestamp = current_time - last_heartbeat = current_time - - # Send heartbeat if no updates for a while (keeps connection alive) - elif current_time - last_heartbeat > heartbeat_interval: - heartbeat_data = { - "current_timestamp": current_time, - "total_tasks": task_counts["active"] + task_counts["retrying"], - "task_counts": task_counts, - "change_type": "heartbeat" - } - - event_data = json.dumps(heartbeat_data) - yield f"data: {event_data}\n\n" - - last_heartbeat = current_time - logger.debug("SSE: Sent heartbeat") - - # Responsive polling - much faster for real-time updates - active_task_count = task_counts["active"] + task_counts["retrying"] - - if current_time < burst_mode_until: - # Burst mode: poll every 100ms to catch rapid task completions - await asyncio.sleep(0.1) - elif has_actual_changes or active_task_count > 0: - # When there are changes or active tasks, poll very frequently - await asyncio.sleep(0.2) # 200ms for immediate responsiveness - elif current_time - last_update_timestamp < 30.0: - # For 30 seconds after last update, poll more frequently to catch fast completions - await asyncio.sleep(0.5) # 500ms to catch fast transitions - else: - # Only when truly idle for >30s, use longer interval - await asyncio.sleep(2.0) # 2 seconds max when completely idle - except Exception as e: - logger.error(f"Error in SSE event generation: {e}", exc_info=True) + logger.error(f"Error in SSE event streaming: {e}", exc_info=True) # Send error event and continue error_data = json.dumps({"error": "Internal server error", "timestamp": time.time(), "change_type": "error"}) yield f"data: {error_data}\n\n" - await asyncio.sleep(1) # Wait longer on error + await asyncio.sleep(1) except asyncio.CancelledError: logger.info("SSE client disconnected") @@ -927,6 +1034,9 @@ async def stream_task_updates(request: Request): except Exception as e: logger.error(f"SSE connection error: {e}", exc_info=True) return + finally: + # Clean up - remove client from broadcaster + await sse_broadcaster.remove_client(client_queue) return StreamingResponse( event_generator(), @@ -947,23 +1057,14 @@ async def generate_task_update_event(since_timestamp: float, active_only: bool, This replicates the logic from get_task_updates but for SSE format. """ try: - # Get all tasks + # Get all tasks for filtering all_tasks = get_all_tasks() - updated_tasks = [] - active_tasks = [] current_time = time.time() - # Task categorization counters - task_counts = { - "active": 0, - "queued": 0, - "completed": 0, - "error": 0, - "cancelled": 0, - "retrying": 0, - "skipped": 0 - } + updated_tasks = [] + active_tasks = [] + # Process tasks for filtering only - no counting here for task_summary in all_tasks: task_id = task_summary.get("task_id") if not task_id: @@ -974,29 +1075,11 @@ async def generate_task_update_event(since_timestamp: float, active_only: bool, continue last_status = get_last_task_status(task_id) - - # Check if task has been updated since the given timestamp - task_timestamp = last_status.get("timestamp") if last_status else task_info.get("created_at", 0) - - # Determine task status and categorize task_status = get_task_status_from_last_status(last_status) is_active_task = is_task_active(task_status) - # Categorize tasks by status using ProgressState constants - if task_status == ProgressState.RETRYING: - task_counts["retrying"] += 1 - elif task_status in {ProgressState.QUEUED, "pending"}: - task_counts["queued"] += 1 - elif task_status in {ProgressState.COMPLETE, ProgressState.DONE}: - task_counts["completed"] += 1 - elif task_status == ProgressState.ERROR: - task_counts["error"] += 1 - elif task_status == ProgressState.CANCELLED: - task_counts["cancelled"] += 1 - elif task_status == ProgressState.SKIPPED: - task_counts["skipped"] += 1 - elif is_active_task: - task_counts["active"] += 1 + # Check if task has been updated since the given timestamp + task_timestamp = last_status.get("timestamp") if last_status else task_info.get("created_at", 0) # Always include active tasks in updates, apply filtering to others # Also include recently completed/terminal tasks to ensure "done" status gets sent @@ -1024,16 +1107,15 @@ async def generate_task_update_event(since_timestamp: float, active_only: bool, initial_data = { "tasks": all_returned_tasks, "current_timestamp": current_time, - "total_tasks": task_counts["active"] + task_counts["retrying"], - "all_tasks_count": len(all_tasks), - "task_counts": task_counts, - "active_tasks": len(active_tasks), "updated_count": len(updated_tasks), "since_timestamp": since_timestamp, "initial": True # Mark as initial load } - event_data = json.dumps(initial_data) + # Add global task counts since this bypasses the broadcaster + enhanced_data = add_global_task_counts_to_event(initial_data) + + event_data = json.dumps(enhanced_data) return f"data: {event_data}\n\n" except Exception as e: diff --git a/routes/utils/celery_config.py b/routes/utils/celery_config.py index 35abb9c..c751c4e 100644 --- a/routes/utils/celery_config.py +++ b/routes/utils/celery_config.py @@ -121,6 +121,16 @@ task_default_queue = "downloads" task_default_exchange = "downloads" task_default_routing_key = "downloads" +# Task routing - ensure SSE and utility tasks go to utility_tasks queue +task_routes = { + 'routes.utils.celery_tasks.trigger_sse_update_task': {'queue': 'utility_tasks'}, + 'routes.utils.celery_tasks.cleanup_stale_errors': {'queue': 'utility_tasks'}, + 'routes.utils.celery_tasks.delayed_delete_task_data': {'queue': 'utility_tasks'}, + 'routes.utils.celery_tasks.download_track': {'queue': 'downloads'}, + 'routes.utils.celery_tasks.download_album': {'queue': 'downloads'}, + 'routes.utils.celery_tasks.download_playlist': {'queue': 'downloads'}, +} + # Celery task settings task_serializer = "json" accept_content = ["json"] @@ -141,6 +151,19 @@ task_annotations = { "routes.utils.celery_tasks.download_playlist": { "rate_limit": f"{MAX_CONCURRENT_DL}/m", }, + "routes.utils.celery_tasks.trigger_sse_update_task": { + "rate_limit": "500/m", # Allow high rate for real-time SSE updates + "default_retry_delay": 1, # Quick retry for SSE updates + "max_retries": 1, # Limited retries for best-effort delivery + "ignore_result": True, # Don't store results for SSE tasks + "track_started": False, # Don't track when SSE tasks start + }, + "routes.utils.celery_tasks.cleanup_stale_errors": { + "rate_limit": "10/m", # Moderate rate for cleanup tasks + }, + "routes.utils.celery_tasks.delayed_delete_task_data": { + "rate_limit": "100/m", # Moderate rate for cleanup + }, } # Configure retry settings diff --git a/routes/utils/celery_manager.py b/routes/utils/celery_manager.py index 1fc0504..faebe95 100644 --- a/routes/utils/celery_manager.py +++ b/routes/utils/celery_manager.py @@ -149,8 +149,9 @@ class CeleryManager: else: utility_cmd = self._get_worker_command( queues="utility_tasks,default", # Listen to utility and default - concurrency=3, + concurrency=5, # Increased concurrency for SSE updates and utility tasks worker_name_suffix="utw", # Utility Worker + log_level="ERROR" # Reduce log verbosity for utility worker (only errors) ) logger.info( f"Starting Celery Utility Worker with command: {' '.join(utility_cmd)}" @@ -174,7 +175,7 @@ class CeleryManager: self.utility_log_thread_stdout.start() self.utility_log_thread_stderr.start() logger.info( - f"Celery Utility Worker (PID: {self.utility_worker_process.pid}) started with concurrency 3." + f"Celery Utility Worker (PID: {self.utility_worker_process.pid}) started with concurrency 5." ) if ( diff --git a/routes/utils/celery_tasks.py b/routes/utils/celery_tasks.py index a6268ce..43ceba8 100644 --- a/routes/utils/celery_tasks.py +++ b/routes/utils/celery_tasks.py @@ -2,6 +2,7 @@ import time import json import logging import traceback +import asyncio from celery import Celery, Task, states from celery.signals import ( task_prerun, @@ -49,6 +50,26 @@ celery_app.config_from_object("routes.utils.celery_config") redis_client = redis.Redis.from_url(REDIS_URL) +def trigger_sse_event(task_id: str, reason: str = "status_change"): + """Trigger an SSE event using a dedicated Celery worker task""" + try: + # Submit SSE update task to utility worker queue + # This is non-blocking and more reliable than threads + trigger_sse_update_task.apply_async( + args=[task_id, reason], + queue="utility_tasks", + priority=9 # High priority for real-time updates + ) + # Only log at debug level to reduce verbosity + logger.debug(f"SSE: Submitted SSE update task for {task_id} (reason: {reason})") + + except Exception as e: + logger.error(f"Error submitting SSE update task for task {task_id}: {e}", exc_info=True) + + + + + class ProgressState: """Enum-like class for progress states""" @@ -131,6 +152,10 @@ def store_task_status(task_id, status_data): redis_client.publish( update_channel, json.dumps({"task_id": task_id, "status_id": status_id}) ) + + # Trigger immediate SSE event for real-time frontend updates + trigger_sse_event(task_id, "status_update") + except Exception as e: logger.error(f"Error storing task status: {e}") traceback.print_exc() @@ -611,7 +636,6 @@ class ProgressTrackingTask(Task): logger.debug(f"Task {task_id}: Real-time progress for '{track_name}': {percentage}%") - data["status"] = ProgressState.TRACK_PROGRESS data["song"] = track_name artist = data.get("artist", "Unknown") @@ -648,13 +672,6 @@ class ProgressTrackingTask(Task): # Log at debug level logger.debug(f"Task {task_id} track progress: {track_name} by {artist}: {percent}%") - # Set appropriate status - # data["status"] = ( - # ProgressState.REAL_TIME - # if data.get("status") == "real_time" - # else ProgressState.TRACK_PROGRESS - # ) - def _handle_skipped(self, task_id, data, task_info): """Handle skipped status from deezspot""" @@ -991,6 +1008,10 @@ class ProgressTrackingTask(Task): def task_prerun_handler(task_id=None, task=None, *args, **kwargs): """Signal handler when a task begins running""" try: + # Skip verbose logging for SSE tasks + if task and hasattr(task, 'name') and task.name in ['trigger_sse_update_task']: + return + task_info = get_task_info(task_id) # Update task status to processing @@ -1018,6 +1039,10 @@ def task_postrun_handler( ): """Signal handler when a task finishes""" try: + # Skip verbose logging for SSE tasks + if task and hasattr(task, 'name') and task.name in ['trigger_sse_update_task']: + return + last_status_for_history = get_last_task_status(task_id) if last_status_for_history and last_status_for_history.get("status") in [ ProgressState.COMPLETE, @@ -1607,3 +1632,38 @@ def delayed_delete_task_data(task_id, reason): """ logger.info(f"Executing delayed deletion for task {task_id}. Reason: {reason}") delete_task_data_and_log(task_id, reason) + + +@celery_app.task( + name="trigger_sse_update_task", + queue="utility_tasks", + bind=True +) +def trigger_sse_update_task(self, task_id: str, reason: str = "status_update"): + """ + Dedicated Celery task for triggering SSE task summary updates. + Uses Redis pub/sub to communicate with the main FastAPI process. + """ + try: + # Send task summary update via Redis pub/sub + logger.debug(f"SSE Task: Processing summary update for task {task_id} (reason: {reason})") + + event_data = { + "task_id": task_id, + "reason": reason, + "timestamp": time.time(), + "change_type": "task_summary", + "event_type": "summary_update" + } + + # Use Redis pub/sub for cross-process communication + redis_client.publish("sse_events", json.dumps(event_data)) + logger.debug(f"SSE Task: Published summary update for task {task_id}") + + except Exception as e: + # Only log errors, not success cases + logger.error(f"SSE Task: Failed to publish summary update for task {task_id}: {e}", exc_info=True) + # Don't raise exception to avoid task retry - SSE updates are best-effort + + + diff --git a/spotizerr-ui/src/components/Queue.tsx b/spotizerr-ui/src/components/Queue.tsx index 356e191..7d18cea 100644 --- a/spotizerr-ui/src/components/Queue.tsx +++ b/spotizerr-ui/src/components/Queue.tsx @@ -206,10 +206,10 @@ const CancelledTaskCard = ({ item }: { item: QueueItem }) => { ); }; -const QueueItemCard = ({ item }: { item: QueueItem }) => { +const QueueItemCard = ({ item, cachedStatus }: { item: QueueItem, cachedStatus: string }) => { const { removeItem, cancelItem } = useContext(QueueContext) || {}; - const status = getStatus(item); + const status = cachedStatus; const progress = getProgress(item); const trackInfo = getCurrentTrackInfo(item); const styleInfo = statusStyles[status as keyof typeof statusStyles] || statusStyles.queued; @@ -497,13 +497,19 @@ export const Queue = () => { if (!context || !isVisible) return null; - const hasActive = items.some(item => isActiveStatus(getStatus(item))); - const hasFinished = items.some(item => isTerminalStatus(getStatus(item))); + // Optimize: Calculate status once per item and reuse throughout render + const itemsWithStatus = items.map(item => ({ + ...item, + _cachedStatus: getStatus(item) + })); - // Sort items by priority - const sortedItems = [...items].sort((a, b) => { - const statusA = getStatus(a); - const statusB = getStatus(b); + const hasActive = itemsWithStatus.some(item => isActiveStatus(item._cachedStatus)); + const hasFinished = itemsWithStatus.some(item => isTerminalStatus(item._cachedStatus)); + + // Sort items by priority using cached status + const sortedItems = [...itemsWithStatus].sort((a, b) => { + const statusA = a._cachedStatus; + const statusB = b._cachedStatus; const getPriority = (status: string) => { const priorities = { @@ -581,50 +587,55 @@ export const Queue = () => { style={{ touchAction: isDragging ? 'none' : 'pan-y' }} > {(() => { - const visibleItems = sortedItems.filter(item => !isTerminalStatus(getStatus(item)) || (item.lastCallback && 'timestamp' in item.lastCallback)); + const visibleItems = sortedItems.filter(item => { + const status = item._cachedStatus; + return !isTerminalStatus(status) || + (item.lastCallback && 'timestamp' in item.lastCallback) || + status === "cancelled"; + }); return visibleItems.length === 0 ? ( -
-
- -
-

The queue is empty.

-

Downloads will appear here

+
+
+
- ) : ( - <> +

The queue is empty.

+

Downloads will appear here

+
+ ) : ( + <> {visibleItems.map(item => { - if (getStatus(item) === "cancelled") { + if (item._cachedStatus === "cancelled") { return ; } - return ; + return ; })} - - {/* Loading indicator */} - {isLoadingMore && ( -
-
-
- Loading more tasks... -
+ + {/* Loading indicator */} + {isLoadingMore && ( +
+
+
+ Loading more tasks...
- )} - - {/* Load more button */} - {hasMore && !isLoadingMore && ( -
- -
- )} - +
+ )} + + {/* Load more button */} + {hasMore && !isLoadingMore && ( +
+ +
+ )} + ); })()}
diff --git a/spotizerr-ui/src/contexts/QueueProvider.tsx b/spotizerr-ui/src/contexts/QueueProvider.tsx index 110a01f..15a635c 100644 --- a/spotizerr-ui/src/contexts/QueueProvider.tsx +++ b/spotizerr-ui/src/contexts/QueueProvider.tsx @@ -26,7 +26,7 @@ export function QueueProvider({ children }: { children: ReactNode }) { const reconnectAttempts = useRef(0); const maxReconnectAttempts = 5; const pageSize = 20; - + // Health check for SSE connection const lastHeartbeat = useRef(Date.now()); const healthCheckInterval = useRef(null); @@ -62,15 +62,15 @@ export function QueueProvider({ children }: { children: ReactNode }) { // Handle different callback structures if (task.last_line) { try { - if ("track" in task.last_line) { - name = task.last_line.track.title || name; - artist = task.last_line.track.artists?.[0]?.name || artist; - } else if ("album" in task.last_line) { - name = task.last_line.album.title || name; - artist = task.last_line.album.artists?.map((a: any) => a.name).join(", ") || artist; - } else if ("playlist" in task.last_line) { - name = task.last_line.playlist.title || name; - artist = task.last_line.playlist.owner?.name || artist; + if ("track" in task.last_line) { + name = task.last_line.track.title || name; + artist = task.last_line.track.artists?.[0]?.name || artist; + } else if ("album" in task.last_line) { + name = task.last_line.album.title || name; + artist = task.last_line.album.artists?.map((a: any) => a.name).join(", ") || artist; + } else if ("playlist" in task.last_line) { + name = task.last_line.playlist.title || name; + artist = task.last_line.playlist.owner?.name || artist; } } catch (error) { console.warn(`createQueueItemFromTask: Error parsing callback for task ${task.task_id}:`, error); @@ -157,6 +157,31 @@ export function QueueProvider({ children }: { children: ReactNode }) { try { const data = JSON.parse(event.data); + // Debug logging for all SSE events + console.log("🔄 SSE Event Received:", { + timestamp: new Date().toISOString(), + changeType: data.change_type || "update", + totalTasks: data.total_tasks, + taskCounts: data.task_counts, + tasksCount: data.tasks?.length || 0, + taskIds: data.tasks?.map((t: any) => { + const tempItem = createQueueItemFromTask(t); + const status = getStatus(tempItem); + // Special logging for playlist/album track progress + if (t.last_line?.current_track && t.last_line?.total_tracks) { + return { + id: t.task_id, + status, + type: t.download_type, + track: `${t.last_line.current_track}/${t.last_line.total_tracks}`, + trackStatus: t.last_line.status_info?.status + }; + } + return { id: t.task_id, status, type: t.download_type }; + }) || [], + rawData: data + }); + if (data.error) { console.error("SSE error:", data.error); toast.error("Connection error"); @@ -165,6 +190,7 @@ export function QueueProvider({ children }: { children: ReactNode }) { // Handle different message types from optimized backend const changeType = data.change_type || "update"; + const triggerReason = data.trigger_reason || ""; if (changeType === "heartbeat") { // Heartbeat - just update counts, no task processing @@ -197,7 +223,8 @@ export function QueueProvider({ children }: { children: ReactNode }) { setTotalTasks(calculatedTotal); if (updatedTasks?.length > 0) { - console.log(`SSE: Processing ${updatedTasks.length} task updates`); + const updateType = triggerReason === "callback_update" ? "real-time callback" : "task summary"; + console.log(`SSE: Processing ${updatedTasks.length} ${updateType} updates`); setItems(prev => { // Create improved deduplication maps diff --git a/spotizerr-ui/src/contexts/queue-context.ts b/spotizerr-ui/src/contexts/queue-context.ts index 00d4e5a..e870fdc 100644 --- a/spotizerr-ui/src/contexts/queue-context.ts +++ b/spotizerr-ui/src/contexts/queue-context.ts @@ -53,9 +53,26 @@ export const getStatus = (item: QueueItem): string => { } if (isTrackCallback(item.lastCallback)) { - // For parent downloads, if we're getting track callbacks, the parent is "downloading" + // For parent downloads, check if this is the final track if (item.downloadType === "album" || item.downloadType === "playlist") { - return item.lastCallback.status_info.status === "done" ? "downloading" : "downloading"; + const currentTrack = item.lastCallback.current_track || 1; + const totalTracks = item.lastCallback.total_tracks || 1; + const trackStatus = item.lastCallback.status_info.status; + + // If this is the last track and it's in a terminal state, the parent is done + if (currentTrack >= totalTracks && ["done", "skipped", "error"].includes(trackStatus)) { + console.log(`🎵 Playlist/Album completed: ${item.name} (track ${currentTrack}/${totalTracks}, status: ${trackStatus})`); + return "completed"; + } + + // If track is in terminal state but not the last track, parent is still downloading + if (["done", "skipped", "error"].includes(trackStatus)) { + console.log(`🎵 Playlist/Album progress: ${item.name} (track ${currentTrack}/${totalTracks}, status: ${trackStatus}) - continuing...`); + return "downloading"; + } + + // Track is actively being processed + return "downloading"; } return item.lastCallback.status_info.status; }