1533 lines
57 KiB
Python
Executable File
1533 lines
57 KiB
Python
Executable File
from fastapi import APIRouter, HTTPException, Request, Depends
|
|
from fastapi.responses import StreamingResponse
|
|
import logging
|
|
import time
|
|
import json
|
|
import asyncio
|
|
from typing import Set, Optional
|
|
|
|
import redis
|
|
import threading
|
|
from routes.utils.celery_config import REDIS_URL, get_config_params
|
|
|
|
from routes.utils.celery_tasks import (
|
|
get_task_info,
|
|
get_task_status,
|
|
get_last_task_status,
|
|
get_all_tasks,
|
|
cancel_task,
|
|
delete_task_data_and_log,
|
|
ProgressState,
|
|
)
|
|
|
|
# Import authentication dependencies
|
|
from routes.auth.middleware import (
|
|
require_auth_from_state,
|
|
get_current_user_from_state,
|
|
User,
|
|
)
|
|
|
|
# Configure logging
|
|
logger = logging.getLogger(__name__)
|
|
|
|
router = APIRouter()
|
|
|
|
|
|
# Global SSE Event Broadcaster
|
|
class SSEBroadcaster:
|
|
def __init__(self):
|
|
self.clients: Set[asyncio.Queue] = set()
|
|
# Per-task throttling/batching/deduplication state
|
|
self._task_state = {} # task_id -> dict with last_sent, last_event, last_send_time, scheduled_handle
|
|
# Load configurable interval
|
|
config = get_config_params()
|
|
self.sse_update_interval = float(config.get("sseUpdateIntervalSeconds", 1))
|
|
|
|
async def add_client(self, queue: asyncio.Queue):
|
|
"""Add a new SSE client"""
|
|
self.clients.add(queue)
|
|
logger.debug(f"SSE: Client connected (total: {len(self.clients)})")
|
|
|
|
async def remove_client(self, queue: asyncio.Queue):
|
|
"""Remove an SSE client"""
|
|
self.clients.discard(queue)
|
|
logger.debug(f"SSE: Client disconnected (total: {len(self.clients)})")
|
|
|
|
async def broadcast_event(self, event_data: dict):
|
|
"""
|
|
Throttle, batch, and deduplicate SSE events per task.
|
|
Only emit at most 1 update/sec per task, aggregate within window, suppress redundant updates.
|
|
"""
|
|
if not self.clients:
|
|
logger.debug("SSE Broadcaster: No clients connected, skipping broadcast")
|
|
return
|
|
|
|
# Defensive: always work with a list of tasks
|
|
tasks = event_data.get("tasks", [])
|
|
if not isinstance(tasks, list):
|
|
tasks = [tasks]
|
|
|
|
# For each task, throttle/batch/dedupe
|
|
for task in tasks:
|
|
task_id = task.get("task_id")
|
|
if not task_id:
|
|
continue
|
|
|
|
now = time.time()
|
|
state = self._task_state.setdefault(task_id, {
|
|
"last_sent": None,
|
|
"last_event": None,
|
|
"last_send_time": 0,
|
|
"scheduled_handle": None,
|
|
})
|
|
|
|
# Deduplication: if event is identical to last sent, skip
|
|
if state["last_sent"] is not None and self._events_equal(state["last_sent"], task):
|
|
logger.debug(f"SSE: Deduped event for task {task_id}")
|
|
continue
|
|
|
|
# Throttling: if within interval, batch (store as last_event, schedule send)
|
|
elapsed = now - state["last_send_time"]
|
|
if elapsed < self.sse_update_interval:
|
|
state["last_event"] = task
|
|
if state["scheduled_handle"] is None:
|
|
delay = self.sse_update_interval - elapsed
|
|
loop = asyncio.get_event_loop()
|
|
state["scheduled_handle"] = loop.call_later(
|
|
delay, lambda: asyncio.create_task(self._send_batched_event(task_id))
|
|
)
|
|
continue
|
|
|
|
# Otherwise, send immediately
|
|
await self._send_event(task_id, task)
|
|
state["last_send_time"] = now
|
|
state["last_sent"] = task
|
|
state["last_event"] = None
|
|
if state["scheduled_handle"]:
|
|
state["scheduled_handle"].cancel()
|
|
state["scheduled_handle"] = None
|
|
|
|
async def _send_batched_event(self, task_id):
|
|
state = self._task_state.get(task_id)
|
|
if not state or not state["last_event"]:
|
|
return
|
|
await self._send_event(task_id, state["last_event"])
|
|
state["last_send_time"] = time.time()
|
|
state["last_sent"] = state["last_event"]
|
|
state["last_event"] = None
|
|
state["scheduled_handle"] = None
|
|
|
|
async def _send_event(self, task_id, task):
|
|
# Compose event_data for this task
|
|
event_data = {
|
|
"tasks": [task],
|
|
"current_timestamp": time.time(),
|
|
"change_type": "update",
|
|
}
|
|
enhanced_event_data = add_global_task_counts_to_event(event_data.copy())
|
|
event_json = json.dumps(enhanced_event_data)
|
|
sse_data = f"data: {event_json}\n\n"
|
|
|
|
disconnected = set()
|
|
sent_count = 0
|
|
for client_queue in self.clients.copy():
|
|
try:
|
|
await client_queue.put(sse_data)
|
|
sent_count += 1
|
|
except Exception as e:
|
|
logger.error(f"SSE: Failed to send to client: {e}")
|
|
disconnected.add(client_queue)
|
|
for client in disconnected:
|
|
self.clients.discard(client)
|
|
logger.debug(
|
|
f"SSE Broadcaster: Sent throttled/batched event for task {task_id} to {sent_count} clients"
|
|
)
|
|
|
|
def _events_equal(self, a, b):
|
|
# Compare two task dicts for deduplication (ignore timestamps)
|
|
if not isinstance(a, dict) or not isinstance(b, dict):
|
|
return False
|
|
a_copy = dict(a)
|
|
b_copy = dict(b)
|
|
a_copy.pop("timestamp", None)
|
|
b_copy.pop("timestamp", None)
|
|
return a_copy == b_copy
|
|
|
|
|
|
# Global broadcaster instance
|
|
sse_broadcaster = SSEBroadcaster()
|
|
|
|
# Redis subscriber for cross-process SSE events
|
|
|
|
# Redis client for SSE pub/sub
|
|
sse_redis_client = redis.Redis.from_url(REDIS_URL)
|
|
|
|
|
|
def start_sse_redis_subscriber():
|
|
"""Start Redis subscriber to listen for SSE events from Celery workers"""
|
|
|
|
def redis_subscriber_thread():
|
|
try:
|
|
pubsub = sse_redis_client.pubsub()
|
|
pubsub.subscribe("sse_events")
|
|
logger.info("SSE Redis Subscriber: Started listening for events")
|
|
|
|
for message in pubsub.listen():
|
|
if message["type"] == "message":
|
|
try:
|
|
event_data = json.loads(message["data"].decode("utf-8"))
|
|
event_type = event_data.get("event_type", "unknown")
|
|
task_id = event_data.get("task_id", "unknown")
|
|
|
|
logger.debug(
|
|
f"SSE Redis Subscriber: Received {event_type} for task {task_id}"
|
|
)
|
|
|
|
# Handle different event types
|
|
if event_type == "progress_update":
|
|
# Transform callback data into standardized update format expected by frontend
|
|
standardized = standardize_incoming_event(event_data)
|
|
if standardized:
|
|
loop = asyncio.new_event_loop()
|
|
asyncio.set_event_loop(loop)
|
|
try:
|
|
loop.run_until_complete(
|
|
sse_broadcaster.broadcast_event(standardized)
|
|
)
|
|
logger.debug(
|
|
f"SSE Redis Subscriber: Broadcasted standardized progress update to {len(sse_broadcaster.clients)} clients"
|
|
)
|
|
finally:
|
|
loop.close()
|
|
elif event_type == "summary_update":
|
|
# Task summary update - use standardized trigger
|
|
loop = asyncio.new_event_loop()
|
|
asyncio.set_event_loop(loop)
|
|
try:
|
|
loop.run_until_complete(
|
|
trigger_sse_update(
|
|
task_id, event_data.get("reason", "update")
|
|
)
|
|
)
|
|
logger.debug(
|
|
f"SSE Redis Subscriber: Processed summary update for {task_id}"
|
|
)
|
|
finally:
|
|
loop.close()
|
|
else:
|
|
# Unknown event type - attempt to standardize and broadcast
|
|
standardized = standardize_incoming_event(event_data)
|
|
if standardized:
|
|
loop = asyncio.new_event_loop()
|
|
asyncio.set_event_loop(loop)
|
|
try:
|
|
loop.run_until_complete(
|
|
sse_broadcaster.broadcast_event(standardized)
|
|
)
|
|
logger.debug(
|
|
f"SSE Redis Subscriber: Broadcasted standardized {event_type} to {len(sse_broadcaster.clients)} clients"
|
|
)
|
|
finally:
|
|
loop.close()
|
|
|
|
except Exception as e:
|
|
logger.error(
|
|
f"SSE Redis Subscriber: Error processing message: {e}",
|
|
exc_info=True,
|
|
)
|
|
|
|
except Exception as e:
|
|
logger.error(f"SSE Redis Subscriber: Fatal error: {e}", exc_info=True)
|
|
|
|
# Start Redis subscriber in background thread
|
|
thread = threading.Thread(target=redis_subscriber_thread, daemon=True)
|
|
thread.start()
|
|
logger.debug("SSE Redis Subscriber: Background thread started")
|
|
|
|
|
|
def build_task_object_from_callback(
|
|
task_id: str, callback_data: dict
|
|
) -> Optional[dict]:
|
|
"""Build a standardized task object from callback payload and task info."""
|
|
try:
|
|
task_info = get_task_info(task_id)
|
|
if not task_info:
|
|
return None
|
|
return {
|
|
"task_id": task_id,
|
|
"original_url": f"http://localhost:7171/api/{task_info.get('download_type', 'track')}/download/{task_info.get('url', '').split('/')[-1] if task_info.get('url') else ''}",
|
|
"last_line": callback_data,
|
|
"timestamp": time.time(),
|
|
"download_type": task_info.get("download_type", "track"),
|
|
"type": task_info.get("type", task_info.get("download_type", "track")),
|
|
"name": task_info.get("name", "Unknown"),
|
|
"artist": task_info.get("artist", ""),
|
|
"created_at": task_info.get("created_at"),
|
|
}
|
|
except Exception as e:
|
|
logger.error(
|
|
f"Error building task object from callback for {task_id}: {e}",
|
|
exc_info=True,
|
|
)
|
|
return None
|
|
|
|
|
|
def standardize_incoming_event(event_data: dict) -> Optional[dict]:
|
|
"""
|
|
Convert various incoming event shapes into a standardized SSE payload:
|
|
{
|
|
'change_type': 'update' | 'heartbeat',
|
|
'tasks': [...],
|
|
'current_timestamp': float,
|
|
'trigger_reason': str (optional)
|
|
}
|
|
"""
|
|
try:
|
|
# Heartbeat passthrough (ensure tasks array exists)
|
|
if event_data.get("change_type") == "heartbeat":
|
|
return {
|
|
"change_type": "heartbeat",
|
|
"tasks": [],
|
|
"current_timestamp": time.time(),
|
|
}
|
|
|
|
# If already has tasks, just coerce change_type
|
|
if isinstance(event_data.get("tasks"), list):
|
|
return {
|
|
"change_type": event_data.get("change_type", "update"),
|
|
"tasks": event_data["tasks"],
|
|
"current_timestamp": time.time(),
|
|
"trigger_reason": event_data.get("trigger_reason"),
|
|
}
|
|
|
|
# If it's a callback-shaped event
|
|
callback_data = event_data.get("callback_data")
|
|
task_id = event_data.get("task_id")
|
|
if callback_data and task_id:
|
|
task_obj = build_task_object_from_callback(task_id, callback_data)
|
|
if task_obj:
|
|
return {
|
|
"change_type": "update",
|
|
"tasks": [task_obj],
|
|
"current_timestamp": time.time(),
|
|
"trigger_reason": event_data.get("event_type", "callback_update"),
|
|
}
|
|
|
|
# Fallback to empty update
|
|
return {
|
|
"change_type": "update",
|
|
"tasks": [],
|
|
"current_timestamp": time.time(),
|
|
}
|
|
except Exception as e:
|
|
logger.error(f"Failed to standardize incoming event: {e}", exc_info=True)
|
|
return None
|
|
|
|
|
|
async def transform_callback_to_task_format(task_id: str, event_data: dict) -> dict:
|
|
"""Transform callback event data into the task format expected by frontend"""
|
|
try:
|
|
# Import here to avoid circular imports
|
|
from routes.utils.celery_tasks import get_task_info
|
|
|
|
# Get task info to build complete task object
|
|
task_info = get_task_info(task_id)
|
|
if not task_info:
|
|
logger.warning(f"SSE Transform: No task info found for {task_id}")
|
|
return None
|
|
|
|
# Extract callback data
|
|
callback_data = event_data.get("callback_data", {})
|
|
|
|
# Build task object in the format expected by frontend
|
|
task_object = {
|
|
"task_id": task_id,
|
|
"original_url": f"http://localhost:7171/api/{task_info.get('download_type', 'track')}/download/{task_info.get('url', '').split('/')[-1] if task_info.get('url') else ''}",
|
|
"last_line": callback_data, # This is what frontend expects for callback data
|
|
"timestamp": event_data.get("timestamp", time.time()),
|
|
"download_type": task_info.get("download_type", "track"),
|
|
"type": task_info.get("type", task_info.get("download_type", "track")),
|
|
"name": task_info.get("name", "Unknown"),
|
|
"artist": task_info.get("artist", ""),
|
|
"created_at": task_info.get("created_at"),
|
|
}
|
|
|
|
# Build minimal event data - global counts will be added at broadcast time
|
|
return {
|
|
"change_type": "update",
|
|
"tasks": [task_object], # Frontend expects tasks array
|
|
"current_timestamp": time.time(),
|
|
"updated_count": 1,
|
|
"since_timestamp": time.time(),
|
|
"trigger_reason": "callback_update",
|
|
}
|
|
|
|
except Exception as e:
|
|
logger.error(
|
|
f"SSE Transform: Error transforming callback for task {task_id}: {e}",
|
|
exc_info=True,
|
|
)
|
|
return None
|
|
|
|
|
|
# Start the Redis subscriber when module loads
|
|
start_sse_redis_subscriber()
|
|
|
|
|
|
async def trigger_sse_update(task_id: str, reason: str = "task_update"):
|
|
"""Trigger an immediate SSE update for a specific task"""
|
|
try:
|
|
current_time = time.time()
|
|
|
|
# Find the specific task that changed
|
|
task_info = get_task_info(task_id)
|
|
if not task_info:
|
|
logger.warning(f"SSE: Task {task_id} not found for update")
|
|
return
|
|
|
|
last_status = get_last_task_status(task_id)
|
|
|
|
# Create a dummy request for the _build_task_response function
|
|
class DummyRequest:
|
|
def __init__(self):
|
|
self.base_url = "http://localhost:7171"
|
|
|
|
dummy_request = DummyRequest()
|
|
task_response = _build_task_response(
|
|
task_info, last_status, task_id, current_time, dummy_request
|
|
)
|
|
|
|
# Create standardized event data - global counts will be added at broadcast time
|
|
event_data = {
|
|
"tasks": [task_response],
|
|
"current_timestamp": current_time,
|
|
"since_timestamp": current_time,
|
|
"change_type": "update",
|
|
"trigger_reason": reason,
|
|
}
|
|
|
|
await sse_broadcaster.broadcast_event(event_data)
|
|
logger.debug(f"SSE: Broadcast update for task {task_id} (reason: {reason})")
|
|
|
|
except Exception as e:
|
|
logger.error(f"SSE: Failed to trigger update for task {task_id}: {e}")
|
|
|
|
|
|
# Define active task states using ProgressState constants
|
|
ACTIVE_TASK_STATES = {
|
|
ProgressState.INITIALIZING, # "initializing" - task is starting up
|
|
ProgressState.PROCESSING, # "processing" - task is being processed
|
|
ProgressState.DOWNLOADING, # "downloading" - actively downloading
|
|
ProgressState.PROGRESS, # "progress" - album/playlist progress updates
|
|
ProgressState.TRACK_PROGRESS, # "track_progress" - real-time track progress
|
|
ProgressState.REAL_TIME, # "real_time" - real-time download progress
|
|
ProgressState.RETRYING, # "retrying" - task is retrying after error
|
|
"real-time", # "real-time" - real-time download progress (hyphenated version)
|
|
ProgressState.QUEUED, # "queued" - task is queued and waiting
|
|
"pending", # "pending" - legacy queued status
|
|
}
|
|
|
|
# Define terminal task states that should be included when recently completed
|
|
TERMINAL_TASK_STATES = {
|
|
ProgressState.COMPLETE, # "complete" - task completed successfully
|
|
ProgressState.DONE, # "done" - task finished processing
|
|
ProgressState.ERROR, # "error" - task failed
|
|
ProgressState.CANCELLED, # "cancelled" - task was cancelled
|
|
ProgressState.SKIPPED, # "skipped" - task was skipped
|
|
}
|
|
|
|
|
|
def get_task_status_from_last_status(last_status):
|
|
"""
|
|
Extract the task status from last_status, checking both possible locations.
|
|
Uses improved priority logic to handle real-time downloads correctly.
|
|
|
|
Args:
|
|
last_status: The last status dict from get_last_task_status()
|
|
|
|
Returns:
|
|
str: The task status string
|
|
"""
|
|
if not last_status:
|
|
return "unknown"
|
|
|
|
# For real-time downloads, prioritize status_info.status as it contains the actual progress state
|
|
status_info = last_status.get("status_info", {})
|
|
if isinstance(status_info, dict) and "status" in status_info:
|
|
status_info_status = status_info["status"]
|
|
# If status_info contains an active status, use it regardless of top-level status
|
|
if status_info_status in ACTIVE_TASK_STATES:
|
|
return status_info_status
|
|
|
|
# Fall back to top-level status
|
|
top_level_status = last_status.get("status", "unknown")
|
|
|
|
# If both exist but neither is active, prefer the more recent one (usually top-level)
|
|
# For active states, we already handled status_info above
|
|
return top_level_status
|
|
|
|
|
|
def is_task_active(task_status):
|
|
"""
|
|
Determine if a task is currently active (working/processing).
|
|
|
|
Args:
|
|
task_status: The status string from the task
|
|
|
|
Returns:
|
|
bool: True if the task is active, False otherwise
|
|
"""
|
|
if not task_status or task_status == "unknown":
|
|
return False
|
|
return task_status in ACTIVE_TASK_STATES
|
|
|
|
|
|
def get_global_task_counts():
|
|
"""
|
|
Get comprehensive task counts for ALL tasks in Redis.
|
|
This is called right before sending SSE events to ensure accurate counts.
|
|
|
|
Returns:
|
|
dict: Task counts by status
|
|
"""
|
|
task_counts = {
|
|
"active": 0,
|
|
"queued": 0,
|
|
"completed": 0,
|
|
"error": 0,
|
|
"cancelled": 0,
|
|
"retrying": 0,
|
|
"skipped": 0,
|
|
}
|
|
|
|
try:
|
|
# Get ALL tasks from Redis - this is the source of truth
|
|
all_tasks = get_all_tasks()
|
|
|
|
for task_summary in all_tasks:
|
|
task_id = task_summary.get("task_id")
|
|
if not task_id:
|
|
continue
|
|
|
|
task_info = get_task_info(task_id)
|
|
if not task_info:
|
|
continue
|
|
|
|
last_status = get_last_task_status(task_id)
|
|
task_status = get_task_status_from_last_status(last_status)
|
|
is_active_task = is_task_active(task_status)
|
|
|
|
# Categorize tasks by status using ProgressState constants
|
|
if task_status == ProgressState.RETRYING:
|
|
task_counts["retrying"] += 1
|
|
elif task_status in {ProgressState.QUEUED, "pending"}:
|
|
task_counts["queued"] += 1
|
|
elif task_status in {ProgressState.COMPLETE, ProgressState.DONE}:
|
|
task_counts["completed"] += 1
|
|
elif task_status == ProgressState.ERROR:
|
|
task_counts["error"] += 1
|
|
elif task_status == ProgressState.CANCELLED:
|
|
task_counts["cancelled"] += 1
|
|
elif task_status == ProgressState.SKIPPED:
|
|
task_counts["skipped"] += 1
|
|
elif is_active_task:
|
|
task_counts["active"] += 1
|
|
|
|
logger.debug(
|
|
f"Global task counts: {task_counts} (total: {len(all_tasks)} tasks)"
|
|
)
|
|
|
|
except Exception as e:
|
|
logger.error(f"Error getting global task counts: {e}", exc_info=True)
|
|
|
|
return task_counts
|
|
|
|
|
|
def add_global_task_counts_to_event(event_data):
|
|
"""
|
|
Add global task counts to any SSE event data right before broadcasting.
|
|
This ensures all SSE events have accurate, up-to-date counts of ALL tasks.
|
|
|
|
Args:
|
|
event_data: The event data dictionary to be sent via SSE
|
|
|
|
Returns:
|
|
dict: Enhanced event data with global task counts
|
|
"""
|
|
try:
|
|
# Get fresh counts of ALL tasks right before sending
|
|
global_task_counts = get_global_task_counts()
|
|
|
|
# Add/update the counts in the event data
|
|
event_data["task_counts"] = global_task_counts
|
|
event_data["active_tasks"] = global_task_counts["active"]
|
|
event_data["all_tasks_count"] = sum(global_task_counts.values())
|
|
|
|
# Ensure tasks array is present for schema consistency
|
|
if "tasks" not in event_data:
|
|
event_data["tasks"] = []
|
|
|
|
# Ensure change_type is present
|
|
if "change_type" not in event_data:
|
|
event_data["change_type"] = "update"
|
|
|
|
return event_data
|
|
|
|
except Exception as e:
|
|
logger.error(
|
|
f"Error adding global task counts to SSE event: {e}", exc_info=True
|
|
)
|
|
return event_data
|
|
|
|
|
|
def _build_error_callback_object(last_status):
|
|
"""
|
|
Constructs a structured error callback object based on the last status of a task.
|
|
This conforms to the CallbackObject types in the frontend.
|
|
"""
|
|
# The 'type' from the status update corresponds to the download_type (album, playlist, track)
|
|
download_type = last_status.get("type")
|
|
name = last_status.get("name")
|
|
# The 'artist' field from the status may contain artist names or a playlist owner's name
|
|
artist_or_owner = last_status.get("artist")
|
|
error_message = last_status.get("error", "An unknown error occurred.")
|
|
|
|
status_info = {"status": "error", "error": error_message}
|
|
|
|
callback_object = {"status_info": status_info}
|
|
|
|
if download_type == "album":
|
|
callback_object["album"] = {
|
|
"type": "album",
|
|
"title": name,
|
|
"artists": [{"type": "artistAlbum", "name": artist_or_owner}]
|
|
if artist_or_owner
|
|
else [],
|
|
}
|
|
elif download_type == "playlist":
|
|
playlist_payload = {"type": "playlist", "title": name}
|
|
if artist_or_owner:
|
|
playlist_payload["owner"] = {"type": "user", "name": artist_or_owner}
|
|
callback_object["playlist"] = playlist_payload
|
|
elif download_type == "track":
|
|
callback_object["track"] = {
|
|
"type": "track",
|
|
"title": name,
|
|
"artists": [{"type": "artistTrack", "name": artist_or_owner}]
|
|
if artist_or_owner
|
|
else [],
|
|
}
|
|
else:
|
|
# Fallback for unknown types to avoid breaking the client, returning a basic error structure.
|
|
return {
|
|
"status_info": status_info,
|
|
"unstructured_error": True,
|
|
"details": {
|
|
"type": download_type,
|
|
"name": name,
|
|
"artist_or_owner": artist_or_owner,
|
|
},
|
|
}
|
|
|
|
return callback_object
|
|
|
|
|
|
def _build_task_response(
|
|
task_info, last_status, task_id, current_time, request: Request
|
|
):
|
|
"""
|
|
Helper function to build a standardized task response object.
|
|
"""
|
|
# Dynamically construct original_url
|
|
dynamic_original_url = ""
|
|
download_type = task_info.get("download_type")
|
|
item_url = task_info.get("url")
|
|
|
|
if download_type and item_url:
|
|
try:
|
|
item_id = item_url.split("/")[-1]
|
|
if item_id:
|
|
base_url = (
|
|
str(request.base_url).rstrip("/")
|
|
if request
|
|
else "http://localhost:7171"
|
|
)
|
|
dynamic_original_url = (
|
|
f"{base_url}/api/{download_type}/download/{item_id}"
|
|
)
|
|
else:
|
|
logger.warning(
|
|
f"Could not extract item ID from URL: {item_url} for task {task_id}. Falling back for original_url."
|
|
)
|
|
original_request_obj = task_info.get("original_request", {})
|
|
dynamic_original_url = original_request_obj.get("original_url", "")
|
|
except Exception as e:
|
|
logger.error(
|
|
f"Error constructing dynamic original_url for task {task_id}: {e}",
|
|
exc_info=True,
|
|
)
|
|
original_request_obj = task_info.get("original_request", {})
|
|
dynamic_original_url = original_request_obj.get("original_url", "")
|
|
else:
|
|
logger.warning(
|
|
f"Missing download_type ('{download_type}') or item_url ('{item_url}') in task_info for task {task_id}. Falling back for original_url."
|
|
)
|
|
# Auto-delete faulty task data to keep the queue clean
|
|
try:
|
|
delete_task_data_and_log(
|
|
task_id,
|
|
reason="Auto-cleaned: Missing download_type or url in task_info.",
|
|
)
|
|
# Trigger SSE so clients refresh their task lists
|
|
try:
|
|
# Avoid circular import at top-level
|
|
import asyncio as _asyncio
|
|
|
|
# Fire-and-forget; if no event loop available, ignore
|
|
loop = _asyncio.get_event_loop()
|
|
if loop.is_running():
|
|
_asyncio.create_task(
|
|
trigger_sse_update(task_id, "auto_deleted_faulty")
|
|
)
|
|
except Exception:
|
|
pass
|
|
except Exception as _e:
|
|
logger.error(f"Auto-delete failed for faulty task {task_id}: {_e}")
|
|
original_request_obj = task_info.get("original_request", {})
|
|
dynamic_original_url = original_request_obj.get("original_url", "")
|
|
|
|
status_count = len(get_task_status(task_id))
|
|
|
|
# Determine last_line content
|
|
if last_status and "raw_callback" in last_status:
|
|
last_line_content = last_status["raw_callback"]
|
|
elif last_status and get_task_status_from_last_status(last_status) == "error":
|
|
last_line_content = _build_error_callback_object(last_status)
|
|
else:
|
|
last_line_content = last_status
|
|
|
|
# Normalize created_at to a numeric timestamp
|
|
created_at_value = task_info.get("created_at")
|
|
if not isinstance(created_at_value, (int, float)):
|
|
created_at_value = current_time
|
|
|
|
task_response = {
|
|
"original_url": dynamic_original_url,
|
|
"last_line": last_line_content,
|
|
"timestamp": last_status.get("timestamp") if last_status else current_time,
|
|
"task_id": task_id,
|
|
"status_count": status_count,
|
|
"created_at": created_at_value,
|
|
"name": task_info.get("name"),
|
|
"artist": task_info.get("artist"),
|
|
"type": task_info.get("type"),
|
|
"download_type": task_info.get("download_type"),
|
|
}
|
|
if last_status and last_status.get("summary"):
|
|
task_response["summary"] = last_status["summary"]
|
|
|
|
return task_response
|
|
|
|
|
|
async def get_paginated_tasks(
|
|
page=1, limit=20, active_only=False, request: Optional[Request] = None
|
|
):
|
|
"""
|
|
Get paginated list of tasks.
|
|
"""
|
|
try:
|
|
all_tasks = get_all_tasks()
|
|
|
|
# Get global task counts
|
|
task_counts = get_global_task_counts()
|
|
|
|
active_tasks = []
|
|
other_tasks = []
|
|
|
|
# Process tasks for pagination and response building
|
|
for task_summary in all_tasks:
|
|
task_id = task_summary.get("task_id")
|
|
if not task_id:
|
|
continue
|
|
|
|
task_info = get_task_info(task_id)
|
|
if not task_info:
|
|
continue
|
|
|
|
last_status = get_last_task_status(task_id)
|
|
task_status = get_task_status_from_last_status(last_status)
|
|
is_active_task = is_task_active(task_status)
|
|
|
|
task_response = _build_task_response(
|
|
task_info, last_status, task_id, time.time(), request
|
|
)
|
|
|
|
if is_active_task:
|
|
active_tasks.append(task_response)
|
|
else:
|
|
other_tasks.append(task_response)
|
|
|
|
# Sort other tasks by creation time (newest first)
|
|
other_tasks.sort(key=lambda x: (x.get("created_at") or 0.0), reverse=True)
|
|
|
|
if active_only:
|
|
paginated_tasks = active_tasks
|
|
pagination_info = {
|
|
"page": page,
|
|
"limit": limit,
|
|
"total_non_active": 0,
|
|
"has_more": False,
|
|
"returned_non_active": 0,
|
|
}
|
|
else:
|
|
# Apply pagination to non-active tasks
|
|
offset = (page - 1) * limit
|
|
paginated_other_tasks = other_tasks[offset : offset + limit]
|
|
paginated_tasks = active_tasks + paginated_other_tasks
|
|
|
|
pagination_info = {
|
|
"page": page,
|
|
"limit": limit,
|
|
"total_non_active": len(other_tasks),
|
|
"has_more": len(other_tasks) > offset + limit,
|
|
"returned_non_active": len(paginated_other_tasks),
|
|
}
|
|
|
|
response = {
|
|
"tasks": paginated_tasks,
|
|
"current_timestamp": time.time(),
|
|
"total_tasks": task_counts["active"]
|
|
+ task_counts["retrying"], # Only active/retrying tasks for counter
|
|
"all_tasks_count": len(all_tasks), # Total count of all tasks
|
|
"task_counts": task_counts, # Categorized counts
|
|
"active_tasks": len(active_tasks),
|
|
"updated_count": len(paginated_tasks),
|
|
"pagination": pagination_info,
|
|
}
|
|
|
|
return response
|
|
|
|
except Exception as e:
|
|
logger.error(f"Error in get_paginated_tasks: {e}", exc_info=True)
|
|
raise HTTPException(
|
|
status_code=500, detail={"error": "Failed to retrieve paginated tasks"}
|
|
)
|
|
|
|
|
|
# IMPORTANT: Specific routes MUST come before parameterized routes in FastAPI
|
|
# Otherwise "updates" gets matched as a {task_id} parameter!
|
|
|
|
|
|
@router.get("/list")
|
|
async def list_tasks(
|
|
request: Request, current_user: User = Depends(require_auth_from_state)
|
|
):
|
|
"""
|
|
Retrieve a paginated list of all tasks in the system.
|
|
Returns a detailed list of task objects including status and metadata.
|
|
|
|
Query parameters:
|
|
page (int): Page number for pagination (default: 1)
|
|
limit (int): Number of tasks per page (default: 50, max: 100)
|
|
active_only (bool): If true, only return active tasks (downloading, processing, etc.)
|
|
"""
|
|
try:
|
|
# Get query parameters
|
|
page = int(request.query_params.get("page", 1))
|
|
limit = min(int(request.query_params.get("limit", 50)), 100) # Cap at 100
|
|
active_only = request.query_params.get("active_only", "").lower() == "true"
|
|
|
|
tasks = get_all_tasks()
|
|
active_tasks = []
|
|
other_tasks = []
|
|
|
|
# Task categorization counters
|
|
task_counts = {
|
|
"active": 0,
|
|
"queued": 0,
|
|
"completed": 0,
|
|
"error": 0,
|
|
"cancelled": 0,
|
|
"retrying": 0,
|
|
"skipped": 0,
|
|
}
|
|
|
|
for task_summary in tasks:
|
|
task_id = task_summary.get("task_id")
|
|
if not task_id:
|
|
continue
|
|
|
|
task_info = get_task_info(task_id)
|
|
if not task_info:
|
|
continue
|
|
|
|
last_status = get_last_task_status(task_id)
|
|
task_status = get_task_status_from_last_status(last_status)
|
|
is_active_task = is_task_active(task_status)
|
|
|
|
# Categorize tasks by status using ProgressState constants
|
|
if task_status == ProgressState.RETRYING:
|
|
task_counts["retrying"] += 1
|
|
elif task_status in {
|
|
ProgressState.QUEUED,
|
|
"pending",
|
|
}: # Keep "pending" for backward compatibility
|
|
task_counts["queued"] += 1
|
|
elif task_status in {ProgressState.COMPLETE, ProgressState.DONE}:
|
|
task_counts["completed"] += 1
|
|
elif task_status == ProgressState.ERROR:
|
|
task_counts["error"] += 1
|
|
elif task_status == ProgressState.CANCELLED:
|
|
task_counts["cancelled"] += 1
|
|
elif task_status == ProgressState.SKIPPED:
|
|
task_counts["skipped"] += 1
|
|
elif is_active_task:
|
|
task_counts["active"] += 1
|
|
|
|
task_response = _build_task_response(
|
|
task_info, last_status, task_id, time.time(), request
|
|
)
|
|
|
|
if is_active_task:
|
|
active_tasks.append(task_response)
|
|
else:
|
|
other_tasks.append(task_response)
|
|
|
|
# Sort other tasks by creation time (newest first)
|
|
other_tasks.sort(key=lambda x: (x.get("created_at") or 0.0), reverse=True)
|
|
|
|
if active_only:
|
|
# Return only active tasks without pagination
|
|
response_tasks = active_tasks
|
|
pagination_info = {
|
|
"page": page,
|
|
"limit": limit,
|
|
"total_items": len(active_tasks),
|
|
"total_pages": 1,
|
|
"has_more": False,
|
|
}
|
|
else:
|
|
# Apply pagination to non-active tasks and combine with active tasks
|
|
offset = (page - 1) * limit
|
|
|
|
# Always include active tasks at the top
|
|
if page == 1:
|
|
# For first page, include active tasks + first batch of other tasks
|
|
available_space = limit - len(active_tasks)
|
|
paginated_other_tasks = other_tasks[: max(0, available_space)]
|
|
response_tasks = active_tasks + paginated_other_tasks
|
|
else:
|
|
# For subsequent pages, only include other tasks
|
|
# Adjust offset to account for active tasks shown on first page
|
|
adjusted_offset = offset - len(active_tasks)
|
|
if adjusted_offset < 0:
|
|
adjusted_offset = 0
|
|
paginated_other_tasks = other_tasks[
|
|
adjusted_offset : adjusted_offset + limit
|
|
]
|
|
response_tasks = paginated_other_tasks
|
|
|
|
total_items = len(active_tasks) + len(other_tasks)
|
|
total_pages = ((total_items - 1) // limit) + 1 if total_items > 0 else 1
|
|
|
|
pagination_info = {
|
|
"page": page,
|
|
"limit": limit,
|
|
"total_items": total_items,
|
|
"total_pages": total_pages,
|
|
"has_more": page < total_pages,
|
|
"active_tasks": len(active_tasks),
|
|
"total_other_tasks": len(other_tasks),
|
|
}
|
|
|
|
response = {
|
|
"tasks": response_tasks,
|
|
"pagination": pagination_info,
|
|
"total_tasks": task_counts["active"]
|
|
+ task_counts["retrying"], # Only active/retrying tasks for counter
|
|
"all_tasks_count": len(tasks), # Total count of all tasks
|
|
"task_counts": task_counts, # Categorized counts
|
|
"active_tasks": len(active_tasks),
|
|
"timestamp": time.time(),
|
|
}
|
|
|
|
return response
|
|
except Exception as e:
|
|
logger.error(f"Error in /api/prgs/list: {e}", exc_info=True)
|
|
raise HTTPException(
|
|
status_code=500, detail={"error": "Failed to retrieve task list"}
|
|
)
|
|
|
|
|
|
@router.get("/updates")
|
|
async def get_task_updates(
|
|
request: Request, current_user: User = Depends(require_auth_from_state)
|
|
):
|
|
"""
|
|
Retrieve only tasks that have been updated since the specified timestamp.
|
|
This endpoint is optimized for polling to reduce unnecessary data transfer.
|
|
|
|
Query parameters:
|
|
since (float): Unix timestamp - only return tasks updated after this time
|
|
page (int): Page number for pagination (default: 1)
|
|
limit (int): Number of queued/completed tasks per page (default: 20, max: 100)
|
|
active_only (bool): If true, only return active tasks (downloading, processing, etc.)
|
|
|
|
Returns:
|
|
JSON object containing:
|
|
- tasks: Array of updated task objects
|
|
- current_timestamp: Current server timestamp for next poll
|
|
- total_tasks: Total number of tasks in system
|
|
- active_tasks: Number of active tasks
|
|
- pagination: Pagination info for queued/completed tasks
|
|
"""
|
|
try:
|
|
# Get query parameters
|
|
since_param = request.query_params.get("since")
|
|
page = int(request.query_params.get("page", 1))
|
|
limit = min(int(request.query_params.get("limit", 20)), 100) # Cap at 100
|
|
active_only = request.query_params.get("active_only", "").lower() == "true"
|
|
|
|
if not since_param:
|
|
# If no 'since' parameter, return paginated tasks (fallback behavior)
|
|
response = await get_paginated_tasks(page, limit, active_only, request)
|
|
return response
|
|
|
|
try:
|
|
since_timestamp = float(since_param)
|
|
except (ValueError, TypeError):
|
|
raise HTTPException(
|
|
status_code=400, detail={"error": "Invalid 'since' timestamp format"}
|
|
)
|
|
|
|
# Get all tasks
|
|
all_tasks = get_all_tasks()
|
|
current_time = time.time()
|
|
|
|
# Get global task counts
|
|
task_counts = get_global_task_counts()
|
|
|
|
updated_tasks = []
|
|
active_tasks = []
|
|
|
|
# Process tasks for filtering and response building
|
|
for task_summary in all_tasks:
|
|
task_id = task_summary.get("task_id")
|
|
if not task_id:
|
|
continue
|
|
|
|
task_info = get_task_info(task_id)
|
|
if not task_info:
|
|
continue
|
|
|
|
last_status = get_last_task_status(task_id)
|
|
task_status = get_task_status_from_last_status(last_status)
|
|
is_active_task = is_task_active(task_status)
|
|
|
|
# Check if task has been updated since the given timestamp
|
|
task_timestamp = (
|
|
last_status.get("timestamp")
|
|
if last_status
|
|
else task_info.get("created_at", 0)
|
|
)
|
|
|
|
# Always include active tasks in updates, apply filtering to others
|
|
# Also include recently completed/terminal tasks to ensure "done" status gets sent
|
|
is_recently_terminal = (
|
|
task_status in TERMINAL_TASK_STATES and task_timestamp > since_timestamp
|
|
)
|
|
should_include = (
|
|
is_active_task
|
|
or (task_timestamp > since_timestamp and not active_only)
|
|
or is_recently_terminal
|
|
)
|
|
|
|
if should_include:
|
|
# Construct the same detailed task object as in list_tasks()
|
|
task_response = _build_task_response(
|
|
task_info, last_status, task_id, current_time, request
|
|
)
|
|
|
|
if is_active_task:
|
|
active_tasks.append(task_response)
|
|
else:
|
|
updated_tasks.append(task_response)
|
|
|
|
# Apply pagination to non-active tasks
|
|
offset = (page - 1) * limit
|
|
paginated_updated_tasks = (
|
|
updated_tasks[offset : offset + limit] if not active_only else []
|
|
)
|
|
|
|
# Combine active tasks (always shown) with paginated updated tasks
|
|
all_returned_tasks = active_tasks + paginated_updated_tasks
|
|
|
|
# Sort by priority (active first, then by creation time)
|
|
all_returned_tasks.sort(
|
|
key=lambda x: (
|
|
0 if x.get("task_id") in [t["task_id"] for t in active_tasks] else 1,
|
|
-(x.get("created_at") or 0),
|
|
)
|
|
)
|
|
|
|
response = {
|
|
"tasks": all_returned_tasks,
|
|
"current_timestamp": current_time,
|
|
"total_tasks": task_counts["active"]
|
|
+ task_counts["retrying"], # Only active/retrying tasks for counter
|
|
"all_tasks_count": len(all_tasks), # Total count of all tasks
|
|
"task_counts": task_counts, # Categorized counts
|
|
"active_tasks": len(active_tasks),
|
|
"updated_count": len(updated_tasks),
|
|
"since_timestamp": since_timestamp,
|
|
"pagination": {
|
|
"page": page,
|
|
"limit": limit,
|
|
"total_non_active": len(updated_tasks),
|
|
"has_more": len(updated_tasks) > offset + limit,
|
|
"returned_non_active": len(paginated_updated_tasks),
|
|
},
|
|
}
|
|
|
|
logger.debug(
|
|
f"Returning {len(active_tasks)} active + {len(paginated_updated_tasks)} paginated tasks out of {len(all_tasks)} total"
|
|
)
|
|
return response
|
|
|
|
except HTTPException:
|
|
raise
|
|
except Exception as e:
|
|
logger.error(f"Error in /api/prgs/updates: {e}", exc_info=True)
|
|
raise HTTPException(
|
|
status_code=500, detail={"error": "Failed to retrieve task updates"}
|
|
)
|
|
|
|
|
|
@router.post("/cancel/all")
|
|
async def cancel_all_tasks(current_user: User = Depends(require_auth_from_state)):
|
|
"""
|
|
Cancel all active (running or queued) tasks.
|
|
"""
|
|
try:
|
|
tasks_to_cancel = get_all_tasks()
|
|
cancelled_count = 0
|
|
errors = []
|
|
|
|
for task_summary in tasks_to_cancel:
|
|
task_id = task_summary.get("task_id")
|
|
if not task_id:
|
|
continue
|
|
try:
|
|
cancel_task(task_id)
|
|
cancelled_count += 1
|
|
except Exception as e:
|
|
error_message = f"Failed to cancel task {task_id}: {e}"
|
|
logger.error(error_message)
|
|
errors.append(error_message)
|
|
|
|
response = {
|
|
"message": f"Attempted to cancel all active tasks. {cancelled_count} tasks cancelled.",
|
|
"cancelled_count": cancelled_count,
|
|
"errors": errors,
|
|
}
|
|
return response
|
|
except Exception as e:
|
|
logger.error(f"Error in /api/prgs/cancel/all: {e}", exc_info=True)
|
|
raise HTTPException(
|
|
status_code=500, detail={"error": "Failed to cancel all tasks"}
|
|
)
|
|
|
|
|
|
@router.post("/cancel/{task_id}")
|
|
async def cancel_task_endpoint(
|
|
task_id: str, current_user: User = Depends(require_auth_from_state)
|
|
):
|
|
"""
|
|
Cancel a running or queued task.
|
|
|
|
Args:
|
|
task_id: The ID of the task to cancel
|
|
"""
|
|
try:
|
|
# First check if this is a task ID in the new system
|
|
task_info = get_task_info(task_id)
|
|
|
|
if task_info:
|
|
# This is a task ID in the new system
|
|
result = cancel_task(task_id)
|
|
try:
|
|
# Push an immediate SSE update so clients reflect cancellation and partial summary
|
|
await trigger_sse_update(task_id, "cancelled")
|
|
result["sse_notified"] = True
|
|
except Exception as e:
|
|
logger.error(f"SSE notify after cancel failed for {task_id}: {e}")
|
|
return result
|
|
|
|
# If not found in new system, we need to handle the old system cancellation
|
|
# For now, return an error as we're transitioning to the new system
|
|
raise HTTPException(
|
|
status_code=400,
|
|
detail={
|
|
"status": "error",
|
|
"message": "Cancellation for old system is not supported in the new API. Please use the new task ID format.",
|
|
},
|
|
)
|
|
except HTTPException:
|
|
raise
|
|
except Exception as e:
|
|
raise HTTPException(status_code=500, detail=f"An error occurred: {e}")
|
|
|
|
|
|
@router.delete("/delete/{task_id}")
|
|
async def delete_task(
|
|
task_id: str, current_user: User = Depends(require_auth_from_state)
|
|
):
|
|
"""
|
|
Delete a task's information and history.
|
|
|
|
Args:
|
|
task_id: A task UUID from Celery
|
|
"""
|
|
# Only support new task IDs
|
|
task_info = get_task_info(task_id)
|
|
if not task_info:
|
|
raise HTTPException(status_code=404, detail="Task not found")
|
|
|
|
# First, cancel the task if it's running
|
|
cancel_task(task_id)
|
|
|
|
return {"message": f"Task {task_id} deleted successfully"}
|
|
|
|
|
|
@router.get("/stream")
|
|
async def stream_task_updates(
|
|
request: Request, current_user: User = Depends(get_current_user_from_state)
|
|
):
|
|
"""
|
|
Stream real-time task updates via Server-Sent Events (SSE).
|
|
Now uses event-driven architecture for true real-time updates.
|
|
Uses optional authentication to avoid breaking SSE connections.
|
|
|
|
Query parameters:
|
|
active_only (bool): If true, only stream active tasks (downloading, processing, etc.)
|
|
|
|
Returns:
|
|
Server-Sent Events stream with task update data in JSON format
|
|
"""
|
|
|
|
# Get query parameters
|
|
active_only = request.query_params.get("active_only", "").lower() == "true"
|
|
|
|
async def event_generator():
|
|
# Create a queue for this client
|
|
client_queue = asyncio.Queue()
|
|
|
|
try:
|
|
# Register this client with the broadcaster
|
|
logger.debug("SSE Stream: New client connecting...")
|
|
await sse_broadcaster.add_client(client_queue)
|
|
logger.debug(
|
|
f"SSE Stream: Client registered successfully, total clients: {len(sse_broadcaster.clients)}"
|
|
)
|
|
|
|
# Send initial data immediately upon connection (standardized 'update')
|
|
initial_data = await generate_task_update_event(
|
|
time.time(), active_only, request
|
|
)
|
|
yield initial_data
|
|
|
|
# Send periodic heartbeats and listen for real-time events
|
|
last_heartbeat = time.time()
|
|
heartbeat_interval = 30.0
|
|
|
|
while True:
|
|
try:
|
|
# Wait for either an event or timeout for heartbeat
|
|
try:
|
|
event_data = await asyncio.wait_for(
|
|
client_queue.get(), timeout=heartbeat_interval
|
|
)
|
|
# Send the real-time event
|
|
yield event_data
|
|
last_heartbeat = time.time()
|
|
except asyncio.TimeoutError:
|
|
# Send heartbeat if no events for a while
|
|
current_time = time.time()
|
|
if current_time - last_heartbeat >= heartbeat_interval:
|
|
# Generate current task counts for heartbeat
|
|
all_tasks = get_all_tasks()
|
|
task_counts = {
|
|
"active": 0,
|
|
"queued": 0,
|
|
"completed": 0,
|
|
"error": 0,
|
|
"cancelled": 0,
|
|
"retrying": 0,
|
|
"skipped": 0,
|
|
}
|
|
|
|
for task_summary in all_tasks:
|
|
task_id = task_summary.get("task_id")
|
|
if not task_id:
|
|
continue
|
|
task_info = get_task_info(task_id)
|
|
if not task_info:
|
|
continue
|
|
last_status = get_last_task_status(task_id)
|
|
task_status = get_task_status_from_last_status(
|
|
last_status
|
|
)
|
|
|
|
if task_status == ProgressState.RETRYING:
|
|
task_counts["retrying"] += 1
|
|
elif task_status in {ProgressState.QUEUED, "pending"}:
|
|
task_counts["queued"] += 1
|
|
elif task_status in {
|
|
ProgressState.COMPLETE,
|
|
ProgressState.DONE,
|
|
}:
|
|
task_counts["completed"] += 1
|
|
elif task_status == ProgressState.ERROR:
|
|
task_counts["error"] += 1
|
|
elif task_status == ProgressState.CANCELLED:
|
|
task_counts["cancelled"] += 1
|
|
elif task_status == ProgressState.SKIPPED:
|
|
task_counts["skipped"] += 1
|
|
elif is_task_active(task_status):
|
|
task_counts["active"] += 1
|
|
|
|
heartbeat_data = {
|
|
"current_timestamp": current_time,
|
|
"total_tasks": task_counts["active"]
|
|
+ task_counts["retrying"],
|
|
"task_counts": task_counts,
|
|
"change_type": "heartbeat",
|
|
"tasks": [],
|
|
}
|
|
|
|
event_json = json.dumps(heartbeat_data)
|
|
yield f"data: {event_json}\n\n"
|
|
last_heartbeat = current_time
|
|
|
|
except Exception as e:
|
|
logger.error(f"Error in SSE event streaming: {e}", exc_info=True)
|
|
# Send error event and continue
|
|
error_data = json.dumps(
|
|
{
|
|
"error": "Internal server error",
|
|
"timestamp": time.time(),
|
|
"change_type": "error",
|
|
"tasks": [],
|
|
}
|
|
)
|
|
yield f"data: {error_data}\n\n"
|
|
await asyncio.sleep(1)
|
|
|
|
except asyncio.CancelledError:
|
|
logger.debug("SSE client disconnected")
|
|
return
|
|
except Exception as e:
|
|
logger.error(f"SSE connection error: {e}", exc_info=True)
|
|
return
|
|
finally:
|
|
# Clean up - remove client from broadcaster
|
|
await sse_broadcaster.remove_client(client_queue)
|
|
|
|
return StreamingResponse(
|
|
event_generator(),
|
|
media_type="text/plain",
|
|
headers={
|
|
"Cache-Control": "no-cache",
|
|
"Connection": "keep-alive",
|
|
"Content-Type": "text/event-stream",
|
|
"Access-Control-Allow-Origin": "*",
|
|
"Access-Control-Allow-Headers": "Cache-Control",
|
|
},
|
|
)
|
|
|
|
|
|
async def generate_task_update_event(
|
|
since_timestamp: float, active_only: bool, request: Request
|
|
) -> str:
|
|
"""
|
|
Generate initial task update event for SSE connection.
|
|
This replicates the logic from get_task_updates but for SSE format.
|
|
"""
|
|
try:
|
|
# Get all tasks for filtering
|
|
all_tasks = get_all_tasks()
|
|
current_time = time.time()
|
|
|
|
updated_tasks = []
|
|
active_tasks = []
|
|
|
|
# Process tasks for filtering only - no counting here
|
|
for task_summary in all_tasks:
|
|
task_id = task_summary.get("task_id")
|
|
if not task_id:
|
|
continue
|
|
|
|
task_info = get_task_info(task_id)
|
|
if not task_info:
|
|
continue
|
|
|
|
last_status = get_last_task_status(task_id)
|
|
task_status = get_task_status_from_last_status(last_status)
|
|
is_active_task = is_task_active(task_status)
|
|
|
|
# Check if task has been updated since the given timestamp
|
|
task_timestamp = (
|
|
last_status.get("timestamp")
|
|
if last_status
|
|
else task_info.get("created_at", 0)
|
|
)
|
|
|
|
# Always include active tasks in updates, apply filtering to others
|
|
# Also include recently completed/terminal tasks to ensure "done" status gets sent
|
|
is_recently_terminal = (
|
|
task_status in TERMINAL_TASK_STATES and task_timestamp > since_timestamp
|
|
)
|
|
should_include = (
|
|
is_active_task
|
|
or (task_timestamp > since_timestamp and not active_only)
|
|
or is_recently_terminal
|
|
)
|
|
|
|
if should_include:
|
|
# Construct the same detailed task object as in updates endpoint
|
|
task_response = _build_task_response(
|
|
task_info, last_status, task_id, current_time, request
|
|
)
|
|
|
|
if is_active_task:
|
|
active_tasks.append(task_response)
|
|
else:
|
|
updated_tasks.append(task_response)
|
|
|
|
# Combine active tasks (always shown) with updated tasks
|
|
all_returned_tasks = active_tasks + updated_tasks
|
|
|
|
# Sort by priority (active first, then by creation time)
|
|
all_returned_tasks.sort(
|
|
key=lambda x: (
|
|
0 if x.get("task_id") in [t["task_id"] for t in active_tasks] else 1,
|
|
-(x.get("created_at") or 0),
|
|
)
|
|
)
|
|
|
|
initial_data = {
|
|
"tasks": all_returned_tasks,
|
|
"current_timestamp": current_time,
|
|
"updated_count": len(updated_tasks),
|
|
"since_timestamp": since_timestamp,
|
|
"change_type": "update",
|
|
"initial": True, # Mark as initial load
|
|
}
|
|
|
|
# Add global task counts since this bypasses the broadcaster
|
|
enhanced_data = add_global_task_counts_to_event(initial_data)
|
|
|
|
event_data = json.dumps(enhanced_data)
|
|
return f"data: {event_data}\n\n"
|
|
|
|
except Exception as e:
|
|
logger.error(f"Error generating initial SSE event: {e}", exc_info=True)
|
|
error_data = json.dumps(
|
|
{
|
|
"error": "Failed to load initial data",
|
|
"timestamp": time.time(),
|
|
"tasks": [],
|
|
"change_type": "error",
|
|
}
|
|
)
|
|
return f"data: {error_data}\n\n"
|
|
|
|
|
|
# IMPORTANT: This parameterized route MUST come AFTER all specific routes
|
|
# Otherwise FastAPI will match specific routes like "/updates" as task_id parameters
|
|
@router.get("/{task_id}")
|
|
async def get_task_details(
|
|
task_id: str,
|
|
request: Request,
|
|
current_user: User = Depends(require_auth_from_state),
|
|
):
|
|
"""
|
|
Return a JSON object with the resource type, its name (title),
|
|
the last progress update, and, if available, the original request parameters.
|
|
|
|
This function works with the new task ID based system.
|
|
|
|
Args:
|
|
task_id: A task UUID from Celery
|
|
"""
|
|
# Only support new task IDs
|
|
task_info = get_task_info(task_id)
|
|
if not task_info:
|
|
raise HTTPException(status_code=404, detail="Task not found")
|
|
|
|
# Dynamically construct original_url
|
|
dynamic_original_url = ""
|
|
download_type = task_info.get("download_type")
|
|
# The 'url' field in task_info stores the Spotify/Deezer URL of the item
|
|
# e.g., https://open.spotify.com/album/albumId or https://www.deezer.com/track/trackId
|
|
item_url = task_info.get("url")
|
|
|
|
if download_type and item_url:
|
|
try:
|
|
# Extract the ID from the item_url (last part of the path)
|
|
item_id = item_url.split("/")[-1]
|
|
if item_id: # Ensure item_id is not empty
|
|
base_url = str(request.base_url).rstrip("/")
|
|
dynamic_original_url = (
|
|
f"{base_url}/api/{download_type}/download/{item_id}"
|
|
)
|
|
else:
|
|
logger.warning(
|
|
f"Could not extract item ID from URL: {item_url} for task {task_id}. Falling back for original_url."
|
|
)
|
|
original_request_obj = task_info.get("original_request", {})
|
|
dynamic_original_url = original_request_obj.get("original_url", "")
|
|
except Exception as e:
|
|
logger.error(
|
|
f"Error constructing dynamic original_url for task {task_id}: {e}",
|
|
exc_info=True,
|
|
)
|
|
original_request_obj = task_info.get("original_request", {})
|
|
dynamic_original_url = original_request_obj.get(
|
|
"original_url", ""
|
|
) # Fallback on any error
|
|
else:
|
|
logger.warning(
|
|
f"Missing download_type ('{download_type}') or item_url ('{item_url}') in task_info for task {task_id}. Falling back for original_url."
|
|
)
|
|
original_request_obj = task_info.get("original_request", {})
|
|
dynamic_original_url = original_request_obj.get("original_url", "")
|
|
|
|
last_status = get_last_task_status(task_id)
|
|
status_count = len(get_task_status(task_id))
|
|
|
|
# Determine last_line content
|
|
if last_status and "raw_callback" in last_status:
|
|
last_line_content = last_status["raw_callback"]
|
|
elif last_status and get_task_status_from_last_status(last_status) == "error":
|
|
last_line_content = _build_error_callback_object(last_status)
|
|
else:
|
|
# Fallback for non-error, no raw_callback, or if last_status is None
|
|
last_line_content = last_status
|
|
|
|
response = {
|
|
"original_url": dynamic_original_url,
|
|
"last_line": last_line_content,
|
|
"timestamp": last_status.get("timestamp") if last_status else time.time(),
|
|
"task_id": task_id,
|
|
"status_count": status_count,
|
|
"created_at": task_info.get("created_at"),
|
|
"name": task_info.get("name"),
|
|
"artist": task_info.get("artist"),
|
|
"type": task_info.get("type"),
|
|
"download_type": task_info.get("download_type"),
|
|
}
|
|
if last_status and last_status.get("summary"):
|
|
response["summary"] = last_status["summary"]
|
|
return response
|