Files
spotizerr-dev/routes/system/progress.py
2025-08-22 22:06:42 +02:00

1220 lines
50 KiB
Python
Executable File

from fastapi import APIRouter, HTTPException, Request, Depends
from fastapi.responses import JSONResponse, StreamingResponse
import logging
import time
import json
import asyncio
from typing import Dict, Set
from routes.utils.celery_tasks import (
get_task_info,
get_task_status,
get_last_task_status,
get_all_tasks,
cancel_task,
ProgressState,
)
# Import authentication dependencies
from routes.auth.middleware import require_auth_from_state, get_current_user_from_state, User
# Configure logging
logger = logging.getLogger(__name__)
router = APIRouter()
# Global SSE Event Broadcaster
class SSEBroadcaster:
def __init__(self):
self.clients: Set[asyncio.Queue] = set()
async def add_client(self, queue: asyncio.Queue):
"""Add a new SSE client"""
self.clients.add(queue)
logger.info(f"SSE: Client connected (total: {len(self.clients)})")
async def remove_client(self, queue: asyncio.Queue):
"""Remove an SSE client"""
self.clients.discard(queue)
logger.info(f"SSE: Client disconnected (total: {len(self.clients)})")
async def broadcast_event(self, event_data: dict):
"""Broadcast an event to all connected clients"""
logger.debug(f"SSE Broadcaster: Attempting to broadcast to {len(self.clients)} clients")
if not self.clients:
logger.debug("SSE Broadcaster: No clients connected, skipping broadcast")
return
# Add global task counts right before broadcasting - this is the single source of truth
enhanced_event_data = add_global_task_counts_to_event(event_data.copy())
event_json = json.dumps(enhanced_event_data)
sse_data = f"data: {event_json}\n\n"
logger.debug(f"SSE Broadcaster: Broadcasting event: {enhanced_event_data.get('change_type', 'unknown')} with {enhanced_event_data.get('active_tasks', 0)} active tasks")
# Send to all clients, remove disconnected ones
disconnected = set()
sent_count = 0
for client_queue in self.clients.copy():
try:
await client_queue.put(sse_data)
sent_count += 1
logger.debug(f"SSE: Successfully sent to client queue")
except Exception as e:
logger.error(f"SSE: Failed to send to client: {e}")
disconnected.add(client_queue)
# Clean up disconnected clients
for client in disconnected:
self.clients.discard(client)
logger.info(f"SSE Broadcaster: Successfully sent to {sent_count} clients, removed {len(disconnected)} disconnected clients")
# Global broadcaster instance
sse_broadcaster = SSEBroadcaster()
# Redis subscriber for cross-process SSE events
import redis
import threading
from routes.utils.celery_config import REDIS_URL
# Redis client for SSE pub/sub
sse_redis_client = redis.Redis.from_url(REDIS_URL)
def start_sse_redis_subscriber():
"""Start Redis subscriber to listen for SSE events from Celery workers"""
def redis_subscriber_thread():
try:
pubsub = sse_redis_client.pubsub()
pubsub.subscribe("sse_events")
logger.info("SSE Redis Subscriber: Started listening for events")
for message in pubsub.listen():
if message['type'] == 'message':
try:
event_data = json.loads(message['data'].decode('utf-8'))
event_type = event_data.get('event_type', 'unknown')
task_id = event_data.get('task_id', 'unknown')
logger.debug(f"SSE Redis Subscriber: Received {event_type} for task {task_id}")
# Handle different event types
if event_type == 'progress_update':
# Transform callback data into task format expected by frontend
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
try:
broadcast_data = loop.run_until_complete(transform_callback_to_task_format(task_id, event_data))
if broadcast_data:
loop.run_until_complete(sse_broadcaster.broadcast_event(broadcast_data))
logger.debug(f"SSE Redis Subscriber: Broadcasted callback to {len(sse_broadcaster.clients)} clients")
finally:
loop.close()
elif event_type == 'summary_update':
# Task summary update - use existing trigger_sse_update logic
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
try:
loop.run_until_complete(trigger_sse_update(task_id, event_data.get('reason', 'update')))
logger.debug(f"SSE Redis Subscriber: Processed summary update for {task_id}")
finally:
loop.close()
else:
# Unknown event type - broadcast as-is
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
try:
loop.run_until_complete(sse_broadcaster.broadcast_event(event_data))
logger.debug(f"SSE Redis Subscriber: Broadcasted {event_type} to {len(sse_broadcaster.clients)} clients")
finally:
loop.close()
except Exception as e:
logger.error(f"SSE Redis Subscriber: Error processing message: {e}", exc_info=True)
except Exception as e:
logger.error(f"SSE Redis Subscriber: Fatal error: {e}", exc_info=True)
# Start Redis subscriber in background thread
thread = threading.Thread(target=redis_subscriber_thread, daemon=True)
thread.start()
logger.info("SSE Redis Subscriber: Background thread started")
async def transform_callback_to_task_format(task_id: str, event_data: dict) -> dict:
"""Transform callback event data into the task format expected by frontend"""
try:
# Import here to avoid circular imports
from routes.utils.celery_tasks import get_task_info, get_all_tasks
# Get task info to build complete task object
task_info = get_task_info(task_id)
if not task_info:
logger.warning(f"SSE Transform: No task info found for {task_id}")
return None
# Extract callback data
callback_data = event_data.get('callback_data', {})
# Build task object in the format expected by frontend
task_object = {
"task_id": task_id,
"original_url": f"http://localhost:7171/api/{task_info.get('download_type', 'track')}/download/{task_info.get('url', '').split('/')[-1] if task_info.get('url') else ''}",
"last_line": callback_data, # This is what frontend expects for callback data
"timestamp": event_data.get('timestamp', time.time()),
"download_type": task_info.get('download_type', 'track'),
"type": task_info.get('type', task_info.get('download_type', 'track')),
"name": task_info.get('name', 'Unknown'),
"artist": task_info.get('artist', ''),
"created_at": task_info.get('created_at'),
}
# Build minimal event data - global counts will be added at broadcast time
return {
"change_type": "update", # Use "update" so it gets processed by existing frontend logic
"tasks": [task_object], # Frontend expects tasks array
"current_timestamp": time.time(),
"updated_count": 1,
"since_timestamp": time.time(),
"trigger_reason": "callback_update"
}
except Exception as e:
logger.error(f"SSE Transform: Error transforming callback for task {task_id}: {e}", exc_info=True)
return None
# Start the Redis subscriber when module loads
start_sse_redis_subscriber()
async def trigger_sse_update(task_id: str, reason: str = "task_update"):
"""Trigger an immediate SSE update for a specific task"""
try:
current_time = time.time()
# Find the specific task that changed
task_info = get_task_info(task_id)
if not task_info:
logger.warning(f"SSE: Task {task_id} not found for update")
return
last_status = get_last_task_status(task_id)
# Create a dummy request for the _build_task_response function
from fastapi import Request
class DummyRequest:
def __init__(self):
self.base_url = "http://localhost:7171"
dummy_request = DummyRequest()
task_response = _build_task_response(task_info, last_status, task_id, current_time, dummy_request)
# Create minimal event data - global counts will be added at broadcast time
event_data = {
"tasks": [task_response],
"current_timestamp": current_time,
"since_timestamp": current_time,
"change_type": "realtime",
"trigger_reason": reason
}
await sse_broadcaster.broadcast_event(event_data)
logger.debug(f"SSE: Broadcast update for task {task_id} (reason: {reason})")
except Exception as e:
logger.error(f"SSE: Failed to trigger update for task {task_id}: {e}")
# Define active task states using ProgressState constants
ACTIVE_TASK_STATES = {
ProgressState.INITIALIZING, # "initializing" - task is starting up
ProgressState.PROCESSING, # "processing" - task is being processed
ProgressState.DOWNLOADING, # "downloading" - actively downloading
ProgressState.PROGRESS, # "progress" - album/playlist progress updates
ProgressState.TRACK_PROGRESS, # "track_progress" - real-time track progress
ProgressState.REAL_TIME, # "real_time" - real-time download progress
ProgressState.RETRYING, # "retrying" - task is retrying after error
"real-time", # "real-time" - real-time download progress (hyphenated version)
ProgressState.QUEUED, # "queued" - task is queued and waiting
"pending", # "pending" - legacy queued status
}
# Define terminal task states that should be included when recently completed
TERMINAL_TASK_STATES = {
ProgressState.COMPLETE, # "complete" - task completed successfully
ProgressState.DONE, # "done" - task finished processing
ProgressState.ERROR, # "error" - task failed
ProgressState.CANCELLED, # "cancelled" - task was cancelled
ProgressState.SKIPPED, # "skipped" - task was skipped
}
def get_task_status_from_last_status(last_status):
"""
Extract the task status from last_status, checking both possible locations.
Uses improved priority logic to handle real-time downloads correctly.
Args:
last_status: The last status dict from get_last_task_status()
Returns:
str: The task status string
"""
if not last_status:
return "unknown"
# For real-time downloads, prioritize status_info.status as it contains the actual progress state
status_info = last_status.get("status_info", {})
if isinstance(status_info, dict) and "status" in status_info:
status_info_status = status_info["status"]
# If status_info contains an active status, use it regardless of top-level status
if status_info_status in ACTIVE_TASK_STATES:
return status_info_status
# Fall back to top-level status
top_level_status = last_status.get("status", "unknown")
# If both exist but neither is active, prefer the more recent one (usually top-level)
# For active states, we already handled status_info above
return top_level_status
def is_task_active(task_status):
"""
Determine if a task is currently active (working/processing).
Args:
task_status: The status string from the task
Returns:
bool: True if the task is active, False otherwise
"""
if not task_status or task_status == "unknown":
return False
return task_status in ACTIVE_TASK_STATES
def get_global_task_counts():
"""
Get comprehensive task counts for ALL tasks in Redis.
This is called right before sending SSE events to ensure accurate counts.
Returns:
dict: Task counts by status
"""
task_counts = {
"active": 0,
"queued": 0,
"completed": 0,
"error": 0,
"cancelled": 0,
"retrying": 0,
"skipped": 0
}
try:
# Get ALL tasks from Redis - this is the source of truth
all_tasks = get_all_tasks()
for task_summary in all_tasks:
task_id = task_summary.get("task_id")
if not task_id:
continue
task_info = get_task_info(task_id)
if not task_info:
continue
last_status = get_last_task_status(task_id)
task_status = get_task_status_from_last_status(last_status)
is_active_task = is_task_active(task_status)
# Categorize tasks by status using ProgressState constants
if task_status == ProgressState.RETRYING:
task_counts["retrying"] += 1
elif task_status in {ProgressState.QUEUED, "pending"}:
task_counts["queued"] += 1
elif task_status in {ProgressState.COMPLETE, ProgressState.DONE}:
task_counts["completed"] += 1
elif task_status == ProgressState.ERROR:
task_counts["error"] += 1
elif task_status == ProgressState.CANCELLED:
task_counts["cancelled"] += 1
elif task_status == ProgressState.SKIPPED:
task_counts["skipped"] += 1
elif is_active_task:
task_counts["active"] += 1
logger.debug(f"Global task counts: {task_counts} (total: {len(all_tasks)} tasks)")
except Exception as e:
logger.error(f"Error getting global task counts: {e}", exc_info=True)
return task_counts
def add_global_task_counts_to_event(event_data):
"""
Add global task counts to any SSE event data right before broadcasting.
This ensures all SSE events have accurate, up-to-date counts of ALL tasks.
Args:
event_data: The event data dictionary to be sent via SSE
Returns:
dict: Enhanced event data with global task counts
"""
try:
# Get fresh counts of ALL tasks right before sending
global_task_counts = get_global_task_counts()
# Add/update the counts in the event data
event_data["task_counts"] = global_task_counts
event_data["active_tasks"] = global_task_counts["active"]
event_data["all_tasks_count"] = sum(global_task_counts.values())
return event_data
except Exception as e:
logger.error(f"Error adding global task counts to SSE event: {e}", exc_info=True)
return event_data
def _build_error_callback_object(last_status):
"""
Constructs a structured error callback object based on the last status of a task.
This conforms to the CallbackObject types in the frontend.
"""
# The 'type' from the status update corresponds to the download_type (album, playlist, track)
download_type = last_status.get("type")
name = last_status.get("name")
# The 'artist' field from the status may contain artist names or a playlist owner's name
artist_or_owner = last_status.get("artist")
error_message = last_status.get("error", "An unknown error occurred.")
status_info = {"status": "error", "error": error_message}
callback_object = {"status_info": status_info}
if download_type == "album":
callback_object["album"] = {
"type": "album",
"title": name,
"artists": [{
"type": "artistAlbum",
"name": artist_or_owner
}] if artist_or_owner else [],
}
elif download_type == "playlist":
playlist_payload = {"type": "playlist", "title": name}
if artist_or_owner:
playlist_payload["owner"] = {"type": "user", "name": artist_or_owner}
callback_object["playlist"] = playlist_payload
elif download_type == "track":
callback_object["track"] = {
"type": "track",
"title": name,
"artists": [{
"type": "artistTrack",
"name": artist_or_owner
}] if artist_or_owner else [],
}
else:
# Fallback for unknown types to avoid breaking the client, returning a basic error structure.
return {
"status_info": status_info,
"unstructured_error": True,
"details": {
"type": download_type,
"name": name,
"artist_or_owner": artist_or_owner,
},
}
return callback_object
def _build_task_response(task_info, last_status, task_id, current_time, request: Request):
"""
Helper function to build a standardized task response object.
"""
# Dynamically construct original_url
dynamic_original_url = ""
download_type = task_info.get("download_type")
item_url = task_info.get("url")
if download_type and item_url:
try:
item_id = item_url.split("/")[-1]
if item_id:
base_url = str(request.base_url).rstrip("/")
dynamic_original_url = (
f"{base_url}/api/{download_type}/download/{item_id}"
)
else:
logger.warning(
f"Could not extract item ID from URL: {item_url} for task {task_id}. Falling back for original_url."
)
original_request_obj = task_info.get("original_request", {})
dynamic_original_url = original_request_obj.get("original_url", "")
except Exception as e:
logger.error(
f"Error constructing dynamic original_url for task {task_id}: {e}",
exc_info=True,
)
original_request_obj = task_info.get("original_request", {})
dynamic_original_url = original_request_obj.get("original_url", "")
else:
logger.warning(
f"Missing download_type ('{download_type}') or item_url ('{item_url}') in task_info for task {task_id}. Falling back for original_url."
)
original_request_obj = task_info.get("original_request", {})
dynamic_original_url = original_request_obj.get("original_url", "")
status_count = len(get_task_status(task_id))
# Determine last_line content
if last_status and "raw_callback" in last_status:
last_line_content = last_status["raw_callback"]
elif last_status and get_task_status_from_last_status(last_status) == "error":
last_line_content = _build_error_callback_object(last_status)
else:
last_line_content = last_status
task_response = {
"original_url": dynamic_original_url,
"last_line": last_line_content,
"timestamp": last_status.get("timestamp") if last_status else current_time,
"task_id": task_id,
"status_count": status_count,
"created_at": task_info.get("created_at"),
"name": task_info.get("name"),
"artist": task_info.get("artist"),
"type": task_info.get("type"),
"download_type": task_info.get("download_type"),
}
if last_status and last_status.get("summary"):
task_response["summary"] = last_status["summary"]
return task_response
async def get_paginated_tasks(page=1, limit=20, active_only=False, request: Request = None):
"""
Get paginated list of tasks.
"""
try:
all_tasks = get_all_tasks()
# Get global task counts
task_counts = get_global_task_counts()
active_tasks = []
other_tasks = []
# Process tasks for pagination and response building
for task_summary in all_tasks:
task_id = task_summary.get("task_id")
if not task_id:
continue
task_info = get_task_info(task_id)
if not task_info:
continue
last_status = get_last_task_status(task_id)
task_status = get_task_status_from_last_status(last_status)
is_active_task = is_task_active(task_status)
task_response = _build_task_response(task_info, last_status, task_id, time.time(), request)
if is_active_task:
active_tasks.append(task_response)
else:
other_tasks.append(task_response)
# Sort other tasks by creation time (newest first)
other_tasks.sort(key=lambda x: x.get("created_at", 0), reverse=True)
if active_only:
paginated_tasks = active_tasks
pagination_info = {
"page": page,
"limit": limit,
"total_non_active": 0,
"has_more": False,
"returned_non_active": 0
}
else:
# Apply pagination to non-active tasks
offset = (page - 1) * limit
paginated_other_tasks = other_tasks[offset:offset + limit]
paginated_tasks = active_tasks + paginated_other_tasks
pagination_info = {
"page": page,
"limit": limit,
"total_non_active": len(other_tasks),
"has_more": len(other_tasks) > offset + limit,
"returned_non_active": len(paginated_other_tasks)
}
response = {
"tasks": paginated_tasks,
"current_timestamp": time.time(),
"total_tasks": task_counts["active"] + task_counts["retrying"], # Only active/retrying tasks for counter
"all_tasks_count": len(all_tasks), # Total count of all tasks
"task_counts": task_counts, # Categorized counts
"active_tasks": len(active_tasks),
"updated_count": len(paginated_tasks),
"pagination": pagination_info
}
return response
except Exception as e:
logger.error(f"Error in get_paginated_tasks: {e}", exc_info=True)
raise HTTPException(status_code=500, detail={"error": "Failed to retrieve paginated tasks"})
# IMPORTANT: Specific routes MUST come before parameterized routes in FastAPI
# Otherwise "updates" gets matched as a {task_id} parameter!
@router.get("/list")
async def list_tasks(request: Request, current_user: User = Depends(require_auth_from_state)):
"""
Retrieve a paginated list of all tasks in the system.
Returns a detailed list of task objects including status and metadata.
Query parameters:
page (int): Page number for pagination (default: 1)
limit (int): Number of tasks per page (default: 50, max: 100)
active_only (bool): If true, only return active tasks (downloading, processing, etc.)
"""
try:
# Get query parameters
page = int(request.query_params.get('page', 1))
limit = min(int(request.query_params.get('limit', 50)), 100) # Cap at 100
active_only = request.query_params.get('active_only', '').lower() == 'true'
tasks = get_all_tasks()
active_tasks = []
other_tasks = []
# Task categorization counters
task_counts = {
"active": 0,
"queued": 0,
"completed": 0,
"error": 0,
"cancelled": 0,
"retrying": 0,
"skipped": 0
}
for task_summary in tasks:
task_id = task_summary.get("task_id")
if not task_id:
continue
task_info = get_task_info(task_id)
if not task_info:
continue
last_status = get_last_task_status(task_id)
task_status = get_task_status_from_last_status(last_status)
is_active_task = is_task_active(task_status)
# Categorize tasks by status using ProgressState constants
if task_status == ProgressState.RETRYING:
task_counts["retrying"] += 1
elif task_status in {ProgressState.QUEUED, "pending"}: # Keep "pending" for backward compatibility
task_counts["queued"] += 1
elif task_status in {ProgressState.COMPLETE, ProgressState.DONE}:
task_counts["completed"] += 1
elif task_status == ProgressState.ERROR:
task_counts["error"] += 1
elif task_status == ProgressState.CANCELLED:
task_counts["cancelled"] += 1
elif task_status == ProgressState.SKIPPED:
task_counts["skipped"] += 1
elif is_active_task:
task_counts["active"] += 1
task_response = _build_task_response(task_info, last_status, task_id, time.time(), request)
if is_active_task:
active_tasks.append(task_response)
else:
other_tasks.append(task_response)
# Sort other tasks by creation time (newest first)
other_tasks.sort(key=lambda x: x.get("created_at", 0), reverse=True)
if active_only:
# Return only active tasks without pagination
response_tasks = active_tasks
pagination_info = {
"page": page,
"limit": limit,
"total_items": len(active_tasks),
"total_pages": 1,
"has_more": False
}
else:
# Apply pagination to non-active tasks and combine with active tasks
offset = (page - 1) * limit
# Always include active tasks at the top
if page == 1:
# For first page, include active tasks + first batch of other tasks
available_space = limit - len(active_tasks)
paginated_other_tasks = other_tasks[:max(0, available_space)]
response_tasks = active_tasks + paginated_other_tasks
else:
# For subsequent pages, only include other tasks
# Adjust offset to account for active tasks shown on first page
adjusted_offset = offset - len(active_tasks)
if adjusted_offset < 0:
adjusted_offset = 0
paginated_other_tasks = other_tasks[adjusted_offset:adjusted_offset + limit]
response_tasks = paginated_other_tasks
total_items = len(active_tasks) + len(other_tasks)
total_pages = ((total_items - 1) // limit) + 1 if total_items > 0 else 1
pagination_info = {
"page": page,
"limit": limit,
"total_items": total_items,
"total_pages": total_pages,
"has_more": page < total_pages,
"active_tasks": len(active_tasks),
"total_other_tasks": len(other_tasks)
}
response = {
"tasks": response_tasks,
"pagination": pagination_info,
"total_tasks": task_counts["active"] + task_counts["retrying"], # Only active/retrying tasks for counter
"all_tasks_count": len(tasks), # Total count of all tasks
"task_counts": task_counts, # Categorized counts
"active_tasks": len(active_tasks),
"timestamp": time.time()
}
return response
except Exception as e:
logger.error(f"Error in /api/prgs/list: {e}", exc_info=True)
raise HTTPException(status_code=500, detail={"error": "Failed to retrieve task list"})
@router.get("/updates")
async def get_task_updates(request: Request, current_user: User = Depends(require_auth_from_state)):
"""
Retrieve only tasks that have been updated since the specified timestamp.
This endpoint is optimized for polling to reduce unnecessary data transfer.
Query parameters:
since (float): Unix timestamp - only return tasks updated after this time
page (int): Page number for pagination (default: 1)
limit (int): Number of queued/completed tasks per page (default: 20, max: 100)
active_only (bool): If true, only return active tasks (downloading, processing, etc.)
Returns:
JSON object containing:
- tasks: Array of updated task objects
- current_timestamp: Current server timestamp for next poll
- total_tasks: Total number of tasks in system
- active_tasks: Number of active tasks
- pagination: Pagination info for queued/completed tasks
"""
try:
# Get query parameters
since_param = request.query_params.get('since')
page = int(request.query_params.get('page', 1))
limit = min(int(request.query_params.get('limit', 20)), 100) # Cap at 100
active_only = request.query_params.get('active_only', '').lower() == 'true'
if not since_param:
# If no 'since' parameter, return paginated tasks (fallback behavior)
response = await get_paginated_tasks(page, limit, active_only, request)
return response
try:
since_timestamp = float(since_param)
except (ValueError, TypeError):
raise HTTPException(status_code=400, detail={"error": "Invalid 'since' timestamp format"})
# Get all tasks
all_tasks = get_all_tasks()
current_time = time.time()
# Get global task counts
task_counts = get_global_task_counts()
updated_tasks = []
active_tasks = []
# Process tasks for filtering and response building
for task_summary in all_tasks:
task_id = task_summary.get("task_id")
if not task_id:
continue
task_info = get_task_info(task_id)
if not task_info:
continue
last_status = get_last_task_status(task_id)
task_status = get_task_status_from_last_status(last_status)
is_active_task = is_task_active(task_status)
# Check if task has been updated since the given timestamp
task_timestamp = last_status.get("timestamp") if last_status else task_info.get("created_at", 0)
# Always include active tasks in updates, apply filtering to others
# Also include recently completed/terminal tasks to ensure "done" status gets sent
is_recently_terminal = task_status in TERMINAL_TASK_STATES and task_timestamp > since_timestamp
should_include = is_active_task or (task_timestamp > since_timestamp and not active_only) or is_recently_terminal
if should_include:
# Construct the same detailed task object as in list_tasks()
task_response = _build_task_response(task_info, last_status, task_id, current_time, request)
if is_active_task:
active_tasks.append(task_response)
else:
updated_tasks.append(task_response)
# Apply pagination to non-active tasks
offset = (page - 1) * limit
paginated_updated_tasks = updated_tasks[offset:offset + limit] if not active_only else []
# Combine active tasks (always shown) with paginated updated tasks
all_returned_tasks = active_tasks + paginated_updated_tasks
# Sort by priority (active first, then by creation time)
all_returned_tasks.sort(key=lambda x: (
0 if x.get("task_id") in [t["task_id"] for t in active_tasks] else 1,
-(x.get("created_at") or 0)
))
response = {
"tasks": all_returned_tasks,
"current_timestamp": current_time,
"total_tasks": task_counts["active"] + task_counts["retrying"], # Only active/retrying tasks for counter
"all_tasks_count": len(all_tasks), # Total count of all tasks
"task_counts": task_counts, # Categorized counts
"active_tasks": len(active_tasks),
"updated_count": len(updated_tasks),
"since_timestamp": since_timestamp,
"pagination": {
"page": page,
"limit": limit,
"total_non_active": len(updated_tasks),
"has_more": len(updated_tasks) > offset + limit,
"returned_non_active": len(paginated_updated_tasks)
}
}
logger.debug(f"Returning {len(active_tasks)} active + {len(paginated_updated_tasks)} paginated tasks out of {len(all_tasks)} total")
return response
except HTTPException:
raise
except Exception as e:
logger.error(f"Error in /api/prgs/updates: {e}", exc_info=True)
raise HTTPException(status_code=500, detail={"error": "Failed to retrieve task updates"})
@router.post("/cancel/all")
async def cancel_all_tasks(current_user: User = Depends(require_auth_from_state)):
"""
Cancel all active (running or queued) tasks.
"""
try:
tasks_to_cancel = get_all_tasks()
cancelled_count = 0
errors = []
for task_summary in tasks_to_cancel:
task_id = task_summary.get("task_id")
if not task_id:
continue
try:
cancel_task(task_id)
cancelled_count += 1
except Exception as e:
error_message = f"Failed to cancel task {task_id}: {e}"
logger.error(error_message)
errors.append(error_message)
response = {
"message": f"Attempted to cancel all active tasks. {cancelled_count} tasks cancelled.",
"cancelled_count": cancelled_count,
"errors": errors,
}
return response
except Exception as e:
logger.error(f"Error in /api/prgs/cancel/all: {e}", exc_info=True)
raise HTTPException(status_code=500, detail={"error": "Failed to cancel all tasks"})
@router.post("/cancel/{task_id}")
async def cancel_task_endpoint(task_id: str, current_user: User = Depends(require_auth_from_state)):
"""
Cancel a running or queued task.
Args:
task_id: The ID of the task to cancel
"""
try:
# First check if this is a task ID in the new system
task_info = get_task_info(task_id)
if task_info:
# This is a task ID in the new system
result = cancel_task(task_id)
try:
# Push an immediate SSE update so clients reflect cancellation and partial summary
await trigger_sse_update(task_id, "cancelled")
result["sse_notified"] = True
except Exception as e:
logger.error(f"SSE notify after cancel failed for {task_id}: {e}")
return result
# If not found in new system, we need to handle the old system cancellation
# For now, return an error as we're transitioning to the new system
raise HTTPException(
status_code=400,
detail={
"status": "error",
"message": "Cancellation for old system is not supported in the new API. Please use the new task ID format.",
}
)
except HTTPException:
raise
except Exception as e:
raise HTTPException(status_code=500, detail=f"An error occurred: {e}")
@router.delete("/delete/{task_id}")
async def delete_task(task_id: str, current_user: User = Depends(require_auth_from_state)):
"""
Delete a task's information and history.
Args:
task_id: A task UUID from Celery
"""
# Only support new task IDs
task_info = get_task_info(task_id)
if not task_info:
raise HTTPException(status_code=404, detail="Task not found")
# First, cancel the task if it's running
cancel_task(task_id)
return {"message": f"Task {task_id} deleted successfully"}
@router.get("/stream")
async def stream_task_updates(request: Request, current_user: User = Depends(get_current_user_from_state)):
"""
Stream real-time task updates via Server-Sent Events (SSE).
Now uses event-driven architecture for true real-time updates.
Uses optional authentication to avoid breaking SSE connections.
Query parameters:
active_only (bool): If true, only stream active tasks (downloading, processing, etc.)
Returns:
Server-Sent Events stream with task update data in JSON format
"""
# Get query parameters
active_only = request.query_params.get('active_only', '').lower() == 'true'
async def event_generator():
# Create a queue for this client
client_queue = asyncio.Queue()
try:
# Register this client with the broadcaster
logger.info(f"SSE Stream: New client connecting...")
await sse_broadcaster.add_client(client_queue)
logger.info(f"SSE Stream: Client registered successfully, total clients: {len(sse_broadcaster.clients)}")
# Send initial data immediately upon connection
initial_data = await generate_task_update_event(time.time(), active_only, request)
yield initial_data
# Also send any active tasks as callback-style events to newly connected clients
all_tasks = get_all_tasks()
for task_summary in all_tasks:
task_id = task_summary.get("task_id")
if not task_id:
continue
task_info = get_task_info(task_id)
if not task_info:
continue
last_status = get_last_task_status(task_id)
task_status = get_task_status_from_last_status(last_status)
# Send recent callback data for active or recently completed tasks
if is_task_active(task_status) or (last_status and last_status.get("timestamp", 0) > time.time() - 30):
if last_status and "raw_callback" in last_status:
callback_event = {
"task_id": task_id,
"callback_data": last_status["raw_callback"],
"timestamp": last_status.get("timestamp", time.time()),
"change_type": "callback",
"event_type": "progress_update",
"replay": True # Mark as replay for client
}
event_json = json.dumps(callback_event)
yield f"data: {event_json}\n\n"
logger.info(f"SSE Stream: Sent replay callback for task {task_id}")
# Send periodic heartbeats and listen for real-time events
last_heartbeat = time.time()
heartbeat_interval = 30.0
while True:
try:
# Wait for either an event or timeout for heartbeat
try:
event_data = await asyncio.wait_for(client_queue.get(), timeout=heartbeat_interval)
# Send the real-time event
yield event_data
last_heartbeat = time.time()
except asyncio.TimeoutError:
# Send heartbeat if no events for a while
current_time = time.time()
if current_time - last_heartbeat >= heartbeat_interval:
# Generate current task counts for heartbeat
all_tasks = get_all_tasks()
task_counts = {"active": 0, "queued": 0, "completed": 0, "error": 0, "cancelled": 0, "retrying": 0, "skipped": 0}
for task_summary in all_tasks:
task_id = task_summary.get("task_id")
if not task_id:
continue
task_info = get_task_info(task_id)
if not task_info:
continue
last_status = get_last_task_status(task_id)
task_status = get_task_status_from_last_status(last_status)
if task_status == ProgressState.RETRYING:
task_counts["retrying"] += 1
elif task_status in {ProgressState.QUEUED, "pending"}:
task_counts["queued"] += 1
elif task_status in {ProgressState.COMPLETE, ProgressState.DONE}:
task_counts["completed"] += 1
elif task_status == ProgressState.ERROR:
task_counts["error"] += 1
elif task_status == ProgressState.CANCELLED:
task_counts["cancelled"] += 1
elif task_status == ProgressState.SKIPPED:
task_counts["skipped"] += 1
elif is_task_active(task_status):
task_counts["active"] += 1
heartbeat_data = {
"current_timestamp": current_time,
"total_tasks": task_counts["active"] + task_counts["retrying"],
"task_counts": task_counts,
"change_type": "heartbeat"
}
event_json = json.dumps(heartbeat_data)
yield f"data: {event_json}\n\n"
last_heartbeat = current_time
except Exception as e:
logger.error(f"Error in SSE event streaming: {e}", exc_info=True)
# Send error event and continue
error_data = json.dumps({"error": "Internal server error", "timestamp": time.time(), "change_type": "error"})
yield f"data: {error_data}\n\n"
await asyncio.sleep(1)
except asyncio.CancelledError:
logger.info("SSE client disconnected")
return
except Exception as e:
logger.error(f"SSE connection error: {e}", exc_info=True)
return
finally:
# Clean up - remove client from broadcaster
await sse_broadcaster.remove_client(client_queue)
return StreamingResponse(
event_generator(),
media_type="text/plain",
headers={
"Cache-Control": "no-cache",
"Connection": "keep-alive",
"Content-Type": "text/event-stream",
"Access-Control-Allow-Origin": "*",
"Access-Control-Allow-Headers": "Cache-Control"
}
)
async def generate_task_update_event(since_timestamp: float, active_only: bool, request: Request) -> str:
"""
Generate initial task update event for SSE connection.
This replicates the logic from get_task_updates but for SSE format.
"""
try:
# Get all tasks for filtering
all_tasks = get_all_tasks()
current_time = time.time()
updated_tasks = []
active_tasks = []
# Process tasks for filtering only - no counting here
for task_summary in all_tasks:
task_id = task_summary.get("task_id")
if not task_id:
continue
task_info = get_task_info(task_id)
if not task_info:
continue
last_status = get_last_task_status(task_id)
task_status = get_task_status_from_last_status(last_status)
is_active_task = is_task_active(task_status)
# Check if task has been updated since the given timestamp
task_timestamp = last_status.get("timestamp") if last_status else task_info.get("created_at", 0)
# Always include active tasks in updates, apply filtering to others
# Also include recently completed/terminal tasks to ensure "done" status gets sent
is_recently_terminal = task_status in TERMINAL_TASK_STATES and task_timestamp > since_timestamp
should_include = is_active_task or (task_timestamp > since_timestamp and not active_only) or is_recently_terminal
if should_include:
# Construct the same detailed task object as in updates endpoint
task_response = _build_task_response(task_info, last_status, task_id, current_time, request)
if is_active_task:
active_tasks.append(task_response)
else:
updated_tasks.append(task_response)
# Combine active tasks (always shown) with updated tasks
all_returned_tasks = active_tasks + updated_tasks
# Sort by priority (active first, then by creation time)
all_returned_tasks.sort(key=lambda x: (
0 if x.get("task_id") in [t["task_id"] for t in active_tasks] else 1,
-(x.get("created_at") or 0)
))
initial_data = {
"tasks": all_returned_tasks,
"current_timestamp": current_time,
"updated_count": len(updated_tasks),
"since_timestamp": since_timestamp,
"initial": True # Mark as initial load
}
# Add global task counts since this bypasses the broadcaster
enhanced_data = add_global_task_counts_to_event(initial_data)
event_data = json.dumps(enhanced_data)
return f"data: {event_data}\n\n"
except Exception as e:
logger.error(f"Error generating initial SSE event: {e}", exc_info=True)
error_data = json.dumps({"error": "Failed to load initial data", "timestamp": time.time()})
return f"data: {error_data}\n\n"
# IMPORTANT: This parameterized route MUST come AFTER all specific routes
# Otherwise FastAPI will match specific routes like "/updates" as task_id parameters
@router.get("/{task_id}")
async def get_task_details(task_id: str, request: Request, current_user: User = Depends(require_auth_from_state)):
"""
Return a JSON object with the resource type, its name (title),
the last progress update, and, if available, the original request parameters.
This function works with the new task ID based system.
Args:
task_id: A task UUID from Celery
"""
# Only support new task IDs
task_info = get_task_info(task_id)
if not task_info:
raise HTTPException(status_code=404, detail="Task not found")
# Dynamically construct original_url
dynamic_original_url = ""
download_type = task_info.get("download_type")
# The 'url' field in task_info stores the Spotify/Deezer URL of the item
# e.g., https://open.spotify.com/album/albumId or https://www.deezer.com/track/trackId
item_url = task_info.get("url")
if download_type and item_url:
try:
# Extract the ID from the item_url (last part of the path)
item_id = item_url.split("/")[-1]
if item_id: # Ensure item_id is not empty
base_url = str(request.base_url).rstrip("/")
dynamic_original_url = (
f"{base_url}/api/{download_type}/download/{item_id}"
)
else:
logger.warning(
f"Could not extract item ID from URL: {item_url} for task {task_id}. Falling back for original_url."
)
original_request_obj = task_info.get("original_request", {})
dynamic_original_url = original_request_obj.get("original_url", "")
except Exception as e:
logger.error(
f"Error constructing dynamic original_url for task {task_id}: {e}",
exc_info=True,
)
original_request_obj = task_info.get("original_request", {})
dynamic_original_url = original_request_obj.get(
"original_url", ""
) # Fallback on any error
else:
logger.warning(
f"Missing download_type ('{download_type}') or item_url ('{item_url}') in task_info for task {task_id}. Falling back for original_url."
)
original_request_obj = task_info.get("original_request", {})
dynamic_original_url = original_request_obj.get("original_url", "")
last_status = get_last_task_status(task_id)
status_count = len(get_task_status(task_id))
# Determine last_line content
if last_status and "raw_callback" in last_status:
last_line_content = last_status["raw_callback"]
elif last_status and get_task_status_from_last_status(last_status) == "error":
last_line_content = _build_error_callback_object(last_status)
else:
# Fallback for non-error, no raw_callback, or if last_status is None
last_line_content = last_status
response = {
"original_url": dynamic_original_url,
"last_line": last_line_content,
"timestamp": last_status.get("timestamp") if last_status else time.time(),
"task_id": task_id,
"status_count": status_count,
"created_at": task_info.get("created_at"),
"name": task_info.get("name"),
"artist": task_info.get("artist"),
"type": task_info.get("type"),
"download_type": task_info.get("download_type"),
}
if last_status and last_status.get("summary"):
response["summary"] = last_status["summary"]
return response