Merge branch 'dev' into fixup-bulk-add-celery

This commit is contained in:
Phlogi
2025-08-27 21:17:47 +02:00
committed by GitHub
26 changed files with 1617 additions and 1141 deletions

View File

@@ -1,4 +1,4 @@
from fastapi import APIRouter, HTTPException, Depends, Request
from fastapi import APIRouter, HTTPException, Depends
from fastapi.security import HTTPBearer, HTTPAuthorizationCredentials
from pydantic import BaseModel
from typing import Optional, List
@@ -14,6 +14,7 @@ security = HTTPBearer(auto_error=False)
# Include SSO sub-router
try:
from .sso import router as sso_router
router.include_router(sso_router, tags=["sso"])
logging.info("SSO sub-router included in auth router")
except ImportError as e:
@@ -34,6 +35,7 @@ class RegisterRequest(BaseModel):
class CreateUserRequest(BaseModel):
"""Admin-only request to create users when registration is disabled"""
username: str
password: str
email: Optional[str] = None
@@ -42,17 +44,20 @@ class CreateUserRequest(BaseModel):
class RoleUpdateRequest(BaseModel):
"""Request to update user role"""
role: str
class PasswordChangeRequest(BaseModel):
"""Request to change user password"""
current_password: str
new_password: str
class AdminPasswordResetRequest(BaseModel):
"""Request for admin to reset user password"""
new_password: str
@@ -87,20 +92,20 @@ class AuthStatusResponse(BaseModel):
# Dependency to get current user
async def get_current_user(
credentials: HTTPAuthorizationCredentials = Depends(security)
credentials: HTTPAuthorizationCredentials = Depends(security),
) -> Optional[User]:
"""Get current user from JWT token"""
if not AUTH_ENABLED:
# When auth is disabled, return a mock admin user
return User(username="system", role="admin")
if not credentials:
return None
payload = token_manager.verify_token(credentials.credentials)
if not payload:
return None
user = user_manager.get_user(payload["username"])
return user
@@ -109,25 +114,22 @@ async def require_auth(current_user: User = Depends(get_current_user)) -> User:
"""Require authentication - raises HTTPException if not authenticated"""
if not AUTH_ENABLED:
return User(username="system", role="admin")
if not current_user:
raise HTTPException(
status_code=401,
detail="Authentication required",
headers={"WWW-Authenticate": "Bearer"},
)
return current_user
async def require_admin(current_user: User = Depends(require_auth)) -> User:
"""Require admin role - raises HTTPException if not admin"""
if current_user.role != "admin":
raise HTTPException(
status_code=403,
detail="Admin access required"
)
raise HTTPException(status_code=403, detail="Admin access required")
return current_user
@@ -138,24 +140,33 @@ async def auth_status(current_user: Optional[User] = Depends(get_current_user)):
# Check if SSO is enabled and get available providers
sso_enabled = False
sso_providers = []
try:
from . import sso
sso_enabled = sso.SSO_ENABLED and AUTH_ENABLED
if sso.google_sso:
sso_providers.append("google")
if sso.github_sso:
sso_providers.append("github")
if getattr(sso, "custom_sso", None):
sso_providers.append("custom")
if getattr(sso, "custom_sso_providers", None):
if (
len(getattr(sso, "custom_sso_providers", {})) > 0
and "custom" not in sso_providers
):
sso_providers.append("custom")
except ImportError:
pass # SSO module not available
return AuthStatusResponse(
auth_enabled=AUTH_ENABLED,
authenticated=current_user is not None,
user=UserResponse(**current_user.to_public_dict()) if current_user else None,
registration_enabled=AUTH_ENABLED and not DISABLE_REGISTRATION,
sso_enabled=sso_enabled,
sso_providers=sso_providers
sso_providers=sso_providers,
)
@@ -163,23 +174,16 @@ async def auth_status(current_user: Optional[User] = Depends(get_current_user)):
async def login(request: LoginRequest):
"""Authenticate user and return access token"""
if not AUTH_ENABLED:
raise HTTPException(
status_code=400,
detail="Authentication is disabled"
)
raise HTTPException(status_code=400, detail="Authentication is disabled")
user = user_manager.authenticate_user(request.username, request.password)
if not user:
raise HTTPException(
status_code=401,
detail="Invalid username or password"
)
raise HTTPException(status_code=401, detail="Invalid username or password")
access_token = token_manager.create_token(user)
return LoginResponse(
access_token=access_token,
user=UserResponse(**user.to_public_dict())
access_token=access_token, user=UserResponse(**user.to_public_dict())
)
@@ -187,31 +191,28 @@ async def login(request: LoginRequest):
async def register(request: RegisterRequest):
"""Register a new user"""
if not AUTH_ENABLED:
raise HTTPException(
status_code=400,
detail="Authentication is disabled"
)
raise HTTPException(status_code=400, detail="Authentication is disabled")
if DISABLE_REGISTRATION:
raise HTTPException(
status_code=403,
detail="Public registration is disabled. Contact an administrator to create an account."
detail="Public registration is disabled. Contact an administrator to create an account.",
)
# Check if this is the first user (should be admin)
existing_users = user_manager.list_users()
role = "admin" if len(existing_users) == 0 else "user"
success, message = user_manager.create_user(
username=request.username,
password=request.password,
email=request.email,
role=role
role=role,
)
if not success:
raise HTTPException(status_code=400, detail=message)
return MessageResponse(message=message)
@@ -233,70 +234,57 @@ async def list_users(current_user: User = Depends(require_admin)):
async def delete_user(username: str, current_user: User = Depends(require_admin)):
"""Delete a user (admin only)"""
if username == current_user.username:
raise HTTPException(
status_code=400,
detail="Cannot delete your own account"
)
raise HTTPException(status_code=400, detail="Cannot delete your own account")
success, message = user_manager.delete_user(username)
if not success:
raise HTTPException(status_code=404, detail=message)
return MessageResponse(message=message)
@router.put("/users/{username}/role", response_model=MessageResponse)
async def update_user_role(
username: str,
request: RoleUpdateRequest,
current_user: User = Depends(require_admin)
username: str,
request: RoleUpdateRequest,
current_user: User = Depends(require_admin),
):
"""Update user role (admin only)"""
if request.role not in ["user", "admin"]:
raise HTTPException(
status_code=400,
detail="Role must be 'user' or 'admin'"
)
raise HTTPException(status_code=400, detail="Role must be 'user' or 'admin'")
if username == current_user.username:
raise HTTPException(
status_code=400,
detail="Cannot change your own role"
)
raise HTTPException(status_code=400, detail="Cannot change your own role")
success, message = user_manager.update_user_role(username, request.role)
if not success:
raise HTTPException(status_code=404, detail=message)
return MessageResponse(message=message)
@router.post("/users/create", response_model=MessageResponse)
async def create_user_admin(request: CreateUserRequest, current_user: User = Depends(require_admin)):
async def create_user_admin(
request: CreateUserRequest, current_user: User = Depends(require_admin)
):
"""Create a new user (admin only) - for use when registration is disabled"""
if not AUTH_ENABLED:
raise HTTPException(
status_code=400,
detail="Authentication is disabled"
)
raise HTTPException(status_code=400, detail="Authentication is disabled")
# Validate role
if request.role not in ["user", "admin"]:
raise HTTPException(
status_code=400,
detail="Role must be 'user' or 'admin'"
)
raise HTTPException(status_code=400, detail="Role must be 'user' or 'admin'")
success, message = user_manager.create_user(
username=request.username,
password=request.password,
email=request.email,
role=request.role
role=request.role,
)
if not success:
raise HTTPException(status_code=400, detail=message)
return MessageResponse(message=message)
@@ -309,22 +297,18 @@ async def get_profile(current_user: User = Depends(require_auth)):
@router.put("/profile/password", response_model=MessageResponse)
async def change_password(
request: PasswordChangeRequest,
current_user: User = Depends(require_auth)
request: PasswordChangeRequest, current_user: User = Depends(require_auth)
):
"""Change current user's password"""
if not AUTH_ENABLED:
raise HTTPException(
status_code=400,
detail="Authentication is disabled"
)
raise HTTPException(status_code=400, detail="Authentication is disabled")
success, message = user_manager.change_password(
username=current_user.username,
current_password=request.current_password,
new_password=request.new_password
new_password=request.new_password,
)
if not success:
# Determine appropriate HTTP status code based on error message
if "Current password is incorrect" in message:
@@ -333,9 +317,9 @@ async def change_password(
status_code = 404
else:
status_code = 400
raise HTTPException(status_code=status_code, detail=message)
return MessageResponse(message=message)
@@ -343,30 +327,26 @@ async def change_password(
async def admin_reset_password(
username: str,
request: AdminPasswordResetRequest,
current_user: User = Depends(require_admin)
current_user: User = Depends(require_admin),
):
"""Admin reset user password (admin only)"""
if not AUTH_ENABLED:
raise HTTPException(
status_code=400,
detail="Authentication is disabled"
)
raise HTTPException(status_code=400, detail="Authentication is disabled")
success, message = user_manager.admin_reset_password(
username=username,
new_password=request.new_password
username=username, new_password=request.new_password
)
if not success:
# Determine appropriate HTTP status code based on error message
if "User not found" in message:
status_code = 404
else:
status_code = 400
raise HTTPException(status_code=status_code, detail=message)
return MessageResponse(message=message)
# Note: SSO routes are included in the main app, not here to avoid circular imports
# Note: SSO routes are included in the main app, not here to avoid circular imports

View File

@@ -1,17 +1,19 @@
"""
SSO (Single Sign-On) implementation for Google and GitHub authentication
"""
import os
import logging
from typing import Optional, Dict, Any
from datetime import datetime, timedelta
from fastapi import APIRouter, Request, HTTPException, Depends
from fastapi import APIRouter, Request, HTTPException
from fastapi.responses import RedirectResponse
from fastapi_sso.sso.google import GoogleSSO
from fastapi_sso.sso.github import GithubSSO
from fastapi_sso.sso.base import OpenID
from pydantic import BaseModel
from fastapi_sso.sso.generic import create_provider
from . import user_manager, token_manager, User, AUTH_ENABLED, DISABLE_REGISTRATION
@@ -25,11 +27,14 @@ GOOGLE_CLIENT_ID = os.getenv("GOOGLE_CLIENT_ID")
GOOGLE_CLIENT_SECRET = os.getenv("GOOGLE_CLIENT_SECRET")
GITHUB_CLIENT_ID = os.getenv("GITHUB_CLIENT_ID")
GITHUB_CLIENT_SECRET = os.getenv("GITHUB_CLIENT_SECRET")
SSO_BASE_REDIRECT_URI = os.getenv("SSO_BASE_REDIRECT_URI", "http://localhost:7171/api/auth/sso/callback")
SSO_BASE_REDIRECT_URI = os.getenv(
"SSO_BASE_REDIRECT_URI", "http://localhost:7171/api/auth/sso/callback"
)
# Initialize SSO providers
google_sso = None
github_sso = None
custom_sso = None
if GOOGLE_CLIENT_ID and GOOGLE_CLIENT_SECRET:
google_sso = GoogleSSO(
@@ -47,6 +52,154 @@ if GITHUB_CLIENT_ID and GITHUB_CLIENT_SECRET:
allow_insecure_http=True, # Set to False in production with HTTPS
)
# Custom/Generic OAuth provider configuration
CUSTOM_SSO_CLIENT_ID = os.getenv("CUSTOM_SSO_CLIENT_ID")
CUSTOM_SSO_CLIENT_SECRET = os.getenv("CUSTOM_SSO_CLIENT_SECRET")
CUSTOM_SSO_AUTHORIZATION_ENDPOINT = os.getenv("CUSTOM_SSO_AUTHORIZATION_ENDPOINT")
CUSTOM_SSO_TOKEN_ENDPOINT = os.getenv("CUSTOM_SSO_TOKEN_ENDPOINT")
CUSTOM_SSO_USERINFO_ENDPOINT = os.getenv("CUSTOM_SSO_USERINFO_ENDPOINT")
CUSTOM_SSO_SCOPE = os.getenv("CUSTOM_SSO_SCOPE") # comma-separated list
CUSTOM_SSO_NAME = os.getenv("CUSTOM_SSO_NAME", "custom")
CUSTOM_SSO_DISPLAY_NAME = os.getenv("CUSTOM_SSO_DISPLAY_NAME", "Custom")
def _default_custom_response_convertor(
userinfo: Dict[str, Any], _client=None
) -> OpenID:
"""Best-effort convertor from generic userinfo to OpenID."""
user_id = (
userinfo.get("sub")
or userinfo.get("id")
or userinfo.get("user_id")
or userinfo.get("uid")
or userinfo.get("uuid")
)
email = userinfo.get("email")
display_name = (
userinfo.get("name")
or userinfo.get("preferred_username")
or userinfo.get("login")
or email
or (str(user_id) if user_id is not None else None)
)
picture = userinfo.get("picture") or userinfo.get("avatar_url")
if not user_id and email:
user_id = email
return OpenID(
id=str(user_id) if user_id is not None else "",
email=email,
display_name=display_name,
picture=picture,
provider=CUSTOM_SSO_NAME,
)
if all(
[
CUSTOM_SSO_CLIENT_ID,
CUSTOM_SSO_CLIENT_SECRET,
CUSTOM_SSO_AUTHORIZATION_ENDPOINT,
CUSTOM_SSO_TOKEN_ENDPOINT,
CUSTOM_SSO_USERINFO_ENDPOINT,
]
):
discovery = {
"authorization_endpoint": CUSTOM_SSO_AUTHORIZATION_ENDPOINT,
"token_endpoint": CUSTOM_SSO_TOKEN_ENDPOINT,
"userinfo_endpoint": CUSTOM_SSO_USERINFO_ENDPOINT,
}
default_scope = (
[s.strip() for s in CUSTOM_SSO_SCOPE.split(",") if s.strip()]
if CUSTOM_SSO_SCOPE
else None
)
CustomProvider = create_provider(
name=CUSTOM_SSO_NAME,
discovery_document=discovery,
response_convertor=_default_custom_response_convertor,
default_scope=default_scope,
)
custom_sso = CustomProvider(
client_id=CUSTOM_SSO_CLIENT_ID,
client_secret=CUSTOM_SSO_CLIENT_SECRET,
redirect_uri=f"{SSO_BASE_REDIRECT_URI}/custom",
allow_insecure_http=True, # Set to False in production with HTTPS
)
# Support multiple indexed custom providers (CUSTOM_*_i), up to 10
custom_sso_providers: Dict[int, Dict[str, Any]] = {}
def _make_response_convertor(provider_name: str):
def _convert(userinfo: Dict[str, Any], _client=None) -> OpenID:
user_id = (
userinfo.get("sub")
or userinfo.get("id")
or userinfo.get("user_id")
or userinfo.get("uid")
or userinfo.get("uuid")
)
email = userinfo.get("email")
display_name = (
userinfo.get("name")
or userinfo.get("preferred_username")
or userinfo.get("login")
or email
or (str(user_id) if user_id is not None else None)
)
picture = userinfo.get("picture") or userinfo.get("avatar_url")
if not user_id and email:
user_id = email
return OpenID(
id=str(user_id) if user_id is not None else "",
email=email,
display_name=display_name,
picture=picture,
provider=provider_name,
)
return _convert
for i in range(1, 11):
cid = os.getenv(f"CUSTOM_SSO_CLIENT_ID_{i}")
csecret = os.getenv(f"CUSTOM_SSO_CLIENT_SECRET_{i}")
auth_ep = os.getenv(f"CUSTOM_SSO_AUTHORIZATION_ENDPOINT_{i}")
token_ep = os.getenv(f"CUSTOM_SSO_TOKEN_ENDPOINT_{i}")
userinfo_ep = os.getenv(f"CUSTOM_SSO_USERINFO_ENDPOINT_{i}")
scope_raw = os.getenv(f"CUSTOM_SSO_SCOPE_{i}")
name_i = os.getenv(f"CUSTOM_SSO_NAME_{i}", f"custom{i}")
display_name_i = os.getenv(f"CUSTOM_SSO_DISPLAY_NAME_{i}", f"Custom {i}")
if all([cid, csecret, auth_ep, token_ep, userinfo_ep]):
discovery_i = {
"authorization_endpoint": auth_ep,
"token_endpoint": token_ep,
"userinfo_endpoint": userinfo_ep,
}
default_scope_i = (
[s.strip() for s in scope_raw.split(",") if s.strip()]
if scope_raw
else None
)
ProviderClass = create_provider(
name=name_i,
discovery_document=discovery_i,
response_convertor=_make_response_convertor(name_i),
default_scope=default_scope_i,
)
provider_instance = ProviderClass(
client_id=cid,
client_secret=csecret,
redirect_uri=f"{SSO_BASE_REDIRECT_URI}/custom/{i}",
allow_insecure_http=True, # Set to False in production with HTTPS
)
custom_sso_providers[i] = {
"sso": provider_instance,
"name": name_i,
"display_name": display_name_i,
}
class MessageResponse(BaseModel):
message: str
@@ -70,21 +223,25 @@ def create_or_update_sso_user(openid: OpenID, provider: str) -> User:
# Generate username from email or use provider ID
email = openid.email
if not email:
raise HTTPException(status_code=400, detail="Email is required for SSO authentication")
raise HTTPException(
status_code=400, detail="Email is required for SSO authentication"
)
# Use email prefix as username, fallback to provider + id
username = email.split("@")[0]
if not username:
username = f"{provider}_{openid.id}"
# Check if user already exists by email
existing_user = None
users = user_manager.load_users()
for user_data in users.values():
if user_data.get("email") == email:
existing_user = User(**{k: v for k, v in user_data.items() if k != "password_hash"})
existing_user = User(
**{k: v for k, v in user_data.items() if k != "password_hash"}
)
break
if existing_user:
# Update last login for existing user (always allowed)
users[existing_user.username]["last_login"] = datetime.utcnow().isoformat()
@@ -96,10 +253,10 @@ def create_or_update_sso_user(openid: OpenID, provider: str) -> User:
# Check if registration is disabled before creating new user
if DISABLE_REGISTRATION:
raise HTTPException(
status_code=403,
detail="Registration is disabled. Contact an administrator to create an account."
status_code=403,
detail="Registration is disabled. Contact an administrator to create an account.",
)
# Create new user
# Ensure username is unique
counter = 1
@@ -107,20 +264,20 @@ def create_or_update_sso_user(openid: OpenID, provider: str) -> User:
while username in users:
username = f"{original_username}{counter}"
counter += 1
user = User(
username=username,
email=email,
role="user" # Default role for SSO users
role="user", # Default role for SSO users
)
users[username] = {
**user.to_dict(),
"sso_provider": provider,
"sso_id": openid.id,
"password_hash": None # SSO users don't have passwords
"password_hash": None, # SSO users don't have passwords
}
user_manager.save_users(users)
logger.info(f"Created SSO user: {username} via {provider}")
return user
@@ -130,27 +287,51 @@ def create_or_update_sso_user(openid: OpenID, provider: str) -> User:
async def sso_status():
"""Get SSO status and available providers"""
providers = []
if google_sso:
providers.append(SSOProvider(
name="google",
display_name="Google",
enabled=True,
login_url="/api/auth/sso/login/google"
))
providers.append(
SSOProvider(
name="google",
display_name="Google",
enabled=True,
login_url="/api/auth/sso/login/google",
)
)
if github_sso:
providers.append(SSOProvider(
name="github",
display_name="GitHub",
enabled=True,
login_url="/api/auth/sso/login/github"
))
providers.append(
SSOProvider(
name="github",
display_name="GitHub",
enabled=True,
login_url="/api/auth/sso/login/github",
)
)
if custom_sso:
providers.append(
SSOProvider(
name="custom",
display_name=CUSTOM_SSO_DISPLAY_NAME,
enabled=True,
login_url="/api/auth/sso/login/custom",
)
)
for idx, cfg in custom_sso_providers.items():
providers.append(
SSOProvider(
name=cfg["name"],
display_name=cfg.get("display_name", cfg["name"]),
enabled=True,
login_url=f"/api/auth/sso/login/custom/{idx}",
)
)
return SSOStatusResponse(
sso_enabled=SSO_ENABLED and AUTH_ENABLED,
providers=providers,
registration_enabled=not DISABLE_REGISTRATION
registration_enabled=not DISABLE_REGISTRATION,
)
@@ -159,12 +340,14 @@ async def google_login():
"""Initiate Google SSO login"""
if not SSO_ENABLED or not AUTH_ENABLED:
raise HTTPException(status_code=400, detail="SSO is disabled")
if not google_sso:
raise HTTPException(status_code=400, detail="Google SSO is not configured")
async with google_sso:
return await google_sso.get_login_redirect(params={"prompt": "consent", "access_type": "offline"})
return await google_sso.get_login_redirect(
params={"prompt": "consent", "access_type": "offline"}
)
@router.get("/sso/login/github")
@@ -172,37 +355,66 @@ async def github_login():
"""Initiate GitHub SSO login"""
if not SSO_ENABLED or not AUTH_ENABLED:
raise HTTPException(status_code=400, detail="SSO is disabled")
if not github_sso:
raise HTTPException(status_code=400, detail="GitHub SSO is not configured")
async with github_sso:
return await github_sso.get_login_redirect()
@router.get("/sso/login/custom")
async def custom_login():
"""Initiate Custom SSO login"""
if not SSO_ENABLED or not AUTH_ENABLED:
raise HTTPException(status_code=400, detail="SSO is disabled")
if not custom_sso:
raise HTTPException(status_code=400, detail="Custom SSO is not configured")
async with custom_sso:
return await custom_sso.get_login_redirect()
@router.get("/sso/login/custom/{index}")
async def custom_login_indexed(index: int):
"""Initiate indexed Custom SSO login"""
if not SSO_ENABLED or not AUTH_ENABLED:
raise HTTPException(status_code=400, detail="SSO is disabled")
cfg = custom_sso_providers.get(index)
if not cfg:
raise HTTPException(
status_code=400, detail="Custom SSO provider not configured"
)
async with cfg["sso"]:
return await cfg["sso"].get_login_redirect()
@router.get("/sso/callback/google")
async def google_callback(request: Request):
"""Handle Google SSO callback"""
if not SSO_ENABLED or not AUTH_ENABLED:
raise HTTPException(status_code=400, detail="SSO is disabled")
if not google_sso:
raise HTTPException(status_code=400, detail="Google SSO is not configured")
try:
async with google_sso:
openid = await google_sso.verify_and_process(request)
# Create or update user
user = create_or_update_sso_user(openid, "google")
# Create JWT token
access_token = token_manager.create_token(user)
# Redirect to frontend with token (you might want to customize this)
frontend_url = os.getenv("FRONTEND_URL", "http://localhost:3000")
response = RedirectResponse(url=f"{frontend_url}?token={access_token}")
# Also set as HTTP-only cookie
response.set_cookie(
key="access_token",
@@ -210,18 +422,18 @@ async def google_callback(request: Request):
httponly=True,
secure=False, # Set to True in production with HTTPS
samesite="lax",
max_age=timedelta(hours=24).total_seconds()
max_age=timedelta(hours=24).total_seconds(),
)
return response
except HTTPException as e:
# Handle specific HTTP exceptions (like registration disabled)
frontend_url = os.getenv("FRONTEND_URL", "http://localhost:3000")
error_msg = e.detail if hasattr(e, 'detail') else "Authentication failed"
error_msg = e.detail if hasattr(e, "detail") else "Authentication failed"
logger.warning(f"Google SSO callback error: {error_msg}")
return RedirectResponse(url=f"{frontend_url}?error={error_msg}")
except Exception as e:
logger.error(f"Google SSO callback error: {e}")
frontend_url = os.getenv("FRONTEND_URL", "http://localhost:3000")
@@ -233,24 +445,24 @@ async def github_callback(request: Request):
"""Handle GitHub SSO callback"""
if not SSO_ENABLED or not AUTH_ENABLED:
raise HTTPException(status_code=400, detail="SSO is disabled")
if not github_sso:
raise HTTPException(status_code=400, detail="GitHub SSO is not configured")
try:
async with github_sso:
openid = await github_sso.verify_and_process(request)
# Create or update user
user = create_or_update_sso_user(openid, "github")
# Create JWT token
access_token = token_manager.create_token(user)
# Redirect to frontend with token (you might want to customize this)
frontend_url = os.getenv("FRONTEND_URL", "http://localhost:3000")
response = RedirectResponse(url=f"{frontend_url}?token={access_token}")
# Also set as HTTP-only cookie
response.set_cookie(
key="access_token",
@@ -258,24 +470,123 @@ async def github_callback(request: Request):
httponly=True,
secure=False, # Set to True in production with HTTPS
samesite="lax",
max_age=timedelta(hours=24).total_seconds()
max_age=timedelta(hours=24).total_seconds(),
)
return response
except HTTPException as e:
# Handle specific HTTP exceptions (like registration disabled)
frontend_url = os.getenv("FRONTEND_URL", "http://localhost:3000")
error_msg = e.detail if hasattr(e, 'detail') else "Authentication failed"
error_msg = e.detail if hasattr(e, "detail") else "Authentication failed"
logger.warning(f"GitHub SSO callback error: {error_msg}")
return RedirectResponse(url=f"{frontend_url}?error={error_msg}")
except Exception as e:
logger.error(f"GitHub SSO callback error: {e}")
frontend_url = os.getenv("FRONTEND_URL", "http://localhost:3000")
return RedirectResponse(url=f"{frontend_url}?error=Authentication failed")
@router.get("/sso/callback/custom")
async def custom_callback(request: Request):
"""Handle Custom SSO callback"""
if not SSO_ENABLED or not AUTH_ENABLED:
raise HTTPException(status_code=400, detail="SSO is disabled")
if not custom_sso:
raise HTTPException(status_code=400, detail="Custom SSO is not configured")
try:
async with custom_sso:
openid = await custom_sso.verify_and_process(request)
# Create or update user
user = create_or_update_sso_user(openid, "custom")
# Create JWT token
access_token = token_manager.create_token(user)
# Redirect to frontend with token (you might want to customize this)
frontend_url = os.getenv("FRONTEND_URL", "http://localhost:3000")
response = RedirectResponse(url=f"{frontend_url}?token={access_token}")
# Also set as HTTP-only cookie
response.set_cookie(
key="access_token",
value=access_token,
httponly=True,
secure=False, # Set to True in production with HTTPS
samesite="lax",
max_age=timedelta(hours=24).total_seconds(),
)
return response
except HTTPException as e:
# Handle specific HTTP exceptions (like registration disabled)
frontend_url = os.getenv("FRONTEND_URL", "http://localhost:3000")
error_msg = e.detail if hasattr(e, "detail") else "Authentication failed"
logger.warning(f"Custom SSO callback error: {error_msg}")
return RedirectResponse(url=f"{frontend_url}?error={error_msg}")
except Exception as e:
logger.error(f"Custom SSO callback error: {e}")
frontend_url = os.getenv("FRONTEND_URL", "http://localhost:3000")
return RedirectResponse(url=f"{frontend_url}?error=Authentication failed")
@router.get("/sso/callback/custom/{index}")
async def custom_callback_indexed(request: Request, index: int):
"""Handle indexed Custom SSO callback"""
if not SSO_ENABLED or not AUTH_ENABLED:
raise HTTPException(status_code=400, detail="SSO is disabled")
cfg = custom_sso_providers.get(index)
if not cfg:
raise HTTPException(
status_code=400, detail="Custom SSO provider not configured"
)
try:
async with cfg["sso"]:
openid = await cfg["sso"].verify_and_process(request)
# Create or update user
user = create_or_update_sso_user(openid, cfg["name"])
# Create JWT token
access_token = token_manager.create_token(user)
# Redirect to frontend with token (you might want to customize this)
frontend_url = os.getenv("FRONTEND_URL", "http://localhost:3000")
response = RedirectResponse(url=f"{frontend_url}?token={access_token}")
# Also set as HTTP-only cookie
response.set_cookie(
key="access_token",
value=access_token,
httponly=True,
secure=False, # Set to True in production with HTTPS
samesite="lax",
max_age=timedelta(hours=24).total_seconds(),
)
return response
except HTTPException as e:
# Handle specific HTTP exceptions (like registration disabled)
frontend_url = os.getenv("FRONTEND_URL", "http://localhost:3000")
error_msg = e.detail if hasattr(e, "detail") else "Authentication failed"
logger.warning(f"Custom[{index}] SSO callback error: {error_msg}")
return RedirectResponse(url=f"{frontend_url}?error={error_msg}")
except Exception as e:
logger.error(f"Custom[{index}] SSO callback error: {e}")
frontend_url = os.getenv("FRONTEND_URL", "http://localhost:3000")
return RedirectResponse(url=f"{frontend_url}?error=Authentication failed")
@router.post("/sso/unlink/{provider}", response_model=MessageResponse)
async def unlink_sso_provider(
provider: str,
@@ -284,27 +595,42 @@ async def unlink_sso_provider(
"""Unlink SSO provider from user account"""
if not SSO_ENABLED or not AUTH_ENABLED:
raise HTTPException(status_code=400, detail="SSO is disabled")
if provider not in ["google", "github"]:
available = []
if google_sso:
available.append("google")
if github_sso:
available.append("github")
if custom_sso:
available.append("custom")
for cfg in custom_sso_providers.values():
available.append(cfg["name"])
if provider not in available:
raise HTTPException(status_code=400, detail="Invalid SSO provider")
# Get current user from request (avoiding circular imports)
from .middleware import require_auth_from_state
current_user = await require_auth_from_state(request)
if not current_user.sso_provider:
raise HTTPException(status_code=400, detail="User is not linked to any SSO provider")
raise HTTPException(
status_code=400, detail="User is not linked to any SSO provider"
)
if current_user.sso_provider != provider:
raise HTTPException(status_code=400, detail=f"User is not linked to {provider}")
# Update user to remove SSO linkage
users = user_manager.load_users()
if current_user.username in users:
users[current_user.username]["sso_provider"] = None
users[current_user.username]["sso_id"] = None
user_manager.save_users(users)
logger.info(f"Unlinked SSO provider {provider} from user {current_user.username}")
return MessageResponse(message=f"SSO provider {provider} unlinked successfully")
logger.info(
f"Unlinked SSO provider {provider} from user {current_user.username}"
)
return MessageResponse(message=f"SSO provider {provider} unlinked successfully")

View File

@@ -5,11 +5,12 @@ import uuid
import time
from routes.utils.celery_queue_manager import download_queue_manager
from routes.utils.celery_tasks import store_task_info, store_task_status, ProgressState
from routes.utils.get_info import get_spotify_info
from routes.utils.get_info import get_client, get_album
from routes.utils.errors import DuplicateDownloadError
# Import authentication dependencies
from routes.auth.middleware import require_auth_from_state, User
# Config and credentials helpers
router = APIRouter()
@@ -34,7 +35,8 @@ async def handle_download(
# Fetch metadata from Spotify
try:
album_info = get_spotify_info(album_id, "album")
client = get_client()
album_info = get_album(client, album_id)
if (
not album_info
or not album_info.get("name")
@@ -155,6 +157,7 @@ async def get_album_info(
"""
Retrieve Spotify album metadata given a Spotify album ID.
Expects a query parameter 'id' that contains the Spotify album ID.
Returns the raw JSON from get_album in routes.utils.get_info.
"""
spotify_id = request.query_params.get("id")
@@ -162,27 +165,9 @@ async def get_album_info(
return JSONResponse(content={"error": "Missing parameter: id"}, status_code=400)
try:
# Optional pagination params for tracks
limit_param = request.query_params.get("limit")
offset_param = request.query_params.get("offset")
limit = int(limit_param) if limit_param is not None else None
offset = int(offset_param) if offset_param is not None else None
# Fetch album metadata
album_info = get_spotify_info(spotify_id, "album")
# Fetch album tracks with pagination
album_tracks = get_spotify_info(
spotify_id, "album_tracks", limit=limit, offset=offset
)
# Merge tracks into album payload in the same shape Spotify returns on album
album_info["tracks"] = album_tracks
client = get_client()
album_info = get_album(client, spotify_id)
return JSONResponse(content=album_info, status_code=200)
except ValueError as ve:
return JSONResponse(
content={"error": f"Invalid limit/offset: {str(ve)}"}, status_code=400
)
except Exception as e:
error_data = {"error": str(e), "traceback": traceback.format_exc()}
return JSONResponse(content=error_data, status_code=500)

View File

@@ -18,10 +18,9 @@ from routes.utils.watch.db import (
get_watched_artists,
add_specific_albums_to_artist_table,
remove_specific_albums_from_artist_table,
is_album_in_artist_db,
)
from routes.utils.watch.manager import check_watched_artists, get_watch_config
from routes.utils.get_info import get_spotify_info
from routes.utils.get_info import get_client, get_artist, get_album
# Import authentication dependencies
from routes.auth.middleware import require_auth_from_state, User
@@ -66,9 +65,6 @@ async def handle_artist_download(
)
try:
# Import and call the updated download_artist_albums() function.
# from routes.utils.artist import download_artist_albums # Already imported at top
# Delegate to the download_artist_albums function which will handle album filtering
successfully_queued_albums, duplicate_albums = download_artist_albums(
url=url,
@@ -118,13 +114,15 @@ async def cancel_artist_download():
@router.get("/info")
async def get_artist_info(
request: Request, current_user: User = Depends(require_auth_from_state),
limit: int = Query(10, ge=1), # default=10, must be >=1
offset: int = Query(0, ge=0) # default=0, must be >=0
request: Request,
current_user: User = Depends(require_auth_from_state),
limit: int = Query(10, ge=1), # default=10, must be >=1
offset: int = Query(0, ge=0), # default=0, must be >=0
):
"""
Retrieves Spotify artist metadata given a Spotify artist ID.
Expects a query parameter 'id' with the Spotify artist ID.
Returns the raw JSON from get_artist in routes.utils.get_info.
"""
spotify_id = request.query_params.get("id")
@@ -132,37 +130,8 @@ async def get_artist_info(
return JSONResponse(content={"error": "Missing parameter: id"}, status_code=400)
try:
# Get artist metadata first
artist_metadata = get_spotify_info(spotify_id, "artist")
# Get artist discography for albums
artist_discography = get_spotify_info(spotify_id, "artist_discography", limit=limit, offset=offset)
# Combine metadata with discography
artist_info = {**artist_metadata, "albums": artist_discography}
# If artist_info is successfully fetched and has albums,
# check if the artist is watched and augment album items with is_locally_known status
if (
artist_info
and artist_info.get("albums")
and artist_info["albums"].get("items")
):
watched_artist_details = get_watched_artist(
spotify_id
) # spotify_id is the artist ID
if watched_artist_details: # Artist is being watched
for album_item in artist_info["albums"]["items"]:
if album_item and album_item.get("id"):
album_id = album_item["id"]
album_item["is_locally_known"] = is_album_in_artist_db(
spotify_id, album_id
)
elif album_item: # Album object exists but no ID
album_item["is_locally_known"] = False
# If not watched, or no albums, is_locally_known will not be added.
# Frontend should handle absence of this key as false.
client = get_client()
artist_info = get_artist(client, spotify_id)
return JSONResponse(content=artist_info, status_code=200)
except Exception as e:
return JSONResponse(
@@ -191,15 +160,9 @@ async def add_artist_to_watchlist(
if get_watched_artist(artist_spotify_id):
return {"message": f"Artist {artist_spotify_id} is already being watched."}
# Get artist metadata directly for name and basic info
artist_metadata = get_spotify_info(artist_spotify_id, "artist")
client = get_client()
artist_metadata = get_artist(client, artist_spotify_id)
# Get artist discography for album count
artist_album_list_data = get_spotify_info(
artist_spotify_id, "artist_discography"
)
# Check if we got artist metadata
if not artist_metadata or not artist_metadata.get("name"):
logger.error(
f"Could not fetch artist metadata for {artist_spotify_id} from Spotify."
@@ -211,24 +174,22 @@ async def add_artist_to_watchlist(
},
)
# Check if we got album data
if not artist_album_list_data or not isinstance(
artist_album_list_data.get("items"), list
# Derive a rough total album count from groups if present
total_albums = 0
for key in (
"album_group",
"single_group",
"compilation_group",
"appears_on_group",
):
logger.warning(
f"Could not fetch album list details for artist {artist_spotify_id} from Spotify. Proceeding with metadata only."
)
grp = artist_metadata.get(key)
if isinstance(grp, list):
total_albums += len(grp)
# Construct the artist_data object expected by add_artist_db
artist_data_for_db = {
"id": artist_spotify_id,
"name": artist_metadata.get("name", "Unknown Artist"),
"albums": { # Mimic structure if add_artist_db expects it for total_albums
"total": artist_album_list_data.get("total", 0)
if artist_album_list_data
else 0
},
# Add any other fields add_artist_db might expect from a true artist object if necessary
"albums": {"total": total_albums},
}
add_artist_db(artist_data_for_db)
@@ -446,21 +407,25 @@ async def mark_albums_as_known_for_artist(
detail={"error": f"Artist {artist_spotify_id} is not being watched."},
)
client = get_client()
fetched_albums_details = []
for album_id in album_ids:
try:
# We need full album details. get_spotify_info with type "album" should provide this.
album_detail = get_spotify_info(album_id, "album")
if album_detail and album_detail.get("id"):
fetched_albums_details.append(album_detail)
else:
logger.warning(
f"Could not fetch details for album {album_id} when marking as known for artist {artist_spotify_id}."
try:
for album_id in album_ids:
try:
album_detail = get_album(client, album_id)
if album_detail and album_detail.get("id"):
fetched_albums_details.append(album_detail)
else:
logger.warning(
f"Could not fetch details for album {album_id} when marking as known for artist {artist_spotify_id}."
)
except Exception as e:
logger.error(
f"Failed to fetch Spotify details for album {album_id}: {e}"
)
except Exception as e:
logger.error(
f"Failed to fetch Spotify details for album {album_id}: {e}"
)
finally:
# No need to close_client here, as get_client is shared
pass
if not fetched_albums_details:
return {

View File

@@ -11,25 +11,42 @@ from routes.auth.middleware import require_auth_from_state, User
from routes.utils.get_info import get_spotify_info
from routes.utils.celery_queue_manager import download_queue_manager
# Assuming these imports are available for queue management and Spotify info
from routes.utils.get_info import (
get_client,
get_track,
get_album,
get_playlist,
get_artist,
)
router = APIRouter()
logger = logging.getLogger(__name__)
class BulkAddLinksRequest(BaseModel):
links: List[str]
@router.post("/bulk-add-spotify-links")
async def bulk_add_spotify_links(request: BulkAddLinksRequest, req: Request, current_user: User = Depends(require_auth_from_state)):
added_count = 0
failed_links = []
total_links = len(request.links)
client = get_client()
for link in request.links:
# Assuming links are pre-filtered by the frontend,
# but still handle potential errors during info retrieval or unsupported types
# Extract type and ID from the link directly using regex
match = re.match(r"https://open\.spotify\.com(?:/intl-[a-z]{2})?/(track|album|playlist|artist)/([a-zA-Z0-9]+)(?:\?.*)?", link)
match = re.match(
r"https://open\.spotify\.com(?:/intl-[a-z]{2})?/(track|album|playlist|artist)/([a-zA-Z0-9]+)(?:\?.*)?",
link,
)
if not match:
logger.warning(f"Could not parse Spotify link (unexpected format after frontend filter): {link}")
logger.warning(
f"Could not parse Spotify link (unexpected format after frontend filter): {link}"
)
failed_links.append(link)
continue
@@ -39,18 +56,30 @@ async def bulk_add_spotify_links(request: BulkAddLinksRequest, req: Request, cur
try:
# Get basic info to confirm existence and get name/artist
# For playlists, we might want to get full info later when adding to queue
if spotify_type == "playlist":
item_info = get_spotify_info(spotify_id, "playlist_metadata")
item_info = get_playlist(client, spotify_id, expand_items=False)
elif spotify_type == "track":
item_info = get_track(client, spotify_id)
elif spotify_type == "album":
item_info = get_album(client, spotify_id)
elif spotify_type == "artist":
# Not queued below, but fetch to validate link and name if needed
item_info = get_artist(client, spotify_id)
else:
item_info = get_spotify_info(spotify_id, spotify_type)
logger.warning(
f"Unsupported Spotify type: {spotify_type} for link: {link}"
)
failed_links.append(link)
continue
item_name = item_info.get("name", "Unknown Name")
artist_name = ""
if spotify_type in ["track", "album"]:
artists = item_info.get("artists", [])
if artists:
artist_name = ", ".join([a.get("name", "Unknown Artist") for a in artists])
artist_name = ", ".join(
[a.get("name", "Unknown Artist") for a in artists]
)
elif spotify_type == "playlist":
owner = item_info.get("owner", {})
artist_name = owner.get("display_name", "Unknown Owner")
@@ -77,8 +106,16 @@ async def bulk_add_spotify_links(request: BulkAddLinksRequest, req: Request, cur
added_count += 1
logger.debug(f"Added {added_count}/{total_links} {spotify_type} '{item_name}' ({spotify_id}) to queue with task_id: {task_id}.")
else:
logger.warning(f"Failed to add {spotify_type} '{item_name}' ({spotify_id}) to queue.")
logger.warning(
f"Unsupported Spotify type for download: {spotify_type} for link: {link}"
)
failed_links.append(link)
continue
added_count += 1
logger.debug(
f"Added {added_count + 1}/{total_links} {spotify_type} '{item_name}' ({spotify_id}) to queue."
)
except Exception as e:
logger.error(f"Error processing Spotify link {link}: {e}", exc_info=True)

View File

@@ -1,6 +1,5 @@
from fastapi import APIRouter, HTTPException, Request, Depends
from fastapi.responses import JSONResponse
import json
import traceback
import logging # Added logging import
import uuid # For generating error task IDs
@@ -20,10 +19,9 @@ from routes.utils.watch.db import (
get_watched_playlist,
get_watched_playlists,
add_specific_tracks_to_playlist_table,
remove_specific_tracks_from_playlist_table,
is_track_in_playlist_db, # Added import
remove_specific_tracks_from_playlist_table, # Added import
)
from routes.utils.get_info import get_spotify_info # Already used, but ensure it's here
from routes.utils.get_info import get_client, get_playlist, get_track
from routes.utils.watch.manager import (
check_watched_playlists,
get_watch_config,
@@ -31,7 +29,9 @@ from routes.utils.watch.manager import (
from routes.utils.errors import DuplicateDownloadError
# Import authentication dependencies
from routes.auth.middleware import require_auth_from_state, require_admin_from_state, User
from routes.auth.middleware import require_auth_from_state, User
from routes.utils.celery_config import get_config_params
from routes.utils.credentials import get_spotify_blob_path
logger = logging.getLogger(__name__) # Added logger initialization
router = APIRouter()
@@ -43,7 +43,11 @@ def construct_spotify_url(item_id: str, item_type: str = "track") -> str:
@router.get("/download/{playlist_id}")
async def handle_download(playlist_id: str, request: Request, current_user: User = Depends(require_auth_from_state)):
async def handle_download(
playlist_id: str,
request: Request,
current_user: User = Depends(require_auth_from_state),
):
# Retrieve essential parameters from the request.
# name = request.args.get('name') # Removed
# artist = request.args.get('artist') # Removed
@@ -51,11 +55,14 @@ async def handle_download(playlist_id: str, request: Request, current_user: User
# Construct the URL from playlist_id
url = construct_spotify_url(playlist_id, "playlist")
orig_params["original_url"] = str(request.url) # Update original_url to the constructed one
orig_params["original_url"] = str(
request.url
) # Update original_url to the constructed one
# Fetch metadata from Spotify using optimized function
try:
from routes.utils.get_info import get_playlist_metadata
playlist_info = get_playlist_metadata(playlist_id)
if (
not playlist_info
@@ -66,7 +73,7 @@ async def handle_download(playlist_id: str, request: Request, current_user: User
content={
"error": f"Could not retrieve metadata for playlist ID: {playlist_id}"
},
status_code=404
status_code=404,
)
name_from_spotify = playlist_info.get("name")
@@ -79,14 +86,13 @@ async def handle_download(playlist_id: str, request: Request, current_user: User
content={
"error": f"Failed to fetch metadata for playlist {playlist_id}: {str(e)}"
},
status_code=500
status_code=500,
)
# Validate required parameters
if not url: # This check might be redundant now but kept for safety
return JSONResponse(
content={"error": "Missing required parameter: url"},
status_code=400
content={"error": "Missing required parameter: url"}, status_code=400
)
try:
@@ -106,7 +112,7 @@ async def handle_download(playlist_id: str, request: Request, current_user: User
"error": "Duplicate download detected.",
"existing_task": e.existing_task,
},
status_code=409
status_code=409,
)
except Exception as e:
# Generic error handling for other issues during task submission
@@ -136,25 +142,23 @@ async def handle_download(playlist_id: str, request: Request, current_user: User
"error": f"Failed to queue playlist download: {str(e)}",
"task_id": error_task_id,
},
status_code=500
status_code=500,
)
return JSONResponse(
content={"task_id": task_id},
status_code=202
)
return JSONResponse(content={"task_id": task_id}, status_code=202)
@router.get("/download/cancel")
async def cancel_download(request: Request, current_user: User = Depends(require_auth_from_state)):
async def cancel_download(
request: Request, current_user: User = Depends(require_auth_from_state)
):
"""
Cancel a running playlist download process by its task id.
"""
task_id = request.query_params.get("task_id")
if not task_id:
return JSONResponse(
content={"error": "Missing task id (task_id) parameter"},
status_code=400
content={"error": "Missing task id (task_id) parameter"}, status_code=400
)
# Use the queue manager's cancellation method.
@@ -165,124 +169,94 @@ async def cancel_download(request: Request, current_user: User = Depends(require
@router.get("/info")
async def get_playlist_info(request: Request, current_user: User = Depends(require_auth_from_state)):
async def get_playlist_info(
request: Request, current_user: User = Depends(require_auth_from_state)
):
"""
Retrieve Spotify playlist metadata given a Spotify playlist ID.
Expects a query parameter 'id' that contains the Spotify playlist ID.
"""
spotify_id = request.query_params.get("id")
include_tracks = request.query_params.get("include_tracks", "false").lower() == "true"
if not spotify_id:
return JSONResponse(
content={"error": "Missing parameter: id"},
status_code=400
)
try:
# Use the optimized playlist info function
from routes.utils.get_info import get_playlist_info_optimized
playlist_info = get_playlist_info_optimized(spotify_id, include_tracks=include_tracks)
# If playlist_info is successfully fetched, check if it's watched
# and augment track items with is_locally_known status
if playlist_info and playlist_info.get("id"):
watched_playlist_details = get_watched_playlist(playlist_info["id"])
if watched_playlist_details: # Playlist is being watched
if playlist_info.get("tracks") and playlist_info["tracks"].get("items"):
for item in playlist_info["tracks"]["items"]:
if item and item.get("track") and item["track"].get("id"):
track_id = item["track"]["id"]
item["track"]["is_locally_known"] = is_track_in_playlist_db(
playlist_info["id"], track_id
)
elif item and item.get(
"track"
): # Track object exists but no ID
item["track"]["is_locally_known"] = False
# If not watched, or no tracks, is_locally_known will not be added, or tracks won't exist to add it to.
# Frontend should handle absence of this key as false.
return JSONResponse(
content=playlist_info, status_code=200
)
except Exception as e:
error_data = {"error": str(e), "traceback": traceback.format_exc()}
return JSONResponse(content=error_data, status_code=500)
@router.get("/metadata")
async def get_playlist_metadata(request: Request, current_user: User = Depends(require_auth_from_state)):
"""
Retrieve only Spotify playlist metadata (no tracks) to avoid rate limiting.
Expects a query parameter 'id' that contains the Spotify playlist ID.
Always returns the raw JSON from get_playlist with expand_items=False.
"""
spotify_id = request.query_params.get("id")
if not spotify_id:
return JSONResponse(
content={"error": "Missing parameter: id"},
status_code=400
)
return JSONResponse(content={"error": "Missing parameter: id"}, status_code=400)
try:
# Use the optimized playlist metadata function
from routes.utils.get_info import get_playlist_metadata
playlist_metadata = get_playlist_metadata(spotify_id)
# Resolve active account's credentials blob
cfg = get_config_params() or {}
active_account = cfg.get("spotify")
if not active_account:
return JSONResponse(
content={"error": "Active Spotify account not set in configuration."},
status_code=500,
)
blob_path = get_spotify_blob_path(active_account)
if not blob_path.exists():
return JSONResponse(
content={
"error": f"Spotify credentials blob not found for account '{active_account}'"
},
status_code=500,
)
return JSONResponse(
content=playlist_metadata, status_code=200
)
except Exception as e:
error_data = {"error": str(e), "traceback": traceback.format_exc()}
return JSONResponse(content=error_data, status_code=500)
client = get_client()
try:
playlist_info = get_playlist(client, spotify_id, expand_items=False)
finally:
pass
@router.get("/tracks")
async def get_playlist_tracks(request: Request, current_user: User = Depends(require_auth_from_state)):
"""
Retrieve playlist tracks with pagination support for progressive loading.
Expects query parameters: 'id' (playlist ID), 'limit' (optional), 'offset' (optional).
"""
spotify_id = request.query_params.get("id")
limit = int(request.query_params.get("limit", 50))
offset = int(request.query_params.get("offset", 0))
if not spotify_id:
return JSONResponse(
content={"error": "Missing parameter: id"},
status_code=400
)
try:
# Use the optimized playlist tracks function
from routes.utils.get_info import get_playlist_tracks
tracks_data = get_playlist_tracks(spotify_id, limit=limit, offset=offset)
return JSONResponse(
content=tracks_data, status_code=200
)
return JSONResponse(content=playlist_info, status_code=200)
except Exception as e:
error_data = {"error": str(e), "traceback": traceback.format_exc()}
return JSONResponse(content=error_data, status_code=500)
@router.put("/watch/{playlist_spotify_id}")
async def add_to_watchlist(playlist_spotify_id: str, current_user: User = Depends(require_auth_from_state)):
async def add_to_watchlist(
playlist_spotify_id: str, current_user: User = Depends(require_auth_from_state)
):
"""Adds a playlist to the watchlist."""
watch_config = get_watch_config()
if not watch_config.get("enabled", False):
raise HTTPException(status_code=403, detail={"error": "Watch feature is currently disabled globally."})
raise HTTPException(
status_code=403,
detail={"error": "Watch feature is currently disabled globally."},
)
logger.info(f"Attempting to add playlist {playlist_spotify_id} to watchlist.")
try:
# Check if already watched
if get_watched_playlist(playlist_spotify_id):
return {"message": f"Playlist {playlist_spotify_id} is already being watched."}
return {
"message": f"Playlist {playlist_spotify_id} is already being watched."
}
# Fetch playlist details from Spotify to populate our DB (metadata only)
cfg = get_config_params() or {}
active_account = cfg.get("spotify")
if not active_account:
raise HTTPException(
status_code=500,
detail={"error": "Active Spotify account not set in configuration."},
)
blob_path = get_spotify_blob_path(active_account)
if not blob_path.exists():
raise HTTPException(
status_code=500,
detail={
"error": f"Spotify credentials blob not found for account '{active_account}'"
},
)
client = get_client()
try:
playlist_data = get_playlist(
client, playlist_spotify_id, expand_items=False
)
finally:
pass
# Fetch playlist details from Spotify to populate our DB
from routes.utils.get_info import get_playlist_metadata
playlist_data = get_playlist_metadata(playlist_spotify_id)
if not playlist_data or "id" not in playlist_data:
logger.error(
f"Could not fetch details for playlist {playlist_spotify_id} from Spotify."
@@ -291,19 +265,11 @@ async def add_to_watchlist(playlist_spotify_id: str, current_user: User = Depend
status_code=404,
detail={
"error": f"Could not fetch details for playlist {playlist_spotify_id} from Spotify."
}
},
)
add_playlist_db(playlist_data) # This also creates the tracks table
# REMOVED: Do not add initial tracks directly to DB.
# The playlist watch manager will pick them up as new and queue downloads.
# Tracks will be added to DB only after successful download via Celery task callback.
# initial_track_items = playlist_data.get('tracks', {}).get('items', [])
# if initial_track_items:
# from routes.utils.watch.db import add_tracks_to_playlist_db # Keep local import for clarity
# add_tracks_to_playlist_db(playlist_spotify_id, initial_track_items)
logger.info(
f"Playlist {playlist_spotify_id} added to watchlist. Its tracks will be processed by the watch manager."
)
@@ -317,11 +283,16 @@ async def add_to_watchlist(playlist_spotify_id: str, current_user: User = Depend
f"Error adding playlist {playlist_spotify_id} to watchlist: {e}",
exc_info=True,
)
raise HTTPException(status_code=500, detail={"error": f"Could not add playlist to watchlist: {str(e)}"})
raise HTTPException(
status_code=500,
detail={"error": f"Could not add playlist to watchlist: {str(e)}"},
)
@router.get("/watch/{playlist_spotify_id}/status")
async def get_playlist_watch_status(playlist_spotify_id: str, current_user: User = Depends(require_auth_from_state)):
async def get_playlist_watch_status(
playlist_spotify_id: str, current_user: User = Depends(require_auth_from_state)
):
"""Checks if a specific playlist is being watched."""
logger.info(f"Checking watch status for playlist {playlist_spotify_id}.")
try:
@@ -337,22 +308,31 @@ async def get_playlist_watch_status(playlist_spotify_id: str, current_user: User
f"Error checking watch status for playlist {playlist_spotify_id}: {e}",
exc_info=True,
)
raise HTTPException(status_code=500, detail={"error": f"Could not check watch status: {str(e)}"})
raise HTTPException(
status_code=500, detail={"error": f"Could not check watch status: {str(e)}"}
)
@router.delete("/watch/{playlist_spotify_id}")
async def remove_from_watchlist(playlist_spotify_id: str, current_user: User = Depends(require_auth_from_state)):
async def remove_from_watchlist(
playlist_spotify_id: str, current_user: User = Depends(require_auth_from_state)
):
"""Removes a playlist from the watchlist."""
watch_config = get_watch_config()
if not watch_config.get("enabled", False):
raise HTTPException(status_code=403, detail={"error": "Watch feature is currently disabled globally."})
raise HTTPException(
status_code=403,
detail={"error": "Watch feature is currently disabled globally."},
)
logger.info(f"Attempting to remove playlist {playlist_spotify_id} from watchlist.")
try:
if not get_watched_playlist(playlist_spotify_id):
raise HTTPException(
status_code=404,
detail={"error": f"Playlist {playlist_spotify_id} not found in watchlist."}
detail={
"error": f"Playlist {playlist_spotify_id} not found in watchlist."
},
)
remove_playlist_db(playlist_spotify_id)
@@ -369,12 +349,16 @@ async def remove_from_watchlist(playlist_spotify_id: str, current_user: User = D
)
raise HTTPException(
status_code=500,
detail={"error": f"Could not remove playlist from watchlist: {str(e)}"}
detail={"error": f"Could not remove playlist from watchlist: {str(e)}"},
)
@router.post("/watch/{playlist_spotify_id}/tracks")
async def mark_tracks_as_known(playlist_spotify_id: str, request: Request, current_user: User = Depends(require_auth_from_state)):
async def mark_tracks_as_known(
playlist_spotify_id: str,
request: Request,
current_user: User = Depends(require_auth_from_state),
):
"""Fetches details for given track IDs and adds/updates them in the playlist's local DB table."""
watch_config = get_watch_config()
if not watch_config.get("enabled", False):
@@ -382,7 +366,7 @@ async def mark_tracks_as_known(playlist_spotify_id: str, request: Request, curre
status_code=403,
detail={
"error": "Watch feature is currently disabled globally. Cannot mark tracks."
}
},
)
logger.info(
@@ -397,19 +381,22 @@ async def mark_tracks_as_known(playlist_spotify_id: str, request: Request, curre
status_code=400,
detail={
"error": "Invalid request body. Expecting a JSON array of track Spotify IDs."
}
},
)
if not get_watched_playlist(playlist_spotify_id):
raise HTTPException(
status_code=404,
detail={"error": f"Playlist {playlist_spotify_id} is not being watched."}
detail={
"error": f"Playlist {playlist_spotify_id} is not being watched."
},
)
fetched_tracks_details = []
client = get_client()
for track_id in track_ids:
try:
track_detail = get_spotify_info(track_id, "track")
track_detail = get_track(client, track_id)
if track_detail and track_detail.get("id"):
fetched_tracks_details.append(track_detail)
else:
@@ -443,11 +430,18 @@ async def mark_tracks_as_known(playlist_spotify_id: str, request: Request, curre
f"Error marking tracks as known for playlist {playlist_spotify_id}: {e}",
exc_info=True,
)
raise HTTPException(status_code=500, detail={"error": f"Could not mark tracks as known: {str(e)}"})
raise HTTPException(
status_code=500,
detail={"error": f"Could not mark tracks as known: {str(e)}"},
)
@router.delete("/watch/{playlist_spotify_id}/tracks")
async def mark_tracks_as_missing_locally(playlist_spotify_id: str, request: Request, current_user: User = Depends(require_auth_from_state)):
async def mark_tracks_as_missing_locally(
playlist_spotify_id: str,
request: Request,
current_user: User = Depends(require_auth_from_state),
):
"""Removes specified tracks from the playlist's local DB table."""
watch_config = get_watch_config()
if not watch_config.get("enabled", False):
@@ -455,7 +449,7 @@ async def mark_tracks_as_missing_locally(playlist_spotify_id: str, request: Requ
status_code=403,
detail={
"error": "Watch feature is currently disabled globally. Cannot mark tracks."
}
},
)
logger.info(
@@ -470,13 +464,15 @@ async def mark_tracks_as_missing_locally(playlist_spotify_id: str, request: Requ
status_code=400,
detail={
"error": "Invalid request body. Expecting a JSON array of track Spotify IDs."
}
},
)
if not get_watched_playlist(playlist_spotify_id):
raise HTTPException(
status_code=404,
detail={"error": f"Playlist {playlist_spotify_id} is not being watched."}
detail={
"error": f"Playlist {playlist_spotify_id} is not being watched."
},
)
deleted_count = remove_specific_tracks_from_playlist_table(
@@ -495,22 +491,32 @@ async def mark_tracks_as_missing_locally(playlist_spotify_id: str, request: Requ
f"Error marking tracks as missing (deleting locally) for playlist {playlist_spotify_id}: {e}",
exc_info=True,
)
raise HTTPException(status_code=500, detail={"error": f"Could not mark tracks as missing: {str(e)}"})
raise HTTPException(
status_code=500,
detail={"error": f"Could not mark tracks as missing: {str(e)}"},
)
@router.get("/watch/list")
async def list_watched_playlists_endpoint(current_user: User = Depends(require_auth_from_state)):
async def list_watched_playlists_endpoint(
current_user: User = Depends(require_auth_from_state),
):
"""Lists all playlists currently in the watchlist."""
try:
playlists = get_watched_playlists()
return playlists
except Exception as e:
logger.error(f"Error listing watched playlists: {e}", exc_info=True)
raise HTTPException(status_code=500, detail={"error": f"Could not list watched playlists: {str(e)}"})
raise HTTPException(
status_code=500,
detail={"error": f"Could not list watched playlists: {str(e)}"},
)
@router.post("/watch/trigger_check")
async def trigger_playlist_check_endpoint(current_user: User = Depends(require_auth_from_state)):
async def trigger_playlist_check_endpoint(
current_user: User = Depends(require_auth_from_state),
):
"""Manually triggers the playlist checking mechanism for all watched playlists."""
watch_config = get_watch_config()
if not watch_config.get("enabled", False):
@@ -518,7 +524,7 @@ async def trigger_playlist_check_endpoint(current_user: User = Depends(require_a
status_code=403,
detail={
"error": "Watch feature is currently disabled globally. Cannot trigger check."
}
},
)
logger.info("Manual trigger for playlist check received for all playlists.")
@@ -535,12 +541,14 @@ async def trigger_playlist_check_endpoint(current_user: User = Depends(require_a
)
raise HTTPException(
status_code=500,
detail={"error": f"Could not trigger playlist check for all: {str(e)}"}
detail={"error": f"Could not trigger playlist check for all: {str(e)}"},
)
@router.post("/watch/trigger_check/{playlist_spotify_id}")
async def trigger_specific_playlist_check_endpoint(playlist_spotify_id: str, current_user: User = Depends(require_auth_from_state)):
async def trigger_specific_playlist_check_endpoint(
playlist_spotify_id: str, current_user: User = Depends(require_auth_from_state)
):
"""Manually triggers the playlist checking mechanism for a specific playlist."""
watch_config = get_watch_config()
if not watch_config.get("enabled", False):
@@ -548,7 +556,7 @@ async def trigger_specific_playlist_check_endpoint(playlist_spotify_id: str, cur
status_code=403,
detail={
"error": "Watch feature is currently disabled globally. Cannot trigger check."
}
},
)
logger.info(
@@ -565,7 +573,7 @@ async def trigger_specific_playlist_check_endpoint(playlist_spotify_id: str, cur
status_code=404,
detail={
"error": f"Playlist {playlist_spotify_id} is not in the watchlist. Add it first."
}
},
)
# Run check_watched_playlists with the specific ID
@@ -590,5 +598,5 @@ async def trigger_specific_playlist_check_endpoint(playlist_spotify_id: str, cur
status_code=500,
detail={
"error": f"Could not trigger playlist check for {playlist_spotify_id}: {str(e)}"
}
},
)

View File

@@ -1,12 +1,11 @@
from fastapi import APIRouter, HTTPException, Request, Depends
from fastapi import APIRouter, Request, Depends
from fastapi.responses import JSONResponse
import json
import traceback
import uuid
import time
from routes.utils.celery_queue_manager import download_queue_manager
from routes.utils.celery_tasks import store_task_info, store_task_status, ProgressState
from routes.utils.get_info import get_spotify_info
from routes.utils.get_info import get_client, get_track
from routes.utils.errors import DuplicateDownloadError
# Import authentication dependencies
@@ -21,7 +20,11 @@ def construct_spotify_url(item_id: str, item_type: str = "track") -> str:
@router.get("/download/{track_id}")
async def handle_download(track_id: str, request: Request, current_user: User = Depends(require_auth_from_state)):
async def handle_download(
track_id: str,
request: Request,
current_user: User = Depends(require_auth_from_state),
):
# Retrieve essential parameters from the request.
# name = request.args.get('name') # Removed
# artist = request.args.get('artist') # Removed
@@ -31,15 +34,18 @@ async def handle_download(track_id: str, request: Request, current_user: User =
# Fetch metadata from Spotify
try:
track_info = get_spotify_info(track_id, "track")
client = get_client()
track_info = get_track(client, track_id)
if (
not track_info
or not track_info.get("name")
or not track_info.get("artists")
):
return JSONResponse(
content={"error": f"Could not retrieve metadata for track ID: {track_id}"},
status_code=404
content={
"error": f"Could not retrieve metadata for track ID: {track_id}"
},
status_code=404,
)
name_from_spotify = track_info.get("name")
@@ -51,15 +57,16 @@ async def handle_download(track_id: str, request: Request, current_user: User =
except Exception as e:
return JSONResponse(
content={"error": f"Failed to fetch metadata for track {track_id}: {str(e)}"},
status_code=500
content={
"error": f"Failed to fetch metadata for track {track_id}: {str(e)}"
},
status_code=500,
)
# Validate required parameters
if not url:
return JSONResponse(
content={"error": "Missing required parameter: url"},
status_code=400
content={"error": "Missing required parameter: url"}, status_code=400
)
# Add the task to the queue with only essential parameters
@@ -84,7 +91,7 @@ async def handle_download(track_id: str, request: Request, current_user: User =
"error": "Duplicate download detected.",
"existing_task": e.existing_task,
},
status_code=409
status_code=409,
)
except Exception as e:
# Generic error handling for other issues during task submission
@@ -116,25 +123,23 @@ async def handle_download(track_id: str, request: Request, current_user: User =
"error": f"Failed to queue track download: {str(e)}",
"task_id": error_task_id,
},
status_code=500
status_code=500,
)
return JSONResponse(
content={"task_id": task_id},
status_code=202
)
return JSONResponse(content={"task_id": task_id}, status_code=202)
@router.get("/download/cancel")
async def cancel_download(request: Request, current_user: User = Depends(require_auth_from_state)):
async def cancel_download(
request: Request, current_user: User = Depends(require_auth_from_state)
):
"""
Cancel a running download process by its task id.
"""
task_id = request.query_params.get("task_id")
if not task_id:
return JSONResponse(
content={"error": "Missing process id (task_id) parameter"},
status_code=400
content={"error": "Missing process id (task_id) parameter"}, status_code=400
)
# Use the queue manager's cancellation method.
@@ -145,7 +150,9 @@ async def cancel_download(request: Request, current_user: User = Depends(require
@router.get("/info")
async def get_track_info(request: Request, current_user: User = Depends(require_auth_from_state)):
async def get_track_info(
request: Request, current_user: User = Depends(require_auth_from_state)
):
"""
Retrieve Spotify track metadata given a Spotify track ID.
Expects a query parameter 'id' that contains the Spotify track ID.
@@ -153,14 +160,11 @@ async def get_track_info(request: Request, current_user: User = Depends(require_
spotify_id = request.query_params.get("id")
if not spotify_id:
return JSONResponse(
content={"error": "Missing parameter: id"},
status_code=400
)
return JSONResponse(content={"error": "Missing parameter: id"}, status_code=400)
try:
# Use the get_spotify_info function (already imported at top)
track_info = get_spotify_info(spotify_id, "track")
client = get_client()
track_info = get_track(client, spotify_id)
return JSONResponse(content=track_info, status_code=200)
except Exception as e:
error_data = {"error": str(e), "traceback": traceback.format_exc()}

View File

@@ -82,7 +82,7 @@ class SSEBroadcaster:
# Clean up disconnected clients
for client in disconnected:
self.clients.discard(client)
logger.info(
logger.debug(
f"SSE Broadcaster: Successfully sent to {sent_count} clients, removed {len(disconnected)} disconnected clients"
)

View File

@@ -2,9 +2,9 @@ import json
from routes.utils.watch.manager import get_watch_config
import logging
from routes.utils.celery_queue_manager import download_queue_manager
from routes.utils.get_info import get_spotify_info
from routes.utils.credentials import get_credential, _get_global_spotify_api_creds
from routes.utils.errors import DuplicateDownloadError
from routes.utils.get_info import get_spotify_info
from deezspot.libutils.utils import get_ids, link_is_valid

View File

@@ -5,6 +5,9 @@ import threading
import os
import sys
from dotenv import load_dotenv
load_dotenv()
# Import Celery task utilities
from .celery_config import get_config_params, MAX_CONCURRENT_DL
@@ -70,6 +73,12 @@ class CeleryManager:
logger.debug(f"Generated Celery command: {' '.join(command)}")
return command
def _get_worker_env(self):
# Inherit current environment, but set NO_CONSOLE_LOG=1 for subprocess
env = os.environ.copy()
env["NO_CONSOLE_LOG"] = "1"
return env
def _process_output_reader(self, stream, log_prefix, error=False):
logger.debug(f"Log reader thread started for {log_prefix}")
try:
@@ -138,6 +147,7 @@ class CeleryManager:
text=True,
bufsize=1,
universal_newlines=True,
env=self._get_worker_env(),
)
self.download_log_thread_stdout = threading.Thread(
target=self._process_output_reader,
@@ -161,7 +171,7 @@ class CeleryManager:
queues="utility_tasks,default", # Listen to utility and default
concurrency=5, # Increased concurrency for SSE updates and utility tasks
worker_name_suffix="utw", # Utility Worker
log_level_env=os.getenv("LOG_LEVEL", "ERROR").upper(),
log_level_env=os.getenv("LOG_LEVEL", "WARNING").upper(),
)
logger.info(
@@ -174,6 +184,7 @@ class CeleryManager:
text=True,
bufsize=1,
universal_newlines=True,
env=self._get_worker_env(),
)
self.utility_log_thread_stdout = threading.Thread(
target=self._process_output_reader,

View File

@@ -285,9 +285,16 @@ def setup_celery_logging(**kwargs):
"""
This handler ensures Celery uses our application logging settings
instead of its own. Prevents duplicate log configurations.
Also disables console logging if NO_CONSOLE_LOG=1 is set in the environment.
"""
# Using the root logger's handlers and level preserves our config
return logging.getLogger()
root_logger = logging.getLogger()
import os
if os.environ.get("NO_CONSOLE_LOG") == "1":
# Remove all StreamHandlers (console handlers) from the root logger
handlers_to_remove = [h for h in root_logger.handlers if isinstance(h, logging.StreamHandler)]
for h in handlers_to_remove:
root_logger.removeHandler(h)
return root_logger
# The initialization of a worker will log the worker configuration

View File

@@ -1,422 +1,152 @@
import spotipy
from spotipy.oauth2 import SpotifyClientCredentials
from routes.utils.credentials import _get_global_spotify_api_creds
import logging
import time
from typing import Dict, Optional, Any
import os
from typing import Any, Dict, Optional
import threading
# Import Deezer API and logging
from deezspot.deezloader.dee_api import API as DeezerAPI
from deezspot.libutils import LibrespotClient
# Initialize logger
logger = logging.getLogger(__name__)
# Global Spotify client instance for reuse
_spotify_client = None
_last_client_init = 0
_client_init_interval = 3600 # Reinitialize client every hour
# Config helpers to resolve active credentials
from routes.utils.celery_config import get_config_params
from routes.utils.credentials import get_spotify_blob_path
def _get_spotify_client():
"""
Get or create a Spotify client with global credentials.
Implements client reuse and periodic reinitialization.
"""
global _spotify_client, _last_client_init
# -------- Shared Librespot client (process-wide) --------
current_time = time.time()
_shared_client: Optional[LibrespotClient] = None
_shared_blob_path: Optional[str] = None
_client_lock = threading.RLock()
# Reinitialize client if it's been more than an hour or if client doesn't exist
if (
_spotify_client is None
or current_time - _last_client_init > _client_init_interval
):
client_id, client_secret = _get_global_spotify_api_creds()
if not client_id or not client_secret:
raise ValueError(
"Global Spotify API client_id or client_secret not configured in ./data/creds/search.json."
)
# Create new client
_spotify_client = spotipy.Spotify(
client_credentials_manager=SpotifyClientCredentials(
client_id=client_id, client_secret=client_secret
)
def _resolve_blob_path() -> str:
cfg = get_config_params() or {}
active_account = cfg.get("spotify")
if not active_account:
raise RuntimeError("Active Spotify account not set in configuration.")
blob_path = get_spotify_blob_path(active_account)
abs_path = os.path.abspath(str(blob_path))
if not os.path.isfile(abs_path):
raise FileNotFoundError(
f"Spotify credentials blob not found for account '{active_account}' at {abs_path}"
)
_last_client_init = current_time
logger.info("Spotify client initialized/reinitialized")
return _spotify_client
return abs_path
def _rate_limit_handler(func):
def get_client() -> LibrespotClient:
"""
Decorator to handle rate limiting with exponential backoff.
Return a shared LibrespotClient instance initialized from the active account blob.
Re-initializes if the active account changes.
"""
def wrapper(*args, **kwargs):
max_retries = 3
base_delay = 1
for attempt in range(max_retries):
global _shared_client, _shared_blob_path
with _client_lock:
desired_blob = _resolve_blob_path()
if _shared_client is None or _shared_blob_path != desired_blob:
try:
return func(*args, **kwargs)
except Exception as e:
if "429" in str(e) or "rate limit" in str(e).lower():
if attempt < max_retries - 1:
delay = base_delay * (2**attempt)
logger.warning(f"Rate limited, retrying in {delay} seconds...")
time.sleep(delay)
continue
raise e
return func(*args, **kwargs)
return wrapper
if _shared_client is not None:
_shared_client.close()
except Exception:
pass
_shared_client = LibrespotClient(stored_credentials_path=desired_blob)
_shared_blob_path = desired_blob
return _shared_client
@_rate_limit_handler
def get_playlist_metadata(playlist_id: str) -> Dict[str, Any]:
# -------- Thin wrapper API (programmatic use) --------
def create_client(credentials_path: str) -> LibrespotClient:
"""
Get playlist metadata only (no tracks) to avoid rate limiting.
Args:
playlist_id: The Spotify playlist ID
Returns:
Dictionary with playlist metadata (name, description, owner, etc.)
Create a LibrespotClient from a librespot-generated credentials.json file.
"""
client = _get_spotify_client()
try:
# Get basic playlist info without tracks
playlist = client.playlist(
playlist_id,
fields="id,name,description,owner,images,snapshot_id,public,followers,tracks.total",
)
# Add a flag to indicate this is metadata only
playlist["_metadata_only"] = True
playlist["_tracks_loaded"] = False
logger.debug(
f"Retrieved playlist metadata for {playlist_id}: {playlist.get('name', 'Unknown')}"
)
return playlist
except Exception as e:
logger.error(f"Error fetching playlist metadata for {playlist_id}: {e}")
raise
abs_path = os.path.abspath(credentials_path)
if not os.path.isfile(abs_path):
raise FileNotFoundError(f"Credentials file not found: {abs_path}")
return LibrespotClient(stored_credentials_path=abs_path)
@_rate_limit_handler
def get_playlist_tracks(
playlist_id: str, limit: int = 100, offset: int = 0
def close_client(client: LibrespotClient) -> None:
"""
Dispose a LibrespotClient instance.
"""
client.close()
def get_track(client: LibrespotClient, track_in: str) -> Dict[str, Any]:
"""Fetch a track object."""
return client.get_track(track_in)
def get_album(
client: LibrespotClient, album_in: str, include_tracks: bool = False
) -> Dict[str, Any]:
"""
Get playlist tracks with pagination support to handle large playlists efficiently.
Args:
playlist_id: The Spotify playlist ID
limit: Number of tracks to fetch per request (max 100)
offset: Starting position for pagination
Returns:
Dictionary with tracks data
"""
client = _get_spotify_client()
try:
# Get tracks with specified limit and offset
tracks_data = client.playlist_tracks(
playlist_id,
limit=min(limit, 100), # Spotify API max is 100
offset=offset,
fields="items(track(id,name,artists,album,external_urls,preview_url,duration_ms,explicit,popularity)),total,limit,offset",
)
logger.debug(
f"Retrieved {len(tracks_data.get('items', []))} tracks for playlist {playlist_id} (offset: {offset})"
)
return tracks_data
except Exception as e:
logger.error(f"Error fetching playlist tracks for {playlist_id}: {e}")
raise
"""Fetch an album object; optionally include expanded tracks."""
return client.get_album(album_in, include_tracks=include_tracks)
@_rate_limit_handler
def get_playlist_full(playlist_id: str, batch_size: int = 100) -> Dict[str, Any]:
"""
Get complete playlist data with all tracks, using batched requests to avoid rate limiting.
Args:
playlist_id: The Spotify playlist ID
batch_size: Number of tracks to fetch per batch (max 100)
Returns:
Complete playlist data with all tracks
"""
try:
# First get metadata
playlist = get_playlist_metadata(playlist_id)
# Get total track count
total_tracks = playlist.get("tracks", {}).get("total", 0)
if total_tracks == 0:
playlist["tracks"] = {"items": [], "total": 0}
return playlist
# Fetch all tracks in batches
all_tracks = []
offset = 0
while offset < total_tracks:
batch = get_playlist_tracks(playlist_id, limit=batch_size, offset=offset)
batch_items = batch.get("items", [])
all_tracks.extend(batch_items)
offset += len(batch_items)
# Add small delay between batches to be respectful to API
if offset < total_tracks:
time.sleep(0.1)
# Update playlist with complete tracks data
playlist["tracks"] = {
"items": all_tracks,
"total": total_tracks,
"limit": batch_size,
"offset": 0,
}
playlist["_metadata_only"] = False
playlist["_tracks_loaded"] = True
logger.info(
f"Retrieved complete playlist {playlist_id} with {total_tracks} tracks"
)
return playlist
except Exception as e:
logger.error(f"Error fetching complete playlist {playlist_id}: {e}")
raise
def get_artist(client: LibrespotClient, artist_in: str) -> Dict[str, Any]:
"""Fetch an artist object."""
return client.get_artist(artist_in)
def check_playlist_updated(playlist_id: str, last_snapshot_id: str) -> bool:
"""
Check if playlist has been updated by comparing snapshot_id.
This is much more efficient than fetching all tracks.
Args:
playlist_id: The Spotify playlist ID
last_snapshot_id: The last known snapshot_id
Returns:
True if playlist has been updated, False otherwise
"""
try:
metadata = get_playlist_metadata(playlist_id)
current_snapshot_id = metadata.get("snapshot_id")
return current_snapshot_id != last_snapshot_id
except Exception as e:
logger.error(f"Error checking playlist update status for {playlist_id}: {e}")
raise
def get_playlist(
client: LibrespotClient, playlist_in: str, expand_items: bool = False
) -> Dict[str, Any]:
"""Fetch a playlist object; optionally expand track items to full track objects."""
return client.get_playlist(playlist_in, expand_items=expand_items)
@_rate_limit_handler
def get_spotify_info(
spotify_id: str,
spotify_type: str,
limit: Optional[int] = None,
offset: Optional[int] = None,
info_type: str,
limit: int = 50,
offset: int = 0,
) -> Dict[str, Any]:
"""
Get info from Spotify API using Spotipy directly.
Optimized to prevent rate limiting by using appropriate endpoints.
Thin, typed wrapper around common Spotify info lookups using the shared client.
Args:
spotify_id: The Spotify ID of the entity
spotify_type: The type of entity (track, album, playlist, artist, artist_discography, episode, album_tracks)
limit (int, optional): The maximum number of items to return. Used for pagination.
offset (int, optional): The index of the first item to return. Used for pagination.
Currently supports:
- "artist_discography": returns a paginated view over the artist's releases
combined across album_group/single_group/compilation_group/appears_on_group.
Returns:
Dictionary with the entity information
Returns a mapping with at least: items, total, limit, offset.
Also includes a truthy "next" key when more pages are available.
"""
client = _get_spotify_client()
client = get_client()
try:
if spotify_type == "track":
return client.track(spotify_id)
if info_type == "artist_discography":
artist = client.get_artist(spotify_id)
all_items = []
for key in (
"album_group",
"single_group",
"compilation_group",
"appears_on_group",
):
grp = artist.get(key)
if isinstance(grp, list):
all_items.extend(grp)
elif isinstance(grp, dict):
items = grp.get("items") or grp.get("releases") or []
if isinstance(items, list):
all_items.extend(items)
total = len(all_items)
start = max(0, offset or 0)
page_limit = max(1, limit or 50)
end = min(total, start + page_limit)
page_items = all_items[start:end]
has_more = end < total
return {
"items": page_items,
"total": total,
"limit": page_limit,
"offset": start,
"next": bool(has_more),
}
elif spotify_type == "album":
return client.album(spotify_id)
elif spotify_type == "album_tracks":
# Fetch album's tracks with pagination support
return client.album_tracks(
spotify_id, limit=limit or 20, offset=offset or 0
)
elif spotify_type == "playlist":
# Use optimized playlist fetching
return get_playlist_full(spotify_id)
elif spotify_type == "playlist_metadata":
# Get only metadata for playlists
return get_playlist_metadata(spotify_id)
elif spotify_type == "artist":
return client.artist(spotify_id)
elif spotify_type == "artist_discography":
# Get artist's albums with pagination
albums = client.artist_albums(
spotify_id,
limit=limit or 20,
offset=offset or 0,
include_groups="single,album,appears_on",
)
return albums
elif spotify_type == "episode":
return client.episode(spotify_id)
else:
raise ValueError(f"Unsupported Spotify type: {spotify_type}")
except Exception as e:
logger.error(f"Error fetching {spotify_type} {spotify_id}: {e}")
raise
raise ValueError(f"Unsupported info_type: {info_type}")
# Cache for playlist metadata to reduce API calls
_playlist_metadata_cache: Dict[str, tuple[Dict[str, Any], float]] = {}
_cache_ttl = 300 # 5 minutes cache
def get_cached_playlist_metadata(playlist_id: str) -> Optional[Dict[str, Any]]:
def get_playlist_metadata(playlist_id: str) -> Dict[str, Any]:
"""
Get playlist metadata from cache if available and not expired.
Args:
playlist_id: The Spotify playlist ID
Returns:
Cached metadata or None if not available/expired
Fetch playlist metadata using the shared client without expanding items.
"""
if playlist_id in _playlist_metadata_cache:
cached_data, timestamp = _playlist_metadata_cache[playlist_id]
if time.time() - timestamp < _cache_ttl:
return cached_data
return None
def cache_playlist_metadata(playlist_id: str, metadata: Dict[str, Any]):
"""
Cache playlist metadata with timestamp.
Args:
playlist_id: The Spotify playlist ID
metadata: The metadata to cache
"""
_playlist_metadata_cache[playlist_id] = (metadata, time.time())
def get_playlist_info_optimized(
playlist_id: str, include_tracks: bool = False
) -> Dict[str, Any]:
"""
Optimized playlist info function that uses caching and selective loading.
Args:
playlist_id: The Spotify playlist ID
include_tracks: Whether to include track data (default: False to save API calls)
Returns:
Playlist data with or without tracks
"""
# Check cache first
cached_metadata = get_cached_playlist_metadata(playlist_id)
if cached_metadata and not include_tracks:
logger.debug(f"Returning cached metadata for playlist {playlist_id}")
return cached_metadata
if include_tracks:
# Get complete playlist data
playlist_data = get_playlist_full(playlist_id)
# Cache the metadata portion
metadata_only = {k: v for k, v in playlist_data.items() if k != "tracks"}
metadata_only["_metadata_only"] = True
metadata_only["_tracks_loaded"] = False
cache_playlist_metadata(playlist_id, metadata_only)
return playlist_data
else:
# Get metadata only
metadata = get_playlist_metadata(playlist_id)
cache_playlist_metadata(playlist_id, metadata)
return metadata
# Keep the existing Deezer functions unchanged
def get_deezer_info(deezer_id, deezer_type, limit=None):
"""
Get info from Deezer API.
Args:
deezer_id: The Deezer ID of the entity.
deezer_type: The type of entity (track, album, playlist, artist, episode,
artist_top_tracks, artist_albums, artist_related,
artist_radio, artist_playlists).
limit (int, optional): The maximum number of items to return. Used for
artist_top_tracks, artist_albums, artist_playlists.
Deezer API methods usually have their own defaults (e.g., 25)
if limit is not provided or None is passed to them.
Returns:
Dictionary with the entity information.
Raises:
ValueError: If deezer_type is unsupported.
Various exceptions from DeezerAPI (NoDataApi, QuotaExceeded, requests.exceptions.RequestException, etc.)
"""
logger.debug(
f"Fetching Deezer info for ID {deezer_id}, type {deezer_type}, limit {limit}"
)
# DeezerAPI uses class methods; its @classmethod __init__ handles setup.
# No specific ARL or account handling here as DeezerAPI seems to use general endpoints.
if deezer_type == "track":
return DeezerAPI.get_track(deezer_id)
elif deezer_type == "album":
return DeezerAPI.get_album(deezer_id)
elif deezer_type == "playlist":
return DeezerAPI.get_playlist(deezer_id)
elif deezer_type == "artist":
return DeezerAPI.get_artist(deezer_id)
elif deezer_type == "episode":
return DeezerAPI.get_episode(deezer_id)
elif deezer_type == "artist_top_tracks":
if limit is not None:
return DeezerAPI.get_artist_top_tracks(deezer_id, limit=limit)
return DeezerAPI.get_artist_top_tracks(deezer_id) # Use API default limit
elif deezer_type == "artist_albums": # Maps to get_artist_top_albums
if limit is not None:
return DeezerAPI.get_artist_top_albums(deezer_id, limit=limit)
return DeezerAPI.get_artist_top_albums(deezer_id) # Use API default limit
elif deezer_type == "artist_related":
return DeezerAPI.get_artist_related(deezer_id)
elif deezer_type == "artist_radio":
return DeezerAPI.get_artist_radio(deezer_id)
elif deezer_type == "artist_playlists":
if limit is not None:
return DeezerAPI.get_artist_top_playlists(deezer_id, limit=limit)
return DeezerAPI.get_artist_top_playlists(deezer_id) # Use API default limit
else:
logger.error(f"Unsupported Deezer type: {deezer_type}")
raise ValueError(f"Unsupported Deezer type: {deezer_type}")
client = get_client()
return get_playlist(client, playlist_id, expand_items=False)

View File

@@ -27,15 +27,9 @@ from routes.utils.watch.db import (
get_artist_batch_next_offset,
set_artist_batch_next_offset,
)
from routes.utils.get_info import (
get_spotify_info,
get_playlist_metadata,
get_playlist_tracks,
) # To fetch playlist, track, artist, and album details
from routes.utils.celery_queue_manager import download_queue_manager
# Added import to fetch base formatting config
from routes.utils.celery_queue_manager import get_config_params
from routes.utils.celery_queue_manager import download_queue_manager, get_config_params
from routes.utils.get_info import get_client
logger = logging.getLogger(__name__)
MAIN_CONFIG_FILE_PATH = Path("./data/config/main.json")
@@ -358,7 +352,7 @@ def find_tracks_in_playlist(
while not_found_tracks and offset < 10000: # Safety limit
try:
tracks_batch = get_playlist_tracks(
tracks_batch = _fetch_playlist_tracks_page(
playlist_spotify_id, limit=limit, offset=offset
)
@@ -459,7 +453,9 @@ def check_watched_playlists(specific_playlist_id: str = None):
ensure_playlist_table_schema(playlist_spotify_id)
# First, get playlist metadata to check if it has changed
current_playlist_metadata = get_playlist_metadata(playlist_spotify_id)
current_playlist_metadata = _fetch_playlist_metadata(
playlist_spotify_id
)
if not current_playlist_metadata:
logger.error(
f"Playlist Watch Manager: Failed to fetch metadata from Spotify for playlist {playlist_spotify_id}."
@@ -507,7 +503,7 @@ def check_watched_playlists(specific_playlist_id: str = None):
progress_offset, _ = get_playlist_batch_progress(
playlist_spotify_id
)
tracks_batch = get_playlist_tracks(
tracks_batch = _fetch_playlist_tracks_page(
playlist_spotify_id,
limit=batch_limit,
offset=progress_offset,
@@ -573,7 +569,7 @@ def check_watched_playlists(specific_playlist_id: str = None):
logger.info(
f"Playlist Watch Manager: Fetching one batch (limit={batch_limit}, offset={progress_offset}) for playlist '{playlist_name}'."
)
tracks_batch = get_playlist_tracks(
tracks_batch = _fetch_playlist_tracks_page(
playlist_spotify_id, limit=batch_limit, offset=progress_offset
)
batch_items = tracks_batch.get("items", []) if tracks_batch else []
@@ -734,8 +730,8 @@ def check_watched_artists(specific_artist_id: str = None):
logger.debug(
f"Artist Watch Manager: Fetching albums for {artist_spotify_id}. Limit: {limit}, Offset: {offset}"
)
artist_albums_page = get_spotify_info(
artist_spotify_id, "artist_discography", limit=limit, offset=offset
artist_albums_page = _fetch_artist_discography_page(
artist_spotify_id, limit=limit, offset=offset
)
current_page_albums = (
@@ -911,7 +907,8 @@ def run_playlist_check_over_intervals(playlist_spotify_id: str) -> None:
# Determine if we are done: no active processing snapshot and no pending sync
cfg = get_watch_config()
interval = cfg.get("watchPollIntervalSeconds", 3600)
metadata = get_playlist_metadata(playlist_spotify_id)
# Use local helper that leverages Librespot client
metadata = _fetch_playlist_metadata(playlist_spotify_id)
if not metadata:
logger.warning(
f"Manual Playlist Runner: Could not load metadata for {playlist_spotify_id}. Stopping."
@@ -1167,3 +1164,84 @@ def update_playlist_m3u_file(playlist_spotify_id: str):
f"Error updating m3u file for playlist {playlist_spotify_id}: {e}",
exc_info=True,
)
# Helper to build a Librespot client from active account
def _build_librespot_client():
try:
# Reuse shared client managed in routes.utils.get_info
return get_client()
except Exception as e:
raise RuntimeError(f"Failed to initialize Librespot client: {e}")
def _fetch_playlist_metadata(playlist_id: str) -> dict:
client = _build_librespot_client()
return client.get_playlist(playlist_id, expand_items=False)
def _fetch_playlist_tracks_page(playlist_id: str, limit: int, offset: int) -> dict:
client = _build_librespot_client()
# Fetch playlist with minimal items to avoid expanding all tracks unnecessarily
pl = client.get_playlist(playlist_id, expand_items=False)
items = (pl.get("tracks", {}) or {}).get("items", [])
total = (pl.get("tracks", {}) or {}).get("total", len(items))
start = max(0, offset or 0)
end = start + max(1, limit or 50)
page_items_minimal = items[start:end]
# Expand only the tracks in this page using client cache for efficiency
page_items_expanded = []
for item in page_items_minimal:
track_stub = (item or {}).get("track") or {}
track_id = track_stub.get("id")
expanded_track = None
if track_id:
try:
expanded_track = client.get_track(track_id)
except Exception:
expanded_track = None
if expanded_track is None:
# Keep stub as fallback; ensure structure
expanded_track = {
k: v
for k, v in track_stub.items()
if k in ("id", "uri", "type", "external_urls")
}
# Propagate local flag onto track for downstream checks
if item and isinstance(item, dict) and item.get("is_local"):
expanded_track["is_local"] = True
# Rebuild item with expanded track
new_item = dict(item)
new_item["track"] = expanded_track
page_items_expanded.append(new_item)
return {
"items": page_items_expanded,
"total": total,
"limit": end - start,
"offset": start,
}
def _fetch_artist_discography_page(artist_id: str, limit: int, offset: int) -> dict:
# LibrespotClient.get_artist returns a pruned mapping; flatten common discography groups
client = _build_librespot_client()
artist = client.get_artist(artist_id)
all_items = []
# Collect from known groups; also support nested structures if present
for key in ("album_group", "single_group", "compilation_group", "appears_on_group"):
grp = artist.get(key)
if isinstance(grp, list):
all_items.extend(grp)
elif isinstance(grp, dict):
items = grp.get("items") or grp.get("releases") or []
if isinstance(items, list):
all_items.extend(items)
total = len(all_items)
start = max(0, offset or 0)
end = start + max(1, limit or 50)
page_items = all_items[start:end]
return {"items": page_items, "total": total, "limit": limit, "offset": start}