From 9fdc0bde42c7e37892f69633cb57c1a05f184aa0 Mon Sep 17 00:00:00 2001 From: Xoconoch Date: Sat, 2 Aug 2025 12:46:36 -0600 Subject: [PATCH] Finally implemented SSE --- app.py | 315 +++++----- requirements.txt | 8 +- routes/album.py | 105 ++-- routes/artist.py | 303 ++++----- routes/config.py | 111 ++-- routes/credentials.py | 271 ++++---- routes/history.py | 244 ++++---- routes/playlist.py | 387 ++++++------ routes/prgs.py | 653 ++++++++++++++------ routes/search.py | 116 ++-- routes/track.py | 168 ++--- spotizerr-ui/src/contexts/QueueProvider.tsx | 205 +++--- 12 files changed, 1588 insertions(+), 1298 deletions(-) diff --git a/app.py b/app.py index 6c8fa73..753b9e7 100755 --- a/app.py +++ b/app.py @@ -1,14 +1,8 @@ -from flask import Flask, request, send_from_directory -from flask_cors import CORS -from routes.search import search_bp -from routes.credentials import credentials_bp -from routes.album import album_bp -from routes.track import track_bp -from routes.playlist import playlist_bp -from routes.prgs import prgs_bp -from routes.config import config_bp -from routes.artist import artist_bp -from routes.history import history_bp +from fastapi import FastAPI, Request, HTTPException +from fastapi.middleware.cors import CORSMiddleware +from fastapi.staticfiles import StaticFiles +from fastapi.responses import FileResponse +from contextlib import asynccontextmanager import logging import logging.handlers import time @@ -20,10 +14,24 @@ import redis import socket from urllib.parse import urlparse +# Import route routers (to be created) +from routes.search import router as search_router +from routes.credentials import router as credentials_router +from routes.album import router as album_router +from routes.track import router as track_router +from routes.playlist import router as playlist_router +from routes.prgs import router as prgs_router +from routes.config import router as config_router +from routes.artist import router as artist_router +from routes.history import router as history_router + # Import Celery configuration and manager from routes.utils.celery_manager import celery_manager from routes.utils.celery_config import REDIS_URL +# Import and initialize routes (this will start the watch manager) +import routes + # Configure application-wide logging def setup_logging(): @@ -66,175 +74,178 @@ def setup_logging(): root_logger.addHandler(console_handler) # Set up specific loggers - for logger_name in ["werkzeug", "celery", "routes", "flask", "waitress"]: - module_logger = logging.getLogger(logger_name) - module_logger.setLevel(logging.INFO) - # Handlers are inherited from root logger + for logger_name in [ + "routes", + "routes.utils", + "routes.utils.celery_manager", + "routes.utils.celery_tasks", + "routes.utils.watch", + ]: + logger = logging.getLogger(logger_name) + logger.setLevel(logging.INFO) + logger.propagate = True # Propagate to root logger - # Enable propagation for all loggers - logging.getLogger("celery").propagate = True - - # Notify successful setup - root_logger.info("Logging system initialized") - - # Return the main file handler for permissions adjustment - return file_handler + logging.info("Logging system initialized") def check_redis_connection(): - """Check if Redis is reachable and retry with exponential backoff if not""" - max_retries = 5 - retry_count = 0 - retry_delay = 1 # start with 1 second + """Check if Redis is available and accessible""" + if not REDIS_URL: + logging.error("REDIS_URL is not configured. Please check your environment.") + return False - # Extract host and port from REDIS_URL - redis_host = "redis" # default - redis_port = 6379 # default + try: + # Parse Redis URL + parsed_url = urlparse(REDIS_URL) + host = parsed_url.hostname or "localhost" + port = parsed_url.port or 6379 - # Parse from REDIS_URL if possible - if REDIS_URL: - # parse hostname and port (handles optional auth) - try: - parsed = urlparse(REDIS_URL) - if parsed.hostname: - redis_host = parsed.hostname - if parsed.port: - redis_port = parsed.port - except Exception: - pass + logging.info(f"Testing Redis connection to {host}:{port}...") - # Log Redis connection details - logging.info(f"Checking Redis connection to {redis_host}:{redis_port}") + # Test socket connection first + sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM) + sock.settimeout(5) + result = sock.connect_ex((host, port)) + sock.close() - while retry_count < max_retries: - try: - # First try socket connection to check if Redis port is open - sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM) - sock.settimeout(2) - result = sock.connect_ex((redis_host, redis_port)) - sock.close() + if result != 0: + logging.error(f"Cannot connect to Redis at {host}:{port}") + return False - if result != 0: - raise ConnectionError( - f"Cannot connect to Redis at {redis_host}:{redis_port}" - ) + # Test Redis client connection + r = redis.from_url(REDIS_URL, socket_connect_timeout=5, socket_timeout=5) + r.ping() + logging.info("Redis connection successful") + return True - # If socket connection successful, try Redis ping - r = redis.Redis.from_url(REDIS_URL) - r.ping() - logging.info("Successfully connected to Redis") - return True - except Exception as e: - retry_count += 1 - if retry_count >= max_retries: - logging.error( - f"Failed to connect to Redis after {max_retries} attempts: {e}" - ) - logging.error( - f"Make sure Redis is running at {redis_host}:{redis_port}" - ) - return False + except redis.ConnectionError as e: + logging.error(f"Redis connection error: {e}") + return False + except redis.TimeoutError as e: + logging.error(f"Redis timeout error: {e}") + return False + except Exception as e: + logging.error(f"Unexpected error checking Redis connection: {e}") + return False - logging.warning(f"Redis connection attempt {retry_count} failed: {e}") - logging.info(f"Retrying in {retry_delay} seconds...") - time.sleep(retry_delay) - retry_delay *= 2 # exponential backoff - return False +@asynccontextmanager +async def lifespan(app: FastAPI): + """Handle application startup and shutdown""" + # Startup + setup_logging() + + # Check Redis connection + if not check_redis_connection(): + logging.error("Failed to connect to Redis. Please ensure Redis is running and accessible.") + # Don't exit, but warn - some functionality may not work + + # Start Celery workers + try: + celery_manager.start() + logging.info("Celery workers started successfully") + except Exception as e: + logging.error(f"Failed to start Celery workers: {e}") + + yield + + # Shutdown + try: + celery_manager.stop() + logging.info("Celery workers stopped") + except Exception as e: + logging.error(f"Error stopping Celery workers: {e}") def create_app(): - app = Flask(__name__, static_folder="spotizerr-ui/dist", static_url_path="/") + app = FastAPI( + title="Spotizerr API", + description="Music download service API", + version="1.0.0", + lifespan=lifespan + ) # Set up CORS - CORS(app) + app.add_middleware( + CORSMiddleware, + allow_origins=["*"], + allow_credentials=True, + allow_methods=["*"], + allow_headers=["*"], + ) - # Register blueprints - app.register_blueprint(config_bp, url_prefix="/api") - app.register_blueprint(search_bp, url_prefix="/api") - app.register_blueprint(credentials_bp, url_prefix="/api/credentials") - app.register_blueprint(album_bp, url_prefix="/api/album") - app.register_blueprint(track_bp, url_prefix="/api/track") - app.register_blueprint(playlist_bp, url_prefix="/api/playlist") - app.register_blueprint(artist_bp, url_prefix="/api/artist") - app.register_blueprint(prgs_bp, url_prefix="/api/prgs") - app.register_blueprint(history_bp, url_prefix="/api/history") - - # Serve React App - @app.route("/", defaults={"path": ""}) - @app.route("/") - def serve_react_app(path): - if path != "" and os.path.exists(os.path.join(app.static_folder, path)): - return send_from_directory(app.static_folder, path) - else: - return send_from_directory(app.static_folder, "index.html") + # Register routers with URL prefixes + app.include_router(config_router, prefix="/api", tags=["config"]) + app.include_router(search_router, prefix="/api", tags=["search"]) + app.include_router(credentials_router, prefix="/api/credentials", tags=["credentials"]) + app.include_router(album_router, prefix="/api/album", tags=["album"]) + app.include_router(track_router, prefix="/api/track", tags=["track"]) + app.include_router(playlist_router, prefix="/api/playlist", tags=["playlist"]) + app.include_router(artist_router, prefix="/api/artist", tags=["artist"]) + app.include_router(prgs_router, prefix="/api/prgs", tags=["progress"]) + app.include_router(history_router, prefix="/api/history", tags=["history"]) # Add request logging middleware - @app.before_request - def log_request(): - request.start_time = time.time() - app.logger.debug(f"Request: {request.method} {request.path}") + @app.middleware("http") + async def log_requests(request: Request, call_next): + start_time = time.time() + + # Log request + logger = logging.getLogger("uvicorn.access") + logger.debug(f"Request: {request.method} {request.url.path}") + + try: + response = await call_next(request) + + # Log response + duration = round((time.time() - start_time) * 1000, 2) + logger.debug(f"Response: {response.status_code} | Duration: {duration}ms") + + return response + except Exception as e: + # Log errors + logger.error(f"Server error: {str(e)}", exc_info=True) + raise HTTPException(status_code=500, detail="Internal Server Error") - @app.after_request - def log_response(response): - if hasattr(request, "start_time"): - duration = round((time.time() - request.start_time) * 1000, 2) - app.logger.debug(f"Response: {response.status} | Duration: {duration}ms") - return response - - # Error logging - @app.errorhandler(Exception) - def handle_exception(e): - app.logger.error(f"Server error: {str(e)}", exc_info=True) - return "Internal Server Error", 500 + # Mount static files for React app + if os.path.exists("spotizerr-ui/dist"): + app.mount("/static", StaticFiles(directory="spotizerr-ui/dist"), name="static") + + # Serve React App - catch-all route for SPA + @app.get("/{full_path:path}") + async def serve_react_app(full_path: str): + """Serve React app with fallback to index.html for SPA routing""" + static_dir = "spotizerr-ui/dist" + + # If it's a file that exists, serve it + if full_path and os.path.exists(os.path.join(static_dir, full_path)): + return FileResponse(os.path.join(static_dir, full_path)) + else: + # Fallback to index.html for SPA routing + return FileResponse(os.path.join(static_dir, "index.html")) + else: + logging.warning("React app build directory not found at spotizerr-ui/dist") return app def start_celery_workers(): """Start Celery workers with dynamic configuration""" - logging.info("Starting Celery workers with dynamic configuration") - celery_manager.start() - - # Register shutdown handler - atexit.register(celery_manager.stop) + # This function is now handled by the lifespan context manager + # and the celery_manager.start() call + pass if __name__ == "__main__": - # Configure application logging - log_handler = setup_logging() - - # Set permissions for log file - try: - if os.name != "nt": # Not Windows - os.chmod(log_handler.baseFilename, 0o666) - except Exception as e: - logging.warning(f"Could not set permissions on log file: {e}") - - # Check Redis connection before starting - if not check_redis_connection(): - logging.error("Exiting: Could not establish Redis connection.") - sys.exit(1) - - # Start Celery workers in a separate thread - start_celery_workers() - - # Clean up Celery workers on exit - atexit.register(celery_manager.stop) - - # Create Flask app + import uvicorn + app = create_app() - - # Get host and port from environment variables or use defaults - host = os.environ.get("HOST", "0.0.0.0") - port = int(os.environ.get("PORT", 7171)) - - # Use Flask's built-in server for development - # logging.info(f"Starting Flask development server on http://{host}:{port}") - # app.run(host=host, port=port, debug=True) - - # The following uses Waitress, a production-ready server. - # To use it, comment out the app.run() line above and uncomment the lines below. - logging.info(f"Starting server with Waitress on http://{host}:{port}") - from waitress import serve - serve(app, host=host, port=port) + + # Run with uvicorn + uvicorn.run( + app, + host="0.0.0.0", + port=7171, + log_level="info", + access_log=True + ) diff --git a/requirements.txt b/requirements.txt index bc3c2bf..2446cc2 100755 --- a/requirements.txt +++ b/requirements.txt @@ -1,5 +1,5 @@ -waitress==3.0.2 +fastapi==0.115.6 +uvicorn[standard]==0.32.1 celery==5.5.3 -Flask==3.1.1 -flask_cors==6.0.0 -deezspot-spotizerr==2.2.0 \ No newline at end of file +deezspot-spotizerr==2.2.0 +httpx \ No newline at end of file diff --git a/routes/album.py b/routes/album.py index 7b0fc19..e839da2 100755 --- a/routes/album.py +++ b/routes/album.py @@ -1,4 +1,5 @@ -from flask import Blueprint, Response, request +from fastapi import APIRouter, HTTPException, Request +from fastapi.responses import JSONResponse import json import traceback import uuid @@ -8,7 +9,7 @@ from routes.utils.celery_tasks import store_task_info, store_task_status, Progre from routes.utils.get_info import get_spotify_info from routes.utils.errors import DuplicateDownloadError -album_bp = Blueprint("album", __name__) +router = APIRouter() def construct_spotify_url(item_id: str, item_type: str = "track") -> str: @@ -16,8 +17,8 @@ def construct_spotify_url(item_id: str, item_type: str = "track") -> str: return f"https://open.spotify.com/{item_type}/{item_id}" -@album_bp.route("/download/", methods=["GET"]) -def handle_download(album_id): +@router.get("/download/{album_id}") +async def handle_download(album_id: str, request: Request): # Retrieve essential parameters from the request. # name = request.args.get('name') # artist = request.args.get('artist') @@ -33,12 +34,9 @@ def handle_download(album_id): or not album_info.get("name") or not album_info.get("artists") ): - return Response( - json.dumps( - {"error": f"Could not retrieve metadata for album ID: {album_id}"} - ), - status=404, - mimetype="application/json", + return JSONResponse( + content={"error": f"Could not retrieve metadata for album ID: {album_id}"}, + status_code=404 ) name_from_spotify = album_info.get("name") @@ -49,27 +47,23 @@ def handle_download(album_id): ) except Exception as e: - return Response( - json.dumps( - {"error": f"Failed to fetch metadata for album {album_id}: {str(e)}"} - ), - status=500, - mimetype="application/json", + return JSONResponse( + content={"error": f"Failed to fetch metadata for album {album_id}: {str(e)}"}, + status_code=500 ) # Validate required parameters if not url: - return Response( - json.dumps({"error": "Missing required parameter: url"}), - status=400, - mimetype="application/json", + return JSONResponse( + content={"error": "Missing required parameter: url"}, + status_code=400 ) # Add the task to the queue with only essential parameters # The queue manager will now handle all config parameters # Include full original request URL in metadata - orig_params = request.args.to_dict() - orig_params["original_url"] = request.url + orig_params = dict(request.query_params) + orig_params["original_url"] = str(request.url) try: task_id = download_queue_manager.add_task( { @@ -81,15 +75,12 @@ def handle_download(album_id): } ) except DuplicateDownloadError as e: - return Response( - json.dumps( - { - "error": "Duplicate download detected.", - "existing_task": e.existing_task, - } - ), - status=409, - mimetype="application/json", + return JSONResponse( + content={ + "error": "Duplicate download detected.", + "existing_task": e.existing_task, + }, + status_code=409 ) except Exception as e: # Generic error handling for other issues during task submission @@ -116,61 +107,57 @@ def handle_download(album_id): "timestamp": time.time(), }, ) - return Response( - json.dumps( - { - "error": f"Failed to queue album download: {str(e)}", - "task_id": error_task_id, - } - ), - status=500, - mimetype="application/json", + return JSONResponse( + content={ + "error": f"Failed to queue album download: {str(e)}", + "task_id": error_task_id, + }, + status_code=500 ) - return Response( - json.dumps({"task_id": task_id}), status=202, mimetype="application/json" + return JSONResponse( + content={"task_id": task_id}, + status_code=202 ) -@album_bp.route("/download/cancel", methods=["GET"]) -def cancel_download(): +@router.get("/download/cancel") +async def cancel_download(request: Request): """ Cancel a running download process by its task id. """ - task_id = request.args.get("task_id") + task_id = request.query_params.get("task_id") if not task_id: - return Response( - json.dumps({"error": "Missing process id (task_id) parameter"}), - status=400, - mimetype="application/json", + return JSONResponse( + content={"error": "Missing process id (task_id) parameter"}, + status_code=400 ) # Use the queue manager's cancellation method. result = download_queue_manager.cancel_task(task_id) status_code = 200 if result.get("status") == "cancelled" else 404 - return Response(json.dumps(result), status=status_code, mimetype="application/json") + return JSONResponse(content=result, status_code=status_code) -@album_bp.route("/info", methods=["GET"]) -def get_album_info(): +@router.get("/info") +async def get_album_info(request: Request): """ Retrieve Spotify album metadata given a Spotify album ID. Expects a query parameter 'id' that contains the Spotify album ID. """ - spotify_id = request.args.get("id") + spotify_id = request.query_params.get("id") if not spotify_id: - return Response( - json.dumps({"error": "Missing parameter: id"}), - status=400, - mimetype="application/json", + return JSONResponse( + content={"error": "Missing parameter: id"}, + status_code=400 ) try: # Use the get_spotify_info function (already imported at top) album_info = get_spotify_info(spotify_id, "album") - return Response(json.dumps(album_info), status=200, mimetype="application/json") + return JSONResponse(content=album_info, status_code=200) except Exception as e: error_data = {"error": str(e), "traceback": traceback.format_exc()} - return Response(json.dumps(error_data), status=500, mimetype="application/json") + return JSONResponse(content=error_data, status_code=500) diff --git a/routes/artist.py b/routes/artist.py index 98b605b..0332c50 100644 --- a/routes/artist.py +++ b/routes/artist.py @@ -1,8 +1,9 @@ """ -Artist endpoint blueprint. +Artist endpoint router. """ -from flask import Blueprint, Response, request, jsonify +from fastapi import APIRouter, HTTPException, Request +from fastapi.responses import JSONResponse import json import traceback from routes.utils.artist import download_artist_albums @@ -22,7 +23,7 @@ from routes.utils.watch.db import ( from routes.utils.watch.manager import check_watched_artists, get_watch_config from routes.utils.get_info import get_spotify_info -artist_bp = Blueprint("artist", __name__, url_prefix="/api/artist") +router = APIRouter() # Existing log_json can be used, or a logger instance. # Let's initialize a logger for consistency with merged code. @@ -38,8 +39,8 @@ def log_json(message_dict): print(json.dumps(message_dict)) -@artist_bp.route("/download/", methods=["GET"]) -def handle_artist_download(artist_id): +@router.get("/download/{artist_id}") +async def handle_artist_download(artist_id: str, request: Request): """ Enqueues album download tasks for the given artist. Expected query parameters: @@ -49,14 +50,13 @@ def handle_artist_download(artist_id): url = construct_spotify_url(artist_id, "artist") # Retrieve essential parameters from the request. - album_type = request.args.get("album_type", "album,single,compilation") + album_type = request.query_params.get("album_type", "album,single,compilation") # Validate required parameters if not url: # This check is mostly for safety, as url is constructed - return Response( - json.dumps({"error": "Missing required parameter: url"}), - status=400, - mimetype="application/json", + return JSONResponse( + content={"error": "Missing required parameter: url"}, + status_code=400 ) try: @@ -65,7 +65,7 @@ def handle_artist_download(artist_id): # Delegate to the download_artist_albums function which will handle album filtering successfully_queued_albums, duplicate_albums = download_artist_albums( - url=url, album_type=album_type, request_args=request.args.to_dict() + url=url, album_type=album_type, request_args=dict(request.query_params) ) # Return the list of album task IDs. @@ -80,51 +80,45 @@ def handle_artist_download(artist_id): f" {len(duplicate_albums)} albums were already in progress or queued." ) - return Response( - json.dumps(response_data), - status=202, # Still 202 Accepted as some operations may have succeeded - mimetype="application/json", + return JSONResponse( + content=response_data, + status_code=202 # Still 202 Accepted as some operations may have succeeded ) except Exception as e: - return Response( - json.dumps( - { - "status": "error", - "message": str(e), - "traceback": traceback.format_exc(), - } - ), - status=500, - mimetype="application/json", + return JSONResponse( + content={ + "status": "error", + "message": str(e), + "traceback": traceback.format_exc(), + }, + status_code=500 ) -@artist_bp.route("/download/cancel", methods=["GET"]) -def cancel_artist_download(): +@router.get("/download/cancel") +async def cancel_artist_download(): """ Cancelling an artist download is not supported since the endpoint only enqueues album tasks. (Cancellation for individual album tasks can be implemented via the queue manager.) """ - return Response( - json.dumps({"error": "Artist download cancellation is not supported."}), - status=400, - mimetype="application/json", + return JSONResponse( + content={"error": "Artist download cancellation is not supported."}, + status_code=400 ) -@artist_bp.route("/info", methods=["GET"]) -def get_artist_info(): +@router.get("/info") +async def get_artist_info(request: Request): """ Retrieves Spotify artist metadata given a Spotify artist ID. Expects a query parameter 'id' with the Spotify artist ID. """ - spotify_id = request.args.get("id") + spotify_id = request.query_params.get("id") if not spotify_id: - return Response( - json.dumps({"error": "Missing parameter: id"}), - status=400, - mimetype="application/json", + return JSONResponse( + content={"error": "Missing parameter: id"}, + status_code=400 ) try: @@ -158,33 +152,30 @@ def get_artist_info(): # If not watched, or no albums, is_locally_known will not be added. # Frontend should handle absence of this key as false. - return Response( - json.dumps(artist_info), status=200, mimetype="application/json" + return JSONResponse( + content=artist_info, status_code=200 ) except Exception as e: - return Response( - json.dumps({"error": str(e), "traceback": traceback.format_exc()}), - status=500, - mimetype="application/json", + return JSONResponse( + content={"error": str(e), "traceback": traceback.format_exc()}, + status_code=500 ) # --- Merged Artist Watch Routes --- -@artist_bp.route("/watch/", methods=["PUT"]) -def add_artist_to_watchlist(artist_spotify_id): +@router.put("/watch/{artist_spotify_id}") +async def add_artist_to_watchlist(artist_spotify_id: str): """Adds an artist to the watchlist.""" watch_config = get_watch_config() if not watch_config.get("enabled", False): - return jsonify({"error": "Watch feature is currently disabled globally."}), 403 + raise HTTPException(status_code=403, detail={"error": "Watch feature is currently disabled globally."}) logger.info(f"Attempting to add artist {artist_spotify_id} to watchlist.") try: if get_watched_artist(artist_spotify_id): - return jsonify( - {"message": f"Artist {artist_spotify_id} is already being watched."} - ), 200 + return {"message": f"Artist {artist_spotify_id} is already being watched."} # Get artist metadata directly for name and basic info artist_metadata = get_spotify_info(artist_spotify_id, "artist") @@ -199,11 +190,12 @@ def add_artist_to_watchlist(artist_spotify_id): logger.error( f"Could not fetch artist metadata for {artist_spotify_id} from Spotify." ) - return jsonify( - { + raise HTTPException( + status_code=404, + detail={ "error": f"Could not fetch artist metadata for {artist_spotify_id} to initiate watch." } - ), 404 + ) # Check if we got album data if not artist_album_list_data or not isinstance( @@ -228,115 +220,118 @@ def add_artist_to_watchlist(artist_spotify_id): logger.info( f"Artist {artist_spotify_id} ('{artist_metadata.get('name', 'Unknown Artist')}') added to watchlist. Their albums will be processed by the watch manager." ) - return jsonify( - { - "message": f"Artist {artist_spotify_id} added to watchlist. Albums will be processed shortly." - } - ), 201 + return { + "message": f"Artist {artist_spotify_id} added to watchlist. Albums will be processed shortly." + } + except HTTPException: + raise except Exception as e: logger.error( f"Error adding artist {artist_spotify_id} to watchlist: {e}", exc_info=True ) - return jsonify({"error": f"Could not add artist to watchlist: {str(e)}"}), 500 + raise HTTPException(status_code=500, detail={"error": f"Could not add artist to watchlist: {str(e)}"}) -@artist_bp.route("/watch//status", methods=["GET"]) -def get_artist_watch_status(artist_spotify_id): +@router.get("/watch/{artist_spotify_id}/status") +async def get_artist_watch_status(artist_spotify_id: str): """Checks if a specific artist is being watched.""" logger.info(f"Checking watch status for artist {artist_spotify_id}.") try: artist = get_watched_artist(artist_spotify_id) if artist: - return jsonify({"is_watched": True, "artist_data": dict(artist)}), 200 + return {"is_watched": True, "artist_data": dict(artist)} else: - return jsonify({"is_watched": False}), 200 + return {"is_watched": False} except Exception as e: logger.error( f"Error checking watch status for artist {artist_spotify_id}: {e}", exc_info=True, ) - return jsonify({"error": f"Could not check watch status: {str(e)}"}), 500 + raise HTTPException(status_code=500, detail={"error": f"Could not check watch status: {str(e)}"}) -@artist_bp.route("/watch/", methods=["DELETE"]) -def remove_artist_from_watchlist(artist_spotify_id): +@router.delete("/watch/{artist_spotify_id}") +async def remove_artist_from_watchlist(artist_spotify_id: str): """Removes an artist from the watchlist.""" watch_config = get_watch_config() if not watch_config.get("enabled", False): - return jsonify({"error": "Watch feature is currently disabled globally."}), 403 + raise HTTPException(status_code=403, detail={"error": "Watch feature is currently disabled globally."}) logger.info(f"Attempting to remove artist {artist_spotify_id} from watchlist.") try: if not get_watched_artist(artist_spotify_id): - return jsonify( - {"error": f"Artist {artist_spotify_id} not found in watchlist."} - ), 404 + raise HTTPException( + status_code=404, + detail={"error": f"Artist {artist_spotify_id} not found in watchlist."} + ) remove_artist_db(artist_spotify_id) logger.info(f"Artist {artist_spotify_id} removed from watchlist successfully.") - return jsonify( - {"message": f"Artist {artist_spotify_id} removed from watchlist."} - ), 200 + return {"message": f"Artist {artist_spotify_id} removed from watchlist."} + except HTTPException: + raise except Exception as e: logger.error( f"Error removing artist {artist_spotify_id} from watchlist: {e}", exc_info=True, ) - return jsonify( - {"error": f"Could not remove artist from watchlist: {str(e)}"} - ), 500 + raise HTTPException( + status_code=500, + detail={"error": f"Could not remove artist from watchlist: {str(e)}"} + ) -@artist_bp.route("/watch/list", methods=["GET"]) -def list_watched_artists_endpoint(): +@router.get("/watch/list") +async def list_watched_artists_endpoint(): """Lists all artists currently in the watchlist.""" try: artists = get_watched_artists() - return jsonify([dict(artist) for artist in artists]), 200 + return [dict(artist) for artist in artists] except Exception as e: logger.error(f"Error listing watched artists: {e}", exc_info=True) - return jsonify({"error": f"Could not list watched artists: {str(e)}"}), 500 + raise HTTPException(status_code=500, detail={"error": f"Could not list watched artists: {str(e)}"}) -@artist_bp.route("/watch/trigger_check", methods=["POST"]) -def trigger_artist_check_endpoint(): +@router.post("/watch/trigger_check") +async def trigger_artist_check_endpoint(): """Manually triggers the artist checking mechanism for all watched artists.""" watch_config = get_watch_config() if not watch_config.get("enabled", False): - return jsonify( - { + raise HTTPException( + status_code=403, + detail={ "error": "Watch feature is currently disabled globally. Cannot trigger check." } - ), 403 + ) logger.info("Manual trigger for artist check received for all artists.") try: thread = threading.Thread(target=check_watched_artists, args=(None,)) thread.start() - return jsonify( - { - "message": "Artist check triggered successfully in the background for all artists." - } - ), 202 + return { + "message": "Artist check triggered successfully in the background for all artists." + } except Exception as e: logger.error( f"Error manually triggering artist check for all: {e}", exc_info=True ) - return jsonify( - {"error": f"Could not trigger artist check for all: {str(e)}"} - ), 500 + raise HTTPException( + status_code=500, + detail={"error": f"Could not trigger artist check for all: {str(e)}"} + ) -@artist_bp.route("/watch/trigger_check/", methods=["POST"]) -def trigger_specific_artist_check_endpoint(artist_spotify_id: str): +@router.post("/watch/trigger_check/{artist_spotify_id}") +async def trigger_specific_artist_check_endpoint(artist_spotify_id: str): """Manually triggers the artist checking mechanism for a specific artist.""" watch_config = get_watch_config() if not watch_config.get("enabled", False): - return jsonify( - { + raise HTTPException( + status_code=403, + detail={ "error": "Watch feature is currently disabled globally. Cannot trigger check." } - ), 403 + ) logger.info( f"Manual trigger for specific artist check received for ID: {artist_spotify_id}" @@ -347,11 +342,12 @@ def trigger_specific_artist_check_endpoint(artist_spotify_id: str): logger.warning( f"Trigger specific check: Artist ID {artist_spotify_id} not found in watchlist." ) - return jsonify( - { + raise HTTPException( + status_code=404, + detail={ "error": f"Artist {artist_spotify_id} is not in the watchlist. Add it first." } - ), 404 + ) thread = threading.Thread( target=check_watched_artists, args=(artist_spotify_id,) @@ -360,50 +356,54 @@ def trigger_specific_artist_check_endpoint(artist_spotify_id: str): logger.info( f"Artist check triggered in background for specific artist ID: {artist_spotify_id}" ) - return jsonify( - { - "message": f"Artist check triggered successfully in the background for {artist_spotify_id}." - } - ), 202 + return { + "message": f"Artist check triggered successfully in the background for {artist_spotify_id}." + } + except HTTPException: + raise except Exception as e: logger.error( f"Error manually triggering specific artist check for {artist_spotify_id}: {e}", exc_info=True, ) - return jsonify( - { + raise HTTPException( + status_code=500, + detail={ "error": f"Could not trigger artist check for {artist_spotify_id}: {str(e)}" } - ), 500 + ) -@artist_bp.route("/watch//albums", methods=["POST"]) -def mark_albums_as_known_for_artist(artist_spotify_id): +@router.post("/watch/{artist_spotify_id}/albums") +async def mark_albums_as_known_for_artist(artist_spotify_id: str, request: Request): """Fetches details for given album IDs and adds/updates them in the artist's local DB table.""" watch_config = get_watch_config() if not watch_config.get("enabled", False): - return jsonify( - { + raise HTTPException( + status_code=403, + detail={ "error": "Watch feature is currently disabled globally. Cannot mark albums." } - ), 403 + ) logger.info(f"Attempting to mark albums as known for artist {artist_spotify_id}.") try: - album_ids = request.json + album_ids = await request.json() if not isinstance(album_ids, list) or not all( isinstance(aid, str) for aid in album_ids ): - return jsonify( - { + raise HTTPException( + status_code=400, + detail={ "error": "Invalid request body. Expecting a JSON array of album Spotify IDs." } - ), 400 + ) if not get_watched_artist(artist_spotify_id): - return jsonify( - {"error": f"Artist {artist_spotify_id} is not being watched."} - ), 404 + raise HTTPException( + status_code=404, + detail={"error": f"Artist {artist_spotify_id} is not being watched."} + ) fetched_albums_details = [] for album_id in album_ids: @@ -422,12 +422,10 @@ def mark_albums_as_known_for_artist(artist_spotify_id): ) if not fetched_albums_details: - return jsonify( - { - "message": "No valid album details could be fetched to mark as known.", - "processed_count": 0, - } - ), 200 + return { + "message": "No valid album details could be fetched to mark as known.", + "processed_count": 0, + } processed_count = add_specific_albums_to_artist_table( artist_spotify_id, fetched_albums_details @@ -435,48 +433,51 @@ def mark_albums_as_known_for_artist(artist_spotify_id): logger.info( f"Successfully marked/updated {processed_count} albums as known for artist {artist_spotify_id}." ) - return jsonify( - { - "message": f"Successfully processed {processed_count} albums for artist {artist_spotify_id}." - } - ), 200 + return { + "message": f"Successfully processed {processed_count} albums for artist {artist_spotify_id}." + } + except HTTPException: + raise except Exception as e: logger.error( f"Error marking albums as known for artist {artist_spotify_id}: {e}", exc_info=True, ) - return jsonify({"error": f"Could not mark albums as known: {str(e)}"}), 500 + raise HTTPException(status_code=500, detail={"error": f"Could not mark albums as known: {str(e)}"}) -@artist_bp.route("/watch//albums", methods=["DELETE"]) -def mark_albums_as_missing_locally_for_artist(artist_spotify_id): +@router.delete("/watch/{artist_spotify_id}/albums") +async def mark_albums_as_missing_locally_for_artist(artist_spotify_id: str, request: Request): """Removes specified albums from the artist's local DB table.""" watch_config = get_watch_config() if not watch_config.get("enabled", False): - return jsonify( - { + raise HTTPException( + status_code=403, + detail={ "error": "Watch feature is currently disabled globally. Cannot mark albums." } - ), 403 + ) logger.info( f"Attempting to mark albums as missing (delete locally) for artist {artist_spotify_id}." ) try: - album_ids = request.json + album_ids = await request.json() if not isinstance(album_ids, list) or not all( isinstance(aid, str) for aid in album_ids ): - return jsonify( - { + raise HTTPException( + status_code=400, + detail={ "error": "Invalid request body. Expecting a JSON array of album Spotify IDs." } - ), 400 + ) if not get_watched_artist(artist_spotify_id): - return jsonify( - {"error": f"Artist {artist_spotify_id} is not being watched."} - ), 404 + raise HTTPException( + status_code=404, + detail={"error": f"Artist {artist_spotify_id} is not being watched."} + ) deleted_count = remove_specific_albums_from_artist_table( artist_spotify_id, album_ids @@ -484,14 +485,14 @@ def mark_albums_as_missing_locally_for_artist(artist_spotify_id): logger.info( f"Successfully removed {deleted_count} albums locally for artist {artist_spotify_id}." ) - return jsonify( - { - "message": f"Successfully removed {deleted_count} albums locally for artist {artist_spotify_id}." - } - ), 200 + return { + "message": f"Successfully removed {deleted_count} albums locally for artist {artist_spotify_id}." + } + except HTTPException: + raise except Exception as e: logger.error( f"Error marking albums as missing (deleting locally) for artist {artist_spotify_id}: {e}", exc_info=True, ) - return jsonify({"error": f"Could not mark albums as missing: {str(e)}"}), 500 + raise HTTPException(status_code=500, detail={"error": f"Could not mark albums as missing: {str(e)}"}) diff --git a/routes/config.py b/routes/config.py index 19a8adf..1565fac 100644 --- a/routes/config.py +++ b/routes/config.py @@ -1,4 +1,5 @@ -from flask import Blueprint, jsonify, request +from fastapi import APIRouter, HTTPException, Request +from fastapi.responses import JSONResponse import json import logging import os @@ -18,7 +19,7 @@ from routes.utils.watch.manager import ( logger = logging.getLogger(__name__) -config_bp = Blueprint("config", __name__) +router = APIRouter() # Flag for config change notifications @@ -108,26 +109,28 @@ def save_watch_config_http(watch_config_data): # Renamed return False, str(e) -@config_bp.route("/config", methods=["GET"]) -def handle_config(): +@router.get("/config") +async def handle_config(): """Handles GET requests for the main configuration.""" try: config = get_config() - return jsonify(config) + return config except Exception as e: logger.error(f"Error in GET /config: {e}", exc_info=True) - return jsonify( - {"error": "Failed to retrieve configuration", "details": str(e)} - ), 500 + raise HTTPException( + status_code=500, + detail={"error": "Failed to retrieve configuration", "details": str(e)} + ) -@config_bp.route("/config", methods=["POST", "PUT"]) -def update_config(): +@router.post("/config") +@router.put("/config") +async def update_config(request: Request): """Handles POST/PUT requests to update the main configuration.""" try: - new_config = request.get_json() + new_config = await request.json() if not isinstance(new_config, dict): - return jsonify({"error": "Invalid config format"}), 400 + raise HTTPException(status_code=400, detail={"error": "Invalid config format"}) # Preserve the explicitFilter setting from environment explicit_filter_env = os.environ.get("EXPLICIT_FILTER", "false").lower() @@ -140,73 +143,83 @@ def update_config(): if updated_config_values is None: # This case should ideally not be reached if save_config succeeded # and get_config handles errors by returning a default or None. - return jsonify( - {"error": "Failed to retrieve configuration after saving"} - ), 500 + raise HTTPException( + status_code=500, + detail={"error": "Failed to retrieve configuration after saving"} + ) - return jsonify(updated_config_values) + return updated_config_values else: - return jsonify( - {"error": "Failed to update configuration", "details": error_msg} - ), 500 + raise HTTPException( + status_code=500, + detail={"error": "Failed to update configuration", "details": error_msg} + ) except json.JSONDecodeError: - return jsonify({"error": "Invalid JSON data"}), 400 + raise HTTPException(status_code=400, detail={"error": "Invalid JSON data"}) + except HTTPException: + raise except Exception as e: logger.error(f"Error in POST/PUT /config: {e}", exc_info=True) - return jsonify( - {"error": "Failed to update configuration", "details": str(e)} - ), 500 + raise HTTPException( + status_code=500, + detail={"error": "Failed to update configuration", "details": str(e)} + ) -@config_bp.route("/config/check", methods=["GET"]) -def check_config_changes(): +@router.get("/config/check") +async def check_config_changes(): # This endpoint seems more related to dynamically checking if config changed # on disk, which might not be necessary if settings are applied on restart # or by a dedicated manager. For now, just return current config. try: config = get_config() - return jsonify( - {"message": "Current configuration retrieved.", "config": config} - ) + return {"message": "Current configuration retrieved.", "config": config} except Exception as e: logger.error(f"Error in GET /config/check: {e}", exc_info=True) - return jsonify( - {"error": "Failed to check configuration", "details": str(e)} - ), 500 + raise HTTPException( + status_code=500, + detail={"error": "Failed to check configuration", "details": str(e)} + ) -@config_bp.route("/config/watch", methods=["GET"]) -def handle_watch_config(): +@router.get("/config/watch") +async def handle_watch_config(): """Handles GET requests for the watch configuration.""" try: watch_config = get_watch_config_http() - return jsonify(watch_config) + return watch_config except Exception as e: logger.error(f"Error in GET /config/watch: {e}", exc_info=True) - return jsonify( - {"error": "Failed to retrieve watch configuration", "details": str(e)} - ), 500 + raise HTTPException( + status_code=500, + detail={"error": "Failed to retrieve watch configuration", "details": str(e)} + ) -@config_bp.route("/config/watch", methods=["POST", "PUT"]) -def update_watch_config(): +@router.post("/config/watch") +@router.put("/config/watch") +async def update_watch_config(request: Request): """Handles POST/PUT requests to update the watch configuration.""" try: - new_watch_config = request.get_json() + new_watch_config = await request.json() if not isinstance(new_watch_config, dict): - return jsonify({"error": "Invalid watch config format"}), 400 + raise HTTPException(status_code=400, detail={"error": "Invalid watch config format"}) success, error_msg = save_watch_config_http(new_watch_config) if success: - return jsonify({"message": "Watch configuration updated successfully"}), 200 + return {"message": "Watch configuration updated successfully"} else: - return jsonify( - {"error": "Failed to update watch configuration", "details": error_msg} - ), 500 + raise HTTPException( + status_code=500, + detail={"error": "Failed to update watch configuration", "details": error_msg} + ) except json.JSONDecodeError: - return jsonify({"error": "Invalid JSON data for watch config"}), 400 + raise HTTPException(status_code=400, detail={"error": "Invalid JSON data for watch config"}) + except HTTPException: + raise except Exception as e: logger.error(f"Error in POST/PUT /config/watch: {e}", exc_info=True) - return jsonify( - {"error": "Failed to update watch configuration", "details": str(e)} - ), 500 + raise HTTPException( + status_code=500, + detail={"error": "Failed to update watch configuration", "details": str(e)} + ) diff --git a/routes/credentials.py b/routes/credentials.py index dfc31f0..f58037d 100755 --- a/routes/credentials.py +++ b/routes/credentials.py @@ -1,4 +1,6 @@ -from flask import Blueprint, request, jsonify +from fastapi import APIRouter, HTTPException, Request +import json +import logging from routes.utils.credentials import ( get_credential, list_credentials, @@ -10,159 +12,210 @@ from routes.utils.credentials import ( _get_global_spotify_api_creds, save_global_spotify_api_creds, ) -import logging logger = logging.getLogger(__name__) -credentials_bp = Blueprint("credentials", __name__) +router = APIRouter() -# Initialize the database and tables when the blueprint is loaded +# Initialize the database and tables when the router is loaded init_credentials_db() -@credentials_bp.route("/spotify_api_config", methods=["GET", "PUT"]) -def handle_spotify_api_config(): +@router.get("/spotify_api_config") +@router.put("/spotify_api_config") +async def handle_spotify_api_config(request: Request): """Handles GET and PUT requests for the global Spotify API client_id and client_secret.""" try: if request.method == "GET": client_id, client_secret = _get_global_spotify_api_creds() if client_id is not None and client_secret is not None: - return jsonify( - {"client_id": client_id, "client_secret": client_secret} - ), 200 + return {"client_id": client_id, "client_secret": client_secret} else: # If search.json exists but is empty/incomplete, or doesn't exist - return jsonify( - { - "warning": "Global Spotify API credentials are not fully configured or file is missing.", - "client_id": client_id or "", - "client_secret": client_secret or "", - } - ), 200 + return { + "warning": "Global Spotify API credentials are not fully configured or file is missing.", + "client_id": client_id or "", + "client_secret": client_secret or "", + } elif request.method == "PUT": - data = request.get_json() + data = await request.json() if not data or "client_id" not in data or "client_secret" not in data: - return jsonify( - { - "error": "Request body must contain 'client_id' and 'client_secret'" - } - ), 400 + raise HTTPException( + status_code=400, + detail={"error": "Request body must contain 'client_id' and 'client_secret'"} + ) client_id = data["client_id"] client_secret = data["client_secret"] if not isinstance(client_id, str) or not isinstance(client_secret, str): - return jsonify( - {"error": "'client_id' and 'client_secret' must be strings"} - ), 400 + raise HTTPException( + status_code=400, + detail={"error": "'client_id' and 'client_secret' must be strings"} + ) if save_global_spotify_api_creds(client_id, client_secret): - return jsonify( - {"message": "Global Spotify API credentials updated successfully."} - ), 200 + return {"message": "Global Spotify API credentials updated successfully."} else: - return jsonify( - {"error": "Failed to save global Spotify API credentials."} - ), 500 + raise HTTPException( + status_code=500, + detail={"error": "Failed to save global Spotify API credentials."} + ) + except HTTPException: + raise except Exception as e: logger.error(f"Error in /spotify_api_config: {e}", exc_info=True) - return jsonify({"error": f"An unexpected error occurred: {str(e)}"}), 500 + raise HTTPException(status_code=500, detail={"error": f"An unexpected error occurred: {str(e)}"}) -@credentials_bp.route("/", methods=["GET"]) -def handle_list_credentials(service): +@router.get("/{service}") +async def handle_list_credentials(service: str): try: if service not in ["spotify", "deezer"]: - return jsonify( - {"error": "Invalid service. Must be 'spotify' or 'deezer'"} - ), 400 - return jsonify(list_credentials(service)) + raise HTTPException( + status_code=400, + detail={"error": "Invalid service. Must be 'spotify' or 'deezer'"} + ) + return list_credentials(service) except ValueError as e: # Should not happen with service check above - return jsonify({"error": str(e)}), 400 + raise HTTPException(status_code=400, detail={"error": str(e)}) except Exception as e: logger.error(f"Error listing credentials for {service}: {e}", exc_info=True) - return jsonify({"error": f"An unexpected error occurred: {str(e)}"}), 500 + raise HTTPException(status_code=500, detail={"error": f"An unexpected error occurred: {str(e)}"}) -@credentials_bp.route("//", methods=["GET", "POST", "PUT", "DELETE"]) -def handle_single_credential(service, name): +@router.get("/{service}/{name}") +async def handle_get_credential(service: str, name: str): try: if service not in ["spotify", "deezer"]: - return jsonify( - {"error": "Invalid service. Must be 'spotify' or 'deezer'"} - ), 400 - - # cred_type logic is removed for Spotify as API keys are global. - # For Deezer, it's always 'credentials' type implicitly. - - if request.method == "GET": - # get_credential for Spotify now only returns region and blob_file_path - return jsonify(get_credential(service, name)) - - elif request.method == "POST": - data = request.get_json() - if not data: - return jsonify({"error": "Request body cannot be empty."}), 400 - # create_credential for Spotify now expects 'region' and 'blob_content' - # For Deezer, it expects 'arl' and 'region' - # Validation is handled within create_credential utility function - result = create_credential(service, name, data) - return jsonify( - { - "message": f"Credential for '{name}' ({service}) created successfully.", - "details": result, - } - ), 201 - - elif request.method == "PUT": - data = request.get_json() - if not data: - return jsonify({"error": "Request body cannot be empty."}), 400 - # edit_credential for Spotify now handles updates to 'region', 'blob_content' - # For Deezer, 'arl', 'region' - result = edit_credential(service, name, data) - return jsonify( - { - "message": f"Credential for '{name}' ({service}) updated successfully.", - "details": result, - } - ) - - elif request.method == "DELETE": - # delete_credential for Spotify also handles deleting the blob directory - result = delete_credential(service, name) - return jsonify( - { - "message": f"Credential for '{name}' ({service}) deleted successfully.", - "details": result, - } + raise HTTPException( + status_code=400, + detail={"error": "Invalid service. Must be 'spotify' or 'deezer'"} ) + # get_credential for Spotify now only returns region and blob_file_path + return get_credential(service, name) except (ValueError, FileNotFoundError, FileExistsError) as e: status_code = 400 if isinstance(e, FileNotFoundError): status_code = 404 elif isinstance(e, FileExistsError): status_code = 409 - logger.warning(f"Client error in /<{service}>/<{name}>: {str(e)}") - return jsonify({"error": str(e)}), status_code + logger.warning(f"Client error in /{service}/{name}: {str(e)}") + raise HTTPException(status_code=status_code, detail={"error": str(e)}) except Exception as e: - logger.error(f"Server error in /<{service}>/<{name}>: {e}", exc_info=True) - return jsonify({"error": f"An unexpected error occurred: {str(e)}"}), 500 + logger.error(f"Server error in /{service}/{name}: {e}", exc_info=True) + raise HTTPException(status_code=500, detail={"error": f"An unexpected error occurred: {str(e)}"}) + + +@router.post("/{service}/{name}") +async def handle_create_credential(service: str, name: str, request: Request): + try: + if service not in ["spotify", "deezer"]: + raise HTTPException( + status_code=400, + detail={"error": "Invalid service. Must be 'spotify' or 'deezer'"} + ) + + data = await request.json() + if not data: + raise HTTPException(status_code=400, detail={"error": "Request body cannot be empty."}) + + # create_credential for Spotify now expects 'region' and 'blob_content' + # For Deezer, it expects 'arl' and 'region' + # Validation is handled within create_credential utility function + result = create_credential(service, name, data) + return { + "message": f"Credential for '{name}' ({service}) created successfully.", + "details": result, + } + except (ValueError, FileNotFoundError, FileExistsError) as e: + status_code = 400 + if isinstance(e, FileNotFoundError): + status_code = 404 + elif isinstance(e, FileExistsError): + status_code = 409 + logger.warning(f"Client error in /{service}/{name}: {str(e)}") + raise HTTPException(status_code=status_code, detail={"error": str(e)}) + except Exception as e: + logger.error(f"Server error in /{service}/{name}: {e}", exc_info=True) + raise HTTPException(status_code=500, detail={"error": f"An unexpected error occurred: {str(e)}"}) + + +@router.put("/{service}/{name}") +async def handle_update_credential(service: str, name: str, request: Request): + try: + if service not in ["spotify", "deezer"]: + raise HTTPException( + status_code=400, + detail={"error": "Invalid service. Must be 'spotify' or 'deezer'"} + ) + + data = await request.json() + if not data: + raise HTTPException(status_code=400, detail={"error": "Request body cannot be empty."}) + + # edit_credential for Spotify now handles updates to 'region', 'blob_content' + # For Deezer, 'arl', 'region' + result = edit_credential(service, name, data) + return { + "message": f"Credential for '{name}' ({service}) updated successfully.", + "details": result, + } + except (ValueError, FileNotFoundError, FileExistsError) as e: + status_code = 400 + if isinstance(e, FileNotFoundError): + status_code = 404 + elif isinstance(e, FileExistsError): + status_code = 409 + logger.warning(f"Client error in /{service}/{name}: {str(e)}") + raise HTTPException(status_code=status_code, detail={"error": str(e)}) + except Exception as e: + logger.error(f"Server error in /{service}/{name}: {e}", exc_info=True) + raise HTTPException(status_code=500, detail={"error": f"An unexpected error occurred: {str(e)}"}) + + +@router.delete("/{service}/{name}") +async def handle_delete_credential(service: str, name: str): + try: + if service not in ["spotify", "deezer"]: + raise HTTPException( + status_code=400, + detail={"error": "Invalid service. Must be 'spotify' or 'deezer'"} + ) + + # delete_credential for Spotify also handles deleting the blob directory + result = delete_credential(service, name) + return { + "message": f"Credential for '{name}' ({service}) deleted successfully.", + "details": result, + } + except (ValueError, FileNotFoundError, FileExistsError) as e: + status_code = 400 + if isinstance(e, FileNotFoundError): + status_code = 404 + elif isinstance(e, FileExistsError): + status_code = 409 + logger.warning(f"Client error in /{service}/{name}: {str(e)}") + raise HTTPException(status_code=status_code, detail={"error": str(e)}) + except Exception as e: + logger.error(f"Server error in /{service}/{name}: {e}", exc_info=True) + raise HTTPException(status_code=500, detail={"error": f"An unexpected error occurred: {str(e)}"}) # The '/search//' route is now obsolete for Spotify and has been removed. -@credentials_bp.route("/all/", methods=["GET"]) -def handle_all_credentials(service): +@router.get("/all/{service}") +async def handle_all_credentials(service: str): """Lists all credentials for a given service. For Spotify, API keys are global and not listed per account.""" try: if service not in ["spotify", "deezer"]: - return jsonify( - {"error": "Invalid service. Must be 'spotify' or 'deezer'"} - ), 400 + raise HTTPException( + status_code=400, + detail={"error": "Invalid service. Must be 'spotify' or 'deezer'"} + ) credentials_list = [] account_names = list_credentials(service) # This lists names from DB @@ -190,14 +243,14 @@ def handle_all_credentials(service): } ) - return jsonify(credentials_list) + return credentials_list except Exception as e: logger.error(f"Error in /all/{service}: {e}", exc_info=True) - return jsonify({"error": f"An unexpected error occurred: {str(e)}"}), 500 + raise HTTPException(status_code=500, detail={"error": f"An unexpected error occurred: {str(e)}"}) -@credentials_bp.route("/markets", methods=["GET"]) -def handle_markets(): +@router.get("/markets") +async def handle_markets(): """ Returns a list of unique market regions for Deezer and Spotify accounts. """ @@ -229,13 +282,11 @@ def handle_markets(): f"Could not retrieve region for spotify account {name}: {e}" ) - return jsonify( - { - "deezer": sorted(list(deezer_regions)), - "spotify": sorted(list(spotify_regions)), - } - ), 200 + return { + "deezer": sorted(list(deezer_regions)), + "spotify": sorted(list(spotify_regions)), + } except Exception as e: logger.error(f"Error in /markets: {e}", exc_info=True) - return jsonify({"error": f"An unexpected error occurred: {str(e)}"}), 500 + raise HTTPException(status_code=500, detail={"error": f"An unexpected error occurred: {str(e)}"}) diff --git a/routes/history.py b/routes/history.py index a59879c..024831d 100644 --- a/routes/history.py +++ b/routes/history.py @@ -1,4 +1,5 @@ -from flask import Blueprint, Response, request, jsonify +from fastapi import APIRouter, HTTPException, Request +from fastapi.responses import JSONResponse import json import traceback import logging @@ -6,11 +7,11 @@ from routes.utils.history_manager import history_manager logger = logging.getLogger(__name__) -history_bp = Blueprint("history", __name__) +router = APIRouter() -@history_bp.route("/", methods=["GET"]) -def get_history(): +@router.get("/") +async def get_history(request: Request): """ Retrieve download history with optional filtering and pagination. @@ -22,27 +23,25 @@ def get_history(): """ try: # Parse query parameters - limit = min(int(request.args.get("limit", 100)), 500) # Cap at 500 - offset = max(int(request.args.get("offset", 0)), 0) - download_type = request.args.get("download_type") - status = request.args.get("status") + limit = min(int(request.query_params.get("limit", 100)), 500) # Cap at 500 + offset = max(int(request.query_params.get("offset", 0)), 0) + download_type = request.query_params.get("download_type") + status = request.query_params.get("status") # Validate download_type if provided valid_types = ["track", "album", "playlist"] if download_type and download_type not in valid_types: - return Response( - json.dumps({"error": f"Invalid download_type. Must be one of: {valid_types}"}), - status=400, - mimetype="application/json", + return JSONResponse( + content={"error": f"Invalid download_type. Must be one of: {valid_types}"}, + status_code=400 ) # Validate status if provided valid_statuses = ["completed", "failed", "skipped", "in_progress"] if status and status not in valid_statuses: - return Response( - json.dumps({"error": f"Invalid status. Must be one of: {valid_statuses}"}), - status=400, - mimetype="application/json", + return JSONResponse( + content={"error": f"Invalid status. Must be one of: {valid_statuses}"}, + status_code=400 ) # Get history from manager @@ -70,29 +69,26 @@ def get_history(): response_data["filters"] = {} response_data["filters"]["status"] = status - return Response( - json.dumps(response_data), - status=200, - mimetype="application/json" + return JSONResponse( + content=response_data, + status_code=200 ) except ValueError as e: - return Response( - json.dumps({"error": f"Invalid parameter value: {str(e)}"}), - status=400, - mimetype="application/json", + return JSONResponse( + content={"error": f"Invalid parameter value: {str(e)}"}, + status_code=400 ) except Exception as e: logger.error(f"Error retrieving download history: {e}", exc_info=True) - return Response( - json.dumps({"error": "Failed to retrieve download history", "details": str(e)}), - status=500, - mimetype="application/json", + return JSONResponse( + content={"error": "Failed to retrieve download history", "details": str(e)}, + status_code=500 ) -@history_bp.route("/", methods=["GET"]) -def get_download_by_task_id(task_id): +@router.get("/{task_id}") +async def get_download_by_task_id(task_id: str): """ Retrieve specific download history by task ID. @@ -103,29 +99,26 @@ def get_download_by_task_id(task_id): download = history_manager.get_download_by_task_id(task_id) if not download: - return Response( - json.dumps({"error": f"Download with task ID '{task_id}' not found"}), - status=404, - mimetype="application/json", + return JSONResponse( + content={"error": f"Download with task ID '{task_id}' not found"}, + status_code=404 ) - return Response( - json.dumps(download), - status=200, - mimetype="application/json" + return JSONResponse( + content=download, + status_code=200 ) except Exception as e: logger.error(f"Error retrieving download for task {task_id}: {e}", exc_info=True) - return Response( - json.dumps({"error": "Failed to retrieve download", "details": str(e)}), - status=500, - mimetype="application/json", + return JSONResponse( + content={"error": "Failed to retrieve download", "details": str(e)}, + status_code=500 ) -@history_bp.route("//children", methods=["GET"]) -def get_download_children(task_id): +@router.get("/{task_id}/children") +async def get_download_children(task_id: str): """ Retrieve children tracks for an album or playlist download. @@ -137,18 +130,16 @@ def get_download_children(task_id): download = history_manager.get_download_by_task_id(task_id) if not download: - return Response( - json.dumps({"error": f"Download with task ID '{task_id}' not found"}), - status=404, - mimetype="application/json", + return JSONResponse( + content={"error": f"Download with task ID '{task_id}' not found"}, + status_code=404 ) children_table = download.get("children_table") if not children_table: - return Response( - json.dumps({"error": f"Download '{task_id}' has no children tracks"}), - status=404, - mimetype="application/json", + return JSONResponse( + content={"error": f"Download '{task_id}' has no children tracks"}, + status_code=404 ) # Get children tracks @@ -163,46 +154,42 @@ def get_download_children(task_id): "track_count": len(children) } - return Response( - json.dumps(response_data), - status=200, - mimetype="application/json" + return JSONResponse( + content=response_data, + status_code=200 ) except Exception as e: logger.error(f"Error retrieving children for task {task_id}: {e}", exc_info=True) - return Response( - json.dumps({"error": "Failed to retrieve download children", "details": str(e)}), - status=500, - mimetype="application/json", + return JSONResponse( + content={"error": "Failed to retrieve download children", "details": str(e)}, + status_code=500 ) -@history_bp.route("/stats", methods=["GET"]) -def get_download_stats(): +@router.get("/stats") +async def get_download_stats(): """ Get download statistics and summary information. """ try: stats = history_manager.get_download_stats() - return Response( - json.dumps(stats), - status=200, - mimetype="application/json" + return JSONResponse( + content=stats, + status_code=200 ) except Exception as e: logger.error(f"Error retrieving download stats: {e}", exc_info=True) - return Response( - json.dumps({"error": "Failed to retrieve download statistics", "details": str(e)}), - status=500, - mimetype="application/json", + return JSONResponse( + content={"error": "Failed to retrieve download statistics", "details": str(e)}, + status_code=500 ) -@history_bp.route("/search", methods=["GET"]) -def search_history(): +@router.get("/search") +async def search_history(request: Request): """ Search download history by title or artist. @@ -211,15 +198,14 @@ def search_history(): - limit: Maximum number of results (default: 50, max: 200) """ try: - query = request.args.get("q") + query = request.query_params.get("q") if not query: - return Response( - json.dumps({"error": "Missing required parameter: q (search query)"}), - status=400, - mimetype="application/json", + return JSONResponse( + content={"error": "Missing required parameter: q (search query)"}, + status_code=400 ) - limit = min(int(request.args.get("limit", 50)), 200) # Cap at 200 + limit = min(int(request.query_params.get("limit", 50)), 200) # Cap at 200 # Search history results = history_manager.search_history(query, limit) @@ -231,29 +217,26 @@ def search_history(): "limit": limit } - return Response( - json.dumps(response_data), - status=200, - mimetype="application/json" + return JSONResponse( + content=response_data, + status_code=200 ) except ValueError as e: - return Response( - json.dumps({"error": f"Invalid parameter value: {str(e)}"}), - status=400, - mimetype="application/json", + return JSONResponse( + content={"error": f"Invalid parameter value: {str(e)}"}, + status_code=400 ) except Exception as e: logger.error(f"Error searching download history: {e}", exc_info=True) - return Response( - json.dumps({"error": "Failed to search download history", "details": str(e)}), - status=500, - mimetype="application/json", + return JSONResponse( + content={"error": "Failed to search download history", "details": str(e)}, + status_code=500 ) -@history_bp.route("/recent", methods=["GET"]) -def get_recent_downloads(): +@router.get("/recent") +async def get_recent_downloads(request: Request): """ Get most recent downloads. @@ -261,7 +244,7 @@ def get_recent_downloads(): - limit: Maximum number of results (default: 20, max: 100) """ try: - limit = min(int(request.args.get("limit", 20)), 100) # Cap at 100 + limit = min(int(request.query_params.get("limit", 20)), 100) # Cap at 100 recent = history_manager.get_recent_downloads(limit) @@ -271,29 +254,26 @@ def get_recent_downloads(): "limit": limit } - return Response( - json.dumps(response_data), - status=200, - mimetype="application/json" + return JSONResponse( + content=response_data, + status_code=200 ) except ValueError as e: - return Response( - json.dumps({"error": f"Invalid parameter value: {str(e)}"}), - status=400, - mimetype="application/json", + return JSONResponse( + content={"error": f"Invalid parameter value: {str(e)}"}, + status_code=400 ) except Exception as e: logger.error(f"Error retrieving recent downloads: {e}", exc_info=True) - return Response( - json.dumps({"error": "Failed to retrieve recent downloads", "details": str(e)}), - status=500, - mimetype="application/json", + return JSONResponse( + content={"error": "Failed to retrieve recent downloads", "details": str(e)}, + status_code=500 ) -@history_bp.route("/failed", methods=["GET"]) -def get_failed_downloads(): +@router.get("/failed") +async def get_failed_downloads(request: Request): """ Get failed downloads. @@ -301,7 +281,7 @@ def get_failed_downloads(): - limit: Maximum number of results (default: 50, max: 200) """ try: - limit = min(int(request.args.get("limit", 50)), 200) # Cap at 200 + limit = min(int(request.query_params.get("limit", 50)), 200) # Cap at 200 failed = history_manager.get_failed_downloads(limit) @@ -311,29 +291,26 @@ def get_failed_downloads(): "limit": limit } - return Response( - json.dumps(response_data), - status=200, - mimetype="application/json" + return JSONResponse( + content=response_data, + status_code=200 ) except ValueError as e: - return Response( - json.dumps({"error": f"Invalid parameter value: {str(e)}"}), - status=400, - mimetype="application/json", + return JSONResponse( + content={"error": f"Invalid parameter value: {str(e)}"}, + status_code=400 ) except Exception as e: logger.error(f"Error retrieving failed downloads: {e}", exc_info=True) - return Response( - json.dumps({"error": "Failed to retrieve failed downloads", "details": str(e)}), - status=500, - mimetype="application/json", + return JSONResponse( + content={"error": "Failed to retrieve failed downloads", "details": str(e)}, + status_code=500 ) -@history_bp.route("/cleanup", methods=["POST"]) -def cleanup_old_history(): +@router.post("/cleanup") +async def cleanup_old_history(request: Request): """ Clean up old download history. @@ -341,14 +318,13 @@ def cleanup_old_history(): - days_old: Number of days old to keep (default: 30) """ try: - data = request.get_json() or {} + data = await request.json() if request.headers.get("content-type") == "application/json" else {} days_old = data.get("days_old", 30) if not isinstance(days_old, int) or days_old <= 0: - return Response( - json.dumps({"error": "days_old must be a positive integer"}), - status=400, - mimetype="application/json", + return JSONResponse( + content={"error": "days_old must be a positive integer"}, + status_code=400 ) deleted_count = history_manager.clear_old_history(days_old) @@ -359,16 +335,14 @@ def cleanup_old_history(): "days_old": days_old } - return Response( - json.dumps(response_data), - status=200, - mimetype="application/json" + return JSONResponse( + content=response_data, + status_code=200 ) except Exception as e: logger.error(f"Error cleaning up old history: {e}", exc_info=True) - return Response( - json.dumps({"error": "Failed to cleanup old history", "details": str(e)}), - status=500, - mimetype="application/json", + return JSONResponse( + content={"error": "Failed to cleanup old history", "details": str(e)}, + status_code=500 ) \ No newline at end of file diff --git a/routes/playlist.py b/routes/playlist.py index 8793b24..9582c3e 100755 --- a/routes/playlist.py +++ b/routes/playlist.py @@ -1,4 +1,5 @@ -from flask import Blueprint, Response, request, jsonify +from fastapi import APIRouter, HTTPException, Request +from fastapi.responses import JSONResponse import json import traceback import logging # Added logging import @@ -30,7 +31,7 @@ from routes.utils.watch.manager import ( from routes.utils.errors import DuplicateDownloadError logger = logging.getLogger(__name__) # Added logger initialization -playlist_bp = Blueprint("playlist", __name__, url_prefix="/api/playlist") +router = APIRouter() def construct_spotify_url(item_id: str, item_type: str = "track") -> str: @@ -38,18 +39,16 @@ def construct_spotify_url(item_id: str, item_type: str = "track") -> str: return f"https://open.spotify.com/{item_type}/{item_id}" -@playlist_bp.route("/download/", methods=["GET"]) -def handle_download(playlist_id): +@router.get("/download/{playlist_id}") +async def handle_download(playlist_id: str, request: Request): # Retrieve essential parameters from the request. # name = request.args.get('name') # Removed # artist = request.args.get('artist') # Removed - orig_params = request.args.to_dict() + orig_params = dict(request.query_params) # Construct the URL from playlist_id url = construct_spotify_url(playlist_id, "playlist") - orig_params["original_url"] = ( - request.url - ) # Update original_url to the constructed one + orig_params["original_url"] = str(request.url) # Update original_url to the constructed one # Fetch metadata from Spotify using optimized function try: @@ -60,14 +59,11 @@ def handle_download(playlist_id): or not playlist_info.get("name") or not playlist_info.get("owner") ): - return Response( - json.dumps( - { - "error": f"Could not retrieve metadata for playlist ID: {playlist_id}" - } - ), - status=404, - mimetype="application/json", + return JSONResponse( + content={ + "error": f"Could not retrieve metadata for playlist ID: {playlist_id}" + }, + status_code=404 ) name_from_spotify = playlist_info.get("name") @@ -76,22 +72,18 @@ def handle_download(playlist_id): artist_from_spotify = owner_info.get("display_name", "Unknown Owner") except Exception as e: - return Response( - json.dumps( - { - "error": f"Failed to fetch metadata for playlist {playlist_id}: {str(e)}" - } - ), - status=500, - mimetype="application/json", + return JSONResponse( + content={ + "error": f"Failed to fetch metadata for playlist {playlist_id}: {str(e)}" + }, + status_code=500 ) # Validate required parameters if not url: # This check might be redundant now but kept for safety - return Response( - json.dumps({"error": "Missing required parameter: url"}), - status=400, - mimetype="application/json", + return JSONResponse( + content={"error": "Missing required parameter: url"}, + status_code=400 ) try: @@ -105,15 +97,12 @@ def handle_download(playlist_id): } ) except DuplicateDownloadError as e: - return Response( - json.dumps( - { - "error": "Duplicate download detected.", - "existing_task": e.existing_task, - } - ), - status=409, - mimetype="application/json", + return JSONResponse( + content={ + "error": "Duplicate download detected.", + "existing_task": e.existing_task, + }, + status_code=409 ) except Exception as e: # Generic error handling for other issues during task submission @@ -138,58 +127,52 @@ def handle_download(playlist_id): "timestamp": time.time(), }, ) - return Response( - json.dumps( - { - "error": f"Failed to queue playlist download: {str(e)}", - "task_id": error_task_id, - } - ), - status=500, - mimetype="application/json", + return JSONResponse( + content={ + "error": f"Failed to queue playlist download: {str(e)}", + "task_id": error_task_id, + }, + status_code=500 ) - return Response( - json.dumps({"task_id": task_id}), - status=202, - mimetype="application/json", + return JSONResponse( + content={"task_id": task_id}, + status_code=202 ) -@playlist_bp.route("/download/cancel", methods=["GET"]) -def cancel_download(): +@router.get("/download/cancel") +async def cancel_download(request: Request): """ Cancel a running playlist download process by its task id. """ - task_id = request.args.get("task_id") + task_id = request.query_params.get("task_id") if not task_id: - return Response( - json.dumps({"error": "Missing task id (task_id) parameter"}), - status=400, - mimetype="application/json", + return JSONResponse( + content={"error": "Missing task id (task_id) parameter"}, + status_code=400 ) # Use the queue manager's cancellation method. result = download_queue_manager.cancel_task(task_id) status_code = 200 if result.get("status") == "cancelled" else 404 - return Response(json.dumps(result), status=status_code, mimetype="application/json") + return JSONResponse(content=result, status_code=status_code) -@playlist_bp.route("/info", methods=["GET"]) -def get_playlist_info(): +@router.get("/info") +async def get_playlist_info(request: Request): """ Retrieve Spotify playlist metadata given a Spotify playlist ID. Expects a query parameter 'id' that contains the Spotify playlist ID. """ - spotify_id = request.args.get("id") - include_tracks = request.args.get("include_tracks", "false").lower() == "true" + spotify_id = request.query_params.get("id") + include_tracks = request.query_params.get("include_tracks", "false").lower() == "true" if not spotify_id: - return Response( - json.dumps({"error": "Missing parameter: id"}), - status=400, - mimetype="application/json", + return JSONResponse( + content={"error": "Missing parameter: id"}, + status_code=400 ) try: @@ -216,27 +199,26 @@ def get_playlist_info(): # If not watched, or no tracks, is_locally_known will not be added, or tracks won't exist to add it to. # Frontend should handle absence of this key as false. - return Response( - json.dumps(playlist_info), status=200, mimetype="application/json" + return JSONResponse( + content=playlist_info, status_code=200 ) except Exception as e: error_data = {"error": str(e), "traceback": traceback.format_exc()} - return Response(json.dumps(error_data), status=500, mimetype="application/json") + return JSONResponse(content=error_data, status_code=500) -@playlist_bp.route("/metadata", methods=["GET"]) -def get_playlist_metadata(): +@router.get("/metadata") +async def get_playlist_metadata(request: Request): """ Retrieve only Spotify playlist metadata (no tracks) to avoid rate limiting. Expects a query parameter 'id' that contains the Spotify playlist ID. """ - spotify_id = request.args.get("id") + spotify_id = request.query_params.get("id") if not spotify_id: - return Response( - json.dumps({"error": "Missing parameter: id"}), - status=400, - mimetype="application/json", + return JSONResponse( + content={"error": "Missing parameter: id"}, + status_code=400 ) try: @@ -244,29 +226,28 @@ def get_playlist_metadata(): from routes.utils.get_info import get_playlist_metadata playlist_metadata = get_playlist_metadata(spotify_id) - return Response( - json.dumps(playlist_metadata), status=200, mimetype="application/json" + return JSONResponse( + content=playlist_metadata, status_code=200 ) except Exception as e: error_data = {"error": str(e), "traceback": traceback.format_exc()} - return Response(json.dumps(error_data), status=500, mimetype="application/json") + return JSONResponse(content=error_data, status_code=500) -@playlist_bp.route("/tracks", methods=["GET"]) -def get_playlist_tracks(): +@router.get("/tracks") +async def get_playlist_tracks(request: Request): """ Retrieve playlist tracks with pagination support for progressive loading. Expects query parameters: 'id' (playlist ID), 'limit' (optional), 'offset' (optional). """ - spotify_id = request.args.get("id") - limit = request.args.get("limit", 50, type=int) - offset = request.args.get("offset", 0, type=int) + spotify_id = request.query_params.get("id") + limit = int(request.query_params.get("limit", 50)) + offset = int(request.query_params.get("offset", 0)) if not spotify_id: - return Response( - json.dumps({"error": "Missing parameter: id"}), - status=400, - mimetype="application/json", + return JSONResponse( + content={"error": "Missing parameter: id"}, + status_code=400 ) try: @@ -274,28 +255,26 @@ def get_playlist_tracks(): from routes.utils.get_info import get_playlist_tracks tracks_data = get_playlist_tracks(spotify_id, limit=limit, offset=offset) - return Response( - json.dumps(tracks_data), status=200, mimetype="application/json" + return JSONResponse( + content=tracks_data, status_code=200 ) except Exception as e: error_data = {"error": str(e), "traceback": traceback.format_exc()} - return Response(json.dumps(error_data), status=500, mimetype="application/json") + return JSONResponse(content=error_data, status_code=500) -@playlist_bp.route("/watch/", methods=["PUT"]) -def add_to_watchlist(playlist_spotify_id): +@router.put("/watch/{playlist_spotify_id}") +async def add_to_watchlist(playlist_spotify_id: str): """Adds a playlist to the watchlist.""" watch_config = get_watch_config() if not watch_config.get("enabled", False): - return jsonify({"error": "Watch feature is currently disabled globally."}), 403 + raise HTTPException(status_code=403, detail={"error": "Watch feature is currently disabled globally."}) logger.info(f"Attempting to add playlist {playlist_spotify_id} to watchlist.") try: # Check if already watched if get_watched_playlist(playlist_spotify_id): - return jsonify( - {"message": f"Playlist {playlist_spotify_id} is already being watched."} - ), 200 + return {"message": f"Playlist {playlist_spotify_id} is already being watched."} # Fetch playlist details from Spotify to populate our DB from routes.utils.get_info import get_playlist_metadata @@ -304,11 +283,12 @@ def add_to_watchlist(playlist_spotify_id): logger.error( f"Could not fetch details for playlist {playlist_spotify_id} from Spotify." ) - return jsonify( - { + raise HTTPException( + status_code=404, + detail={ "error": f"Could not fetch details for playlist {playlist_spotify_id} from Spotify." } - ), 404 + ) add_playlist_db(playlist_data) # This also creates the tracks table @@ -323,99 +303,104 @@ def add_to_watchlist(playlist_spotify_id): logger.info( f"Playlist {playlist_spotify_id} added to watchlist. Its tracks will be processed by the watch manager." ) - return jsonify( - { - "message": f"Playlist {playlist_spotify_id} added to watchlist. Tracks will be processed shortly." - } - ), 201 + return { + "message": f"Playlist {playlist_spotify_id} added to watchlist. Tracks will be processed shortly." + } + except HTTPException: + raise except Exception as e: logger.error( f"Error adding playlist {playlist_spotify_id} to watchlist: {e}", exc_info=True, ) - return jsonify({"error": f"Could not add playlist to watchlist: {str(e)}"}), 500 + raise HTTPException(status_code=500, detail={"error": f"Could not add playlist to watchlist: {str(e)}"}) -@playlist_bp.route("/watch//status", methods=["GET"]) -def get_playlist_watch_status(playlist_spotify_id): +@router.get("/watch/{playlist_spotify_id}/status") +async def get_playlist_watch_status(playlist_spotify_id: str): """Checks if a specific playlist is being watched.""" logger.info(f"Checking watch status for playlist {playlist_spotify_id}.") try: playlist = get_watched_playlist(playlist_spotify_id) if playlist: - return jsonify({"is_watched": True, "playlist_data": playlist}), 200 + return {"is_watched": True, "playlist_data": playlist} else: # Return 200 with is_watched: false, so frontend can clearly distinguish # between "not watched" and an actual error fetching status. - return jsonify({"is_watched": False}), 200 + return {"is_watched": False} except Exception as e: logger.error( f"Error checking watch status for playlist {playlist_spotify_id}: {e}", exc_info=True, ) - return jsonify({"error": f"Could not check watch status: {str(e)}"}), 500 + raise HTTPException(status_code=500, detail={"error": f"Could not check watch status: {str(e)}"}) -@playlist_bp.route("/watch/", methods=["DELETE"]) -def remove_from_watchlist(playlist_spotify_id): +@router.delete("/watch/{playlist_spotify_id}") +async def remove_from_watchlist(playlist_spotify_id: str): """Removes a playlist from the watchlist.""" watch_config = get_watch_config() if not watch_config.get("enabled", False): - return jsonify({"error": "Watch feature is currently disabled globally."}), 403 + raise HTTPException(status_code=403, detail={"error": "Watch feature is currently disabled globally."}) logger.info(f"Attempting to remove playlist {playlist_spotify_id} from watchlist.") try: if not get_watched_playlist(playlist_spotify_id): - return jsonify( - {"error": f"Playlist {playlist_spotify_id} not found in watchlist."} - ), 404 + raise HTTPException( + status_code=404, + detail={"error": f"Playlist {playlist_spotify_id} not found in watchlist."} + ) remove_playlist_db(playlist_spotify_id) logger.info( f"Playlist {playlist_spotify_id} removed from watchlist successfully." ) - return jsonify( - {"message": f"Playlist {playlist_spotify_id} removed from watchlist."} - ), 200 + return {"message": f"Playlist {playlist_spotify_id} removed from watchlist."} + except HTTPException: + raise except Exception as e: logger.error( f"Error removing playlist {playlist_spotify_id} from watchlist: {e}", exc_info=True, ) - return jsonify( - {"error": f"Could not remove playlist from watchlist: {str(e)}"} - ), 500 + raise HTTPException( + status_code=500, + detail={"error": f"Could not remove playlist from watchlist: {str(e)}"} + ) -@playlist_bp.route("/watch//tracks", methods=["POST"]) -def mark_tracks_as_known(playlist_spotify_id): +@router.post("/watch/{playlist_spotify_id}/tracks") +async def mark_tracks_as_known(playlist_spotify_id: str, request: Request): """Fetches details for given track IDs and adds/updates them in the playlist's local DB table.""" watch_config = get_watch_config() if not watch_config.get("enabled", False): - return jsonify( - { + raise HTTPException( + status_code=403, + detail={ "error": "Watch feature is currently disabled globally. Cannot mark tracks." } - ), 403 + ) logger.info( f"Attempting to mark tracks as known for playlist {playlist_spotify_id}." ) try: - track_ids = request.json + track_ids = await request.json() if not isinstance(track_ids, list) or not all( isinstance(tid, str) for tid in track_ids ): - return jsonify( - { + raise HTTPException( + status_code=400, + detail={ "error": "Invalid request body. Expecting a JSON array of track Spotify IDs." } - ), 400 + ) if not get_watched_playlist(playlist_spotify_id): - return jsonify( - {"error": f"Playlist {playlist_spotify_id} is not being watched."} - ), 404 + raise HTTPException( + status_code=404, + detail={"error": f"Playlist {playlist_spotify_id} is not being watched."} + ) fetched_tracks_details = [] for track_id in track_ids: @@ -433,12 +418,10 @@ def mark_tracks_as_known(playlist_spotify_id): ) if not fetched_tracks_details: - return jsonify( - { - "message": "No valid track details could be fetched to mark as known.", - "processed_count": 0, - } - ), 200 + return { + "message": "No valid track details could be fetched to mark as known.", + "processed_count": 0, + } add_specific_tracks_to_playlist_table( playlist_spotify_id, fetched_tracks_details @@ -446,48 +429,51 @@ def mark_tracks_as_known(playlist_spotify_id): logger.info( f"Successfully marked/updated {len(fetched_tracks_details)} tracks as known for playlist {playlist_spotify_id}." ) - return jsonify( - { - "message": f"Successfully processed {len(fetched_tracks_details)} tracks for playlist {playlist_spotify_id}." - } - ), 200 + return { + "message": f"Successfully processed {len(fetched_tracks_details)} tracks for playlist {playlist_spotify_id}." + } + except HTTPException: + raise except Exception as e: logger.error( f"Error marking tracks as known for playlist {playlist_spotify_id}: {e}", exc_info=True, ) - return jsonify({"error": f"Could not mark tracks as known: {str(e)}"}), 500 + raise HTTPException(status_code=500, detail={"error": f"Could not mark tracks as known: {str(e)}"}) -@playlist_bp.route("/watch//tracks", methods=["DELETE"]) -def mark_tracks_as_missing_locally(playlist_spotify_id): +@router.delete("/watch/{playlist_spotify_id}/tracks") +async def mark_tracks_as_missing_locally(playlist_spotify_id: str, request: Request): """Removes specified tracks from the playlist's local DB table.""" watch_config = get_watch_config() if not watch_config.get("enabled", False): - return jsonify( - { + raise HTTPException( + status_code=403, + detail={ "error": "Watch feature is currently disabled globally. Cannot mark tracks." } - ), 403 + ) logger.info( f"Attempting to mark tracks as missing (remove locally) for playlist {playlist_spotify_id}." ) try: - track_ids = request.json + track_ids = await request.json() if not isinstance(track_ids, list) or not all( isinstance(tid, str) for tid in track_ids ): - return jsonify( - { + raise HTTPException( + status_code=400, + detail={ "error": "Invalid request body. Expecting a JSON array of track Spotify IDs." } - ), 400 + ) if not get_watched_playlist(playlist_spotify_id): - return jsonify( - {"error": f"Playlist {playlist_spotify_id} is not being watched."} - ), 404 + raise HTTPException( + status_code=404, + detail={"error": f"Playlist {playlist_spotify_id} is not being watched."} + ) deleted_count = remove_specific_tracks_from_playlist_table( playlist_spotify_id, track_ids @@ -495,72 +481,71 @@ def mark_tracks_as_missing_locally(playlist_spotify_id): logger.info( f"Successfully removed {deleted_count} tracks locally for playlist {playlist_spotify_id}." ) - return jsonify( - { - "message": f"Successfully removed {deleted_count} tracks locally for playlist {playlist_spotify_id}." - } - ), 200 + return { + "message": f"Successfully removed {deleted_count} tracks locally for playlist {playlist_spotify_id}." + } + except HTTPException: + raise except Exception as e: logger.error( f"Error marking tracks as missing (deleting locally) for playlist {playlist_spotify_id}: {e}", exc_info=True, ) - return jsonify({"error": f"Could not mark tracks as missing: {str(e)}"}), 500 + raise HTTPException(status_code=500, detail={"error": f"Could not mark tracks as missing: {str(e)}"}) -@playlist_bp.route("/watch/list", methods=["GET"]) -def list_watched_playlists_endpoint(): +@router.get("/watch/list") +async def list_watched_playlists_endpoint(): """Lists all playlists currently in the watchlist.""" try: playlists = get_watched_playlists() - return jsonify(playlists), 200 + return playlists except Exception as e: logger.error(f"Error listing watched playlists: {e}", exc_info=True) - return jsonify({"error": f"Could not list watched playlists: {str(e)}"}), 500 + raise HTTPException(status_code=500, detail={"error": f"Could not list watched playlists: {str(e)}"}) -@playlist_bp.route("/watch/trigger_check", methods=["POST"]) -def trigger_playlist_check_endpoint(): +@router.post("/watch/trigger_check") +async def trigger_playlist_check_endpoint(): """Manually triggers the playlist checking mechanism for all watched playlists.""" watch_config = get_watch_config() if not watch_config.get("enabled", False): - return jsonify( - { + raise HTTPException( + status_code=403, + detail={ "error": "Watch feature is currently disabled globally. Cannot trigger check." } - ), 403 + ) logger.info("Manual trigger for playlist check received for all playlists.") try: # Run check_watched_playlists without an ID to check all thread = threading.Thread(target=check_watched_playlists, args=(None,)) thread.start() - return jsonify( - { - "message": "Playlist check triggered successfully in the background for all playlists." - } - ), 202 + return { + "message": "Playlist check triggered successfully in the background for all playlists." + } except Exception as e: logger.error( f"Error manually triggering playlist check for all: {e}", exc_info=True ) - return jsonify( - {"error": f"Could not trigger playlist check for all: {str(e)}"} - ), 500 + raise HTTPException( + status_code=500, + detail={"error": f"Could not trigger playlist check for all: {str(e)}"} + ) -@playlist_bp.route( - "/watch/trigger_check/", methods=["POST"] -) -def trigger_specific_playlist_check_endpoint(playlist_spotify_id: str): +@router.post("/watch/trigger_check/{playlist_spotify_id}") +async def trigger_specific_playlist_check_endpoint(playlist_spotify_id: str): """Manually triggers the playlist checking mechanism for a specific playlist.""" watch_config = get_watch_config() if not watch_config.get("enabled", False): - return jsonify( - { + raise HTTPException( + status_code=403, + detail={ "error": "Watch feature is currently disabled globally. Cannot trigger check." } - ), 403 + ) logger.info( f"Manual trigger for specific playlist check received for ID: {playlist_spotify_id}" @@ -572,11 +557,12 @@ def trigger_specific_playlist_check_endpoint(playlist_spotify_id: str): logger.warning( f"Trigger specific check: Playlist ID {playlist_spotify_id} not found in watchlist." ) - return jsonify( - { + raise HTTPException( + status_code=404, + detail={ "error": f"Playlist {playlist_spotify_id} is not in the watchlist. Add it first." } - ), 404 + ) # Run check_watched_playlists with the specific ID thread = threading.Thread( @@ -586,18 +572,19 @@ def trigger_specific_playlist_check_endpoint(playlist_spotify_id: str): logger.info( f"Playlist check triggered in background for specific playlist ID: {playlist_spotify_id}" ) - return jsonify( - { - "message": f"Playlist check triggered successfully in the background for {playlist_spotify_id}." - } - ), 202 + return { + "message": f"Playlist check triggered successfully in the background for {playlist_spotify_id}." + } + except HTTPException: + raise except Exception as e: logger.error( f"Error manually triggering specific playlist check for {playlist_spotify_id}: {e}", exc_info=True, ) - return jsonify( - { + raise HTTPException( + status_code=500, + detail={ "error": f"Could not trigger playlist check for {playlist_spotify_id}: {str(e)}" } - ), 500 + ) diff --git a/routes/prgs.py b/routes/prgs.py index 0a60247..d6a2745 100755 --- a/routes/prgs.py +++ b/routes/prgs.py @@ -1,6 +1,9 @@ -from flask import Blueprint, abort, jsonify, request +from fastapi import APIRouter, HTTPException, Request +from fastapi.responses import JSONResponse, StreamingResponse import logging import time +import json +import asyncio from routes.utils.celery_tasks import ( get_task_info, @@ -14,9 +17,7 @@ from routes.utils.celery_tasks import ( # Configure logging logger = logging.getLogger(__name__) -prgs_bp = Blueprint("prgs", __name__, url_prefix="/api/prgs") - -# (Old .prg file system removed. Using new task system only.) +router = APIRouter() # Define active task states using ProgressState constants ACTIVE_TASK_STATES = { @@ -119,7 +120,7 @@ def _build_error_callback_object(last_status): return callback_object -def _build_task_response(task_info, last_status, task_id, current_time): +def _build_task_response(task_info, last_status, task_id, current_time, request: Request): """ Helper function to build a standardized task response object. """ @@ -132,7 +133,7 @@ def _build_task_response(task_info, last_status, task_id, current_time): try: item_id = item_url.split("/")[-1] if item_id: - base_url = request.host_url.rstrip("/") + base_url = str(request.base_url).rstrip("/") dynamic_original_url = ( f"{base_url}/api/{download_type}/download/{item_id}" ) @@ -184,7 +185,7 @@ def _build_task_response(task_info, last_status, task_id, current_time): return task_response -def get_paginated_tasks(page=1, limit=20, active_only=False): +async def get_paginated_tasks(page=1, limit=20, active_only=False, request: Request = None): """ Get paginated list of tasks. """ @@ -233,7 +234,7 @@ def get_paginated_tasks(page=1, limit=20, active_only=False): elif is_active_task: task_counts["active"] += 1 - task_response = _build_task_response(task_info, last_status, task_id, time.time()) + task_response = _build_task_response(task_info, last_status, task_id, time.time(), request) if is_active_task: active_tasks.append(task_response) @@ -277,117 +278,18 @@ def get_paginated_tasks(page=1, limit=20, active_only=False): "pagination": pagination_info } - return jsonify(response) + return response except Exception as e: logger.error(f"Error in get_paginated_tasks: {e}", exc_info=True) - return jsonify({"error": "Failed to retrieve paginated tasks"}), 500 + raise HTTPException(status_code=500, detail={"error": "Failed to retrieve paginated tasks"}) -@prgs_bp.route("/", methods=["GET"]) -def get_task_details(task_id): - """ - Return a JSON object with the resource type, its name (title), - the last progress update, and, if available, the original request parameters. +# IMPORTANT: Specific routes MUST come before parameterized routes in FastAPI +# Otherwise "updates" gets matched as a {task_id} parameter! - This function works with the new task ID based system. - - Args: - task_id: A task UUID from Celery - """ - # Only support new task IDs - task_info = get_task_info(task_id) - if not task_info: - abort(404, "Task not found") - - # Dynamically construct original_url - dynamic_original_url = "" - download_type = task_info.get("download_type") - # The 'url' field in task_info stores the Spotify/Deezer URL of the item - # e.g., https://open.spotify.com/album/albumId or https://www.deezer.com/track/trackId - item_url = task_info.get("url") - - if download_type and item_url: - try: - # Extract the ID from the item_url (last part of the path) - item_id = item_url.split("/")[-1] - if item_id: # Ensure item_id is not empty - base_url = request.host_url.rstrip("/") - dynamic_original_url = ( - f"{base_url}/api/{download_type}/download/{item_id}" - ) - else: - logger.warning( - f"Could not extract item ID from URL: {item_url} for task {task_id}. Falling back for original_url." - ) - original_request_obj = task_info.get("original_request", {}) - dynamic_original_url = original_request_obj.get("original_url", "") - except Exception as e: - logger.error( - f"Error constructing dynamic original_url for task {task_id}: {e}", - exc_info=True, - ) - original_request_obj = task_info.get("original_request", {}) - dynamic_original_url = original_request_obj.get( - "original_url", "" - ) # Fallback on any error - else: - logger.warning( - f"Missing download_type ('{download_type}') or item_url ('{item_url}') in task_info for task {task_id}. Falling back for original_url." - ) - original_request_obj = task_info.get("original_request", {}) - dynamic_original_url = original_request_obj.get("original_url", "") - - last_status = get_last_task_status(task_id) - status_count = len(get_task_status(task_id)) - - # Determine last_line content - if last_status and "raw_callback" in last_status: - last_line_content = last_status["raw_callback"] - elif last_status and get_task_status_from_last_status(last_status) == "error": - last_line_content = _build_error_callback_object(last_status) - else: - # Fallback for non-error, no raw_callback, or if last_status is None - last_line_content = last_status - - response = { - "original_url": dynamic_original_url, - "last_line": last_line_content, - "timestamp": last_status.get("timestamp") if last_status else time.time(), - "task_id": task_id, - "status_count": status_count, - "created_at": task_info.get("created_at"), - "name": task_info.get("name"), - "artist": task_info.get("artist"), - "type": task_info.get("type"), - "download_type": task_info.get("download_type"), - } - if last_status and last_status.get("summary"): - response["summary"] = last_status["summary"] - return jsonify(response) - - -@prgs_bp.route("/delete/", methods=["DELETE"]) -def delete_task(task_id): - """ - Delete a task's information and history. - - Args: - task_id: A task UUID from Celery - """ - # Only support new task IDs - task_info = get_task_info(task_id) - if not task_info: - abort(404, "Task not found") - - # First, cancel the task if it's running - cancel_task(task_id) - - return {"message": f"Task {task_id} deleted successfully"}, 200 - - -@prgs_bp.route("/list", methods=["GET"]) -def list_tasks(): +@router.get("/list") +async def list_tasks(request: Request): """ Retrieve a paginated list of all tasks in the system. Returns a detailed list of task objects including status and metadata. @@ -399,9 +301,9 @@ def list_tasks(): """ try: # Get query parameters - page = int(request.args.get('page', 1)) - limit = min(int(request.args.get('limit', 50)), 100) # Cap at 100 - active_only = request.args.get('active_only', '').lower() == 'true' + page = int(request.query_params.get('page', 1)) + limit = min(int(request.query_params.get('limit', 50)), 100) # Cap at 100 + active_only = request.query_params.get('active_only', '').lower() == 'true' tasks = get_all_tasks() active_tasks = [] @@ -447,7 +349,7 @@ def list_tasks(): elif is_active_task: task_counts["active"] += 1 - task_response = _build_task_response(task_info, last_status, task_id, time.time()) + task_response = _build_task_response(task_info, last_status, task_id, time.time(), request) if is_active_task: active_tasks.append(task_response) @@ -509,75 +411,14 @@ def list_tasks(): "timestamp": time.time() } - return jsonify(response) + return response except Exception as e: logger.error(f"Error in /api/prgs/list: {e}", exc_info=True) - return jsonify({"error": "Failed to retrieve task list"}), 500 - -@prgs_bp.route("/cancel/", methods=["POST"]) -def cancel_task_endpoint(task_id): - """ - Cancel a running or queued task. - - Args: - task_id: The ID of the task to cancel - """ - try: - # First check if this is a task ID in the new system - task_info = get_task_info(task_id) - - if task_info: - # This is a task ID in the new system - result = cancel_task(task_id) - return jsonify(result) - - # If not found in new system, we need to handle the old system cancellation - # For now, return an error as we're transitioning to the new system - return jsonify( - { - "status": "error", - "message": "Cancellation for old system is not supported in the new API. Please use the new task ID format.", - } - ), 400 - except Exception as e: - abort(500, f"An error occurred: {e}") + raise HTTPException(status_code=500, detail={"error": "Failed to retrieve task list"}) -@prgs_bp.route("/cancel/all", methods=["POST"]) -def cancel_all_tasks(): - """ - Cancel all active (running or queued) tasks. - """ - try: - tasks_to_cancel = get_all_tasks() - cancelled_count = 0 - errors = [] - - for task_summary in tasks_to_cancel: - task_id = task_summary.get("task_id") - if not task_id: - continue - try: - cancel_task(task_id) - cancelled_count += 1 - except Exception as e: - error_message = f"Failed to cancel task {task_id}: {e}" - logger.error(error_message) - errors.append(error_message) - - response = { - "message": f"Attempted to cancel all active tasks. {cancelled_count} tasks cancelled.", - "cancelled_count": cancelled_count, - "errors": errors, - } - return jsonify(response), 200 - except Exception as e: - logger.error(f"Error in /api/prgs/cancel/all: {e}", exc_info=True) - return jsonify({"error": "Failed to cancel all tasks"}), 500 - - -@prgs_bp.route("/updates", methods=["GET"]) -def get_task_updates(): +@router.get("/updates") +async def get_task_updates(request: Request): """ Retrieve only tasks that have been updated since the specified timestamp. This endpoint is optimized for polling to reduce unnecessary data transfer. @@ -598,19 +439,20 @@ def get_task_updates(): """ try: # Get query parameters - since_param = request.args.get('since') - page = int(request.args.get('page', 1)) - limit = min(int(request.args.get('limit', 20)), 100) # Cap at 100 - active_only = request.args.get('active_only', '').lower() == 'true' + since_param = request.query_params.get('since') + page = int(request.query_params.get('page', 1)) + limit = min(int(request.query_params.get('limit', 20)), 100) # Cap at 100 + active_only = request.query_params.get('active_only', '').lower() == 'true' if not since_param: # If no 'since' parameter, return paginated tasks (fallback behavior) - return get_paginated_tasks(page, limit, active_only) + response = await get_paginated_tasks(page, limit, active_only, request) + return response try: since_timestamp = float(since_param) except (ValueError, TypeError): - return jsonify({"error": "Invalid 'since' timestamp format"}), 400 + raise HTTPException(status_code=400, detail={"error": "Invalid 'since' timestamp format"}) # Get all tasks all_tasks = get_all_tasks() @@ -668,7 +510,7 @@ def get_task_updates(): if should_include: # Construct the same detailed task object as in list_tasks() - task_response = _build_task_response(task_info, last_status, task_id, current_time) + task_response = _build_task_response(task_info, last_status, task_id, current_time, request) if is_active_task: active_tasks.append(task_response) @@ -707,8 +549,437 @@ def get_task_updates(): } logger.debug(f"Returning {len(active_tasks)} active + {len(paginated_updated_tasks)} paginated tasks out of {len(all_tasks)} total") - return jsonify(response) + return response + except HTTPException: + raise except Exception as e: logger.error(f"Error in /api/prgs/updates: {e}", exc_info=True) - return jsonify({"error": "Failed to retrieve task updates"}), 500 + raise HTTPException(status_code=500, detail={"error": "Failed to retrieve task updates"}) + + +@router.post("/cancel/all") +async def cancel_all_tasks(): + """ + Cancel all active (running or queued) tasks. + """ + try: + tasks_to_cancel = get_all_tasks() + cancelled_count = 0 + errors = [] + + for task_summary in tasks_to_cancel: + task_id = task_summary.get("task_id") + if not task_id: + continue + try: + cancel_task(task_id) + cancelled_count += 1 + except Exception as e: + error_message = f"Failed to cancel task {task_id}: {e}" + logger.error(error_message) + errors.append(error_message) + + response = { + "message": f"Attempted to cancel all active tasks. {cancelled_count} tasks cancelled.", + "cancelled_count": cancelled_count, + "errors": errors, + } + return response + except Exception as e: + logger.error(f"Error in /api/prgs/cancel/all: {e}", exc_info=True) + raise HTTPException(status_code=500, detail={"error": "Failed to cancel all tasks"}) + + +@router.post("/cancel/{task_id}") +async def cancel_task_endpoint(task_id: str): + """ + Cancel a running or queued task. + + Args: + task_id: The ID of the task to cancel + """ + try: + # First check if this is a task ID in the new system + task_info = get_task_info(task_id) + + if task_info: + # This is a task ID in the new system + result = cancel_task(task_id) + return result + + # If not found in new system, we need to handle the old system cancellation + # For now, return an error as we're transitioning to the new system + raise HTTPException( + status_code=400, + detail={ + "status": "error", + "message": "Cancellation for old system is not supported in the new API. Please use the new task ID format.", + } + ) + except HTTPException: + raise + except Exception as e: + raise HTTPException(status_code=500, detail=f"An error occurred: {e}") + + +@router.delete("/delete/{task_id}") +async def delete_task(task_id: str): + """ + Delete a task's information and history. + + Args: + task_id: A task UUID from Celery + """ + # Only support new task IDs + task_info = get_task_info(task_id) + if not task_info: + raise HTTPException(status_code=404, detail="Task not found") + + # First, cancel the task if it's running + cancel_task(task_id) + + return {"message": f"Task {task_id} deleted successfully"} + + +@router.get("/stream") +async def stream_task_updates(request: Request): + """ + Stream real-time task updates via Server-Sent Events (SSE). + + This endpoint provides continuous updates for task status changes without polling. + Clients can connect and receive instant notifications when tasks update. + + Query parameters: + active_only (bool): If true, only stream active tasks (downloading, processing, etc.) + + Returns: + Server-Sent Events stream with task update data in JSON format + """ + + # Get query parameters + active_only = request.query_params.get('active_only', '').lower() == 'true' + + async def event_generator(): + # Track last update timestamp for this client connection + last_update_timestamp = time.time() + + try: + # Send initial data immediately upon connection + yield await generate_task_update_event(last_update_timestamp, active_only, request) + last_update_timestamp = time.time() + + # Continuous monitoring loop + while True: + try: + # Check for updates since last timestamp + current_time = time.time() + + # Get all tasks and check for updates + all_tasks = get_all_tasks() + updated_tasks = [] + active_tasks = [] + + # Task categorization counters + task_counts = { + "active": 0, + "queued": 0, + "completed": 0, + "error": 0, + "cancelled": 0, + "retrying": 0, + "skipped": 0 + } + + has_updates = False + + for task_summary in all_tasks: + task_id = task_summary.get("task_id") + if not task_id: + continue + + task_info = get_task_info(task_id) + if not task_info: + continue + + last_status = get_last_task_status(task_id) + + # Check if task has been updated since the given timestamp + task_timestamp = last_status.get("timestamp") if last_status else task_info.get("created_at", 0) + + # Determine task status and categorize + task_status = get_task_status_from_last_status(last_status) + is_active_task = is_task_active(task_status) + + # Categorize tasks by status using ProgressState constants + if task_status == ProgressState.RETRYING: + task_counts["retrying"] += 1 + elif task_status in {ProgressState.QUEUED, "pending"}: + task_counts["queued"] += 1 + elif task_status in {ProgressState.COMPLETE, ProgressState.DONE}: + task_counts["completed"] += 1 + elif task_status == ProgressState.ERROR: + task_counts["error"] += 1 + elif task_status == ProgressState.CANCELLED: + task_counts["cancelled"] += 1 + elif task_status == ProgressState.SKIPPED: + task_counts["skipped"] += 1 + elif is_active_task: + task_counts["active"] += 1 + + # Always include active tasks in updates, apply filtering to others + should_include = is_active_task or (task_timestamp > last_update_timestamp and not active_only) + + if should_include: + has_updates = True + # Construct the same detailed task object as in updates endpoint + task_response = _build_task_response(task_info, last_status, task_id, current_time, request) + + if is_active_task: + active_tasks.append(task_response) + else: + updated_tasks.append(task_response) + + # Only send update if there are changes + if has_updates: + # Combine active tasks (always shown) with updated tasks + all_returned_tasks = active_tasks + updated_tasks + + # Sort by priority (active first, then by creation time) + all_returned_tasks.sort(key=lambda x: ( + 0 if x.get("task_id") in [t["task_id"] for t in active_tasks] else 1, + -x.get("created_at", 0) + )) + + update_data = { + "tasks": all_returned_tasks, + "current_timestamp": current_time, + "total_tasks": task_counts["active"] + task_counts["retrying"], + "all_tasks_count": len(all_tasks), + "task_counts": task_counts, + "active_tasks": len(active_tasks), + "updated_count": len(updated_tasks), + "since_timestamp": last_update_timestamp, + } + + # Send SSE event with update data + event_data = json.dumps(update_data) + yield f"data: {event_data}\n\n" + + logger.debug(f"SSE: Sent {len(active_tasks)} active + {len(updated_tasks)} updated tasks") + + # Update last timestamp + last_update_timestamp = current_time + + # Wait before next check (much shorter than polling interval) + await asyncio.sleep(0.5) # Check every 500ms for real-time feel + + except Exception as e: + logger.error(f"Error in SSE event generation: {e}", exc_info=True) + # Send error event and continue + error_data = json.dumps({"error": "Internal server error", "timestamp": time.time()}) + yield f"data: {error_data}\n\n" + await asyncio.sleep(1) # Wait longer on error + + except asyncio.CancelledError: + logger.info("SSE client disconnected") + return + except Exception as e: + logger.error(f"SSE connection error: {e}", exc_info=True) + return + + return StreamingResponse( + event_generator(), + media_type="text/plain", + headers={ + "Cache-Control": "no-cache", + "Connection": "keep-alive", + "Content-Type": "text/event-stream", + "Access-Control-Allow-Origin": "*", + "Access-Control-Allow-Headers": "Cache-Control" + } + ) + + +async def generate_task_update_event(since_timestamp: float, active_only: bool, request: Request) -> str: + """ + Generate initial task update event for SSE connection. + This replicates the logic from get_task_updates but for SSE format. + """ + try: + # Get all tasks + all_tasks = get_all_tasks() + updated_tasks = [] + active_tasks = [] + current_time = time.time() + + # Task categorization counters + task_counts = { + "active": 0, + "queued": 0, + "completed": 0, + "error": 0, + "cancelled": 0, + "retrying": 0, + "skipped": 0 + } + + for task_summary in all_tasks: + task_id = task_summary.get("task_id") + if not task_id: + continue + + task_info = get_task_info(task_id) + if not task_info: + continue + + last_status = get_last_task_status(task_id) + + # Check if task has been updated since the given timestamp + task_timestamp = last_status.get("timestamp") if last_status else task_info.get("created_at", 0) + + # Determine task status and categorize + task_status = get_task_status_from_last_status(last_status) + is_active_task = is_task_active(task_status) + + # Categorize tasks by status using ProgressState constants + if task_status == ProgressState.RETRYING: + task_counts["retrying"] += 1 + elif task_status in {ProgressState.QUEUED, "pending"}: + task_counts["queued"] += 1 + elif task_status in {ProgressState.COMPLETE, ProgressState.DONE}: + task_counts["completed"] += 1 + elif task_status == ProgressState.ERROR: + task_counts["error"] += 1 + elif task_status == ProgressState.CANCELLED: + task_counts["cancelled"] += 1 + elif task_status == ProgressState.SKIPPED: + task_counts["skipped"] += 1 + elif is_active_task: + task_counts["active"] += 1 + + # Always include active tasks in updates, apply filtering to others + should_include = is_active_task or (task_timestamp > since_timestamp and not active_only) + + if should_include: + # Construct the same detailed task object as in updates endpoint + task_response = _build_task_response(task_info, last_status, task_id, current_time, request) + + if is_active_task: + active_tasks.append(task_response) + else: + updated_tasks.append(task_response) + + # Combine active tasks (always shown) with updated tasks + all_returned_tasks = active_tasks + updated_tasks + + # Sort by priority (active first, then by creation time) + all_returned_tasks.sort(key=lambda x: ( + 0 if x.get("task_id") in [t["task_id"] for t in active_tasks] else 1, + -x.get("created_at", 0) + )) + + initial_data = { + "tasks": all_returned_tasks, + "current_timestamp": current_time, + "total_tasks": task_counts["active"] + task_counts["retrying"], + "all_tasks_count": len(all_tasks), + "task_counts": task_counts, + "active_tasks": len(active_tasks), + "updated_count": len(updated_tasks), + "since_timestamp": since_timestamp, + "initial": True # Mark as initial load + } + + event_data = json.dumps(initial_data) + return f"data: {event_data}\n\n" + + except Exception as e: + logger.error(f"Error generating initial SSE event: {e}", exc_info=True) + error_data = json.dumps({"error": "Failed to load initial data", "timestamp": time.time()}) + return f"data: {error_data}\n\n" + + +# IMPORTANT: This parameterized route MUST come AFTER all specific routes +# Otherwise FastAPI will match specific routes like "/updates" as task_id parameters +@router.get("/{task_id}") +async def get_task_details(task_id: str, request: Request): + """ + Return a JSON object with the resource type, its name (title), + the last progress update, and, if available, the original request parameters. + + This function works with the new task ID based system. + + Args: + task_id: A task UUID from Celery + """ + # Only support new task IDs + task_info = get_task_info(task_id) + if not task_info: + raise HTTPException(status_code=404, detail="Task not found") + + # Dynamically construct original_url + dynamic_original_url = "" + download_type = task_info.get("download_type") + # The 'url' field in task_info stores the Spotify/Deezer URL of the item + # e.g., https://open.spotify.com/album/albumId or https://www.deezer.com/track/trackId + item_url = task_info.get("url") + + if download_type and item_url: + try: + # Extract the ID from the item_url (last part of the path) + item_id = item_url.split("/")[-1] + if item_id: # Ensure item_id is not empty + base_url = str(request.base_url).rstrip("/") + dynamic_original_url = ( + f"{base_url}/api/{download_type}/download/{item_id}" + ) + else: + logger.warning( + f"Could not extract item ID from URL: {item_url} for task {task_id}. Falling back for original_url." + ) + original_request_obj = task_info.get("original_request", {}) + dynamic_original_url = original_request_obj.get("original_url", "") + except Exception as e: + logger.error( + f"Error constructing dynamic original_url for task {task_id}: {e}", + exc_info=True, + ) + original_request_obj = task_info.get("original_request", {}) + dynamic_original_url = original_request_obj.get( + "original_url", "" + ) # Fallback on any error + else: + logger.warning( + f"Missing download_type ('{download_type}') or item_url ('{item_url}') in task_info for task {task_id}. Falling back for original_url." + ) + original_request_obj = task_info.get("original_request", {}) + dynamic_original_url = original_request_obj.get("original_url", "") + + last_status = get_last_task_status(task_id) + status_count = len(get_task_status(task_id)) + + # Determine last_line content + if last_status and "raw_callback" in last_status: + last_line_content = last_status["raw_callback"] + elif last_status and get_task_status_from_last_status(last_status) == "error": + last_line_content = _build_error_callback_object(last_status) + else: + # Fallback for non-error, no raw_callback, or if last_status is None + last_line_content = last_status + + response = { + "original_url": dynamic_original_url, + "last_line": last_line_content, + "timestamp": last_status.get("timestamp") if last_status else time.time(), + "task_id": task_id, + "status_count": status_count, + "created_at": task_info.get("created_at"), + "name": task_info.get("name"), + "artist": task_info.get("artist"), + "type": task_info.get("type"), + "download_type": task_info.get("download_type"), + } + if last_status and last_status.get("summary"): + response["summary"] = last_status["summary"] + return response diff --git a/routes/search.py b/routes/search.py index ced9a08..2410e06 100755 --- a/routes/search.py +++ b/routes/search.py @@ -1,71 +1,65 @@ -from flask import Blueprint, jsonify, request -from routes.utils.search import search # Corrected import -from routes.config import get_config # Import get_config function +from fastapi import APIRouter, HTTPException, Request +import json +import traceback +import logging +from routes.utils.search import search -search_bp = Blueprint("search", __name__) +logger = logging.getLogger(__name__) +router = APIRouter() -@search_bp.route("/search", methods=["GET"]) -def handle_search(): +@router.get("/search") +async def handle_search(request: Request): + """ + Handle search requests for tracks, albums, playlists, or artists. + Frontend compatible endpoint that returns results in { items: [] } format. + """ + query = request.query_params.get("q") + # Frontend sends 'search_type', so check both 'search_type' and 'type' + search_type = request.query_params.get("search_type") or request.query_params.get("type", "track") + limit = request.query_params.get("limit", "20") + main = request.query_params.get("main") # Account context + + if not query: + raise HTTPException(status_code=400, detail={"error": "Missing parameter: q"}) + try: - # Get query parameters - query = request.args.get("q", "") - search_type = request.args.get("search_type", "") - limit = int(request.args.get("limit", 10)) - main = request.args.get( - "main", "" - ) # Get the main parameter for account selection + limit = int(limit) + except ValueError: + raise HTTPException(status_code=400, detail={"error": "limit must be an integer"}) - # If main parameter is not provided in the request, get it from config - if not main: - config = get_config() - if config and "spotify" in config: - main = config["spotify"] - print(f"Using main from config: {main}") - - # Validate parameters - if not query: - return jsonify({"error": "Missing search query"}), 400 - - valid_types = ["track", "album", "artist", "playlist", "episode"] - if search_type not in valid_types: - return jsonify({"error": "Invalid search type"}), 400 - - # Perform the search with corrected parameter name - raw_results = search( + try: + # Use the single search_type (not multiple types like before) + result = search( query=query, - search_type=search_type, # Fixed parameter name + search_type=search_type, limit=limit, - main=main, # Pass the main parameter + main=main ) - - # Extract items from the appropriate section of the response based on search_type + + # Extract items from the Spotify API response based on search type + # Spotify API returns results in format like { "tracks": { "items": [...] } } items = [] - if raw_results and search_type + "s" in raw_results: - type_key = search_type + "s" - items = raw_results[type_key].get("items", []) - elif raw_results and search_type in raw_results: - items = raw_results[search_type].get("items", []) - - # Filter out any null items from the results - if items: - items = [item for item in items if item is not None] - - # Return both the items array and the full data for debugging - return jsonify( - { - "items": items, - "data": raw_results, # Include full data for debugging - "error": None, - } - ) - - except ValueError as e: - print(f"ValueError in search: {str(e)}") - return jsonify({"error": str(e)}), 400 + + # Map search types to their plural forms in Spotify response + type_mapping = { + "track": "tracks", + "album": "albums", + "artist": "artists", + "playlist": "playlists", + "episode": "episodes", + "show": "shows" + } + + response_key = type_mapping.get(search_type.lower(), "tracks") + + if result and response_key in result: + items = result[response_key].get("items", []) + + # Return in the format expected by frontend: { items: [] } + return {"items": items} + except Exception as e: - import traceback - - print(f"Exception in search: {str(e)}") - print(traceback.format_exc()) - return jsonify({"error": f"Internal server error: {str(e)}"}), 500 + error_data = {"error": str(e), "traceback": traceback.format_exc()} + logger.error(f"Error in search: {error_data}") + raise HTTPException(status_code=500, detail=error_data) diff --git a/routes/track.py b/routes/track.py index 4057268..9dd14ec 100755 --- a/routes/track.py +++ b/routes/track.py @@ -1,21 +1,15 @@ -from flask import Blueprint, Response, request +from fastapi import APIRouter, HTTPException, Request +from fastapi.responses import JSONResponse import json import traceback -import uuid # For generating error task IDs -import time # For timestamps -from routes.utils.celery_queue_manager import ( - download_queue_manager, - get_existing_task_id, -) -from routes.utils.celery_tasks import ( - store_task_info, - store_task_status, - ProgressState, -) # For error task creation -from urllib.parse import urlparse # for URL validation -from routes.utils.get_info import get_spotify_info # Added import +import uuid +import time +from routes.utils.celery_queue_manager import download_queue_manager +from routes.utils.celery_tasks import store_task_info, store_task_status, ProgressState +from routes.utils.get_info import get_spotify_info +from routes.utils.errors import DuplicateDownloadError -track_bp = Blueprint("track", __name__) +router = APIRouter() def construct_spotify_url(item_id: str, item_type: str = "track") -> str: @@ -23,16 +17,14 @@ def construct_spotify_url(item_id: str, item_type: str = "track") -> str: return f"https://open.spotify.com/{item_type}/{item_id}" -@track_bp.route("/download/", methods=["GET"]) -def handle_download(track_id): +@router.get("/download/{track_id}") +async def handle_download(track_id: str, request: Request): # Retrieve essential parameters from the request. # name = request.args.get('name') # Removed # artist = request.args.get('artist') # Removed - orig_params = request.args.to_dict() # Construct the URL from track_id url = construct_spotify_url(track_id, "track") - orig_params["original_url"] = url # Update original_url to the constructed one # Fetch metadata from Spotify try: @@ -42,12 +34,9 @@ def handle_download(track_id): or not track_info.get("name") or not track_info.get("artists") ): - return Response( - json.dumps( - {"error": f"Could not retrieve metadata for track ID: {track_id}"} - ), - status=404, - mimetype="application/json", + return JSONResponse( + content={"error": f"Could not retrieve metadata for track ID: {track_id}"}, + status_code=404 ) name_from_spotify = track_info.get("name") @@ -58,72 +47,53 @@ def handle_download(track_id): ) except Exception as e: - return Response( - json.dumps( - {"error": f"Failed to fetch metadata for track {track_id}: {str(e)}"} - ), - status=500, - mimetype="application/json", + return JSONResponse( + content={"error": f"Failed to fetch metadata for track {track_id}: {str(e)}"}, + status_code=500 ) # Validate required parameters if not url: - return Response( - json.dumps( - {"error": "Missing required parameter: url", "original_url": url} - ), - status=400, - mimetype="application/json", - ) - # Validate URL domain - parsed = urlparse(url) - host = parsed.netloc.lower() - if not ( - host.endswith("deezer.com") - or host.endswith("open.spotify.com") - or host.endswith("spotify.com") - ): - return Response( - json.dumps({"error": f"Invalid Link {url} :(", "original_url": url}), - status=400, - mimetype="application/json", - ) - - # Check for existing task before adding to the queue - existing_task = get_existing_task_id(url) - if existing_task: - return Response( - json.dumps( - { - "error": "Duplicate download detected.", - "existing_task": existing_task, - } - ), - status=409, - mimetype="application/json", + return JSONResponse( + content={"error": "Missing required parameter: url"}, + status_code=400 ) + # Add the task to the queue with only essential parameters + # The queue manager will now handle all config parameters + # Include full original request URL in metadata + orig_params = dict(request.query_params) + orig_params["original_url"] = str(request.url) try: task_id = download_queue_manager.add_task( { "download_type": "track", "url": url, - "name": name_from_spotify, # Use fetched name - "artist": artist_from_spotify, # Use fetched artist + "name": name_from_spotify, + "artist": artist_from_spotify, "orig_request": orig_params, } ) - # Removed DuplicateDownloadError handling, add_task now manages this by creating an error task. + except DuplicateDownloadError as e: + return JSONResponse( + content={ + "error": "Duplicate download detected.", + "existing_task": e.existing_task, + }, + status_code=409 + ) except Exception as e: # Generic error handling for other issues during task submission + # Create an error task ID if add_task itself fails before returning an ID error_task_id = str(uuid.uuid4()) + store_task_info( error_task_id, { "download_type": "track", "url": url, - "name": name_from_spotify, # Use fetched name - "artist": artist_from_spotify, # Use fetched artist + "name": name_from_spotify, + "artist": artist_from_spotify, "original_request": orig_params, "created_at": time.time(), "is_submission_error_task": True, @@ -137,65 +107,57 @@ def handle_download(track_id): "timestamp": time.time(), }, ) - return Response( - json.dumps( - { - "error": f"Failed to queue track download: {str(e)}", - "task_id": error_task_id, - } - ), - status=500, - mimetype="application/json", + return JSONResponse( + content={ + "error": f"Failed to queue track download: {str(e)}", + "task_id": error_task_id, + }, + status_code=500 ) - return Response( - json.dumps({"task_id": task_id}), - status=202, - mimetype="application/json", + return JSONResponse( + content={"task_id": task_id}, + status_code=202 ) -@track_bp.route("/download/cancel", methods=["GET"]) -def cancel_download(): +@router.get("/download/cancel") +async def cancel_download(request: Request): """ - Cancel a running track download process by its task id. + Cancel a running download process by its task id. """ - task_id = request.args.get("task_id") + task_id = request.query_params.get("task_id") if not task_id: - return Response( - json.dumps({"error": "Missing task id (task_id) parameter"}), - status=400, - mimetype="application/json", + return JSONResponse( + content={"error": "Missing process id (task_id) parameter"}, + status_code=400 ) # Use the queue manager's cancellation method. result = download_queue_manager.cancel_task(task_id) status_code = 200 if result.get("status") == "cancelled" else 404 - return Response(json.dumps(result), status=status_code, mimetype="application/json") + return JSONResponse(content=result, status_code=status_code) -@track_bp.route("/info", methods=["GET"]) -def get_track_info(): +@router.get("/info") +async def get_track_info(request: Request): """ Retrieve Spotify track metadata given a Spotify track ID. Expects a query parameter 'id' that contains the Spotify track ID. """ - spotify_id = request.args.get("id") + spotify_id = request.query_params.get("id") if not spotify_id: - return Response( - json.dumps({"error": "Missing parameter: id"}), - status=400, - mimetype="application/json", + return JSONResponse( + content={"error": "Missing parameter: id"}, + status_code=400 ) try: - # Import and use the get_spotify_info function from the utility module. - from routes.utils.get_info import get_spotify_info - + # Use the get_spotify_info function (already imported at top) track_info = get_spotify_info(spotify_id, "track") - return Response(json.dumps(track_info), status=200, mimetype="application/json") + return JSONResponse(content=track_info, status_code=200) except Exception as e: error_data = {"error": str(e), "traceback": traceback.format_exc()} - return Response(json.dumps(error_data), status=500, mimetype="application/json") + return JSONResponse(content=error_data, status_code=500) diff --git a/spotizerr-ui/src/contexts/QueueProvider.tsx b/spotizerr-ui/src/contexts/QueueProvider.tsx index e68bf01..d687c44 100644 --- a/spotizerr-ui/src/contexts/QueueProvider.tsx +++ b/spotizerr-ui/src/contexts/QueueProvider.tsx @@ -43,10 +43,12 @@ export function QueueProvider({ children }: { children: ReactNode }) { const pollingIntervals = useRef>({}); const cancelledRemovalTimers = useRef>({}); - // Smart polling state - const smartPollingInterval = useRef(null); - const lastUpdateTimestamp = useRef(0); + // SSE connection state + const sseConnection = useRef(null); const isInitialized = useRef(false); + const reconnectTimeoutRef = useRef(null); + const maxReconnectAttempts = 5; + const reconnectAttempts = useRef(0); // Pagination state const [currentPage, setCurrentPage] = useState(1); @@ -150,97 +152,135 @@ export function QueueProvider({ children }: { children: ReactNode }) { }, [scheduleCancelledTaskRemoval]); const startSmartPolling = useCallback(() => { - if (smartPollingInterval.current) return; // Already polling + if (sseConnection.current) return; // Already connected - console.log("Starting smart polling"); + console.log("Starting SSE connection"); - const intervalId = window.setInterval(async () => { + const connectSSE = () => { try { - const response = await apiClient.get<{ - tasks: any[]; - current_timestamp: number; - total_tasks: number; - active_tasks: number; - updated_count: number; - task_counts?: { - active: number; - queued: number; - retrying: number; - completed: number; - error: number; - cancelled: number; - skipped: number; - }; - }>(`/prgs/updates?since=${lastUpdateTimestamp.current}&active_only=true`); + // Create SSE connection + const eventSource = new EventSource(`/api/prgs/stream?active_only=true`); + sseConnection.current = eventSource; - const { tasks: updatedTasks, current_timestamp, total_tasks, task_counts } = response.data; - - // Update the last timestamp for next poll - lastUpdateTimestamp.current = current_timestamp; - - // Update total tasks count - use active + queued if task_counts available - const calculatedTotal = task_counts ? - (task_counts.active + task_counts.queued) : - (total_tasks || 0); - setTotalTasks(calculatedTotal); + eventSource.onopen = () => { + console.log("SSE connection established"); + reconnectAttempts.current = 0; // Reset reconnect attempts on successful connection + }; - if (updatedTasks.length > 0) { - console.log(`Smart polling: ${updatedTasks.length} tasks updated (${response.data.active_tasks} active) out of ${response.data.total_tasks} total`); - - // Create a map of updated tasks by task_id for efficient lookup - const updatedTasksMap = new Map(updatedTasks.map(task => [task.task_id, task])); - - setItems(prev => { - // Update existing items with new data, and add any new active tasks - const updatedItems = prev.map(item => { - const updatedTaskData = updatedTasksMap.get(item.taskId || item.id); - if (updatedTaskData) { - return updateItemFromPrgs(item, updatedTaskData); - } - return item; - }); + eventSource.onmessage = (event) => { + try { + const data = JSON.parse(event.data); + + // Handle error events + if (data.error) { + console.error("SSE error event:", data.error); + toast.error("Connection error: " + data.error); + return; + } - // Only add new active tasks that aren't in our current items and aren't in terminal state - const currentTaskIds = new Set(prev.map(item => item.taskId || item.id)); - const newActiveTasks = updatedTasks - .filter(task => { - const isNew = !currentTaskIds.has(task.task_id); - const status = task.last_line?.status_info?.status || task.last_line?.status || "unknown"; - const isActive = isActiveTaskStatus(status); - const isTerminal = ["completed", "error", "cancelled", "skipped", "done"].includes(status); - return isNew && isActive && !isTerminal; - }) - .map(task => { - const spotifyId = task.original_url?.split("/").pop() || ""; - const baseItem: QueueItem = { - id: task.task_id, - taskId: task.task_id, - name: task.name || "Unknown", - type: task.download_type || "track", - spotifyId: spotifyId, - status: "initializing", - artist: task.artist, - }; - return updateItemFromPrgs(baseItem, task); + const { tasks: updatedTasks, current_timestamp, total_tasks, task_counts } = data; + + // Update total tasks count - use active + queued if task_counts available + const calculatedTotal = task_counts ? + (task_counts.active + task_counts.queued) : + (total_tasks || 0); + setTotalTasks(calculatedTotal); + + if (updatedTasks && updatedTasks.length > 0) { + console.log(`SSE: ${updatedTasks.length} tasks updated (${data.active_tasks} active) out of ${data.total_tasks} total`); + + // Create a map of updated tasks by task_id for efficient lookup + const updatedTasksMap = new Map(updatedTasks.map((task: any) => [task.task_id, task])); + + setItems(prev => { + // Update existing items with new data, and add any new active tasks + const updatedItems = prev.map(item => { + const updatedTaskData = updatedTasksMap.get(item.taskId || item.id); + if (updatedTaskData) { + return updateItemFromPrgs(item, updatedTaskData); + } + return item; + }); + + // Only add new active tasks that aren't in our current items and aren't in terminal state + const currentTaskIds = new Set(prev.map(item => item.taskId || item.id)); + const newActiveTasks = updatedTasks + .filter((task: any) => { + const isNew = !currentTaskIds.has(task.task_id); + const status = task.last_line?.status_info?.status || task.last_line?.status || "unknown"; + const isActive = isActiveTaskStatus(status); + const isTerminal = ["completed", "error", "cancelled", "skipped", "done"].includes(status); + return isNew && isActive && !isTerminal; + }) + .map((task: any) => { + const spotifyId = task.original_url?.split("/").pop() || ""; + const baseItem: QueueItem = { + id: task.task_id, + taskId: task.task_id, + name: task.name || "Unknown", + type: task.download_type || "track", + spotifyId: spotifyId, + status: "initializing", + artist: task.artist, + }; + return updateItemFromPrgs(baseItem, task); + }); + + return newActiveTasks.length > 0 ? [...newActiveTasks, ...updatedItems] : updatedItems; }); + } + } catch (error) { + console.error("Failed to parse SSE message:", error); + } + }; + + eventSource.onerror = (error) => { + console.error("SSE connection error:", error); + + // Close the connection + eventSource.close(); + sseConnection.current = null; + + // Attempt to reconnect with exponential backoff + if (reconnectAttempts.current < maxReconnectAttempts) { + reconnectAttempts.current++; + const delay = Math.min(1000 * Math.pow(2, reconnectAttempts.current - 1), 30000); // Max 30 seconds + + console.log(`SSE reconnecting in ${delay}ms (attempt ${reconnectAttempts.current}/${maxReconnectAttempts})`); + + reconnectTimeoutRef.current = window.setTimeout(() => { + connectSSE(); + }, delay); + } else { + console.error("SSE max reconnection attempts reached"); + toast.error("Connection lost. Please refresh the page."); + } + }; - return newActiveTasks.length > 0 ? [...newActiveTasks, ...updatedItems] : updatedItems; - }); - } } catch (error) { - console.error("Smart polling failed:", error); + console.error("Failed to create SSE connection:", error); + toast.error("Failed to establish real-time connection"); } - }, 2000); // Poll every 2 seconds + }; - smartPollingInterval.current = intervalId; + connectSSE(); }, [updateItemFromPrgs]); const stopSmartPolling = useCallback(() => { - if (smartPollingInterval.current) { - console.log("Stopping smart polling"); - clearInterval(smartPollingInterval.current); - smartPollingInterval.current = null; + if (sseConnection.current) { + console.log("Closing SSE connection"); + sseConnection.current.close(); + sseConnection.current = null; } + + // Clear any pending reconnection timeout + if (reconnectTimeoutRef.current) { + clearTimeout(reconnectTimeoutRef.current); + reconnectTimeoutRef.current = null; + } + + // Reset reconnection attempts + reconnectAttempts.current = 0; }, []); const loadMoreTasks = useCallback(async () => { @@ -312,7 +352,7 @@ export function QueueProvider({ children }: { children: ReactNode }) { const startPolling = useCallback( (taskId: string) => { - // Legacy function - now just ensures smart polling is active + // Legacy function - now just ensures SSE connection is active startSmartPolling(); }, [startSmartPolling], @@ -373,10 +413,9 @@ export function QueueProvider({ children }: { children: ReactNode }) { setTotalTasks(calculatedTotal); // Set initial timestamp to current time - lastUpdateTimestamp.current = timestamp; isInitialized.current = true; - // Start smart polling for real-time updates + // Start SSE connection for real-time updates startSmartPolling(); } catch (error) { console.error("Failed to fetch queue from backend:", error); @@ -386,7 +425,7 @@ export function QueueProvider({ children }: { children: ReactNode }) { fetchQueue(); - // Cleanup function to stop polling when component unmounts + // Cleanup function to stop SSE connection when component unmounts return () => { stopSmartPolling(); // Clean up any remaining individual polling intervals (legacy cleanup)