from fastapi import APIRouter, HTTPException, Request, Depends from fastapi.responses import StreamingResponse import logging import time import json import asyncio from typing import Set, Optional import redis import threading from routes.utils.celery_config import REDIS_URL, get_config_params from routes.utils.celery_tasks import ( get_task_info, get_task_status, get_last_task_status, get_all_tasks, cancel_task, delete_task_data_and_log, ProgressState, ) # Import authentication dependencies from routes.auth.middleware import ( require_auth_from_state, get_current_user_from_state, User, ) # Configure logging logger = logging.getLogger(__name__) router = APIRouter() # Global SSE Event Broadcaster class SSEBroadcaster: def __init__(self): self.clients: Set[asyncio.Queue] = set() # Per-task throttling/batching/deduplication state self._task_state = {} # task_id -> dict with last_sent, last_event, last_send_time, scheduled_handle # Load configurable interval config = get_config_params() self.sse_update_interval = float(config.get("sseUpdateIntervalSeconds", 1)) async def add_client(self, queue: asyncio.Queue): """Add a new SSE client""" self.clients.add(queue) logger.debug(f"SSE: Client connected (total: {len(self.clients)})") async def remove_client(self, queue: asyncio.Queue): """Remove an SSE client""" self.clients.discard(queue) logger.debug(f"SSE: Client disconnected (total: {len(self.clients)})") async def broadcast_event(self, event_data: dict): """ Throttle, batch, and deduplicate SSE events per task. Only emit at most 1 update/sec per task, aggregate within window, suppress redundant updates. """ if not self.clients: logger.debug("SSE Broadcaster: No clients connected, skipping broadcast") return # Defensive: always work with a list of tasks tasks = event_data.get("tasks", []) if not isinstance(tasks, list): tasks = [tasks] # For each task, throttle/batch/dedupe for task in tasks: task_id = task.get("task_id") if not task_id: continue now = time.time() state = self._task_state.setdefault(task_id, { "last_sent": None, "last_event": None, "last_send_time": 0, "scheduled_handle": None, }) # Deduplication: if event is identical to last sent, skip if state["last_sent"] is not None and self._events_equal(state["last_sent"], task): logger.debug(f"SSE: Deduped event for task {task_id}") continue # Throttling: if within interval, batch (store as last_event, schedule send) elapsed = now - state["last_send_time"] if elapsed < self.sse_update_interval: state["last_event"] = task if state["scheduled_handle"] is None: delay = self.sse_update_interval - elapsed loop = asyncio.get_event_loop() state["scheduled_handle"] = loop.call_later( delay, lambda: asyncio.create_task(self._send_batched_event(task_id)) ) continue # Otherwise, send immediately await self._send_event(task_id, task) state["last_send_time"] = now state["last_sent"] = task state["last_event"] = None if state["scheduled_handle"]: state["scheduled_handle"].cancel() state["scheduled_handle"] = None async def _send_batched_event(self, task_id): state = self._task_state.get(task_id) if not state or not state["last_event"]: return await self._send_event(task_id, state["last_event"]) state["last_send_time"] = time.time() state["last_sent"] = state["last_event"] state["last_event"] = None state["scheduled_handle"] = None async def _send_event(self, task_id, task): # Compose event_data for this task event_data = { "tasks": [task], "current_timestamp": time.time(), "change_type": "update", } enhanced_event_data = add_global_task_counts_to_event(event_data.copy()) event_json = json.dumps(enhanced_event_data) sse_data = f"data: {event_json}\n\n" disconnected = set() sent_count = 0 for client_queue in self.clients.copy(): try: await client_queue.put(sse_data) sent_count += 1 except Exception as e: logger.error(f"SSE: Failed to send to client: {e}") disconnected.add(client_queue) for client in disconnected: self.clients.discard(client) logger.debug( f"SSE Broadcaster: Sent throttled/batched event for task {task_id} to {sent_count} clients" ) def _events_equal(self, a, b): # Compare two task dicts for deduplication (ignore timestamps) if not isinstance(a, dict) or not isinstance(b, dict): return False a_copy = dict(a) b_copy = dict(b) a_copy.pop("timestamp", None) b_copy.pop("timestamp", None) return a_copy == b_copy # Global broadcaster instance sse_broadcaster = SSEBroadcaster() # Redis subscriber for cross-process SSE events # Redis client for SSE pub/sub sse_redis_client = redis.Redis.from_url(REDIS_URL) def start_sse_redis_subscriber(): """Start Redis subscriber to listen for SSE events from Celery workers""" def redis_subscriber_thread(): try: pubsub = sse_redis_client.pubsub() pubsub.subscribe("sse_events") logger.info("SSE Redis Subscriber: Started listening for events") # Create a single event loop for this thread and reuse it loop = asyncio.new_event_loop() asyncio.set_event_loop(loop) for message in pubsub.listen(): if message["type"] == "message": try: event_data = json.loads(message["data"].decode("utf-8")) event_type = event_data.get("event_type", "unknown") task_id = event_data.get("task_id", "unknown") logger.debug( f"SSE Redis Subscriber: Received {event_type} for task {task_id}" ) # Handle different event types if event_type == "progress_update": # Transform callback data into standardized update format expected by frontend standardized = standardize_incoming_event(event_data) if standardized: loop.run_until_complete( sse_broadcaster.broadcast_event(standardized) ) logger.debug( f"SSE Redis Subscriber: Broadcasted standardized progress update to {len(sse_broadcaster.clients)} clients" ) elif event_type == "summary_update": # Task summary update - use standardized trigger # Short-circuit if task no longer exists to avoid expensive processing try: if not get_task_info(task_id): logger.debug( f"SSE Redis Subscriber: summary_update for missing task {task_id}, skipping" ) else: loop.run_until_complete( trigger_sse_update( task_id, event_data.get("reason", "update") ) ) logger.debug( f"SSE Redis Subscriber: Processed summary update for {task_id}" ) except Exception as _e: logger.error( f"SSE Redis Subscriber: Error handling summary_update for {task_id}: {_e}", exc_info=True, ) else: # Unknown event type - attempt to standardize and broadcast standardized = standardize_incoming_event(event_data) if standardized: loop.run_until_complete( sse_broadcaster.broadcast_event(standardized) ) logger.debug( f"SSE Redis Subscriber: Broadcasted standardized {event_type} to {len(sse_broadcaster.clients)} clients" ) except Exception as e: logger.error( f"SSE Redis Subscriber: Error processing message: {e}", exc_info=True, ) except Exception as e: logger.error(f"SSE Redis Subscriber: Fatal error: {e}", exc_info=True) # Start Redis subscriber in background thread thread = threading.Thread(target=redis_subscriber_thread, daemon=True) thread.start() logger.debug("SSE Redis Subscriber: Background thread started") def build_task_object_from_callback( task_id: str, callback_data: dict ) -> Optional[dict]: """Build a standardized task object from callback payload and task info.""" try: task_info = get_task_info(task_id) if not task_info: return None return { "task_id": task_id, "original_url": f"http://localhost:7171/api/{task_info.get('download_type', 'track')}/download/{task_info.get('url', '').split('/')[-1] if task_info.get('url') else ''}", "last_line": callback_data, "timestamp": time.time(), "download_type": task_info.get("download_type", "track"), "type": task_info.get("type", task_info.get("download_type", "track")), "name": task_info.get("name", "Unknown"), "artist": task_info.get("artist", ""), "created_at": task_info.get("created_at"), } except Exception as e: logger.error( f"Error building task object from callback for {task_id}: {e}", exc_info=True, ) return None def standardize_incoming_event(event_data: dict) -> Optional[dict]: """ Convert various incoming event shapes into a standardized SSE payload: { 'change_type': 'update' | 'heartbeat', 'tasks': [...], 'current_timestamp': float, 'trigger_reason': str (optional) } """ try: # Heartbeat passthrough (ensure tasks array exists) if event_data.get("change_type") == "heartbeat": return { "change_type": "heartbeat", "tasks": [], "current_timestamp": time.time(), } # If already has tasks, just coerce change_type if isinstance(event_data.get("tasks"), list): return { "change_type": event_data.get("change_type", "update"), "tasks": event_data["tasks"], "current_timestamp": time.time(), "trigger_reason": event_data.get("trigger_reason"), } # If it's a callback-shaped event callback_data = event_data.get("callback_data") task_id = event_data.get("task_id") if callback_data and task_id: task_obj = build_task_object_from_callback(task_id, callback_data) if task_obj: return { "change_type": "update", "tasks": [task_obj], "current_timestamp": time.time(), "trigger_reason": event_data.get("event_type", "callback_update"), } # Fallback to empty update return { "change_type": "update", "tasks": [], "current_timestamp": time.time(), } except Exception as e: logger.error(f"Failed to standardize incoming event: {e}", exc_info=True) return None async def transform_callback_to_task_format(task_id: str, event_data: dict) -> dict: """Transform callback event data into the task format expected by frontend""" try: # Import here to avoid circular imports from routes.utils.celery_tasks import get_task_info # Get task info to build complete task object task_info = get_task_info(task_id) if not task_info: logger.warning(f"SSE Transform: No task info found for {task_id}") return None # Extract callback data callback_data = event_data.get("callback_data", {}) # Build task object in the format expected by frontend task_object = { "task_id": task_id, "original_url": f"http://localhost:7171/api/{task_info.get('download_type', 'track')}/download/{task_info.get('url', '').split('/')[-1] if task_info.get('url') else ''}", "last_line": callback_data, # This is what frontend expects for callback data "timestamp": event_data.get("timestamp", time.time()), "download_type": task_info.get("download_type", "track"), "type": task_info.get("type", task_info.get("download_type", "track")), "name": task_info.get("name", "Unknown"), "artist": task_info.get("artist", ""), "created_at": task_info.get("created_at"), } # Build minimal event data - global counts will be added at broadcast time return { "change_type": "update", "tasks": [task_object], # Frontend expects tasks array "current_timestamp": time.time(), "updated_count": 1, "since_timestamp": time.time(), "trigger_reason": "callback_update", } except Exception as e: logger.error( f"SSE Transform: Error transforming callback for task {task_id}: {e}", exc_info=True, ) return None # Start the Redis subscriber when module loads start_sse_redis_subscriber() async def trigger_sse_update(task_id: str, reason: str = "task_update"): """Trigger an immediate SSE update for a specific task""" try: current_time = time.time() # Find the specific task that changed task_info = get_task_info(task_id) if not task_info: logger.debug(f"SSE: Task {task_id} not found for update") return last_status = get_last_task_status(task_id) # Create a dummy request for the _build_task_response function class DummyRequest: def __init__(self): self.base_url = "http://localhost:7171" dummy_request = DummyRequest() task_response = _build_task_response( task_info, last_status, task_id, current_time, dummy_request ) # Create standardized event data - global counts will be added at broadcast time event_data = { "tasks": [task_response], "current_timestamp": current_time, "since_timestamp": current_time, "change_type": "update", "trigger_reason": reason, } await sse_broadcaster.broadcast_event(event_data) logger.debug(f"SSE: Broadcast update for task {task_id} (reason: {reason})") except Exception as e: logger.error(f"SSE: Failed to trigger update for task {task_id}: {e}") # Define active task states using ProgressState constants ACTIVE_TASK_STATES = { ProgressState.INITIALIZING, # "initializing" - task is starting up ProgressState.PROCESSING, # "processing" - task is being processed ProgressState.DOWNLOADING, # "downloading" - actively downloading ProgressState.PROGRESS, # "progress" - album/playlist progress updates ProgressState.TRACK_PROGRESS, # "track_progress" - real-time track progress ProgressState.REAL_TIME, # "real_time" - real-time download progress ProgressState.RETRYING, # "retrying" - task is retrying after error "real-time", # "real-time" - real-time download progress (hyphenated version) ProgressState.QUEUED, # "queued" - task is queued and waiting "pending", # "pending" - legacy queued status } # Define terminal task states that should be included when recently completed TERMINAL_TASK_STATES = { ProgressState.COMPLETE, # "complete" - task completed successfully ProgressState.DONE, # "done" - task finished processing ProgressState.ERROR, # "error" - task failed ProgressState.CANCELLED, # "cancelled" - task was cancelled ProgressState.SKIPPED, # "skipped" - task was skipped } def get_task_status_from_last_status(last_status): """ Extract the task status from last_status, checking both possible locations. Uses improved priority logic to handle real-time downloads correctly. Args: last_status: The last status dict from get_last_task_status() Returns: str: The task status string """ if not last_status: return "unknown" # For real-time downloads, prioritize status_info.status as it contains the actual progress state status_info = last_status.get("status_info", {}) if isinstance(status_info, dict) and "status" in status_info: status_info_status = status_info["status"] # If status_info contains an active status, use it regardless of top-level status if status_info_status in ACTIVE_TASK_STATES: return status_info_status # Fall back to top-level status top_level_status = last_status.get("status", "unknown") # If both exist but neither is active, prefer the more recent one (usually top-level) # For active states, we already handled status_info above return top_level_status def is_task_active(task_status): """ Determine if a task is currently active (working/processing). Args: task_status: The status string from the task Returns: bool: True if the task is active, False otherwise """ if not task_status or task_status == "unknown": return False return task_status in ACTIVE_TASK_STATES def get_global_task_counts(): """ Get comprehensive task counts for ALL tasks in Redis. This is called right before sending SSE events to ensure accurate counts. Returns: dict: Task counts by status """ task_counts = { "active": 0, "queued": 0, "completed": 0, "error": 0, "cancelled": 0, "retrying": 0, "skipped": 0, } try: # Get ALL tasks from Redis - this is the source of truth all_tasks = get_all_tasks() for task_summary in all_tasks: task_id = task_summary.get("task_id") if not task_id: continue task_info = get_task_info(task_id) if not task_info: continue last_status = get_last_task_status(task_id) task_status = get_task_status_from_last_status(last_status) is_active_task = is_task_active(task_status) # Categorize tasks by status using ProgressState constants if task_status == ProgressState.RETRYING: task_counts["retrying"] += 1 elif task_status in {ProgressState.QUEUED, "pending"}: task_counts["queued"] += 1 elif task_status in {ProgressState.COMPLETE, ProgressState.DONE}: task_counts["completed"] += 1 elif task_status == ProgressState.ERROR: task_counts["error"] += 1 elif task_status == ProgressState.CANCELLED: task_counts["cancelled"] += 1 elif task_status == ProgressState.SKIPPED: task_counts["skipped"] += 1 elif is_active_task: task_counts["active"] += 1 logger.debug( f"Global task counts: {task_counts} (total: {len(all_tasks)} tasks)" ) except Exception as e: logger.error(f"Error getting global task counts: {e}", exc_info=True) return task_counts def add_global_task_counts_to_event(event_data): """ Add global task counts to any SSE event data right before broadcasting. This ensures all SSE events have accurate, up-to-date counts of ALL tasks. Args: event_data: The event data dictionary to be sent via SSE Returns: dict: Enhanced event data with global task counts """ try: # Get fresh counts of ALL tasks right before sending global_task_counts = get_global_task_counts() # Add/update the counts in the event data event_data["task_counts"] = global_task_counts event_data["active_tasks"] = global_task_counts["active"] event_data["all_tasks_count"] = sum(global_task_counts.values()) # Ensure tasks array is present for schema consistency if "tasks" not in event_data: event_data["tasks"] = [] # Ensure change_type is present if "change_type" not in event_data: event_data["change_type"] = "update" return event_data except Exception as e: logger.error( f"Error adding global task counts to SSE event: {e}", exc_info=True ) return event_data def _build_error_callback_object(last_status): """ Constructs a structured error callback object based on the last status of a task. This conforms to the CallbackObject types in the frontend. """ # The 'type' from the status update corresponds to the download_type (album, playlist, track) download_type = last_status.get("type") name = last_status.get("name") # The 'artist' field from the status may contain artist names or a playlist owner's name artist_or_owner = last_status.get("artist") error_message = last_status.get("error", "An unknown error occurred.") status_info = {"status": "error", "error": error_message} callback_object = {"status_info": status_info} if download_type == "album": callback_object["album"] = { "type": "album", "title": name, "artists": [{"type": "artistAlbum", "name": artist_or_owner}] if artist_or_owner else [], } elif download_type == "playlist": playlist_payload = {"type": "playlist", "title": name} if artist_or_owner: playlist_payload["owner"] = {"type": "user", "name": artist_or_owner} callback_object["playlist"] = playlist_payload elif download_type == "track": callback_object["track"] = { "type": "track", "title": name, "artists": [{"type": "artistTrack", "name": artist_or_owner}] if artist_or_owner else [], } else: # Fallback for unknown types to avoid breaking the client, returning a basic error structure. return { "status_info": status_info, "unstructured_error": True, "details": { "type": download_type, "name": name, "artist_or_owner": artist_or_owner, }, } return callback_object def _build_task_response( task_info, last_status, task_id, current_time, request: Request ): """ Helper function to build a standardized task response object. """ # Dynamically construct original_url dynamic_original_url = "" download_type = task_info.get("download_type") item_url = task_info.get("url") if download_type and item_url: try: item_id = item_url.split("/")[-1] if item_id: base_url = ( str(request.base_url).rstrip("/") if request else "http://localhost:7171" ) dynamic_original_url = ( f"{base_url}/api/{download_type}/download/{item_id}" ) else: logger.warning( f"Could not extract item ID from URL: {item_url} for task {task_id}. Falling back for original_url." ) original_request_obj = task_info.get("original_request", {}) dynamic_original_url = original_request_obj.get("original_url", "") except Exception as e: logger.error( f"Error constructing dynamic original_url for task {task_id}: {e}", exc_info=True, ) original_request_obj = task_info.get("original_request", {}) dynamic_original_url = original_request_obj.get("original_url", "") else: logger.warning( f"Missing download_type ('{download_type}') or item_url ('{item_url}') in task_info for task {task_id}. Falling back for original_url." ) # Auto-delete faulty task data to keep the queue clean try: delete_task_data_and_log( task_id, reason="Auto-cleaned: Missing download_type or url in task_info.", ) # Trigger SSE so clients refresh their task lists try: # Avoid circular import at top-level import asyncio as _asyncio # Fire-and-forget; if no event loop available, ignore loop = _asyncio.get_event_loop() if loop.is_running(): _asyncio.create_task( trigger_sse_update(task_id, "auto_deleted_faulty") ) except Exception: pass except Exception as _e: logger.error(f"Auto-delete failed for faulty task {task_id}: {_e}") original_request_obj = task_info.get("original_request", {}) dynamic_original_url = original_request_obj.get("original_url", "") status_count = len(get_task_status(task_id)) # Determine last_line content if last_status and "raw_callback" in last_status: last_line_content = last_status["raw_callback"] elif last_status and get_task_status_from_last_status(last_status) == "error": last_line_content = _build_error_callback_object(last_status) else: last_line_content = last_status # Normalize created_at to a numeric timestamp created_at_value = task_info.get("created_at") if not isinstance(created_at_value, (int, float)): created_at_value = current_time task_response = { "original_url": dynamic_original_url, "last_line": last_line_content, "timestamp": last_status.get("timestamp") if last_status else current_time, "task_id": task_id, "status_count": status_count, "created_at": created_at_value, "name": task_info.get("name"), "artist": task_info.get("artist"), "type": task_info.get("type"), "download_type": task_info.get("download_type"), } if last_status and last_status.get("summary"): task_response["summary"] = last_status["summary"] return task_response async def get_paginated_tasks( page=1, limit=20, active_only=False, request: Optional[Request] = None ): """ Get paginated list of tasks. """ try: all_tasks = get_all_tasks() # Get global task counts task_counts = get_global_task_counts() active_tasks = [] other_tasks = [] # Process tasks for pagination and response building for task_summary in all_tasks: task_id = task_summary.get("task_id") if not task_id: continue task_info = get_task_info(task_id) if not task_info: continue last_status = get_last_task_status(task_id) task_status = get_task_status_from_last_status(last_status) is_active_task = is_task_active(task_status) task_response = _build_task_response( task_info, last_status, task_id, time.time(), request ) if is_active_task: active_tasks.append(task_response) else: other_tasks.append(task_response) # Sort other tasks by creation time (newest first) other_tasks.sort(key=lambda x: (x.get("created_at") or 0.0), reverse=True) if active_only: paginated_tasks = active_tasks pagination_info = { "page": page, "limit": limit, "total_non_active": 0, "has_more": False, "returned_non_active": 0, } else: # Apply pagination to non-active tasks offset = (page - 1) * limit paginated_other_tasks = other_tasks[offset : offset + limit] paginated_tasks = active_tasks + paginated_other_tasks pagination_info = { "page": page, "limit": limit, "total_non_active": len(other_tasks), "has_more": len(other_tasks) > offset + limit, "returned_non_active": len(paginated_other_tasks), } response = { "tasks": paginated_tasks, "current_timestamp": time.time(), "total_tasks": task_counts["active"] + task_counts["retrying"], # Only active/retrying tasks for counter "all_tasks_count": len(all_tasks), # Total count of all tasks "task_counts": task_counts, # Categorized counts "active_tasks": len(active_tasks), "updated_count": len(paginated_tasks), "pagination": pagination_info, } return response except Exception as e: logger.error(f"Error in get_paginated_tasks: {e}", exc_info=True) raise HTTPException( status_code=500, detail={"error": "Failed to retrieve paginated tasks"} ) # IMPORTANT: Specific routes MUST come before parameterized routes in FastAPI # Otherwise "updates" gets matched as a {task_id} parameter! @router.get("/list") async def list_tasks( request: Request, current_user: User = Depends(require_auth_from_state) ): """ Retrieve a paginated list of all tasks in the system. Returns a detailed list of task objects including status and metadata. Query parameters: page (int): Page number for pagination (default: 1) limit (int): Number of tasks per page (default: 50, max: 100) active_only (bool): If true, only return active tasks (downloading, processing, etc.) """ try: # Get query parameters page = int(request.query_params.get("page", 1)) limit = min(int(request.query_params.get("limit", 50)), 100) # Cap at 100 active_only = request.query_params.get("active_only", "").lower() == "true" tasks = get_all_tasks() active_tasks = [] other_tasks = [] # Task categorization counters task_counts = { "active": 0, "queued": 0, "completed": 0, "error": 0, "cancelled": 0, "retrying": 0, "skipped": 0, } for task_summary in tasks: task_id = task_summary.get("task_id") if not task_id: continue task_info = get_task_info(task_id) if not task_info: continue last_status = get_last_task_status(task_id) task_status = get_task_status_from_last_status(last_status) is_active_task = is_task_active(task_status) # Categorize tasks by status using ProgressState constants if task_status == ProgressState.RETRYING: task_counts["retrying"] += 1 elif task_status in { ProgressState.QUEUED, "pending", }: # Keep "pending" for backward compatibility task_counts["queued"] += 1 elif task_status in {ProgressState.COMPLETE, ProgressState.DONE}: task_counts["completed"] += 1 elif task_status == ProgressState.ERROR: task_counts["error"] += 1 elif task_status == ProgressState.CANCELLED: task_counts["cancelled"] += 1 elif task_status == ProgressState.SKIPPED: task_counts["skipped"] += 1 elif is_active_task: task_counts["active"] += 1 task_response = _build_task_response( task_info, last_status, task_id, time.time(), request ) if is_active_task: active_tasks.append(task_response) else: other_tasks.append(task_response) # Sort other tasks by creation time (newest first) other_tasks.sort(key=lambda x: (x.get("created_at") or 0.0), reverse=True) if active_only: # Return only active tasks without pagination response_tasks = active_tasks pagination_info = { "page": page, "limit": limit, "total_items": len(active_tasks), "total_pages": 1, "has_more": False, } else: # Apply pagination to non-active tasks and combine with active tasks offset = (page - 1) * limit # Always include active tasks at the top if page == 1: # For first page, include active tasks + first batch of other tasks available_space = limit - len(active_tasks) paginated_other_tasks = other_tasks[: max(0, available_space)] response_tasks = active_tasks + paginated_other_tasks else: # For subsequent pages, only include other tasks # Adjust offset to account for active tasks shown on first page adjusted_offset = offset - len(active_tasks) if adjusted_offset < 0: adjusted_offset = 0 paginated_other_tasks = other_tasks[ adjusted_offset : adjusted_offset + limit ] response_tasks = paginated_other_tasks total_items = len(active_tasks) + len(other_tasks) total_pages = ((total_items - 1) // limit) + 1 if total_items > 0 else 1 pagination_info = { "page": page, "limit": limit, "total_items": total_items, "total_pages": total_pages, "has_more": page < total_pages, "active_tasks": len(active_tasks), "total_other_tasks": len(other_tasks), } response = { "tasks": response_tasks, "pagination": pagination_info, "total_tasks": task_counts["active"] + task_counts["retrying"], # Only active/retrying tasks for counter "all_tasks_count": len(tasks), # Total count of all tasks "task_counts": task_counts, # Categorized counts "active_tasks": len(active_tasks), "timestamp": time.time(), } return response except Exception as e: logger.error(f"Error in /api/prgs/list: {e}", exc_info=True) raise HTTPException( status_code=500, detail={"error": "Failed to retrieve task list"} ) @router.get("/updates") async def get_task_updates( request: Request, current_user: User = Depends(require_auth_from_state) ): """ Retrieve only tasks that have been updated since the specified timestamp. This endpoint is optimized for polling to reduce unnecessary data transfer. Query parameters: since (float): Unix timestamp - only return tasks updated after this time page (int): Page number for pagination (default: 1) limit (int): Number of queued/completed tasks per page (default: 20, max: 100) active_only (bool): If true, only return active tasks (downloading, processing, etc.) Returns: JSON object containing: - tasks: Array of updated task objects - current_timestamp: Current server timestamp for next poll - total_tasks: Total number of tasks in system - active_tasks: Number of active tasks - pagination: Pagination info for queued/completed tasks """ try: # Get query parameters since_param = request.query_params.get("since") page = int(request.query_params.get("page", 1)) limit = min(int(request.query_params.get("limit", 20)), 100) # Cap at 100 active_only = request.query_params.get("active_only", "").lower() == "true" if not since_param: # If no 'since' parameter, return paginated tasks (fallback behavior) response = await get_paginated_tasks(page, limit, active_only, request) return response try: since_timestamp = float(since_param) except (ValueError, TypeError): raise HTTPException( status_code=400, detail={"error": "Invalid 'since' timestamp format"} ) # Get all tasks all_tasks = get_all_tasks() current_time = time.time() # Get global task counts task_counts = get_global_task_counts() updated_tasks = [] active_tasks = [] # Process tasks for filtering and response building for task_summary in all_tasks: task_id = task_summary.get("task_id") if not task_id: continue task_info = get_task_info(task_id) if not task_info: continue last_status = get_last_task_status(task_id) task_status = get_task_status_from_last_status(last_status) is_active_task = is_task_active(task_status) # Check if task has been updated since the given timestamp task_timestamp = ( last_status.get("timestamp") if last_status else task_info.get("created_at", 0) ) # Always include active tasks in updates, apply filtering to others # Also include recently completed/terminal tasks to ensure "done" status gets sent is_recently_terminal = ( task_status in TERMINAL_TASK_STATES and task_timestamp > since_timestamp ) should_include = ( is_active_task or (task_timestamp > since_timestamp and not active_only) or is_recently_terminal ) if should_include: # Construct the same detailed task object as in list_tasks() task_response = _build_task_response( task_info, last_status, task_id, current_time, request ) if is_active_task: active_tasks.append(task_response) else: updated_tasks.append(task_response) # Apply pagination to non-active tasks offset = (page - 1) * limit paginated_updated_tasks = ( updated_tasks[offset : offset + limit] if not active_only else [] ) # Combine active tasks (always shown) with paginated updated tasks all_returned_tasks = active_tasks + paginated_updated_tasks # Sort by priority (active first, then by creation time) all_returned_tasks.sort( key=lambda x: ( 0 if x.get("task_id") in [t["task_id"] for t in active_tasks] else 1, -(x.get("created_at") or 0), ) ) response = { "tasks": all_returned_tasks, "current_timestamp": current_time, "total_tasks": task_counts["active"] + task_counts["retrying"], # Only active/retrying tasks for counter "all_tasks_count": len(all_tasks), # Total count of all tasks "task_counts": task_counts, # Categorized counts "active_tasks": len(active_tasks), "updated_count": len(updated_tasks), "since_timestamp": since_timestamp, "pagination": { "page": page, "limit": limit, "total_non_active": len(updated_tasks), "has_more": len(updated_tasks) > offset + limit, "returned_non_active": len(paginated_updated_tasks), }, } logger.debug( f"Returning {len(active_tasks)} active + {len(paginated_updated_tasks)} paginated tasks out of {len(all_tasks)} total" ) return response except HTTPException: raise except Exception as e: logger.error(f"Error in /api/prgs/updates: {e}", exc_info=True) raise HTTPException( status_code=500, detail={"error": "Failed to retrieve task updates"} ) @router.post("/cancel/all") async def cancel_all_tasks(current_user: User = Depends(require_auth_from_state)): """ Cancel all active (running or queued) tasks. """ try: tasks_to_cancel = get_all_tasks() cancelled_count = 0 errors = [] for task_summary in tasks_to_cancel: task_id = task_summary.get("task_id") if not task_id: continue try: cancel_task(task_id) cancelled_count += 1 except Exception as e: error_message = f"Failed to cancel task {task_id}: {e}" logger.error(error_message) errors.append(error_message) response = { "message": f"Attempted to cancel all active tasks. {cancelled_count} tasks cancelled.", "cancelled_count": cancelled_count, "errors": errors, } return response except Exception as e: logger.error(f"Error in /api/prgs/cancel/all: {e}", exc_info=True) raise HTTPException( status_code=500, detail={"error": "Failed to cancel all tasks"} ) @router.post("/cancel/{task_id}") async def cancel_task_endpoint( task_id: str, current_user: User = Depends(require_auth_from_state) ): """ Cancel a running or queued task. Args: task_id: The ID of the task to cancel """ try: # First check if this is a task ID in the new system task_info = get_task_info(task_id) if task_info: # This is a task ID in the new system result = cancel_task(task_id) try: # Push an immediate SSE update so clients reflect cancellation and partial summary await trigger_sse_update(task_id, "cancelled") result["sse_notified"] = True except Exception as e: logger.error(f"SSE notify after cancel failed for {task_id}: {e}") return result # If not found in new system, we need to handle the old system cancellation # For now, return an error as we're transitioning to the new system raise HTTPException( status_code=400, detail={ "status": "error", "message": "Cancellation for old system is not supported in the new API. Please use the new task ID format.", }, ) except HTTPException: raise except Exception as e: raise HTTPException(status_code=500, detail=f"An error occurred: {e}") @router.delete("/delete/{task_id}") async def delete_task( task_id: str, current_user: User = Depends(require_auth_from_state) ): """ Delete a task's information and history. Args: task_id: A task UUID from Celery """ # Only support new task IDs task_info = get_task_info(task_id) if not task_info: raise HTTPException(status_code=404, detail="Task not found") # First, cancel the task if it's running cancel_task(task_id) return {"message": f"Task {task_id} deleted successfully"} @router.get("/stream") async def stream_task_updates( request: Request, current_user: User = Depends(get_current_user_from_state) ): """ Stream real-time task updates via Server-Sent Events (SSE). Now uses event-driven architecture for true real-time updates. Uses optional authentication to avoid breaking SSE connections. Query parameters: active_only (bool): If true, only stream active tasks (downloading, processing, etc.) Returns: Server-Sent Events stream with task update data in JSON format """ # Get query parameters active_only = request.query_params.get("active_only", "").lower() == "true" async def event_generator(): # Create a queue for this client client_queue = asyncio.Queue() try: # Register this client with the broadcaster logger.debug("SSE Stream: New client connecting...") await sse_broadcaster.add_client(client_queue) logger.debug( f"SSE Stream: Client registered successfully, total clients: {len(sse_broadcaster.clients)}" ) # Send initial data immediately upon connection (standardized 'update') initial_data = await generate_task_update_event( time.time(), active_only, request ) yield initial_data # Send periodic heartbeats and listen for real-time events last_heartbeat = time.time() heartbeat_interval = 30.0 while True: try: # Wait for either an event or timeout for heartbeat try: event_data = await asyncio.wait_for( client_queue.get(), timeout=heartbeat_interval ) # Send the real-time event yield event_data last_heartbeat = time.time() except asyncio.TimeoutError: # Send heartbeat if no events for a while current_time = time.time() if current_time - last_heartbeat >= heartbeat_interval: # Generate current task counts for heartbeat all_tasks = get_all_tasks() task_counts = { "active": 0, "queued": 0, "completed": 0, "error": 0, "cancelled": 0, "retrying": 0, "skipped": 0, } for task_summary in all_tasks: task_id = task_summary.get("task_id") if not task_id: continue task_info = get_task_info(task_id) if not task_info: continue last_status = get_last_task_status(task_id) task_status = get_task_status_from_last_status( last_status ) if task_status == ProgressState.RETRYING: task_counts["retrying"] += 1 elif task_status in {ProgressState.QUEUED, "pending"}: task_counts["queued"] += 1 elif task_status in { ProgressState.COMPLETE, ProgressState.DONE, }: task_counts["completed"] += 1 elif task_status == ProgressState.ERROR: task_counts["error"] += 1 elif task_status == ProgressState.CANCELLED: task_counts["cancelled"] += 1 elif task_status == ProgressState.SKIPPED: task_counts["skipped"] += 1 elif is_task_active(task_status): task_counts["active"] += 1 heartbeat_data = { "current_timestamp": current_time, "total_tasks": task_counts["active"] + task_counts["retrying"], "task_counts": task_counts, "change_type": "heartbeat", "tasks": [], } event_json = json.dumps(heartbeat_data) yield f"data: {event_json}\n\n" last_heartbeat = current_time except Exception as e: logger.error(f"Error in SSE event streaming: {e}", exc_info=True) # Send error event and continue error_data = json.dumps( { "error": "Internal server error", "timestamp": time.time(), "change_type": "error", "tasks": [], } ) yield f"data: {error_data}\n\n" await asyncio.sleep(1) except asyncio.CancelledError: logger.debug("SSE client disconnected") return except Exception as e: logger.error(f"SSE connection error: {e}", exc_info=True) return finally: # Clean up - remove client from broadcaster await sse_broadcaster.remove_client(client_queue) return StreamingResponse( event_generator(), media_type="text/plain", headers={ "Cache-Control": "no-cache", "Connection": "keep-alive", "Content-Type": "text/event-stream", "Access-Control-Allow-Origin": "*", "Access-Control-Allow-Headers": "Cache-Control", }, ) async def generate_task_update_event( since_timestamp: float, active_only: bool, request: Request ) -> str: """ Generate initial task update event for SSE connection. This replicates the logic from get_task_updates but for SSE format. """ try: # Get all tasks for filtering all_tasks = get_all_tasks() current_time = time.time() updated_tasks = [] active_tasks = [] # Process tasks for filtering only - no counting here for task_summary in all_tasks: task_id = task_summary.get("task_id") if not task_id: continue task_info = get_task_info(task_id) if not task_info: continue last_status = get_last_task_status(task_id) task_status = get_task_status_from_last_status(last_status) is_active_task = is_task_active(task_status) # Check if task has been updated since the given timestamp task_timestamp = ( last_status.get("timestamp") if last_status else task_info.get("created_at", 0) ) # Always include active tasks in updates, apply filtering to others # Also include recently completed/terminal tasks to ensure "done" status gets sent is_recently_terminal = ( task_status in TERMINAL_TASK_STATES and task_timestamp > since_timestamp ) should_include = ( is_active_task or (task_timestamp > since_timestamp and not active_only) or is_recently_terminal ) if should_include: # Construct the same detailed task object as in updates endpoint task_response = _build_task_response( task_info, last_status, task_id, current_time, request ) if is_active_task: active_tasks.append(task_response) else: updated_tasks.append(task_response) # Combine active tasks (always shown) with updated tasks all_returned_tasks = active_tasks + updated_tasks # Sort by priority (active first, then by creation time) all_returned_tasks.sort( key=lambda x: ( 0 if x.get("task_id") in [t["task_id"] for t in active_tasks] else 1, -(x.get("created_at") or 0), ) ) initial_data = { "tasks": all_returned_tasks, "current_timestamp": current_time, "updated_count": len(updated_tasks), "since_timestamp": since_timestamp, "change_type": "update", "initial": True, # Mark as initial load } # Add global task counts since this bypasses the broadcaster enhanced_data = add_global_task_counts_to_event(initial_data) event_data = json.dumps(enhanced_data) return f"data: {event_data}\n\n" except Exception as e: logger.error(f"Error generating initial SSE event: {e}", exc_info=True) error_data = json.dumps( { "error": "Failed to load initial data", "timestamp": time.time(), "tasks": [], "change_type": "error", } ) return f"data: {error_data}\n\n" # IMPORTANT: This parameterized route MUST come AFTER all specific routes # Otherwise FastAPI will match specific routes like "/updates" as task_id parameters @router.get("/{task_id}") async def get_task_details( task_id: str, request: Request, current_user: User = Depends(require_auth_from_state), ): """ Return a JSON object with the resource type, its name (title), the last progress update, and, if available, the original request parameters. This function works with the new task ID based system. Args: task_id: A task UUID from Celery """ # Only support new task IDs task_info = get_task_info(task_id) if not task_info: raise HTTPException(status_code=404, detail="Task not found") # Dynamically construct original_url dynamic_original_url = "" download_type = task_info.get("download_type") # The 'url' field in task_info stores the Spotify/Deezer URL of the item # e.g., https://open.spotify.com/album/albumId or https://www.deezer.com/track/trackId item_url = task_info.get("url") if download_type and item_url: try: # Extract the ID from the item_url (last part of the path) item_id = item_url.split("/")[-1] if item_id: # Ensure item_id is not empty base_url = str(request.base_url).rstrip("/") dynamic_original_url = ( f"{base_url}/api/{download_type}/download/{item_id}" ) else: logger.warning( f"Could not extract item ID from URL: {item_url} for task {task_id}. Falling back for original_url." ) original_request_obj = task_info.get("original_request", {}) dynamic_original_url = original_request_obj.get("original_url", "") except Exception as e: logger.error( f"Error constructing dynamic original_url for task {task_id}: {e}", exc_info=True, ) original_request_obj = task_info.get("original_request", {}) dynamic_original_url = original_request_obj.get( "original_url", "" ) # Fallback on any error else: logger.warning( f"Missing download_type ('{download_type}') or item_url ('{item_url}') in task_info for task {task_id}. Falling back for original_url." ) original_request_obj = task_info.get("original_request", {}) dynamic_original_url = original_request_obj.get("original_url", "") last_status = get_last_task_status(task_id) status_count = len(get_task_status(task_id)) # Determine last_line content if last_status and "raw_callback" in last_status: last_line_content = last_status["raw_callback"] elif last_status and get_task_status_from_last_status(last_status) == "error": last_line_content = _build_error_callback_object(last_status) else: # Fallback for non-error, no raw_callback, or if last_status is None last_line_content = last_status response = { "original_url": dynamic_original_url, "last_line": last_line_content, "timestamp": last_status.get("timestamp") if last_status else time.time(), "task_id": task_id, "status_count": status_count, "created_at": task_info.get("created_at"), "name": task_info.get("name"), "artist": task_info.get("artist"), "type": task_info.get("type"), "download_type": task_info.get("download_type"), } if last_status and last_status.get("summary"): response["summary"] = last_status["summary"] return response