Merge branch 'dev' into fixup-bulk-add-celery
This commit is contained in:
@@ -1,4 +1,4 @@
|
||||
from fastapi import APIRouter, HTTPException, Depends, Request
|
||||
from fastapi import APIRouter, HTTPException, Depends
|
||||
from fastapi.security import HTTPBearer, HTTPAuthorizationCredentials
|
||||
from pydantic import BaseModel
|
||||
from typing import Optional, List
|
||||
@@ -14,6 +14,7 @@ security = HTTPBearer(auto_error=False)
|
||||
# Include SSO sub-router
|
||||
try:
|
||||
from .sso import router as sso_router
|
||||
|
||||
router.include_router(sso_router, tags=["sso"])
|
||||
logging.info("SSO sub-router included in auth router")
|
||||
except ImportError as e:
|
||||
@@ -34,6 +35,7 @@ class RegisterRequest(BaseModel):
|
||||
|
||||
class CreateUserRequest(BaseModel):
|
||||
"""Admin-only request to create users when registration is disabled"""
|
||||
|
||||
username: str
|
||||
password: str
|
||||
email: Optional[str] = None
|
||||
@@ -42,17 +44,20 @@ class CreateUserRequest(BaseModel):
|
||||
|
||||
class RoleUpdateRequest(BaseModel):
|
||||
"""Request to update user role"""
|
||||
|
||||
role: str
|
||||
|
||||
|
||||
class PasswordChangeRequest(BaseModel):
|
||||
"""Request to change user password"""
|
||||
|
||||
current_password: str
|
||||
new_password: str
|
||||
|
||||
|
||||
class AdminPasswordResetRequest(BaseModel):
|
||||
"""Request for admin to reset user password"""
|
||||
|
||||
new_password: str
|
||||
|
||||
|
||||
@@ -87,20 +92,20 @@ class AuthStatusResponse(BaseModel):
|
||||
|
||||
# Dependency to get current user
|
||||
async def get_current_user(
|
||||
credentials: HTTPAuthorizationCredentials = Depends(security)
|
||||
credentials: HTTPAuthorizationCredentials = Depends(security),
|
||||
) -> Optional[User]:
|
||||
"""Get current user from JWT token"""
|
||||
if not AUTH_ENABLED:
|
||||
# When auth is disabled, return a mock admin user
|
||||
return User(username="system", role="admin")
|
||||
|
||||
|
||||
if not credentials:
|
||||
return None
|
||||
|
||||
|
||||
payload = token_manager.verify_token(credentials.credentials)
|
||||
if not payload:
|
||||
return None
|
||||
|
||||
|
||||
user = user_manager.get_user(payload["username"])
|
||||
return user
|
||||
|
||||
@@ -109,25 +114,22 @@ async def require_auth(current_user: User = Depends(get_current_user)) -> User:
|
||||
"""Require authentication - raises HTTPException if not authenticated"""
|
||||
if not AUTH_ENABLED:
|
||||
return User(username="system", role="admin")
|
||||
|
||||
|
||||
if not current_user:
|
||||
raise HTTPException(
|
||||
status_code=401,
|
||||
detail="Authentication required",
|
||||
headers={"WWW-Authenticate": "Bearer"},
|
||||
)
|
||||
|
||||
|
||||
return current_user
|
||||
|
||||
|
||||
async def require_admin(current_user: User = Depends(require_auth)) -> User:
|
||||
"""Require admin role - raises HTTPException if not admin"""
|
||||
if current_user.role != "admin":
|
||||
raise HTTPException(
|
||||
status_code=403,
|
||||
detail="Admin access required"
|
||||
)
|
||||
|
||||
raise HTTPException(status_code=403, detail="Admin access required")
|
||||
|
||||
return current_user
|
||||
|
||||
|
||||
@@ -138,24 +140,33 @@ async def auth_status(current_user: Optional[User] = Depends(get_current_user)):
|
||||
# Check if SSO is enabled and get available providers
|
||||
sso_enabled = False
|
||||
sso_providers = []
|
||||
|
||||
|
||||
try:
|
||||
from . import sso
|
||||
|
||||
sso_enabled = sso.SSO_ENABLED and AUTH_ENABLED
|
||||
if sso.google_sso:
|
||||
sso_providers.append("google")
|
||||
if sso.github_sso:
|
||||
sso_providers.append("github")
|
||||
if getattr(sso, "custom_sso", None):
|
||||
sso_providers.append("custom")
|
||||
if getattr(sso, "custom_sso_providers", None):
|
||||
if (
|
||||
len(getattr(sso, "custom_sso_providers", {})) > 0
|
||||
and "custom" not in sso_providers
|
||||
):
|
||||
sso_providers.append("custom")
|
||||
except ImportError:
|
||||
pass # SSO module not available
|
||||
|
||||
|
||||
return AuthStatusResponse(
|
||||
auth_enabled=AUTH_ENABLED,
|
||||
authenticated=current_user is not None,
|
||||
user=UserResponse(**current_user.to_public_dict()) if current_user else None,
|
||||
registration_enabled=AUTH_ENABLED and not DISABLE_REGISTRATION,
|
||||
sso_enabled=sso_enabled,
|
||||
sso_providers=sso_providers
|
||||
sso_providers=sso_providers,
|
||||
)
|
||||
|
||||
|
||||
@@ -163,23 +174,16 @@ async def auth_status(current_user: Optional[User] = Depends(get_current_user)):
|
||||
async def login(request: LoginRequest):
|
||||
"""Authenticate user and return access token"""
|
||||
if not AUTH_ENABLED:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail="Authentication is disabled"
|
||||
)
|
||||
|
||||
raise HTTPException(status_code=400, detail="Authentication is disabled")
|
||||
|
||||
user = user_manager.authenticate_user(request.username, request.password)
|
||||
if not user:
|
||||
raise HTTPException(
|
||||
status_code=401,
|
||||
detail="Invalid username or password"
|
||||
)
|
||||
|
||||
raise HTTPException(status_code=401, detail="Invalid username or password")
|
||||
|
||||
access_token = token_manager.create_token(user)
|
||||
|
||||
|
||||
return LoginResponse(
|
||||
access_token=access_token,
|
||||
user=UserResponse(**user.to_public_dict())
|
||||
access_token=access_token, user=UserResponse(**user.to_public_dict())
|
||||
)
|
||||
|
||||
|
||||
@@ -187,31 +191,28 @@ async def login(request: LoginRequest):
|
||||
async def register(request: RegisterRequest):
|
||||
"""Register a new user"""
|
||||
if not AUTH_ENABLED:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail="Authentication is disabled"
|
||||
)
|
||||
|
||||
raise HTTPException(status_code=400, detail="Authentication is disabled")
|
||||
|
||||
if DISABLE_REGISTRATION:
|
||||
raise HTTPException(
|
||||
status_code=403,
|
||||
detail="Public registration is disabled. Contact an administrator to create an account."
|
||||
detail="Public registration is disabled. Contact an administrator to create an account.",
|
||||
)
|
||||
|
||||
|
||||
# Check if this is the first user (should be admin)
|
||||
existing_users = user_manager.list_users()
|
||||
role = "admin" if len(existing_users) == 0 else "user"
|
||||
|
||||
|
||||
success, message = user_manager.create_user(
|
||||
username=request.username,
|
||||
password=request.password,
|
||||
email=request.email,
|
||||
role=role
|
||||
role=role,
|
||||
)
|
||||
|
||||
|
||||
if not success:
|
||||
raise HTTPException(status_code=400, detail=message)
|
||||
|
||||
|
||||
return MessageResponse(message=message)
|
||||
|
||||
|
||||
@@ -233,70 +234,57 @@ async def list_users(current_user: User = Depends(require_admin)):
|
||||
async def delete_user(username: str, current_user: User = Depends(require_admin)):
|
||||
"""Delete a user (admin only)"""
|
||||
if username == current_user.username:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail="Cannot delete your own account"
|
||||
)
|
||||
|
||||
raise HTTPException(status_code=400, detail="Cannot delete your own account")
|
||||
|
||||
success, message = user_manager.delete_user(username)
|
||||
if not success:
|
||||
raise HTTPException(status_code=404, detail=message)
|
||||
|
||||
|
||||
return MessageResponse(message=message)
|
||||
|
||||
|
||||
@router.put("/users/{username}/role", response_model=MessageResponse)
|
||||
async def update_user_role(
|
||||
username: str,
|
||||
request: RoleUpdateRequest,
|
||||
current_user: User = Depends(require_admin)
|
||||
username: str,
|
||||
request: RoleUpdateRequest,
|
||||
current_user: User = Depends(require_admin),
|
||||
):
|
||||
"""Update user role (admin only)"""
|
||||
if request.role not in ["user", "admin"]:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail="Role must be 'user' or 'admin'"
|
||||
)
|
||||
|
||||
raise HTTPException(status_code=400, detail="Role must be 'user' or 'admin'")
|
||||
|
||||
if username == current_user.username:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail="Cannot change your own role"
|
||||
)
|
||||
|
||||
raise HTTPException(status_code=400, detail="Cannot change your own role")
|
||||
|
||||
success, message = user_manager.update_user_role(username, request.role)
|
||||
if not success:
|
||||
raise HTTPException(status_code=404, detail=message)
|
||||
|
||||
|
||||
return MessageResponse(message=message)
|
||||
|
||||
|
||||
@router.post("/users/create", response_model=MessageResponse)
|
||||
async def create_user_admin(request: CreateUserRequest, current_user: User = Depends(require_admin)):
|
||||
async def create_user_admin(
|
||||
request: CreateUserRequest, current_user: User = Depends(require_admin)
|
||||
):
|
||||
"""Create a new user (admin only) - for use when registration is disabled"""
|
||||
if not AUTH_ENABLED:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail="Authentication is disabled"
|
||||
)
|
||||
|
||||
raise HTTPException(status_code=400, detail="Authentication is disabled")
|
||||
|
||||
# Validate role
|
||||
if request.role not in ["user", "admin"]:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail="Role must be 'user' or 'admin'"
|
||||
)
|
||||
|
||||
raise HTTPException(status_code=400, detail="Role must be 'user' or 'admin'")
|
||||
|
||||
success, message = user_manager.create_user(
|
||||
username=request.username,
|
||||
password=request.password,
|
||||
email=request.email,
|
||||
role=request.role
|
||||
role=request.role,
|
||||
)
|
||||
|
||||
|
||||
if not success:
|
||||
raise HTTPException(status_code=400, detail=message)
|
||||
|
||||
|
||||
return MessageResponse(message=message)
|
||||
|
||||
|
||||
@@ -309,22 +297,18 @@ async def get_profile(current_user: User = Depends(require_auth)):
|
||||
|
||||
@router.put("/profile/password", response_model=MessageResponse)
|
||||
async def change_password(
|
||||
request: PasswordChangeRequest,
|
||||
current_user: User = Depends(require_auth)
|
||||
request: PasswordChangeRequest, current_user: User = Depends(require_auth)
|
||||
):
|
||||
"""Change current user's password"""
|
||||
if not AUTH_ENABLED:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail="Authentication is disabled"
|
||||
)
|
||||
|
||||
raise HTTPException(status_code=400, detail="Authentication is disabled")
|
||||
|
||||
success, message = user_manager.change_password(
|
||||
username=current_user.username,
|
||||
current_password=request.current_password,
|
||||
new_password=request.new_password
|
||||
new_password=request.new_password,
|
||||
)
|
||||
|
||||
|
||||
if not success:
|
||||
# Determine appropriate HTTP status code based on error message
|
||||
if "Current password is incorrect" in message:
|
||||
@@ -333,9 +317,9 @@ async def change_password(
|
||||
status_code = 404
|
||||
else:
|
||||
status_code = 400
|
||||
|
||||
|
||||
raise HTTPException(status_code=status_code, detail=message)
|
||||
|
||||
|
||||
return MessageResponse(message=message)
|
||||
|
||||
|
||||
@@ -343,30 +327,26 @@ async def change_password(
|
||||
async def admin_reset_password(
|
||||
username: str,
|
||||
request: AdminPasswordResetRequest,
|
||||
current_user: User = Depends(require_admin)
|
||||
current_user: User = Depends(require_admin),
|
||||
):
|
||||
"""Admin reset user password (admin only)"""
|
||||
if not AUTH_ENABLED:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail="Authentication is disabled"
|
||||
)
|
||||
|
||||
raise HTTPException(status_code=400, detail="Authentication is disabled")
|
||||
|
||||
success, message = user_manager.admin_reset_password(
|
||||
username=username,
|
||||
new_password=request.new_password
|
||||
username=username, new_password=request.new_password
|
||||
)
|
||||
|
||||
|
||||
if not success:
|
||||
# Determine appropriate HTTP status code based on error message
|
||||
if "User not found" in message:
|
||||
status_code = 404
|
||||
else:
|
||||
status_code = 400
|
||||
|
||||
|
||||
raise HTTPException(status_code=status_code, detail=message)
|
||||
|
||||
|
||||
return MessageResponse(message=message)
|
||||
|
||||
|
||||
# Note: SSO routes are included in the main app, not here to avoid circular imports
|
||||
# Note: SSO routes are included in the main app, not here to avoid circular imports
|
||||
|
||||
@@ -1,17 +1,19 @@
|
||||
"""
|
||||
SSO (Single Sign-On) implementation for Google and GitHub authentication
|
||||
"""
|
||||
|
||||
import os
|
||||
import logging
|
||||
from typing import Optional, Dict, Any
|
||||
from datetime import datetime, timedelta
|
||||
|
||||
from fastapi import APIRouter, Request, HTTPException, Depends
|
||||
from fastapi import APIRouter, Request, HTTPException
|
||||
from fastapi.responses import RedirectResponse
|
||||
from fastapi_sso.sso.google import GoogleSSO
|
||||
from fastapi_sso.sso.github import GithubSSO
|
||||
from fastapi_sso.sso.base import OpenID
|
||||
from pydantic import BaseModel
|
||||
from fastapi_sso.sso.generic import create_provider
|
||||
|
||||
from . import user_manager, token_manager, User, AUTH_ENABLED, DISABLE_REGISTRATION
|
||||
|
||||
@@ -25,11 +27,14 @@ GOOGLE_CLIENT_ID = os.getenv("GOOGLE_CLIENT_ID")
|
||||
GOOGLE_CLIENT_SECRET = os.getenv("GOOGLE_CLIENT_SECRET")
|
||||
GITHUB_CLIENT_ID = os.getenv("GITHUB_CLIENT_ID")
|
||||
GITHUB_CLIENT_SECRET = os.getenv("GITHUB_CLIENT_SECRET")
|
||||
SSO_BASE_REDIRECT_URI = os.getenv("SSO_BASE_REDIRECT_URI", "http://localhost:7171/api/auth/sso/callback")
|
||||
SSO_BASE_REDIRECT_URI = os.getenv(
|
||||
"SSO_BASE_REDIRECT_URI", "http://localhost:7171/api/auth/sso/callback"
|
||||
)
|
||||
|
||||
# Initialize SSO providers
|
||||
google_sso = None
|
||||
github_sso = None
|
||||
custom_sso = None
|
||||
|
||||
if GOOGLE_CLIENT_ID and GOOGLE_CLIENT_SECRET:
|
||||
google_sso = GoogleSSO(
|
||||
@@ -47,6 +52,154 @@ if GITHUB_CLIENT_ID and GITHUB_CLIENT_SECRET:
|
||||
allow_insecure_http=True, # Set to False in production with HTTPS
|
||||
)
|
||||
|
||||
# Custom/Generic OAuth provider configuration
|
||||
CUSTOM_SSO_CLIENT_ID = os.getenv("CUSTOM_SSO_CLIENT_ID")
|
||||
CUSTOM_SSO_CLIENT_SECRET = os.getenv("CUSTOM_SSO_CLIENT_SECRET")
|
||||
CUSTOM_SSO_AUTHORIZATION_ENDPOINT = os.getenv("CUSTOM_SSO_AUTHORIZATION_ENDPOINT")
|
||||
CUSTOM_SSO_TOKEN_ENDPOINT = os.getenv("CUSTOM_SSO_TOKEN_ENDPOINT")
|
||||
CUSTOM_SSO_USERINFO_ENDPOINT = os.getenv("CUSTOM_SSO_USERINFO_ENDPOINT")
|
||||
CUSTOM_SSO_SCOPE = os.getenv("CUSTOM_SSO_SCOPE") # comma-separated list
|
||||
CUSTOM_SSO_NAME = os.getenv("CUSTOM_SSO_NAME", "custom")
|
||||
CUSTOM_SSO_DISPLAY_NAME = os.getenv("CUSTOM_SSO_DISPLAY_NAME", "Custom")
|
||||
|
||||
|
||||
def _default_custom_response_convertor(
|
||||
userinfo: Dict[str, Any], _client=None
|
||||
) -> OpenID:
|
||||
"""Best-effort convertor from generic userinfo to OpenID."""
|
||||
user_id = (
|
||||
userinfo.get("sub")
|
||||
or userinfo.get("id")
|
||||
or userinfo.get("user_id")
|
||||
or userinfo.get("uid")
|
||||
or userinfo.get("uuid")
|
||||
)
|
||||
email = userinfo.get("email")
|
||||
display_name = (
|
||||
userinfo.get("name")
|
||||
or userinfo.get("preferred_username")
|
||||
or userinfo.get("login")
|
||||
or email
|
||||
or (str(user_id) if user_id is not None else None)
|
||||
)
|
||||
picture = userinfo.get("picture") or userinfo.get("avatar_url")
|
||||
if not user_id and email:
|
||||
user_id = email
|
||||
return OpenID(
|
||||
id=str(user_id) if user_id is not None else "",
|
||||
email=email,
|
||||
display_name=display_name,
|
||||
picture=picture,
|
||||
provider=CUSTOM_SSO_NAME,
|
||||
)
|
||||
|
||||
|
||||
if all(
|
||||
[
|
||||
CUSTOM_SSO_CLIENT_ID,
|
||||
CUSTOM_SSO_CLIENT_SECRET,
|
||||
CUSTOM_SSO_AUTHORIZATION_ENDPOINT,
|
||||
CUSTOM_SSO_TOKEN_ENDPOINT,
|
||||
CUSTOM_SSO_USERINFO_ENDPOINT,
|
||||
]
|
||||
):
|
||||
discovery = {
|
||||
"authorization_endpoint": CUSTOM_SSO_AUTHORIZATION_ENDPOINT,
|
||||
"token_endpoint": CUSTOM_SSO_TOKEN_ENDPOINT,
|
||||
"userinfo_endpoint": CUSTOM_SSO_USERINFO_ENDPOINT,
|
||||
}
|
||||
default_scope = (
|
||||
[s.strip() for s in CUSTOM_SSO_SCOPE.split(",") if s.strip()]
|
||||
if CUSTOM_SSO_SCOPE
|
||||
else None
|
||||
)
|
||||
CustomProvider = create_provider(
|
||||
name=CUSTOM_SSO_NAME,
|
||||
discovery_document=discovery,
|
||||
response_convertor=_default_custom_response_convertor,
|
||||
default_scope=default_scope,
|
||||
)
|
||||
custom_sso = CustomProvider(
|
||||
client_id=CUSTOM_SSO_CLIENT_ID,
|
||||
client_secret=CUSTOM_SSO_CLIENT_SECRET,
|
||||
redirect_uri=f"{SSO_BASE_REDIRECT_URI}/custom",
|
||||
allow_insecure_http=True, # Set to False in production with HTTPS
|
||||
)
|
||||
|
||||
# Support multiple indexed custom providers (CUSTOM_*_i), up to 10
|
||||
custom_sso_providers: Dict[int, Dict[str, Any]] = {}
|
||||
|
||||
|
||||
def _make_response_convertor(provider_name: str):
|
||||
def _convert(userinfo: Dict[str, Any], _client=None) -> OpenID:
|
||||
user_id = (
|
||||
userinfo.get("sub")
|
||||
or userinfo.get("id")
|
||||
or userinfo.get("user_id")
|
||||
or userinfo.get("uid")
|
||||
or userinfo.get("uuid")
|
||||
)
|
||||
email = userinfo.get("email")
|
||||
display_name = (
|
||||
userinfo.get("name")
|
||||
or userinfo.get("preferred_username")
|
||||
or userinfo.get("login")
|
||||
or email
|
||||
or (str(user_id) if user_id is not None else None)
|
||||
)
|
||||
picture = userinfo.get("picture") or userinfo.get("avatar_url")
|
||||
if not user_id and email:
|
||||
user_id = email
|
||||
return OpenID(
|
||||
id=str(user_id) if user_id is not None else "",
|
||||
email=email,
|
||||
display_name=display_name,
|
||||
picture=picture,
|
||||
provider=provider_name,
|
||||
)
|
||||
|
||||
return _convert
|
||||
|
||||
|
||||
for i in range(1, 11):
|
||||
cid = os.getenv(f"CUSTOM_SSO_CLIENT_ID_{i}")
|
||||
csecret = os.getenv(f"CUSTOM_SSO_CLIENT_SECRET_{i}")
|
||||
auth_ep = os.getenv(f"CUSTOM_SSO_AUTHORIZATION_ENDPOINT_{i}")
|
||||
token_ep = os.getenv(f"CUSTOM_SSO_TOKEN_ENDPOINT_{i}")
|
||||
userinfo_ep = os.getenv(f"CUSTOM_SSO_USERINFO_ENDPOINT_{i}")
|
||||
scope_raw = os.getenv(f"CUSTOM_SSO_SCOPE_{i}")
|
||||
name_i = os.getenv(f"CUSTOM_SSO_NAME_{i}", f"custom{i}")
|
||||
display_name_i = os.getenv(f"CUSTOM_SSO_DISPLAY_NAME_{i}", f"Custom {i}")
|
||||
|
||||
if all([cid, csecret, auth_ep, token_ep, userinfo_ep]):
|
||||
discovery_i = {
|
||||
"authorization_endpoint": auth_ep,
|
||||
"token_endpoint": token_ep,
|
||||
"userinfo_endpoint": userinfo_ep,
|
||||
}
|
||||
default_scope_i = (
|
||||
[s.strip() for s in scope_raw.split(",") if s.strip()]
|
||||
if scope_raw
|
||||
else None
|
||||
)
|
||||
ProviderClass = create_provider(
|
||||
name=name_i,
|
||||
discovery_document=discovery_i,
|
||||
response_convertor=_make_response_convertor(name_i),
|
||||
default_scope=default_scope_i,
|
||||
)
|
||||
provider_instance = ProviderClass(
|
||||
client_id=cid,
|
||||
client_secret=csecret,
|
||||
redirect_uri=f"{SSO_BASE_REDIRECT_URI}/custom/{i}",
|
||||
allow_insecure_http=True, # Set to False in production with HTTPS
|
||||
)
|
||||
custom_sso_providers[i] = {
|
||||
"sso": provider_instance,
|
||||
"name": name_i,
|
||||
"display_name": display_name_i,
|
||||
}
|
||||
|
||||
|
||||
class MessageResponse(BaseModel):
|
||||
message: str
|
||||
@@ -70,21 +223,25 @@ def create_or_update_sso_user(openid: OpenID, provider: str) -> User:
|
||||
# Generate username from email or use provider ID
|
||||
email = openid.email
|
||||
if not email:
|
||||
raise HTTPException(status_code=400, detail="Email is required for SSO authentication")
|
||||
|
||||
raise HTTPException(
|
||||
status_code=400, detail="Email is required for SSO authentication"
|
||||
)
|
||||
|
||||
# Use email prefix as username, fallback to provider + id
|
||||
username = email.split("@")[0]
|
||||
if not username:
|
||||
username = f"{provider}_{openid.id}"
|
||||
|
||||
|
||||
# Check if user already exists by email
|
||||
existing_user = None
|
||||
users = user_manager.load_users()
|
||||
for user_data in users.values():
|
||||
if user_data.get("email") == email:
|
||||
existing_user = User(**{k: v for k, v in user_data.items() if k != "password_hash"})
|
||||
existing_user = User(
|
||||
**{k: v for k, v in user_data.items() if k != "password_hash"}
|
||||
)
|
||||
break
|
||||
|
||||
|
||||
if existing_user:
|
||||
# Update last login for existing user (always allowed)
|
||||
users[existing_user.username]["last_login"] = datetime.utcnow().isoformat()
|
||||
@@ -96,10 +253,10 @@ def create_or_update_sso_user(openid: OpenID, provider: str) -> User:
|
||||
# Check if registration is disabled before creating new user
|
||||
if DISABLE_REGISTRATION:
|
||||
raise HTTPException(
|
||||
status_code=403,
|
||||
detail="Registration is disabled. Contact an administrator to create an account."
|
||||
status_code=403,
|
||||
detail="Registration is disabled. Contact an administrator to create an account.",
|
||||
)
|
||||
|
||||
|
||||
# Create new user
|
||||
# Ensure username is unique
|
||||
counter = 1
|
||||
@@ -107,20 +264,20 @@ def create_or_update_sso_user(openid: OpenID, provider: str) -> User:
|
||||
while username in users:
|
||||
username = f"{original_username}{counter}"
|
||||
counter += 1
|
||||
|
||||
|
||||
user = User(
|
||||
username=username,
|
||||
email=email,
|
||||
role="user" # Default role for SSO users
|
||||
role="user", # Default role for SSO users
|
||||
)
|
||||
|
||||
|
||||
users[username] = {
|
||||
**user.to_dict(),
|
||||
"sso_provider": provider,
|
||||
"sso_id": openid.id,
|
||||
"password_hash": None # SSO users don't have passwords
|
||||
"password_hash": None, # SSO users don't have passwords
|
||||
}
|
||||
|
||||
|
||||
user_manager.save_users(users)
|
||||
logger.info(f"Created SSO user: {username} via {provider}")
|
||||
return user
|
||||
@@ -130,27 +287,51 @@ def create_or_update_sso_user(openid: OpenID, provider: str) -> User:
|
||||
async def sso_status():
|
||||
"""Get SSO status and available providers"""
|
||||
providers = []
|
||||
|
||||
|
||||
if google_sso:
|
||||
providers.append(SSOProvider(
|
||||
name="google",
|
||||
display_name="Google",
|
||||
enabled=True,
|
||||
login_url="/api/auth/sso/login/google"
|
||||
))
|
||||
|
||||
providers.append(
|
||||
SSOProvider(
|
||||
name="google",
|
||||
display_name="Google",
|
||||
enabled=True,
|
||||
login_url="/api/auth/sso/login/google",
|
||||
)
|
||||
)
|
||||
|
||||
if github_sso:
|
||||
providers.append(SSOProvider(
|
||||
name="github",
|
||||
display_name="GitHub",
|
||||
enabled=True,
|
||||
login_url="/api/auth/sso/login/github"
|
||||
))
|
||||
|
||||
providers.append(
|
||||
SSOProvider(
|
||||
name="github",
|
||||
display_name="GitHub",
|
||||
enabled=True,
|
||||
login_url="/api/auth/sso/login/github",
|
||||
)
|
||||
)
|
||||
|
||||
if custom_sso:
|
||||
providers.append(
|
||||
SSOProvider(
|
||||
name="custom",
|
||||
display_name=CUSTOM_SSO_DISPLAY_NAME,
|
||||
enabled=True,
|
||||
login_url="/api/auth/sso/login/custom",
|
||||
)
|
||||
)
|
||||
|
||||
for idx, cfg in custom_sso_providers.items():
|
||||
providers.append(
|
||||
SSOProvider(
|
||||
name=cfg["name"],
|
||||
display_name=cfg.get("display_name", cfg["name"]),
|
||||
enabled=True,
|
||||
login_url=f"/api/auth/sso/login/custom/{idx}",
|
||||
)
|
||||
)
|
||||
|
||||
return SSOStatusResponse(
|
||||
sso_enabled=SSO_ENABLED and AUTH_ENABLED,
|
||||
providers=providers,
|
||||
registration_enabled=not DISABLE_REGISTRATION
|
||||
registration_enabled=not DISABLE_REGISTRATION,
|
||||
)
|
||||
|
||||
|
||||
@@ -159,12 +340,14 @@ async def google_login():
|
||||
"""Initiate Google SSO login"""
|
||||
if not SSO_ENABLED or not AUTH_ENABLED:
|
||||
raise HTTPException(status_code=400, detail="SSO is disabled")
|
||||
|
||||
|
||||
if not google_sso:
|
||||
raise HTTPException(status_code=400, detail="Google SSO is not configured")
|
||||
|
||||
|
||||
async with google_sso:
|
||||
return await google_sso.get_login_redirect(params={"prompt": "consent", "access_type": "offline"})
|
||||
return await google_sso.get_login_redirect(
|
||||
params={"prompt": "consent", "access_type": "offline"}
|
||||
)
|
||||
|
||||
|
||||
@router.get("/sso/login/github")
|
||||
@@ -172,37 +355,66 @@ async def github_login():
|
||||
"""Initiate GitHub SSO login"""
|
||||
if not SSO_ENABLED or not AUTH_ENABLED:
|
||||
raise HTTPException(status_code=400, detail="SSO is disabled")
|
||||
|
||||
|
||||
if not github_sso:
|
||||
raise HTTPException(status_code=400, detail="GitHub SSO is not configured")
|
||||
|
||||
|
||||
async with github_sso:
|
||||
return await github_sso.get_login_redirect()
|
||||
|
||||
|
||||
@router.get("/sso/login/custom")
|
||||
async def custom_login():
|
||||
"""Initiate Custom SSO login"""
|
||||
if not SSO_ENABLED or not AUTH_ENABLED:
|
||||
raise HTTPException(status_code=400, detail="SSO is disabled")
|
||||
|
||||
if not custom_sso:
|
||||
raise HTTPException(status_code=400, detail="Custom SSO is not configured")
|
||||
|
||||
async with custom_sso:
|
||||
return await custom_sso.get_login_redirect()
|
||||
|
||||
|
||||
@router.get("/sso/login/custom/{index}")
|
||||
async def custom_login_indexed(index: int):
|
||||
"""Initiate indexed Custom SSO login"""
|
||||
if not SSO_ENABLED or not AUTH_ENABLED:
|
||||
raise HTTPException(status_code=400, detail="SSO is disabled")
|
||||
|
||||
cfg = custom_sso_providers.get(index)
|
||||
if not cfg:
|
||||
raise HTTPException(
|
||||
status_code=400, detail="Custom SSO provider not configured"
|
||||
)
|
||||
|
||||
async with cfg["sso"]:
|
||||
return await cfg["sso"].get_login_redirect()
|
||||
|
||||
|
||||
@router.get("/sso/callback/google")
|
||||
async def google_callback(request: Request):
|
||||
"""Handle Google SSO callback"""
|
||||
if not SSO_ENABLED or not AUTH_ENABLED:
|
||||
raise HTTPException(status_code=400, detail="SSO is disabled")
|
||||
|
||||
|
||||
if not google_sso:
|
||||
raise HTTPException(status_code=400, detail="Google SSO is not configured")
|
||||
|
||||
|
||||
try:
|
||||
async with google_sso:
|
||||
openid = await google_sso.verify_and_process(request)
|
||||
|
||||
|
||||
# Create or update user
|
||||
user = create_or_update_sso_user(openid, "google")
|
||||
|
||||
|
||||
# Create JWT token
|
||||
access_token = token_manager.create_token(user)
|
||||
|
||||
|
||||
# Redirect to frontend with token (you might want to customize this)
|
||||
frontend_url = os.getenv("FRONTEND_URL", "http://localhost:3000")
|
||||
response = RedirectResponse(url=f"{frontend_url}?token={access_token}")
|
||||
|
||||
|
||||
# Also set as HTTP-only cookie
|
||||
response.set_cookie(
|
||||
key="access_token",
|
||||
@@ -210,18 +422,18 @@ async def google_callback(request: Request):
|
||||
httponly=True,
|
||||
secure=False, # Set to True in production with HTTPS
|
||||
samesite="lax",
|
||||
max_age=timedelta(hours=24).total_seconds()
|
||||
max_age=timedelta(hours=24).total_seconds(),
|
||||
)
|
||||
|
||||
|
||||
return response
|
||||
|
||||
|
||||
except HTTPException as e:
|
||||
# Handle specific HTTP exceptions (like registration disabled)
|
||||
frontend_url = os.getenv("FRONTEND_URL", "http://localhost:3000")
|
||||
error_msg = e.detail if hasattr(e, 'detail') else "Authentication failed"
|
||||
error_msg = e.detail if hasattr(e, "detail") else "Authentication failed"
|
||||
logger.warning(f"Google SSO callback error: {error_msg}")
|
||||
return RedirectResponse(url=f"{frontend_url}?error={error_msg}")
|
||||
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Google SSO callback error: {e}")
|
||||
frontend_url = os.getenv("FRONTEND_URL", "http://localhost:3000")
|
||||
@@ -233,24 +445,24 @@ async def github_callback(request: Request):
|
||||
"""Handle GitHub SSO callback"""
|
||||
if not SSO_ENABLED or not AUTH_ENABLED:
|
||||
raise HTTPException(status_code=400, detail="SSO is disabled")
|
||||
|
||||
|
||||
if not github_sso:
|
||||
raise HTTPException(status_code=400, detail="GitHub SSO is not configured")
|
||||
|
||||
|
||||
try:
|
||||
async with github_sso:
|
||||
openid = await github_sso.verify_and_process(request)
|
||||
|
||||
|
||||
# Create or update user
|
||||
user = create_or_update_sso_user(openid, "github")
|
||||
|
||||
|
||||
# Create JWT token
|
||||
access_token = token_manager.create_token(user)
|
||||
|
||||
|
||||
# Redirect to frontend with token (you might want to customize this)
|
||||
frontend_url = os.getenv("FRONTEND_URL", "http://localhost:3000")
|
||||
response = RedirectResponse(url=f"{frontend_url}?token={access_token}")
|
||||
|
||||
|
||||
# Also set as HTTP-only cookie
|
||||
response.set_cookie(
|
||||
key="access_token",
|
||||
@@ -258,24 +470,123 @@ async def github_callback(request: Request):
|
||||
httponly=True,
|
||||
secure=False, # Set to True in production with HTTPS
|
||||
samesite="lax",
|
||||
max_age=timedelta(hours=24).total_seconds()
|
||||
max_age=timedelta(hours=24).total_seconds(),
|
||||
)
|
||||
|
||||
|
||||
return response
|
||||
|
||||
|
||||
except HTTPException as e:
|
||||
# Handle specific HTTP exceptions (like registration disabled)
|
||||
frontend_url = os.getenv("FRONTEND_URL", "http://localhost:3000")
|
||||
error_msg = e.detail if hasattr(e, 'detail') else "Authentication failed"
|
||||
error_msg = e.detail if hasattr(e, "detail") else "Authentication failed"
|
||||
logger.warning(f"GitHub SSO callback error: {error_msg}")
|
||||
return RedirectResponse(url=f"{frontend_url}?error={error_msg}")
|
||||
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"GitHub SSO callback error: {e}")
|
||||
frontend_url = os.getenv("FRONTEND_URL", "http://localhost:3000")
|
||||
return RedirectResponse(url=f"{frontend_url}?error=Authentication failed")
|
||||
|
||||
|
||||
@router.get("/sso/callback/custom")
|
||||
async def custom_callback(request: Request):
|
||||
"""Handle Custom SSO callback"""
|
||||
if not SSO_ENABLED or not AUTH_ENABLED:
|
||||
raise HTTPException(status_code=400, detail="SSO is disabled")
|
||||
|
||||
if not custom_sso:
|
||||
raise HTTPException(status_code=400, detail="Custom SSO is not configured")
|
||||
|
||||
try:
|
||||
async with custom_sso:
|
||||
openid = await custom_sso.verify_and_process(request)
|
||||
|
||||
# Create or update user
|
||||
user = create_or_update_sso_user(openid, "custom")
|
||||
|
||||
# Create JWT token
|
||||
access_token = token_manager.create_token(user)
|
||||
|
||||
# Redirect to frontend with token (you might want to customize this)
|
||||
frontend_url = os.getenv("FRONTEND_URL", "http://localhost:3000")
|
||||
response = RedirectResponse(url=f"{frontend_url}?token={access_token}")
|
||||
|
||||
# Also set as HTTP-only cookie
|
||||
response.set_cookie(
|
||||
key="access_token",
|
||||
value=access_token,
|
||||
httponly=True,
|
||||
secure=False, # Set to True in production with HTTPS
|
||||
samesite="lax",
|
||||
max_age=timedelta(hours=24).total_seconds(),
|
||||
)
|
||||
|
||||
return response
|
||||
|
||||
except HTTPException as e:
|
||||
# Handle specific HTTP exceptions (like registration disabled)
|
||||
frontend_url = os.getenv("FRONTEND_URL", "http://localhost:3000")
|
||||
error_msg = e.detail if hasattr(e, "detail") else "Authentication failed"
|
||||
logger.warning(f"Custom SSO callback error: {error_msg}")
|
||||
return RedirectResponse(url=f"{frontend_url}?error={error_msg}")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Custom SSO callback error: {e}")
|
||||
frontend_url = os.getenv("FRONTEND_URL", "http://localhost:3000")
|
||||
return RedirectResponse(url=f"{frontend_url}?error=Authentication failed")
|
||||
|
||||
|
||||
@router.get("/sso/callback/custom/{index}")
|
||||
async def custom_callback_indexed(request: Request, index: int):
|
||||
"""Handle indexed Custom SSO callback"""
|
||||
if not SSO_ENABLED or not AUTH_ENABLED:
|
||||
raise HTTPException(status_code=400, detail="SSO is disabled")
|
||||
|
||||
cfg = custom_sso_providers.get(index)
|
||||
if not cfg:
|
||||
raise HTTPException(
|
||||
status_code=400, detail="Custom SSO provider not configured"
|
||||
)
|
||||
|
||||
try:
|
||||
async with cfg["sso"]:
|
||||
openid = await cfg["sso"].verify_and_process(request)
|
||||
|
||||
# Create or update user
|
||||
user = create_or_update_sso_user(openid, cfg["name"])
|
||||
|
||||
# Create JWT token
|
||||
access_token = token_manager.create_token(user)
|
||||
|
||||
# Redirect to frontend with token (you might want to customize this)
|
||||
frontend_url = os.getenv("FRONTEND_URL", "http://localhost:3000")
|
||||
response = RedirectResponse(url=f"{frontend_url}?token={access_token}")
|
||||
|
||||
# Also set as HTTP-only cookie
|
||||
response.set_cookie(
|
||||
key="access_token",
|
||||
value=access_token,
|
||||
httponly=True,
|
||||
secure=False, # Set to True in production with HTTPS
|
||||
samesite="lax",
|
||||
max_age=timedelta(hours=24).total_seconds(),
|
||||
)
|
||||
|
||||
return response
|
||||
|
||||
except HTTPException as e:
|
||||
# Handle specific HTTP exceptions (like registration disabled)
|
||||
frontend_url = os.getenv("FRONTEND_URL", "http://localhost:3000")
|
||||
error_msg = e.detail if hasattr(e, "detail") else "Authentication failed"
|
||||
logger.warning(f"Custom[{index}] SSO callback error: {error_msg}")
|
||||
return RedirectResponse(url=f"{frontend_url}?error={error_msg}")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Custom[{index}] SSO callback error: {e}")
|
||||
frontend_url = os.getenv("FRONTEND_URL", "http://localhost:3000")
|
||||
return RedirectResponse(url=f"{frontend_url}?error=Authentication failed")
|
||||
|
||||
|
||||
@router.post("/sso/unlink/{provider}", response_model=MessageResponse)
|
||||
async def unlink_sso_provider(
|
||||
provider: str,
|
||||
@@ -284,27 +595,42 @@ async def unlink_sso_provider(
|
||||
"""Unlink SSO provider from user account"""
|
||||
if not SSO_ENABLED or not AUTH_ENABLED:
|
||||
raise HTTPException(status_code=400, detail="SSO is disabled")
|
||||
|
||||
if provider not in ["google", "github"]:
|
||||
|
||||
available = []
|
||||
if google_sso:
|
||||
available.append("google")
|
||||
if github_sso:
|
||||
available.append("github")
|
||||
if custom_sso:
|
||||
available.append("custom")
|
||||
|
||||
for cfg in custom_sso_providers.values():
|
||||
available.append(cfg["name"])
|
||||
|
||||
if provider not in available:
|
||||
raise HTTPException(status_code=400, detail="Invalid SSO provider")
|
||||
|
||||
|
||||
# Get current user from request (avoiding circular imports)
|
||||
from .middleware import require_auth_from_state
|
||||
|
||||
|
||||
current_user = await require_auth_from_state(request)
|
||||
|
||||
|
||||
if not current_user.sso_provider:
|
||||
raise HTTPException(status_code=400, detail="User is not linked to any SSO provider")
|
||||
|
||||
raise HTTPException(
|
||||
status_code=400, detail="User is not linked to any SSO provider"
|
||||
)
|
||||
|
||||
if current_user.sso_provider != provider:
|
||||
raise HTTPException(status_code=400, detail=f"User is not linked to {provider}")
|
||||
|
||||
|
||||
# Update user to remove SSO linkage
|
||||
users = user_manager.load_users()
|
||||
if current_user.username in users:
|
||||
users[current_user.username]["sso_provider"] = None
|
||||
users[current_user.username]["sso_id"] = None
|
||||
user_manager.save_users(users)
|
||||
logger.info(f"Unlinked SSO provider {provider} from user {current_user.username}")
|
||||
|
||||
return MessageResponse(message=f"SSO provider {provider} unlinked successfully")
|
||||
logger.info(
|
||||
f"Unlinked SSO provider {provider} from user {current_user.username}"
|
||||
)
|
||||
|
||||
return MessageResponse(message=f"SSO provider {provider} unlinked successfully")
|
||||
|
||||
@@ -5,11 +5,12 @@ import uuid
|
||||
import time
|
||||
from routes.utils.celery_queue_manager import download_queue_manager
|
||||
from routes.utils.celery_tasks import store_task_info, store_task_status, ProgressState
|
||||
from routes.utils.get_info import get_spotify_info
|
||||
from routes.utils.get_info import get_client, get_album
|
||||
from routes.utils.errors import DuplicateDownloadError
|
||||
|
||||
# Import authentication dependencies
|
||||
from routes.auth.middleware import require_auth_from_state, User
|
||||
# Config and credentials helpers
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
@@ -34,7 +35,8 @@ async def handle_download(
|
||||
|
||||
# Fetch metadata from Spotify
|
||||
try:
|
||||
album_info = get_spotify_info(album_id, "album")
|
||||
client = get_client()
|
||||
album_info = get_album(client, album_id)
|
||||
if (
|
||||
not album_info
|
||||
or not album_info.get("name")
|
||||
@@ -155,6 +157,7 @@ async def get_album_info(
|
||||
"""
|
||||
Retrieve Spotify album metadata given a Spotify album ID.
|
||||
Expects a query parameter 'id' that contains the Spotify album ID.
|
||||
Returns the raw JSON from get_album in routes.utils.get_info.
|
||||
"""
|
||||
spotify_id = request.query_params.get("id")
|
||||
|
||||
@@ -162,27 +165,9 @@ async def get_album_info(
|
||||
return JSONResponse(content={"error": "Missing parameter: id"}, status_code=400)
|
||||
|
||||
try:
|
||||
# Optional pagination params for tracks
|
||||
limit_param = request.query_params.get("limit")
|
||||
offset_param = request.query_params.get("offset")
|
||||
limit = int(limit_param) if limit_param is not None else None
|
||||
offset = int(offset_param) if offset_param is not None else None
|
||||
|
||||
# Fetch album metadata
|
||||
album_info = get_spotify_info(spotify_id, "album")
|
||||
# Fetch album tracks with pagination
|
||||
album_tracks = get_spotify_info(
|
||||
spotify_id, "album_tracks", limit=limit, offset=offset
|
||||
)
|
||||
|
||||
# Merge tracks into album payload in the same shape Spotify returns on album
|
||||
album_info["tracks"] = album_tracks
|
||||
|
||||
client = get_client()
|
||||
album_info = get_album(client, spotify_id)
|
||||
return JSONResponse(content=album_info, status_code=200)
|
||||
except ValueError as ve:
|
||||
return JSONResponse(
|
||||
content={"error": f"Invalid limit/offset: {str(ve)}"}, status_code=400
|
||||
)
|
||||
except Exception as e:
|
||||
error_data = {"error": str(e), "traceback": traceback.format_exc()}
|
||||
return JSONResponse(content=error_data, status_code=500)
|
||||
|
||||
@@ -18,10 +18,9 @@ from routes.utils.watch.db import (
|
||||
get_watched_artists,
|
||||
add_specific_albums_to_artist_table,
|
||||
remove_specific_albums_from_artist_table,
|
||||
is_album_in_artist_db,
|
||||
)
|
||||
from routes.utils.watch.manager import check_watched_artists, get_watch_config
|
||||
from routes.utils.get_info import get_spotify_info
|
||||
from routes.utils.get_info import get_client, get_artist, get_album
|
||||
|
||||
# Import authentication dependencies
|
||||
from routes.auth.middleware import require_auth_from_state, User
|
||||
@@ -66,9 +65,6 @@ async def handle_artist_download(
|
||||
)
|
||||
|
||||
try:
|
||||
# Import and call the updated download_artist_albums() function.
|
||||
# from routes.utils.artist import download_artist_albums # Already imported at top
|
||||
|
||||
# Delegate to the download_artist_albums function which will handle album filtering
|
||||
successfully_queued_albums, duplicate_albums = download_artist_albums(
|
||||
url=url,
|
||||
@@ -118,13 +114,15 @@ async def cancel_artist_download():
|
||||
|
||||
@router.get("/info")
|
||||
async def get_artist_info(
|
||||
request: Request, current_user: User = Depends(require_auth_from_state),
|
||||
limit: int = Query(10, ge=1), # default=10, must be >=1
|
||||
offset: int = Query(0, ge=0) # default=0, must be >=0
|
||||
request: Request,
|
||||
current_user: User = Depends(require_auth_from_state),
|
||||
limit: int = Query(10, ge=1), # default=10, must be >=1
|
||||
offset: int = Query(0, ge=0), # default=0, must be >=0
|
||||
):
|
||||
"""
|
||||
Retrieves Spotify artist metadata given a Spotify artist ID.
|
||||
Expects a query parameter 'id' with the Spotify artist ID.
|
||||
Returns the raw JSON from get_artist in routes.utils.get_info.
|
||||
"""
|
||||
spotify_id = request.query_params.get("id")
|
||||
|
||||
@@ -132,37 +130,8 @@ async def get_artist_info(
|
||||
return JSONResponse(content={"error": "Missing parameter: id"}, status_code=400)
|
||||
|
||||
try:
|
||||
# Get artist metadata first
|
||||
artist_metadata = get_spotify_info(spotify_id, "artist")
|
||||
|
||||
# Get artist discography for albums
|
||||
artist_discography = get_spotify_info(spotify_id, "artist_discography", limit=limit, offset=offset)
|
||||
|
||||
# Combine metadata with discography
|
||||
artist_info = {**artist_metadata, "albums": artist_discography}
|
||||
|
||||
# If artist_info is successfully fetched and has albums,
|
||||
# check if the artist is watched and augment album items with is_locally_known status
|
||||
if (
|
||||
artist_info
|
||||
and artist_info.get("albums")
|
||||
and artist_info["albums"].get("items")
|
||||
):
|
||||
watched_artist_details = get_watched_artist(
|
||||
spotify_id
|
||||
) # spotify_id is the artist ID
|
||||
if watched_artist_details: # Artist is being watched
|
||||
for album_item in artist_info["albums"]["items"]:
|
||||
if album_item and album_item.get("id"):
|
||||
album_id = album_item["id"]
|
||||
album_item["is_locally_known"] = is_album_in_artist_db(
|
||||
spotify_id, album_id
|
||||
)
|
||||
elif album_item: # Album object exists but no ID
|
||||
album_item["is_locally_known"] = False
|
||||
# If not watched, or no albums, is_locally_known will not be added.
|
||||
# Frontend should handle absence of this key as false.
|
||||
|
||||
client = get_client()
|
||||
artist_info = get_artist(client, spotify_id)
|
||||
return JSONResponse(content=artist_info, status_code=200)
|
||||
except Exception as e:
|
||||
return JSONResponse(
|
||||
@@ -191,15 +160,9 @@ async def add_artist_to_watchlist(
|
||||
if get_watched_artist(artist_spotify_id):
|
||||
return {"message": f"Artist {artist_spotify_id} is already being watched."}
|
||||
|
||||
# Get artist metadata directly for name and basic info
|
||||
artist_metadata = get_spotify_info(artist_spotify_id, "artist")
|
||||
client = get_client()
|
||||
artist_metadata = get_artist(client, artist_spotify_id)
|
||||
|
||||
# Get artist discography for album count
|
||||
artist_album_list_data = get_spotify_info(
|
||||
artist_spotify_id, "artist_discography"
|
||||
)
|
||||
|
||||
# Check if we got artist metadata
|
||||
if not artist_metadata or not artist_metadata.get("name"):
|
||||
logger.error(
|
||||
f"Could not fetch artist metadata for {artist_spotify_id} from Spotify."
|
||||
@@ -211,24 +174,22 @@ async def add_artist_to_watchlist(
|
||||
},
|
||||
)
|
||||
|
||||
# Check if we got album data
|
||||
if not artist_album_list_data or not isinstance(
|
||||
artist_album_list_data.get("items"), list
|
||||
# Derive a rough total album count from groups if present
|
||||
total_albums = 0
|
||||
for key in (
|
||||
"album_group",
|
||||
"single_group",
|
||||
"compilation_group",
|
||||
"appears_on_group",
|
||||
):
|
||||
logger.warning(
|
||||
f"Could not fetch album list details for artist {artist_spotify_id} from Spotify. Proceeding with metadata only."
|
||||
)
|
||||
grp = artist_metadata.get(key)
|
||||
if isinstance(grp, list):
|
||||
total_albums += len(grp)
|
||||
|
||||
# Construct the artist_data object expected by add_artist_db
|
||||
artist_data_for_db = {
|
||||
"id": artist_spotify_id,
|
||||
"name": artist_metadata.get("name", "Unknown Artist"),
|
||||
"albums": { # Mimic structure if add_artist_db expects it for total_albums
|
||||
"total": artist_album_list_data.get("total", 0)
|
||||
if artist_album_list_data
|
||||
else 0
|
||||
},
|
||||
# Add any other fields add_artist_db might expect from a true artist object if necessary
|
||||
"albums": {"total": total_albums},
|
||||
}
|
||||
|
||||
add_artist_db(artist_data_for_db)
|
||||
@@ -446,21 +407,25 @@ async def mark_albums_as_known_for_artist(
|
||||
detail={"error": f"Artist {artist_spotify_id} is not being watched."},
|
||||
)
|
||||
|
||||
client = get_client()
|
||||
fetched_albums_details = []
|
||||
for album_id in album_ids:
|
||||
try:
|
||||
# We need full album details. get_spotify_info with type "album" should provide this.
|
||||
album_detail = get_spotify_info(album_id, "album")
|
||||
if album_detail and album_detail.get("id"):
|
||||
fetched_albums_details.append(album_detail)
|
||||
else:
|
||||
logger.warning(
|
||||
f"Could not fetch details for album {album_id} when marking as known for artist {artist_spotify_id}."
|
||||
try:
|
||||
for album_id in album_ids:
|
||||
try:
|
||||
album_detail = get_album(client, album_id)
|
||||
if album_detail and album_detail.get("id"):
|
||||
fetched_albums_details.append(album_detail)
|
||||
else:
|
||||
logger.warning(
|
||||
f"Could not fetch details for album {album_id} when marking as known for artist {artist_spotify_id}."
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"Failed to fetch Spotify details for album {album_id}: {e}"
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"Failed to fetch Spotify details for album {album_id}: {e}"
|
||||
)
|
||||
finally:
|
||||
# No need to close_client here, as get_client is shared
|
||||
pass
|
||||
|
||||
if not fetched_albums_details:
|
||||
return {
|
||||
|
||||
@@ -11,25 +11,42 @@ from routes.auth.middleware import require_auth_from_state, User
|
||||
from routes.utils.get_info import get_spotify_info
|
||||
from routes.utils.celery_queue_manager import download_queue_manager
|
||||
|
||||
# Assuming these imports are available for queue management and Spotify info
|
||||
from routes.utils.get_info import (
|
||||
get_client,
|
||||
get_track,
|
||||
get_album,
|
||||
get_playlist,
|
||||
get_artist,
|
||||
)
|
||||
|
||||
router = APIRouter()
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class BulkAddLinksRequest(BaseModel):
|
||||
links: List[str]
|
||||
|
||||
|
||||
@router.post("/bulk-add-spotify-links")
|
||||
async def bulk_add_spotify_links(request: BulkAddLinksRequest, req: Request, current_user: User = Depends(require_auth_from_state)):
|
||||
added_count = 0
|
||||
failed_links = []
|
||||
total_links = len(request.links)
|
||||
|
||||
|
||||
client = get_client()
|
||||
for link in request.links:
|
||||
# Assuming links are pre-filtered by the frontend,
|
||||
# but still handle potential errors during info retrieval or unsupported types
|
||||
# Extract type and ID from the link directly using regex
|
||||
match = re.match(r"https://open\.spotify\.com(?:/intl-[a-z]{2})?/(track|album|playlist|artist)/([a-zA-Z0-9]+)(?:\?.*)?", link)
|
||||
match = re.match(
|
||||
r"https://open\.spotify\.com(?:/intl-[a-z]{2})?/(track|album|playlist|artist)/([a-zA-Z0-9]+)(?:\?.*)?",
|
||||
link,
|
||||
)
|
||||
if not match:
|
||||
logger.warning(f"Could not parse Spotify link (unexpected format after frontend filter): {link}")
|
||||
logger.warning(
|
||||
f"Could not parse Spotify link (unexpected format after frontend filter): {link}"
|
||||
)
|
||||
failed_links.append(link)
|
||||
continue
|
||||
|
||||
@@ -39,18 +56,30 @@ async def bulk_add_spotify_links(request: BulkAddLinksRequest, req: Request, cur
|
||||
|
||||
try:
|
||||
# Get basic info to confirm existence and get name/artist
|
||||
# For playlists, we might want to get full info later when adding to queue
|
||||
if spotify_type == "playlist":
|
||||
item_info = get_spotify_info(spotify_id, "playlist_metadata")
|
||||
item_info = get_playlist(client, spotify_id, expand_items=False)
|
||||
elif spotify_type == "track":
|
||||
item_info = get_track(client, spotify_id)
|
||||
elif spotify_type == "album":
|
||||
item_info = get_album(client, spotify_id)
|
||||
elif spotify_type == "artist":
|
||||
# Not queued below, but fetch to validate link and name if needed
|
||||
item_info = get_artist(client, spotify_id)
|
||||
else:
|
||||
item_info = get_spotify_info(spotify_id, spotify_type)
|
||||
|
||||
logger.warning(
|
||||
f"Unsupported Spotify type: {spotify_type} for link: {link}"
|
||||
)
|
||||
failed_links.append(link)
|
||||
continue
|
||||
|
||||
item_name = item_info.get("name", "Unknown Name")
|
||||
artist_name = ""
|
||||
if spotify_type in ["track", "album"]:
|
||||
artists = item_info.get("artists", [])
|
||||
if artists:
|
||||
artist_name = ", ".join([a.get("name", "Unknown Artist") for a in artists])
|
||||
artist_name = ", ".join(
|
||||
[a.get("name", "Unknown Artist") for a in artists]
|
||||
)
|
||||
elif spotify_type == "playlist":
|
||||
owner = item_info.get("owner", {})
|
||||
artist_name = owner.get("display_name", "Unknown Owner")
|
||||
@@ -77,8 +106,16 @@ async def bulk_add_spotify_links(request: BulkAddLinksRequest, req: Request, cur
|
||||
added_count += 1
|
||||
logger.debug(f"Added {added_count}/{total_links} {spotify_type} '{item_name}' ({spotify_id}) to queue with task_id: {task_id}.")
|
||||
else:
|
||||
logger.warning(f"Failed to add {spotify_type} '{item_name}' ({spotify_id}) to queue.")
|
||||
logger.warning(
|
||||
f"Unsupported Spotify type for download: {spotify_type} for link: {link}"
|
||||
)
|
||||
failed_links.append(link)
|
||||
continue
|
||||
|
||||
added_count += 1
|
||||
logger.debug(
|
||||
f"Added {added_count + 1}/{total_links} {spotify_type} '{item_name}' ({spotify_id}) to queue."
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error processing Spotify link {link}: {e}", exc_info=True)
|
||||
|
||||
@@ -1,6 +1,5 @@
|
||||
from fastapi import APIRouter, HTTPException, Request, Depends
|
||||
from fastapi.responses import JSONResponse
|
||||
import json
|
||||
import traceback
|
||||
import logging # Added logging import
|
||||
import uuid # For generating error task IDs
|
||||
@@ -20,10 +19,9 @@ from routes.utils.watch.db import (
|
||||
get_watched_playlist,
|
||||
get_watched_playlists,
|
||||
add_specific_tracks_to_playlist_table,
|
||||
remove_specific_tracks_from_playlist_table,
|
||||
is_track_in_playlist_db, # Added import
|
||||
remove_specific_tracks_from_playlist_table, # Added import
|
||||
)
|
||||
from routes.utils.get_info import get_spotify_info # Already used, but ensure it's here
|
||||
from routes.utils.get_info import get_client, get_playlist, get_track
|
||||
from routes.utils.watch.manager import (
|
||||
check_watched_playlists,
|
||||
get_watch_config,
|
||||
@@ -31,7 +29,9 @@ from routes.utils.watch.manager import (
|
||||
from routes.utils.errors import DuplicateDownloadError
|
||||
|
||||
# Import authentication dependencies
|
||||
from routes.auth.middleware import require_auth_from_state, require_admin_from_state, User
|
||||
from routes.auth.middleware import require_auth_from_state, User
|
||||
from routes.utils.celery_config import get_config_params
|
||||
from routes.utils.credentials import get_spotify_blob_path
|
||||
|
||||
logger = logging.getLogger(__name__) # Added logger initialization
|
||||
router = APIRouter()
|
||||
@@ -43,7 +43,11 @@ def construct_spotify_url(item_id: str, item_type: str = "track") -> str:
|
||||
|
||||
|
||||
@router.get("/download/{playlist_id}")
|
||||
async def handle_download(playlist_id: str, request: Request, current_user: User = Depends(require_auth_from_state)):
|
||||
async def handle_download(
|
||||
playlist_id: str,
|
||||
request: Request,
|
||||
current_user: User = Depends(require_auth_from_state),
|
||||
):
|
||||
# Retrieve essential parameters from the request.
|
||||
# name = request.args.get('name') # Removed
|
||||
# artist = request.args.get('artist') # Removed
|
||||
@@ -51,11 +55,14 @@ async def handle_download(playlist_id: str, request: Request, current_user: User
|
||||
|
||||
# Construct the URL from playlist_id
|
||||
url = construct_spotify_url(playlist_id, "playlist")
|
||||
orig_params["original_url"] = str(request.url) # Update original_url to the constructed one
|
||||
orig_params["original_url"] = str(
|
||||
request.url
|
||||
) # Update original_url to the constructed one
|
||||
|
||||
# Fetch metadata from Spotify using optimized function
|
||||
try:
|
||||
from routes.utils.get_info import get_playlist_metadata
|
||||
|
||||
playlist_info = get_playlist_metadata(playlist_id)
|
||||
if (
|
||||
not playlist_info
|
||||
@@ -66,7 +73,7 @@ async def handle_download(playlist_id: str, request: Request, current_user: User
|
||||
content={
|
||||
"error": f"Could not retrieve metadata for playlist ID: {playlist_id}"
|
||||
},
|
||||
status_code=404
|
||||
status_code=404,
|
||||
)
|
||||
|
||||
name_from_spotify = playlist_info.get("name")
|
||||
@@ -79,14 +86,13 @@ async def handle_download(playlist_id: str, request: Request, current_user: User
|
||||
content={
|
||||
"error": f"Failed to fetch metadata for playlist {playlist_id}: {str(e)}"
|
||||
},
|
||||
status_code=500
|
||||
status_code=500,
|
||||
)
|
||||
|
||||
# Validate required parameters
|
||||
if not url: # This check might be redundant now but kept for safety
|
||||
return JSONResponse(
|
||||
content={"error": "Missing required parameter: url"},
|
||||
status_code=400
|
||||
content={"error": "Missing required parameter: url"}, status_code=400
|
||||
)
|
||||
|
||||
try:
|
||||
@@ -106,7 +112,7 @@ async def handle_download(playlist_id: str, request: Request, current_user: User
|
||||
"error": "Duplicate download detected.",
|
||||
"existing_task": e.existing_task,
|
||||
},
|
||||
status_code=409
|
||||
status_code=409,
|
||||
)
|
||||
except Exception as e:
|
||||
# Generic error handling for other issues during task submission
|
||||
@@ -136,25 +142,23 @@ async def handle_download(playlist_id: str, request: Request, current_user: User
|
||||
"error": f"Failed to queue playlist download: {str(e)}",
|
||||
"task_id": error_task_id,
|
||||
},
|
||||
status_code=500
|
||||
status_code=500,
|
||||
)
|
||||
|
||||
return JSONResponse(
|
||||
content={"task_id": task_id},
|
||||
status_code=202
|
||||
)
|
||||
return JSONResponse(content={"task_id": task_id}, status_code=202)
|
||||
|
||||
|
||||
@router.get("/download/cancel")
|
||||
async def cancel_download(request: Request, current_user: User = Depends(require_auth_from_state)):
|
||||
async def cancel_download(
|
||||
request: Request, current_user: User = Depends(require_auth_from_state)
|
||||
):
|
||||
"""
|
||||
Cancel a running playlist download process by its task id.
|
||||
"""
|
||||
task_id = request.query_params.get("task_id")
|
||||
if not task_id:
|
||||
return JSONResponse(
|
||||
content={"error": "Missing task id (task_id) parameter"},
|
||||
status_code=400
|
||||
content={"error": "Missing task id (task_id) parameter"}, status_code=400
|
||||
)
|
||||
|
||||
# Use the queue manager's cancellation method.
|
||||
@@ -165,124 +169,94 @@ async def cancel_download(request: Request, current_user: User = Depends(require
|
||||
|
||||
|
||||
@router.get("/info")
|
||||
async def get_playlist_info(request: Request, current_user: User = Depends(require_auth_from_state)):
|
||||
async def get_playlist_info(
|
||||
request: Request, current_user: User = Depends(require_auth_from_state)
|
||||
):
|
||||
"""
|
||||
Retrieve Spotify playlist metadata given a Spotify playlist ID.
|
||||
Expects a query parameter 'id' that contains the Spotify playlist ID.
|
||||
"""
|
||||
spotify_id = request.query_params.get("id")
|
||||
include_tracks = request.query_params.get("include_tracks", "false").lower() == "true"
|
||||
|
||||
if not spotify_id:
|
||||
return JSONResponse(
|
||||
content={"error": "Missing parameter: id"},
|
||||
status_code=400
|
||||
)
|
||||
|
||||
try:
|
||||
# Use the optimized playlist info function
|
||||
from routes.utils.get_info import get_playlist_info_optimized
|
||||
playlist_info = get_playlist_info_optimized(spotify_id, include_tracks=include_tracks)
|
||||
|
||||
# If playlist_info is successfully fetched, check if it's watched
|
||||
# and augment track items with is_locally_known status
|
||||
if playlist_info and playlist_info.get("id"):
|
||||
watched_playlist_details = get_watched_playlist(playlist_info["id"])
|
||||
if watched_playlist_details: # Playlist is being watched
|
||||
if playlist_info.get("tracks") and playlist_info["tracks"].get("items"):
|
||||
for item in playlist_info["tracks"]["items"]:
|
||||
if item and item.get("track") and item["track"].get("id"):
|
||||
track_id = item["track"]["id"]
|
||||
item["track"]["is_locally_known"] = is_track_in_playlist_db(
|
||||
playlist_info["id"], track_id
|
||||
)
|
||||
elif item and item.get(
|
||||
"track"
|
||||
): # Track object exists but no ID
|
||||
item["track"]["is_locally_known"] = False
|
||||
# If not watched, or no tracks, is_locally_known will not be added, or tracks won't exist to add it to.
|
||||
# Frontend should handle absence of this key as false.
|
||||
|
||||
return JSONResponse(
|
||||
content=playlist_info, status_code=200
|
||||
)
|
||||
except Exception as e:
|
||||
error_data = {"error": str(e), "traceback": traceback.format_exc()}
|
||||
return JSONResponse(content=error_data, status_code=500)
|
||||
|
||||
|
||||
@router.get("/metadata")
|
||||
async def get_playlist_metadata(request: Request, current_user: User = Depends(require_auth_from_state)):
|
||||
"""
|
||||
Retrieve only Spotify playlist metadata (no tracks) to avoid rate limiting.
|
||||
Expects a query parameter 'id' that contains the Spotify playlist ID.
|
||||
Always returns the raw JSON from get_playlist with expand_items=False.
|
||||
"""
|
||||
spotify_id = request.query_params.get("id")
|
||||
|
||||
if not spotify_id:
|
||||
return JSONResponse(
|
||||
content={"error": "Missing parameter: id"},
|
||||
status_code=400
|
||||
)
|
||||
return JSONResponse(content={"error": "Missing parameter: id"}, status_code=400)
|
||||
|
||||
try:
|
||||
# Use the optimized playlist metadata function
|
||||
from routes.utils.get_info import get_playlist_metadata
|
||||
playlist_metadata = get_playlist_metadata(spotify_id)
|
||||
# Resolve active account's credentials blob
|
||||
cfg = get_config_params() or {}
|
||||
active_account = cfg.get("spotify")
|
||||
if not active_account:
|
||||
return JSONResponse(
|
||||
content={"error": "Active Spotify account not set in configuration."},
|
||||
status_code=500,
|
||||
)
|
||||
blob_path = get_spotify_blob_path(active_account)
|
||||
if not blob_path.exists():
|
||||
return JSONResponse(
|
||||
content={
|
||||
"error": f"Spotify credentials blob not found for account '{active_account}'"
|
||||
},
|
||||
status_code=500,
|
||||
)
|
||||
|
||||
return JSONResponse(
|
||||
content=playlist_metadata, status_code=200
|
||||
)
|
||||
except Exception as e:
|
||||
error_data = {"error": str(e), "traceback": traceback.format_exc()}
|
||||
return JSONResponse(content=error_data, status_code=500)
|
||||
client = get_client()
|
||||
try:
|
||||
playlist_info = get_playlist(client, spotify_id, expand_items=False)
|
||||
finally:
|
||||
pass
|
||||
|
||||
|
||||
@router.get("/tracks")
|
||||
async def get_playlist_tracks(request: Request, current_user: User = Depends(require_auth_from_state)):
|
||||
"""
|
||||
Retrieve playlist tracks with pagination support for progressive loading.
|
||||
Expects query parameters: 'id' (playlist ID), 'limit' (optional), 'offset' (optional).
|
||||
"""
|
||||
spotify_id = request.query_params.get("id")
|
||||
limit = int(request.query_params.get("limit", 50))
|
||||
offset = int(request.query_params.get("offset", 0))
|
||||
|
||||
if not spotify_id:
|
||||
return JSONResponse(
|
||||
content={"error": "Missing parameter: id"},
|
||||
status_code=400
|
||||
)
|
||||
|
||||
try:
|
||||
# Use the optimized playlist tracks function
|
||||
from routes.utils.get_info import get_playlist_tracks
|
||||
tracks_data = get_playlist_tracks(spotify_id, limit=limit, offset=offset)
|
||||
|
||||
return JSONResponse(
|
||||
content=tracks_data, status_code=200
|
||||
)
|
||||
return JSONResponse(content=playlist_info, status_code=200)
|
||||
except Exception as e:
|
||||
error_data = {"error": str(e), "traceback": traceback.format_exc()}
|
||||
return JSONResponse(content=error_data, status_code=500)
|
||||
|
||||
|
||||
@router.put("/watch/{playlist_spotify_id}")
|
||||
async def add_to_watchlist(playlist_spotify_id: str, current_user: User = Depends(require_auth_from_state)):
|
||||
async def add_to_watchlist(
|
||||
playlist_spotify_id: str, current_user: User = Depends(require_auth_from_state)
|
||||
):
|
||||
"""Adds a playlist to the watchlist."""
|
||||
watch_config = get_watch_config()
|
||||
if not watch_config.get("enabled", False):
|
||||
raise HTTPException(status_code=403, detail={"error": "Watch feature is currently disabled globally."})
|
||||
raise HTTPException(
|
||||
status_code=403,
|
||||
detail={"error": "Watch feature is currently disabled globally."},
|
||||
)
|
||||
|
||||
logger.info(f"Attempting to add playlist {playlist_spotify_id} to watchlist.")
|
||||
try:
|
||||
# Check if already watched
|
||||
if get_watched_playlist(playlist_spotify_id):
|
||||
return {"message": f"Playlist {playlist_spotify_id} is already being watched."}
|
||||
return {
|
||||
"message": f"Playlist {playlist_spotify_id} is already being watched."
|
||||
}
|
||||
|
||||
# Fetch playlist details from Spotify to populate our DB (metadata only)
|
||||
cfg = get_config_params() or {}
|
||||
active_account = cfg.get("spotify")
|
||||
if not active_account:
|
||||
raise HTTPException(
|
||||
status_code=500,
|
||||
detail={"error": "Active Spotify account not set in configuration."},
|
||||
)
|
||||
blob_path = get_spotify_blob_path(active_account)
|
||||
if not blob_path.exists():
|
||||
raise HTTPException(
|
||||
status_code=500,
|
||||
detail={
|
||||
"error": f"Spotify credentials blob not found for account '{active_account}'"
|
||||
},
|
||||
)
|
||||
|
||||
client = get_client()
|
||||
try:
|
||||
playlist_data = get_playlist(
|
||||
client, playlist_spotify_id, expand_items=False
|
||||
)
|
||||
finally:
|
||||
pass
|
||||
|
||||
# Fetch playlist details from Spotify to populate our DB
|
||||
from routes.utils.get_info import get_playlist_metadata
|
||||
playlist_data = get_playlist_metadata(playlist_spotify_id)
|
||||
if not playlist_data or "id" not in playlist_data:
|
||||
logger.error(
|
||||
f"Could not fetch details for playlist {playlist_spotify_id} from Spotify."
|
||||
@@ -291,19 +265,11 @@ async def add_to_watchlist(playlist_spotify_id: str, current_user: User = Depend
|
||||
status_code=404,
|
||||
detail={
|
||||
"error": f"Could not fetch details for playlist {playlist_spotify_id} from Spotify."
|
||||
}
|
||||
},
|
||||
)
|
||||
|
||||
add_playlist_db(playlist_data) # This also creates the tracks table
|
||||
|
||||
# REMOVED: Do not add initial tracks directly to DB.
|
||||
# The playlist watch manager will pick them up as new and queue downloads.
|
||||
# Tracks will be added to DB only after successful download via Celery task callback.
|
||||
# initial_track_items = playlist_data.get('tracks', {}).get('items', [])
|
||||
# if initial_track_items:
|
||||
# from routes.utils.watch.db import add_tracks_to_playlist_db # Keep local import for clarity
|
||||
# add_tracks_to_playlist_db(playlist_spotify_id, initial_track_items)
|
||||
|
||||
logger.info(
|
||||
f"Playlist {playlist_spotify_id} added to watchlist. Its tracks will be processed by the watch manager."
|
||||
)
|
||||
@@ -317,11 +283,16 @@ async def add_to_watchlist(playlist_spotify_id: str, current_user: User = Depend
|
||||
f"Error adding playlist {playlist_spotify_id} to watchlist: {e}",
|
||||
exc_info=True,
|
||||
)
|
||||
raise HTTPException(status_code=500, detail={"error": f"Could not add playlist to watchlist: {str(e)}"})
|
||||
raise HTTPException(
|
||||
status_code=500,
|
||||
detail={"error": f"Could not add playlist to watchlist: {str(e)}"},
|
||||
)
|
||||
|
||||
|
||||
@router.get("/watch/{playlist_spotify_id}/status")
|
||||
async def get_playlist_watch_status(playlist_spotify_id: str, current_user: User = Depends(require_auth_from_state)):
|
||||
async def get_playlist_watch_status(
|
||||
playlist_spotify_id: str, current_user: User = Depends(require_auth_from_state)
|
||||
):
|
||||
"""Checks if a specific playlist is being watched."""
|
||||
logger.info(f"Checking watch status for playlist {playlist_spotify_id}.")
|
||||
try:
|
||||
@@ -337,22 +308,31 @@ async def get_playlist_watch_status(playlist_spotify_id: str, current_user: User
|
||||
f"Error checking watch status for playlist {playlist_spotify_id}: {e}",
|
||||
exc_info=True,
|
||||
)
|
||||
raise HTTPException(status_code=500, detail={"error": f"Could not check watch status: {str(e)}"})
|
||||
raise HTTPException(
|
||||
status_code=500, detail={"error": f"Could not check watch status: {str(e)}"}
|
||||
)
|
||||
|
||||
|
||||
@router.delete("/watch/{playlist_spotify_id}")
|
||||
async def remove_from_watchlist(playlist_spotify_id: str, current_user: User = Depends(require_auth_from_state)):
|
||||
async def remove_from_watchlist(
|
||||
playlist_spotify_id: str, current_user: User = Depends(require_auth_from_state)
|
||||
):
|
||||
"""Removes a playlist from the watchlist."""
|
||||
watch_config = get_watch_config()
|
||||
if not watch_config.get("enabled", False):
|
||||
raise HTTPException(status_code=403, detail={"error": "Watch feature is currently disabled globally."})
|
||||
raise HTTPException(
|
||||
status_code=403,
|
||||
detail={"error": "Watch feature is currently disabled globally."},
|
||||
)
|
||||
|
||||
logger.info(f"Attempting to remove playlist {playlist_spotify_id} from watchlist.")
|
||||
try:
|
||||
if not get_watched_playlist(playlist_spotify_id):
|
||||
raise HTTPException(
|
||||
status_code=404,
|
||||
detail={"error": f"Playlist {playlist_spotify_id} not found in watchlist."}
|
||||
detail={
|
||||
"error": f"Playlist {playlist_spotify_id} not found in watchlist."
|
||||
},
|
||||
)
|
||||
|
||||
remove_playlist_db(playlist_spotify_id)
|
||||
@@ -369,12 +349,16 @@ async def remove_from_watchlist(playlist_spotify_id: str, current_user: User = D
|
||||
)
|
||||
raise HTTPException(
|
||||
status_code=500,
|
||||
detail={"error": f"Could not remove playlist from watchlist: {str(e)}"}
|
||||
detail={"error": f"Could not remove playlist from watchlist: {str(e)}"},
|
||||
)
|
||||
|
||||
|
||||
@router.post("/watch/{playlist_spotify_id}/tracks")
|
||||
async def mark_tracks_as_known(playlist_spotify_id: str, request: Request, current_user: User = Depends(require_auth_from_state)):
|
||||
async def mark_tracks_as_known(
|
||||
playlist_spotify_id: str,
|
||||
request: Request,
|
||||
current_user: User = Depends(require_auth_from_state),
|
||||
):
|
||||
"""Fetches details for given track IDs and adds/updates them in the playlist's local DB table."""
|
||||
watch_config = get_watch_config()
|
||||
if not watch_config.get("enabled", False):
|
||||
@@ -382,7 +366,7 @@ async def mark_tracks_as_known(playlist_spotify_id: str, request: Request, curre
|
||||
status_code=403,
|
||||
detail={
|
||||
"error": "Watch feature is currently disabled globally. Cannot mark tracks."
|
||||
}
|
||||
},
|
||||
)
|
||||
|
||||
logger.info(
|
||||
@@ -397,19 +381,22 @@ async def mark_tracks_as_known(playlist_spotify_id: str, request: Request, curre
|
||||
status_code=400,
|
||||
detail={
|
||||
"error": "Invalid request body. Expecting a JSON array of track Spotify IDs."
|
||||
}
|
||||
},
|
||||
)
|
||||
|
||||
if not get_watched_playlist(playlist_spotify_id):
|
||||
raise HTTPException(
|
||||
status_code=404,
|
||||
detail={"error": f"Playlist {playlist_spotify_id} is not being watched."}
|
||||
detail={
|
||||
"error": f"Playlist {playlist_spotify_id} is not being watched."
|
||||
},
|
||||
)
|
||||
|
||||
fetched_tracks_details = []
|
||||
client = get_client()
|
||||
for track_id in track_ids:
|
||||
try:
|
||||
track_detail = get_spotify_info(track_id, "track")
|
||||
track_detail = get_track(client, track_id)
|
||||
if track_detail and track_detail.get("id"):
|
||||
fetched_tracks_details.append(track_detail)
|
||||
else:
|
||||
@@ -443,11 +430,18 @@ async def mark_tracks_as_known(playlist_spotify_id: str, request: Request, curre
|
||||
f"Error marking tracks as known for playlist {playlist_spotify_id}: {e}",
|
||||
exc_info=True,
|
||||
)
|
||||
raise HTTPException(status_code=500, detail={"error": f"Could not mark tracks as known: {str(e)}"})
|
||||
raise HTTPException(
|
||||
status_code=500,
|
||||
detail={"error": f"Could not mark tracks as known: {str(e)}"},
|
||||
)
|
||||
|
||||
|
||||
@router.delete("/watch/{playlist_spotify_id}/tracks")
|
||||
async def mark_tracks_as_missing_locally(playlist_spotify_id: str, request: Request, current_user: User = Depends(require_auth_from_state)):
|
||||
async def mark_tracks_as_missing_locally(
|
||||
playlist_spotify_id: str,
|
||||
request: Request,
|
||||
current_user: User = Depends(require_auth_from_state),
|
||||
):
|
||||
"""Removes specified tracks from the playlist's local DB table."""
|
||||
watch_config = get_watch_config()
|
||||
if not watch_config.get("enabled", False):
|
||||
@@ -455,7 +449,7 @@ async def mark_tracks_as_missing_locally(playlist_spotify_id: str, request: Requ
|
||||
status_code=403,
|
||||
detail={
|
||||
"error": "Watch feature is currently disabled globally. Cannot mark tracks."
|
||||
}
|
||||
},
|
||||
)
|
||||
|
||||
logger.info(
|
||||
@@ -470,13 +464,15 @@ async def mark_tracks_as_missing_locally(playlist_spotify_id: str, request: Requ
|
||||
status_code=400,
|
||||
detail={
|
||||
"error": "Invalid request body. Expecting a JSON array of track Spotify IDs."
|
||||
}
|
||||
},
|
||||
)
|
||||
|
||||
if not get_watched_playlist(playlist_spotify_id):
|
||||
raise HTTPException(
|
||||
status_code=404,
|
||||
detail={"error": f"Playlist {playlist_spotify_id} is not being watched."}
|
||||
detail={
|
||||
"error": f"Playlist {playlist_spotify_id} is not being watched."
|
||||
},
|
||||
)
|
||||
|
||||
deleted_count = remove_specific_tracks_from_playlist_table(
|
||||
@@ -495,22 +491,32 @@ async def mark_tracks_as_missing_locally(playlist_spotify_id: str, request: Requ
|
||||
f"Error marking tracks as missing (deleting locally) for playlist {playlist_spotify_id}: {e}",
|
||||
exc_info=True,
|
||||
)
|
||||
raise HTTPException(status_code=500, detail={"error": f"Could not mark tracks as missing: {str(e)}"})
|
||||
raise HTTPException(
|
||||
status_code=500,
|
||||
detail={"error": f"Could not mark tracks as missing: {str(e)}"},
|
||||
)
|
||||
|
||||
|
||||
@router.get("/watch/list")
|
||||
async def list_watched_playlists_endpoint(current_user: User = Depends(require_auth_from_state)):
|
||||
async def list_watched_playlists_endpoint(
|
||||
current_user: User = Depends(require_auth_from_state),
|
||||
):
|
||||
"""Lists all playlists currently in the watchlist."""
|
||||
try:
|
||||
playlists = get_watched_playlists()
|
||||
return playlists
|
||||
except Exception as e:
|
||||
logger.error(f"Error listing watched playlists: {e}", exc_info=True)
|
||||
raise HTTPException(status_code=500, detail={"error": f"Could not list watched playlists: {str(e)}"})
|
||||
raise HTTPException(
|
||||
status_code=500,
|
||||
detail={"error": f"Could not list watched playlists: {str(e)}"},
|
||||
)
|
||||
|
||||
|
||||
@router.post("/watch/trigger_check")
|
||||
async def trigger_playlist_check_endpoint(current_user: User = Depends(require_auth_from_state)):
|
||||
async def trigger_playlist_check_endpoint(
|
||||
current_user: User = Depends(require_auth_from_state),
|
||||
):
|
||||
"""Manually triggers the playlist checking mechanism for all watched playlists."""
|
||||
watch_config = get_watch_config()
|
||||
if not watch_config.get("enabled", False):
|
||||
@@ -518,7 +524,7 @@ async def trigger_playlist_check_endpoint(current_user: User = Depends(require_a
|
||||
status_code=403,
|
||||
detail={
|
||||
"error": "Watch feature is currently disabled globally. Cannot trigger check."
|
||||
}
|
||||
},
|
||||
)
|
||||
|
||||
logger.info("Manual trigger for playlist check received for all playlists.")
|
||||
@@ -535,12 +541,14 @@ async def trigger_playlist_check_endpoint(current_user: User = Depends(require_a
|
||||
)
|
||||
raise HTTPException(
|
||||
status_code=500,
|
||||
detail={"error": f"Could not trigger playlist check for all: {str(e)}"}
|
||||
detail={"error": f"Could not trigger playlist check for all: {str(e)}"},
|
||||
)
|
||||
|
||||
|
||||
@router.post("/watch/trigger_check/{playlist_spotify_id}")
|
||||
async def trigger_specific_playlist_check_endpoint(playlist_spotify_id: str, current_user: User = Depends(require_auth_from_state)):
|
||||
async def trigger_specific_playlist_check_endpoint(
|
||||
playlist_spotify_id: str, current_user: User = Depends(require_auth_from_state)
|
||||
):
|
||||
"""Manually triggers the playlist checking mechanism for a specific playlist."""
|
||||
watch_config = get_watch_config()
|
||||
if not watch_config.get("enabled", False):
|
||||
@@ -548,7 +556,7 @@ async def trigger_specific_playlist_check_endpoint(playlist_spotify_id: str, cur
|
||||
status_code=403,
|
||||
detail={
|
||||
"error": "Watch feature is currently disabled globally. Cannot trigger check."
|
||||
}
|
||||
},
|
||||
)
|
||||
|
||||
logger.info(
|
||||
@@ -565,7 +573,7 @@ async def trigger_specific_playlist_check_endpoint(playlist_spotify_id: str, cur
|
||||
status_code=404,
|
||||
detail={
|
||||
"error": f"Playlist {playlist_spotify_id} is not in the watchlist. Add it first."
|
||||
}
|
||||
},
|
||||
)
|
||||
|
||||
# Run check_watched_playlists with the specific ID
|
||||
@@ -590,5 +598,5 @@ async def trigger_specific_playlist_check_endpoint(playlist_spotify_id: str, cur
|
||||
status_code=500,
|
||||
detail={
|
||||
"error": f"Could not trigger playlist check for {playlist_spotify_id}: {str(e)}"
|
||||
}
|
||||
},
|
||||
)
|
||||
|
||||
@@ -1,12 +1,11 @@
|
||||
from fastapi import APIRouter, HTTPException, Request, Depends
|
||||
from fastapi import APIRouter, Request, Depends
|
||||
from fastapi.responses import JSONResponse
|
||||
import json
|
||||
import traceback
|
||||
import uuid
|
||||
import time
|
||||
from routes.utils.celery_queue_manager import download_queue_manager
|
||||
from routes.utils.celery_tasks import store_task_info, store_task_status, ProgressState
|
||||
from routes.utils.get_info import get_spotify_info
|
||||
from routes.utils.get_info import get_client, get_track
|
||||
from routes.utils.errors import DuplicateDownloadError
|
||||
|
||||
# Import authentication dependencies
|
||||
@@ -21,7 +20,11 @@ def construct_spotify_url(item_id: str, item_type: str = "track") -> str:
|
||||
|
||||
|
||||
@router.get("/download/{track_id}")
|
||||
async def handle_download(track_id: str, request: Request, current_user: User = Depends(require_auth_from_state)):
|
||||
async def handle_download(
|
||||
track_id: str,
|
||||
request: Request,
|
||||
current_user: User = Depends(require_auth_from_state),
|
||||
):
|
||||
# Retrieve essential parameters from the request.
|
||||
# name = request.args.get('name') # Removed
|
||||
# artist = request.args.get('artist') # Removed
|
||||
@@ -31,15 +34,18 @@ async def handle_download(track_id: str, request: Request, current_user: User =
|
||||
|
||||
# Fetch metadata from Spotify
|
||||
try:
|
||||
track_info = get_spotify_info(track_id, "track")
|
||||
client = get_client()
|
||||
track_info = get_track(client, track_id)
|
||||
if (
|
||||
not track_info
|
||||
or not track_info.get("name")
|
||||
or not track_info.get("artists")
|
||||
):
|
||||
return JSONResponse(
|
||||
content={"error": f"Could not retrieve metadata for track ID: {track_id}"},
|
||||
status_code=404
|
||||
content={
|
||||
"error": f"Could not retrieve metadata for track ID: {track_id}"
|
||||
},
|
||||
status_code=404,
|
||||
)
|
||||
|
||||
name_from_spotify = track_info.get("name")
|
||||
@@ -51,15 +57,16 @@ async def handle_download(track_id: str, request: Request, current_user: User =
|
||||
|
||||
except Exception as e:
|
||||
return JSONResponse(
|
||||
content={"error": f"Failed to fetch metadata for track {track_id}: {str(e)}"},
|
||||
status_code=500
|
||||
content={
|
||||
"error": f"Failed to fetch metadata for track {track_id}: {str(e)}"
|
||||
},
|
||||
status_code=500,
|
||||
)
|
||||
|
||||
# Validate required parameters
|
||||
if not url:
|
||||
return JSONResponse(
|
||||
content={"error": "Missing required parameter: url"},
|
||||
status_code=400
|
||||
content={"error": "Missing required parameter: url"}, status_code=400
|
||||
)
|
||||
|
||||
# Add the task to the queue with only essential parameters
|
||||
@@ -84,7 +91,7 @@ async def handle_download(track_id: str, request: Request, current_user: User =
|
||||
"error": "Duplicate download detected.",
|
||||
"existing_task": e.existing_task,
|
||||
},
|
||||
status_code=409
|
||||
status_code=409,
|
||||
)
|
||||
except Exception as e:
|
||||
# Generic error handling for other issues during task submission
|
||||
@@ -116,25 +123,23 @@ async def handle_download(track_id: str, request: Request, current_user: User =
|
||||
"error": f"Failed to queue track download: {str(e)}",
|
||||
"task_id": error_task_id,
|
||||
},
|
||||
status_code=500
|
||||
status_code=500,
|
||||
)
|
||||
|
||||
return JSONResponse(
|
||||
content={"task_id": task_id},
|
||||
status_code=202
|
||||
)
|
||||
return JSONResponse(content={"task_id": task_id}, status_code=202)
|
||||
|
||||
|
||||
@router.get("/download/cancel")
|
||||
async def cancel_download(request: Request, current_user: User = Depends(require_auth_from_state)):
|
||||
async def cancel_download(
|
||||
request: Request, current_user: User = Depends(require_auth_from_state)
|
||||
):
|
||||
"""
|
||||
Cancel a running download process by its task id.
|
||||
"""
|
||||
task_id = request.query_params.get("task_id")
|
||||
if not task_id:
|
||||
return JSONResponse(
|
||||
content={"error": "Missing process id (task_id) parameter"},
|
||||
status_code=400
|
||||
content={"error": "Missing process id (task_id) parameter"}, status_code=400
|
||||
)
|
||||
|
||||
# Use the queue manager's cancellation method.
|
||||
@@ -145,7 +150,9 @@ async def cancel_download(request: Request, current_user: User = Depends(require
|
||||
|
||||
|
||||
@router.get("/info")
|
||||
async def get_track_info(request: Request, current_user: User = Depends(require_auth_from_state)):
|
||||
async def get_track_info(
|
||||
request: Request, current_user: User = Depends(require_auth_from_state)
|
||||
):
|
||||
"""
|
||||
Retrieve Spotify track metadata given a Spotify track ID.
|
||||
Expects a query parameter 'id' that contains the Spotify track ID.
|
||||
@@ -153,14 +160,11 @@ async def get_track_info(request: Request, current_user: User = Depends(require_
|
||||
spotify_id = request.query_params.get("id")
|
||||
|
||||
if not spotify_id:
|
||||
return JSONResponse(
|
||||
content={"error": "Missing parameter: id"},
|
||||
status_code=400
|
||||
)
|
||||
return JSONResponse(content={"error": "Missing parameter: id"}, status_code=400)
|
||||
|
||||
try:
|
||||
# Use the get_spotify_info function (already imported at top)
|
||||
track_info = get_spotify_info(spotify_id, "track")
|
||||
client = get_client()
|
||||
track_info = get_track(client, spotify_id)
|
||||
return JSONResponse(content=track_info, status_code=200)
|
||||
except Exception as e:
|
||||
error_data = {"error": str(e), "traceback": traceback.format_exc()}
|
||||
|
||||
@@ -82,7 +82,7 @@ class SSEBroadcaster:
|
||||
# Clean up disconnected clients
|
||||
for client in disconnected:
|
||||
self.clients.discard(client)
|
||||
logger.info(
|
||||
logger.debug(
|
||||
f"SSE Broadcaster: Successfully sent to {sent_count} clients, removed {len(disconnected)} disconnected clients"
|
||||
)
|
||||
|
||||
|
||||
@@ -2,9 +2,9 @@ import json
|
||||
from routes.utils.watch.manager import get_watch_config
|
||||
import logging
|
||||
from routes.utils.celery_queue_manager import download_queue_manager
|
||||
from routes.utils.get_info import get_spotify_info
|
||||
from routes.utils.credentials import get_credential, _get_global_spotify_api_creds
|
||||
from routes.utils.errors import DuplicateDownloadError
|
||||
from routes.utils.get_info import get_spotify_info
|
||||
|
||||
from deezspot.libutils.utils import get_ids, link_is_valid
|
||||
|
||||
|
||||
@@ -5,6 +5,9 @@ import threading
|
||||
import os
|
||||
import sys
|
||||
|
||||
from dotenv import load_dotenv
|
||||
load_dotenv()
|
||||
|
||||
# Import Celery task utilities
|
||||
from .celery_config import get_config_params, MAX_CONCURRENT_DL
|
||||
|
||||
@@ -70,6 +73,12 @@ class CeleryManager:
|
||||
logger.debug(f"Generated Celery command: {' '.join(command)}")
|
||||
return command
|
||||
|
||||
def _get_worker_env(self):
|
||||
# Inherit current environment, but set NO_CONSOLE_LOG=1 for subprocess
|
||||
env = os.environ.copy()
|
||||
env["NO_CONSOLE_LOG"] = "1"
|
||||
return env
|
||||
|
||||
def _process_output_reader(self, stream, log_prefix, error=False):
|
||||
logger.debug(f"Log reader thread started for {log_prefix}")
|
||||
try:
|
||||
@@ -138,6 +147,7 @@ class CeleryManager:
|
||||
text=True,
|
||||
bufsize=1,
|
||||
universal_newlines=True,
|
||||
env=self._get_worker_env(),
|
||||
)
|
||||
self.download_log_thread_stdout = threading.Thread(
|
||||
target=self._process_output_reader,
|
||||
@@ -161,7 +171,7 @@ class CeleryManager:
|
||||
queues="utility_tasks,default", # Listen to utility and default
|
||||
concurrency=5, # Increased concurrency for SSE updates and utility tasks
|
||||
worker_name_suffix="utw", # Utility Worker
|
||||
log_level_env=os.getenv("LOG_LEVEL", "ERROR").upper(),
|
||||
log_level_env=os.getenv("LOG_LEVEL", "WARNING").upper(),
|
||||
|
||||
)
|
||||
logger.info(
|
||||
@@ -174,6 +184,7 @@ class CeleryManager:
|
||||
text=True,
|
||||
bufsize=1,
|
||||
universal_newlines=True,
|
||||
env=self._get_worker_env(),
|
||||
)
|
||||
self.utility_log_thread_stdout = threading.Thread(
|
||||
target=self._process_output_reader,
|
||||
|
||||
@@ -285,9 +285,16 @@ def setup_celery_logging(**kwargs):
|
||||
"""
|
||||
This handler ensures Celery uses our application logging settings
|
||||
instead of its own. Prevents duplicate log configurations.
|
||||
Also disables console logging if NO_CONSOLE_LOG=1 is set in the environment.
|
||||
"""
|
||||
# Using the root logger's handlers and level preserves our config
|
||||
return logging.getLogger()
|
||||
root_logger = logging.getLogger()
|
||||
import os
|
||||
if os.environ.get("NO_CONSOLE_LOG") == "1":
|
||||
# Remove all StreamHandlers (console handlers) from the root logger
|
||||
handlers_to_remove = [h for h in root_logger.handlers if isinstance(h, logging.StreamHandler)]
|
||||
for h in handlers_to_remove:
|
||||
root_logger.removeHandler(h)
|
||||
return root_logger
|
||||
|
||||
|
||||
# The initialization of a worker will log the worker configuration
|
||||
|
||||
@@ -1,422 +1,152 @@
|
||||
import spotipy
|
||||
from spotipy.oauth2 import SpotifyClientCredentials
|
||||
from routes.utils.credentials import _get_global_spotify_api_creds
|
||||
import logging
|
||||
import time
|
||||
from typing import Dict, Optional, Any
|
||||
import os
|
||||
from typing import Any, Dict, Optional
|
||||
import threading
|
||||
|
||||
# Import Deezer API and logging
|
||||
from deezspot.deezloader.dee_api import API as DeezerAPI
|
||||
from deezspot.libutils import LibrespotClient
|
||||
|
||||
# Initialize logger
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Global Spotify client instance for reuse
|
||||
_spotify_client = None
|
||||
_last_client_init = 0
|
||||
_client_init_interval = 3600 # Reinitialize client every hour
|
||||
# Config helpers to resolve active credentials
|
||||
from routes.utils.celery_config import get_config_params
|
||||
from routes.utils.credentials import get_spotify_blob_path
|
||||
|
||||
|
||||
def _get_spotify_client():
|
||||
"""
|
||||
Get or create a Spotify client with global credentials.
|
||||
Implements client reuse and periodic reinitialization.
|
||||
"""
|
||||
global _spotify_client, _last_client_init
|
||||
# -------- Shared Librespot client (process-wide) --------
|
||||
|
||||
current_time = time.time()
|
||||
_shared_client: Optional[LibrespotClient] = None
|
||||
_shared_blob_path: Optional[str] = None
|
||||
_client_lock = threading.RLock()
|
||||
|
||||
# Reinitialize client if it's been more than an hour or if client doesn't exist
|
||||
if (
|
||||
_spotify_client is None
|
||||
or current_time - _last_client_init > _client_init_interval
|
||||
):
|
||||
client_id, client_secret = _get_global_spotify_api_creds()
|
||||
|
||||
if not client_id or not client_secret:
|
||||
raise ValueError(
|
||||
"Global Spotify API client_id or client_secret not configured in ./data/creds/search.json."
|
||||
)
|
||||
|
||||
# Create new client
|
||||
_spotify_client = spotipy.Spotify(
|
||||
client_credentials_manager=SpotifyClientCredentials(
|
||||
client_id=client_id, client_secret=client_secret
|
||||
)
|
||||
def _resolve_blob_path() -> str:
|
||||
cfg = get_config_params() or {}
|
||||
active_account = cfg.get("spotify")
|
||||
if not active_account:
|
||||
raise RuntimeError("Active Spotify account not set in configuration.")
|
||||
blob_path = get_spotify_blob_path(active_account)
|
||||
abs_path = os.path.abspath(str(blob_path))
|
||||
if not os.path.isfile(abs_path):
|
||||
raise FileNotFoundError(
|
||||
f"Spotify credentials blob not found for account '{active_account}' at {abs_path}"
|
||||
)
|
||||
_last_client_init = current_time
|
||||
logger.info("Spotify client initialized/reinitialized")
|
||||
|
||||
return _spotify_client
|
||||
return abs_path
|
||||
|
||||
|
||||
def _rate_limit_handler(func):
|
||||
def get_client() -> LibrespotClient:
|
||||
"""
|
||||
Decorator to handle rate limiting with exponential backoff.
|
||||
Return a shared LibrespotClient instance initialized from the active account blob.
|
||||
Re-initializes if the active account changes.
|
||||
"""
|
||||
|
||||
def wrapper(*args, **kwargs):
|
||||
max_retries = 3
|
||||
base_delay = 1
|
||||
|
||||
for attempt in range(max_retries):
|
||||
global _shared_client, _shared_blob_path
|
||||
with _client_lock:
|
||||
desired_blob = _resolve_blob_path()
|
||||
if _shared_client is None or _shared_blob_path != desired_blob:
|
||||
try:
|
||||
return func(*args, **kwargs)
|
||||
except Exception as e:
|
||||
if "429" in str(e) or "rate limit" in str(e).lower():
|
||||
if attempt < max_retries - 1:
|
||||
delay = base_delay * (2**attempt)
|
||||
logger.warning(f"Rate limited, retrying in {delay} seconds...")
|
||||
time.sleep(delay)
|
||||
continue
|
||||
raise e
|
||||
return func(*args, **kwargs)
|
||||
|
||||
return wrapper
|
||||
if _shared_client is not None:
|
||||
_shared_client.close()
|
||||
except Exception:
|
||||
pass
|
||||
_shared_client = LibrespotClient(stored_credentials_path=desired_blob)
|
||||
_shared_blob_path = desired_blob
|
||||
return _shared_client
|
||||
|
||||
|
||||
@_rate_limit_handler
|
||||
def get_playlist_metadata(playlist_id: str) -> Dict[str, Any]:
|
||||
# -------- Thin wrapper API (programmatic use) --------
|
||||
|
||||
|
||||
def create_client(credentials_path: str) -> LibrespotClient:
|
||||
"""
|
||||
Get playlist metadata only (no tracks) to avoid rate limiting.
|
||||
|
||||
Args:
|
||||
playlist_id: The Spotify playlist ID
|
||||
|
||||
Returns:
|
||||
Dictionary with playlist metadata (name, description, owner, etc.)
|
||||
Create a LibrespotClient from a librespot-generated credentials.json file.
|
||||
"""
|
||||
client = _get_spotify_client()
|
||||
|
||||
try:
|
||||
# Get basic playlist info without tracks
|
||||
playlist = client.playlist(
|
||||
playlist_id,
|
||||
fields="id,name,description,owner,images,snapshot_id,public,followers,tracks.total",
|
||||
)
|
||||
|
||||
# Add a flag to indicate this is metadata only
|
||||
playlist["_metadata_only"] = True
|
||||
playlist["_tracks_loaded"] = False
|
||||
|
||||
logger.debug(
|
||||
f"Retrieved playlist metadata for {playlist_id}: {playlist.get('name', 'Unknown')}"
|
||||
)
|
||||
return playlist
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error fetching playlist metadata for {playlist_id}: {e}")
|
||||
raise
|
||||
abs_path = os.path.abspath(credentials_path)
|
||||
if not os.path.isfile(abs_path):
|
||||
raise FileNotFoundError(f"Credentials file not found: {abs_path}")
|
||||
return LibrespotClient(stored_credentials_path=abs_path)
|
||||
|
||||
|
||||
@_rate_limit_handler
|
||||
def get_playlist_tracks(
|
||||
playlist_id: str, limit: int = 100, offset: int = 0
|
||||
def close_client(client: LibrespotClient) -> None:
|
||||
"""
|
||||
Dispose a LibrespotClient instance.
|
||||
"""
|
||||
client.close()
|
||||
|
||||
|
||||
def get_track(client: LibrespotClient, track_in: str) -> Dict[str, Any]:
|
||||
"""Fetch a track object."""
|
||||
return client.get_track(track_in)
|
||||
|
||||
|
||||
def get_album(
|
||||
client: LibrespotClient, album_in: str, include_tracks: bool = False
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
Get playlist tracks with pagination support to handle large playlists efficiently.
|
||||
|
||||
Args:
|
||||
playlist_id: The Spotify playlist ID
|
||||
limit: Number of tracks to fetch per request (max 100)
|
||||
offset: Starting position for pagination
|
||||
|
||||
Returns:
|
||||
Dictionary with tracks data
|
||||
"""
|
||||
client = _get_spotify_client()
|
||||
|
||||
try:
|
||||
# Get tracks with specified limit and offset
|
||||
tracks_data = client.playlist_tracks(
|
||||
playlist_id,
|
||||
limit=min(limit, 100), # Spotify API max is 100
|
||||
offset=offset,
|
||||
fields="items(track(id,name,artists,album,external_urls,preview_url,duration_ms,explicit,popularity)),total,limit,offset",
|
||||
)
|
||||
|
||||
logger.debug(
|
||||
f"Retrieved {len(tracks_data.get('items', []))} tracks for playlist {playlist_id} (offset: {offset})"
|
||||
)
|
||||
return tracks_data
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error fetching playlist tracks for {playlist_id}: {e}")
|
||||
raise
|
||||
"""Fetch an album object; optionally include expanded tracks."""
|
||||
return client.get_album(album_in, include_tracks=include_tracks)
|
||||
|
||||
|
||||
@_rate_limit_handler
|
||||
def get_playlist_full(playlist_id: str, batch_size: int = 100) -> Dict[str, Any]:
|
||||
"""
|
||||
Get complete playlist data with all tracks, using batched requests to avoid rate limiting.
|
||||
|
||||
Args:
|
||||
playlist_id: The Spotify playlist ID
|
||||
batch_size: Number of tracks to fetch per batch (max 100)
|
||||
|
||||
Returns:
|
||||
Complete playlist data with all tracks
|
||||
"""
|
||||
try:
|
||||
# First get metadata
|
||||
playlist = get_playlist_metadata(playlist_id)
|
||||
|
||||
# Get total track count
|
||||
total_tracks = playlist.get("tracks", {}).get("total", 0)
|
||||
|
||||
if total_tracks == 0:
|
||||
playlist["tracks"] = {"items": [], "total": 0}
|
||||
return playlist
|
||||
|
||||
# Fetch all tracks in batches
|
||||
all_tracks = []
|
||||
offset = 0
|
||||
|
||||
while offset < total_tracks:
|
||||
batch = get_playlist_tracks(playlist_id, limit=batch_size, offset=offset)
|
||||
batch_items = batch.get("items", [])
|
||||
all_tracks.extend(batch_items)
|
||||
|
||||
offset += len(batch_items)
|
||||
|
||||
# Add small delay between batches to be respectful to API
|
||||
if offset < total_tracks:
|
||||
time.sleep(0.1)
|
||||
|
||||
# Update playlist with complete tracks data
|
||||
playlist["tracks"] = {
|
||||
"items": all_tracks,
|
||||
"total": total_tracks,
|
||||
"limit": batch_size,
|
||||
"offset": 0,
|
||||
}
|
||||
playlist["_metadata_only"] = False
|
||||
playlist["_tracks_loaded"] = True
|
||||
|
||||
logger.info(
|
||||
f"Retrieved complete playlist {playlist_id} with {total_tracks} tracks"
|
||||
)
|
||||
return playlist
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error fetching complete playlist {playlist_id}: {e}")
|
||||
raise
|
||||
def get_artist(client: LibrespotClient, artist_in: str) -> Dict[str, Any]:
|
||||
"""Fetch an artist object."""
|
||||
return client.get_artist(artist_in)
|
||||
|
||||
|
||||
def check_playlist_updated(playlist_id: str, last_snapshot_id: str) -> bool:
|
||||
"""
|
||||
Check if playlist has been updated by comparing snapshot_id.
|
||||
This is much more efficient than fetching all tracks.
|
||||
|
||||
Args:
|
||||
playlist_id: The Spotify playlist ID
|
||||
last_snapshot_id: The last known snapshot_id
|
||||
|
||||
Returns:
|
||||
True if playlist has been updated, False otherwise
|
||||
"""
|
||||
try:
|
||||
metadata = get_playlist_metadata(playlist_id)
|
||||
current_snapshot_id = metadata.get("snapshot_id")
|
||||
|
||||
return current_snapshot_id != last_snapshot_id
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error checking playlist update status for {playlist_id}: {e}")
|
||||
raise
|
||||
def get_playlist(
|
||||
client: LibrespotClient, playlist_in: str, expand_items: bool = False
|
||||
) -> Dict[str, Any]:
|
||||
"""Fetch a playlist object; optionally expand track items to full track objects."""
|
||||
return client.get_playlist(playlist_in, expand_items=expand_items)
|
||||
|
||||
|
||||
@_rate_limit_handler
|
||||
def get_spotify_info(
|
||||
spotify_id: str,
|
||||
spotify_type: str,
|
||||
limit: Optional[int] = None,
|
||||
offset: Optional[int] = None,
|
||||
info_type: str,
|
||||
limit: int = 50,
|
||||
offset: int = 0,
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
Get info from Spotify API using Spotipy directly.
|
||||
Optimized to prevent rate limiting by using appropriate endpoints.
|
||||
Thin, typed wrapper around common Spotify info lookups using the shared client.
|
||||
|
||||
Args:
|
||||
spotify_id: The Spotify ID of the entity
|
||||
spotify_type: The type of entity (track, album, playlist, artist, artist_discography, episode, album_tracks)
|
||||
limit (int, optional): The maximum number of items to return. Used for pagination.
|
||||
offset (int, optional): The index of the first item to return. Used for pagination.
|
||||
Currently supports:
|
||||
- "artist_discography": returns a paginated view over the artist's releases
|
||||
combined across album_group/single_group/compilation_group/appears_on_group.
|
||||
|
||||
Returns:
|
||||
Dictionary with the entity information
|
||||
Returns a mapping with at least: items, total, limit, offset.
|
||||
Also includes a truthy "next" key when more pages are available.
|
||||
"""
|
||||
client = _get_spotify_client()
|
||||
client = get_client()
|
||||
|
||||
try:
|
||||
if spotify_type == "track":
|
||||
return client.track(spotify_id)
|
||||
if info_type == "artist_discography":
|
||||
artist = client.get_artist(spotify_id)
|
||||
all_items = []
|
||||
for key in (
|
||||
"album_group",
|
||||
"single_group",
|
||||
"compilation_group",
|
||||
"appears_on_group",
|
||||
):
|
||||
grp = artist.get(key)
|
||||
if isinstance(grp, list):
|
||||
all_items.extend(grp)
|
||||
elif isinstance(grp, dict):
|
||||
items = grp.get("items") or grp.get("releases") or []
|
||||
if isinstance(items, list):
|
||||
all_items.extend(items)
|
||||
total = len(all_items)
|
||||
start = max(0, offset or 0)
|
||||
page_limit = max(1, limit or 50)
|
||||
end = min(total, start + page_limit)
|
||||
page_items = all_items[start:end]
|
||||
has_more = end < total
|
||||
return {
|
||||
"items": page_items,
|
||||
"total": total,
|
||||
"limit": page_limit,
|
||||
"offset": start,
|
||||
"next": bool(has_more),
|
||||
}
|
||||
|
||||
elif spotify_type == "album":
|
||||
return client.album(spotify_id)
|
||||
|
||||
elif spotify_type == "album_tracks":
|
||||
# Fetch album's tracks with pagination support
|
||||
return client.album_tracks(
|
||||
spotify_id, limit=limit or 20, offset=offset or 0
|
||||
)
|
||||
|
||||
elif spotify_type == "playlist":
|
||||
# Use optimized playlist fetching
|
||||
return get_playlist_full(spotify_id)
|
||||
|
||||
elif spotify_type == "playlist_metadata":
|
||||
# Get only metadata for playlists
|
||||
return get_playlist_metadata(spotify_id)
|
||||
|
||||
elif spotify_type == "artist":
|
||||
return client.artist(spotify_id)
|
||||
|
||||
elif spotify_type == "artist_discography":
|
||||
# Get artist's albums with pagination
|
||||
albums = client.artist_albums(
|
||||
spotify_id,
|
||||
limit=limit or 20,
|
||||
offset=offset or 0,
|
||||
include_groups="single,album,appears_on",
|
||||
)
|
||||
return albums
|
||||
|
||||
elif spotify_type == "episode":
|
||||
return client.episode(spotify_id)
|
||||
|
||||
else:
|
||||
raise ValueError(f"Unsupported Spotify type: {spotify_type}")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error fetching {spotify_type} {spotify_id}: {e}")
|
||||
raise
|
||||
raise ValueError(f"Unsupported info_type: {info_type}")
|
||||
|
||||
|
||||
# Cache for playlist metadata to reduce API calls
|
||||
_playlist_metadata_cache: Dict[str, tuple[Dict[str, Any], float]] = {}
|
||||
_cache_ttl = 300 # 5 minutes cache
|
||||
|
||||
|
||||
def get_cached_playlist_metadata(playlist_id: str) -> Optional[Dict[str, Any]]:
|
||||
def get_playlist_metadata(playlist_id: str) -> Dict[str, Any]:
|
||||
"""
|
||||
Get playlist metadata from cache if available and not expired.
|
||||
|
||||
Args:
|
||||
playlist_id: The Spotify playlist ID
|
||||
|
||||
Returns:
|
||||
Cached metadata or None if not available/expired
|
||||
Fetch playlist metadata using the shared client without expanding items.
|
||||
"""
|
||||
if playlist_id in _playlist_metadata_cache:
|
||||
cached_data, timestamp = _playlist_metadata_cache[playlist_id]
|
||||
if time.time() - timestamp < _cache_ttl:
|
||||
return cached_data
|
||||
|
||||
return None
|
||||
|
||||
|
||||
def cache_playlist_metadata(playlist_id: str, metadata: Dict[str, Any]):
|
||||
"""
|
||||
Cache playlist metadata with timestamp.
|
||||
|
||||
Args:
|
||||
playlist_id: The Spotify playlist ID
|
||||
metadata: The metadata to cache
|
||||
"""
|
||||
_playlist_metadata_cache[playlist_id] = (metadata, time.time())
|
||||
|
||||
|
||||
def get_playlist_info_optimized(
|
||||
playlist_id: str, include_tracks: bool = False
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
Optimized playlist info function that uses caching and selective loading.
|
||||
|
||||
Args:
|
||||
playlist_id: The Spotify playlist ID
|
||||
include_tracks: Whether to include track data (default: False to save API calls)
|
||||
|
||||
Returns:
|
||||
Playlist data with or without tracks
|
||||
"""
|
||||
# Check cache first
|
||||
cached_metadata = get_cached_playlist_metadata(playlist_id)
|
||||
|
||||
if cached_metadata and not include_tracks:
|
||||
logger.debug(f"Returning cached metadata for playlist {playlist_id}")
|
||||
return cached_metadata
|
||||
|
||||
if include_tracks:
|
||||
# Get complete playlist data
|
||||
playlist_data = get_playlist_full(playlist_id)
|
||||
# Cache the metadata portion
|
||||
metadata_only = {k: v for k, v in playlist_data.items() if k != "tracks"}
|
||||
metadata_only["_metadata_only"] = True
|
||||
metadata_only["_tracks_loaded"] = False
|
||||
cache_playlist_metadata(playlist_id, metadata_only)
|
||||
return playlist_data
|
||||
else:
|
||||
# Get metadata only
|
||||
metadata = get_playlist_metadata(playlist_id)
|
||||
cache_playlist_metadata(playlist_id, metadata)
|
||||
return metadata
|
||||
|
||||
|
||||
# Keep the existing Deezer functions unchanged
|
||||
def get_deezer_info(deezer_id, deezer_type, limit=None):
|
||||
"""
|
||||
Get info from Deezer API.
|
||||
|
||||
Args:
|
||||
deezer_id: The Deezer ID of the entity.
|
||||
deezer_type: The type of entity (track, album, playlist, artist, episode,
|
||||
artist_top_tracks, artist_albums, artist_related,
|
||||
artist_radio, artist_playlists).
|
||||
limit (int, optional): The maximum number of items to return. Used for
|
||||
artist_top_tracks, artist_albums, artist_playlists.
|
||||
Deezer API methods usually have their own defaults (e.g., 25)
|
||||
if limit is not provided or None is passed to them.
|
||||
|
||||
Returns:
|
||||
Dictionary with the entity information.
|
||||
Raises:
|
||||
ValueError: If deezer_type is unsupported.
|
||||
Various exceptions from DeezerAPI (NoDataApi, QuotaExceeded, requests.exceptions.RequestException, etc.)
|
||||
"""
|
||||
logger.debug(
|
||||
f"Fetching Deezer info for ID {deezer_id}, type {deezer_type}, limit {limit}"
|
||||
)
|
||||
|
||||
# DeezerAPI uses class methods; its @classmethod __init__ handles setup.
|
||||
# No specific ARL or account handling here as DeezerAPI seems to use general endpoints.
|
||||
|
||||
if deezer_type == "track":
|
||||
return DeezerAPI.get_track(deezer_id)
|
||||
elif deezer_type == "album":
|
||||
return DeezerAPI.get_album(deezer_id)
|
||||
elif deezer_type == "playlist":
|
||||
return DeezerAPI.get_playlist(deezer_id)
|
||||
elif deezer_type == "artist":
|
||||
return DeezerAPI.get_artist(deezer_id)
|
||||
elif deezer_type == "episode":
|
||||
return DeezerAPI.get_episode(deezer_id)
|
||||
elif deezer_type == "artist_top_tracks":
|
||||
if limit is not None:
|
||||
return DeezerAPI.get_artist_top_tracks(deezer_id, limit=limit)
|
||||
return DeezerAPI.get_artist_top_tracks(deezer_id) # Use API default limit
|
||||
elif deezer_type == "artist_albums": # Maps to get_artist_top_albums
|
||||
if limit is not None:
|
||||
return DeezerAPI.get_artist_top_albums(deezer_id, limit=limit)
|
||||
return DeezerAPI.get_artist_top_albums(deezer_id) # Use API default limit
|
||||
elif deezer_type == "artist_related":
|
||||
return DeezerAPI.get_artist_related(deezer_id)
|
||||
elif deezer_type == "artist_radio":
|
||||
return DeezerAPI.get_artist_radio(deezer_id)
|
||||
elif deezer_type == "artist_playlists":
|
||||
if limit is not None:
|
||||
return DeezerAPI.get_artist_top_playlists(deezer_id, limit=limit)
|
||||
return DeezerAPI.get_artist_top_playlists(deezer_id) # Use API default limit
|
||||
else:
|
||||
logger.error(f"Unsupported Deezer type: {deezer_type}")
|
||||
raise ValueError(f"Unsupported Deezer type: {deezer_type}")
|
||||
client = get_client()
|
||||
return get_playlist(client, playlist_id, expand_items=False)
|
||||
|
||||
@@ -27,15 +27,9 @@ from routes.utils.watch.db import (
|
||||
get_artist_batch_next_offset,
|
||||
set_artist_batch_next_offset,
|
||||
)
|
||||
from routes.utils.get_info import (
|
||||
get_spotify_info,
|
||||
get_playlist_metadata,
|
||||
get_playlist_tracks,
|
||||
) # To fetch playlist, track, artist, and album details
|
||||
from routes.utils.celery_queue_manager import download_queue_manager
|
||||
|
||||
# Added import to fetch base formatting config
|
||||
from routes.utils.celery_queue_manager import get_config_params
|
||||
from routes.utils.celery_queue_manager import download_queue_manager, get_config_params
|
||||
from routes.utils.get_info import get_client
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
MAIN_CONFIG_FILE_PATH = Path("./data/config/main.json")
|
||||
@@ -358,7 +352,7 @@ def find_tracks_in_playlist(
|
||||
|
||||
while not_found_tracks and offset < 10000: # Safety limit
|
||||
try:
|
||||
tracks_batch = get_playlist_tracks(
|
||||
tracks_batch = _fetch_playlist_tracks_page(
|
||||
playlist_spotify_id, limit=limit, offset=offset
|
||||
)
|
||||
|
||||
@@ -459,7 +453,9 @@ def check_watched_playlists(specific_playlist_id: str = None):
|
||||
ensure_playlist_table_schema(playlist_spotify_id)
|
||||
|
||||
# First, get playlist metadata to check if it has changed
|
||||
current_playlist_metadata = get_playlist_metadata(playlist_spotify_id)
|
||||
current_playlist_metadata = _fetch_playlist_metadata(
|
||||
playlist_spotify_id
|
||||
)
|
||||
if not current_playlist_metadata:
|
||||
logger.error(
|
||||
f"Playlist Watch Manager: Failed to fetch metadata from Spotify for playlist {playlist_spotify_id}."
|
||||
@@ -507,7 +503,7 @@ def check_watched_playlists(specific_playlist_id: str = None):
|
||||
progress_offset, _ = get_playlist_batch_progress(
|
||||
playlist_spotify_id
|
||||
)
|
||||
tracks_batch = get_playlist_tracks(
|
||||
tracks_batch = _fetch_playlist_tracks_page(
|
||||
playlist_spotify_id,
|
||||
limit=batch_limit,
|
||||
offset=progress_offset,
|
||||
@@ -573,7 +569,7 @@ def check_watched_playlists(specific_playlist_id: str = None):
|
||||
logger.info(
|
||||
f"Playlist Watch Manager: Fetching one batch (limit={batch_limit}, offset={progress_offset}) for playlist '{playlist_name}'."
|
||||
)
|
||||
tracks_batch = get_playlist_tracks(
|
||||
tracks_batch = _fetch_playlist_tracks_page(
|
||||
playlist_spotify_id, limit=batch_limit, offset=progress_offset
|
||||
)
|
||||
batch_items = tracks_batch.get("items", []) if tracks_batch else []
|
||||
@@ -734,8 +730,8 @@ def check_watched_artists(specific_artist_id: str = None):
|
||||
logger.debug(
|
||||
f"Artist Watch Manager: Fetching albums for {artist_spotify_id}. Limit: {limit}, Offset: {offset}"
|
||||
)
|
||||
artist_albums_page = get_spotify_info(
|
||||
artist_spotify_id, "artist_discography", limit=limit, offset=offset
|
||||
artist_albums_page = _fetch_artist_discography_page(
|
||||
artist_spotify_id, limit=limit, offset=offset
|
||||
)
|
||||
|
||||
current_page_albums = (
|
||||
@@ -911,7 +907,8 @@ def run_playlist_check_over_intervals(playlist_spotify_id: str) -> None:
|
||||
# Determine if we are done: no active processing snapshot and no pending sync
|
||||
cfg = get_watch_config()
|
||||
interval = cfg.get("watchPollIntervalSeconds", 3600)
|
||||
metadata = get_playlist_metadata(playlist_spotify_id)
|
||||
# Use local helper that leverages Librespot client
|
||||
metadata = _fetch_playlist_metadata(playlist_spotify_id)
|
||||
if not metadata:
|
||||
logger.warning(
|
||||
f"Manual Playlist Runner: Could not load metadata for {playlist_spotify_id}. Stopping."
|
||||
@@ -1167,3 +1164,84 @@ def update_playlist_m3u_file(playlist_spotify_id: str):
|
||||
f"Error updating m3u file for playlist {playlist_spotify_id}: {e}",
|
||||
exc_info=True,
|
||||
)
|
||||
|
||||
|
||||
# Helper to build a Librespot client from active account
|
||||
|
||||
|
||||
def _build_librespot_client():
|
||||
try:
|
||||
# Reuse shared client managed in routes.utils.get_info
|
||||
return get_client()
|
||||
except Exception as e:
|
||||
raise RuntimeError(f"Failed to initialize Librespot client: {e}")
|
||||
|
||||
|
||||
def _fetch_playlist_metadata(playlist_id: str) -> dict:
|
||||
client = _build_librespot_client()
|
||||
return client.get_playlist(playlist_id, expand_items=False)
|
||||
|
||||
|
||||
def _fetch_playlist_tracks_page(playlist_id: str, limit: int, offset: int) -> dict:
|
||||
client = _build_librespot_client()
|
||||
# Fetch playlist with minimal items to avoid expanding all tracks unnecessarily
|
||||
pl = client.get_playlist(playlist_id, expand_items=False)
|
||||
items = (pl.get("tracks", {}) or {}).get("items", [])
|
||||
total = (pl.get("tracks", {}) or {}).get("total", len(items))
|
||||
start = max(0, offset or 0)
|
||||
end = start + max(1, limit or 50)
|
||||
page_items_minimal = items[start:end]
|
||||
|
||||
# Expand only the tracks in this page using client cache for efficiency
|
||||
page_items_expanded = []
|
||||
for item in page_items_minimal:
|
||||
track_stub = (item or {}).get("track") or {}
|
||||
track_id = track_stub.get("id")
|
||||
expanded_track = None
|
||||
if track_id:
|
||||
try:
|
||||
expanded_track = client.get_track(track_id)
|
||||
except Exception:
|
||||
expanded_track = None
|
||||
if expanded_track is None:
|
||||
# Keep stub as fallback; ensure structure
|
||||
expanded_track = {
|
||||
k: v
|
||||
for k, v in track_stub.items()
|
||||
if k in ("id", "uri", "type", "external_urls")
|
||||
}
|
||||
# Propagate local flag onto track for downstream checks
|
||||
if item and isinstance(item, dict) and item.get("is_local"):
|
||||
expanded_track["is_local"] = True
|
||||
# Rebuild item with expanded track
|
||||
new_item = dict(item)
|
||||
new_item["track"] = expanded_track
|
||||
page_items_expanded.append(new_item)
|
||||
|
||||
return {
|
||||
"items": page_items_expanded,
|
||||
"total": total,
|
||||
"limit": end - start,
|
||||
"offset": start,
|
||||
}
|
||||
|
||||
|
||||
def _fetch_artist_discography_page(artist_id: str, limit: int, offset: int) -> dict:
|
||||
# LibrespotClient.get_artist returns a pruned mapping; flatten common discography groups
|
||||
client = _build_librespot_client()
|
||||
artist = client.get_artist(artist_id)
|
||||
all_items = []
|
||||
# Collect from known groups; also support nested structures if present
|
||||
for key in ("album_group", "single_group", "compilation_group", "appears_on_group"):
|
||||
grp = artist.get(key)
|
||||
if isinstance(grp, list):
|
||||
all_items.extend(grp)
|
||||
elif isinstance(grp, dict):
|
||||
items = grp.get("items") or grp.get("releases") or []
|
||||
if isinstance(items, list):
|
||||
all_items.extend(items)
|
||||
total = len(all_items)
|
||||
start = max(0, offset or 0)
|
||||
end = start + max(1, limit or 50)
|
||||
page_items = all_items[start:end]
|
||||
return {"items": page_items, "total": total, "limit": limit, "offset": start}
|
||||
|
||||
Reference in New Issue
Block a user