diff --git a/docs/user/environment.md b/docs/user/environment.md index 7d0df4c..1d913f1 100644 --- a/docs/user/environment.md +++ b/docs/user/environment.md @@ -30,6 +30,24 @@ Location: project `.env`. Minimal reference for server admins. - FRONTEND_URL: Public UI base (e.g. `http://127.0.0.1:7171`) - GOOGLE_CLIENT_ID / GOOGLE_CLIENT_SECRET - GITHUB_CLIENT_ID / GITHUB_CLIENT_SECRET +- Custom/Generic OAuth (set all to enable a custom provider): + - CUSTOM_SSO_CLIENT_ID / CUSTOM_SSO_CLIENT_SECRET + - CUSTOM_SSO_AUTHORIZATION_ENDPOINT + - CUSTOM_SSO_TOKEN_ENDPOINT + - CUSTOM_SSO_USERINFO_ENDPOINT + - CUSTOM_SSO_SCOPE: Comma-separated scopes (optional) + - CUSTOM_SSO_NAME: Internal provider name (optional, default `custom`) + - CUSTOM_SSO_DISPLAY_NAME: UI name (optional, default `Custom`) +- Multiple Custom/Generic OAuth providers (up to 10): + - For provider index `i` (1..10), set: + - CUSTOM_SSO_CLIENT_ID_i / CUSTOM_SSO_CLIENT_SECRET_i + - CUSTOM_SSO_AUTHORIZATION_ENDPOINT_i + - CUSTOM_SSO_TOKEN_ENDPOINT_i + - CUSTOM_SSO_USERINFO_ENDPOINT_i + - CUSTOM_SSO_SCOPE_i (optional) + - CUSTOM_SSO_NAME_i (optional, default `custom{i}`) + - CUSTOM_SSO_DISPLAY_NAME_i (optional, default `Custom {i}`) + - Login URLs will be `/api/auth/sso/login/custom/i` and callback `/api/auth/sso/callback/custom/i`. ### Tips - If running behind a reverse proxy, set `FRONTEND_URL` and `SSO_BASE_REDIRECT_URI` to public URLs. diff --git a/routes/auth/auth.py b/routes/auth/auth.py index fa41290..c29fb77 100644 --- a/routes/auth/auth.py +++ b/routes/auth/auth.py @@ -1,4 +1,4 @@ -from fastapi import APIRouter, HTTPException, Depends, Request +from fastapi import APIRouter, HTTPException, Depends from fastapi.security import HTTPBearer, HTTPAuthorizationCredentials from pydantic import BaseModel from typing import Optional, List @@ -14,6 +14,7 @@ security = HTTPBearer(auto_error=False) # Include SSO sub-router try: from .sso import router as sso_router + router.include_router(sso_router, tags=["sso"]) logging.info("SSO sub-router included in auth router") except ImportError as e: @@ -34,6 +35,7 @@ class RegisterRequest(BaseModel): class CreateUserRequest(BaseModel): """Admin-only request to create users when registration is disabled""" + username: str password: str email: Optional[str] = None @@ -42,17 +44,20 @@ class CreateUserRequest(BaseModel): class RoleUpdateRequest(BaseModel): """Request to update user role""" + role: str class PasswordChangeRequest(BaseModel): """Request to change user password""" + current_password: str new_password: str class AdminPasswordResetRequest(BaseModel): """Request for admin to reset user password""" + new_password: str @@ -87,20 +92,20 @@ class AuthStatusResponse(BaseModel): # Dependency to get current user async def get_current_user( - credentials: HTTPAuthorizationCredentials = Depends(security) + credentials: HTTPAuthorizationCredentials = Depends(security), ) -> Optional[User]: """Get current user from JWT token""" if not AUTH_ENABLED: # When auth is disabled, return a mock admin user return User(username="system", role="admin") - + if not credentials: return None - + payload = token_manager.verify_token(credentials.credentials) if not payload: return None - + user = user_manager.get_user(payload["username"]) return user @@ -109,25 +114,22 @@ async def require_auth(current_user: User = Depends(get_current_user)) -> User: """Require authentication - raises HTTPException if not authenticated""" if not AUTH_ENABLED: return User(username="system", role="admin") - + if not current_user: raise HTTPException( status_code=401, detail="Authentication required", headers={"WWW-Authenticate": "Bearer"}, ) - + return current_user async def require_admin(current_user: User = Depends(require_auth)) -> User: """Require admin role - raises HTTPException if not admin""" if current_user.role != "admin": - raise HTTPException( - status_code=403, - detail="Admin access required" - ) - + raise HTTPException(status_code=403, detail="Admin access required") + return current_user @@ -138,24 +140,33 @@ async def auth_status(current_user: Optional[User] = Depends(get_current_user)): # Check if SSO is enabled and get available providers sso_enabled = False sso_providers = [] - + try: from . import sso + sso_enabled = sso.SSO_ENABLED and AUTH_ENABLED if sso.google_sso: sso_providers.append("google") if sso.github_sso: sso_providers.append("github") + if getattr(sso, "custom_sso", None): + sso_providers.append("custom") + if getattr(sso, "custom_sso_providers", None): + if ( + len(getattr(sso, "custom_sso_providers", {})) > 0 + and "custom" not in sso_providers + ): + sso_providers.append("custom") except ImportError: pass # SSO module not available - + return AuthStatusResponse( auth_enabled=AUTH_ENABLED, authenticated=current_user is not None, user=UserResponse(**current_user.to_public_dict()) if current_user else None, registration_enabled=AUTH_ENABLED and not DISABLE_REGISTRATION, sso_enabled=sso_enabled, - sso_providers=sso_providers + sso_providers=sso_providers, ) @@ -163,23 +174,16 @@ async def auth_status(current_user: Optional[User] = Depends(get_current_user)): async def login(request: LoginRequest): """Authenticate user and return access token""" if not AUTH_ENABLED: - raise HTTPException( - status_code=400, - detail="Authentication is disabled" - ) - + raise HTTPException(status_code=400, detail="Authentication is disabled") + user = user_manager.authenticate_user(request.username, request.password) if not user: - raise HTTPException( - status_code=401, - detail="Invalid username or password" - ) - + raise HTTPException(status_code=401, detail="Invalid username or password") + access_token = token_manager.create_token(user) - + return LoginResponse( - access_token=access_token, - user=UserResponse(**user.to_public_dict()) + access_token=access_token, user=UserResponse(**user.to_public_dict()) ) @@ -187,31 +191,28 @@ async def login(request: LoginRequest): async def register(request: RegisterRequest): """Register a new user""" if not AUTH_ENABLED: - raise HTTPException( - status_code=400, - detail="Authentication is disabled" - ) - + raise HTTPException(status_code=400, detail="Authentication is disabled") + if DISABLE_REGISTRATION: raise HTTPException( status_code=403, - detail="Public registration is disabled. Contact an administrator to create an account." + detail="Public registration is disabled. Contact an administrator to create an account.", ) - + # Check if this is the first user (should be admin) existing_users = user_manager.list_users() role = "admin" if len(existing_users) == 0 else "user" - + success, message = user_manager.create_user( username=request.username, password=request.password, email=request.email, - role=role + role=role, ) - + if not success: raise HTTPException(status_code=400, detail=message) - + return MessageResponse(message=message) @@ -233,70 +234,57 @@ async def list_users(current_user: User = Depends(require_admin)): async def delete_user(username: str, current_user: User = Depends(require_admin)): """Delete a user (admin only)""" if username == current_user.username: - raise HTTPException( - status_code=400, - detail="Cannot delete your own account" - ) - + raise HTTPException(status_code=400, detail="Cannot delete your own account") + success, message = user_manager.delete_user(username) if not success: raise HTTPException(status_code=404, detail=message) - + return MessageResponse(message=message) @router.put("/users/{username}/role", response_model=MessageResponse) async def update_user_role( - username: str, - request: RoleUpdateRequest, - current_user: User = Depends(require_admin) + username: str, + request: RoleUpdateRequest, + current_user: User = Depends(require_admin), ): """Update user role (admin only)""" if request.role not in ["user", "admin"]: - raise HTTPException( - status_code=400, - detail="Role must be 'user' or 'admin'" - ) - + raise HTTPException(status_code=400, detail="Role must be 'user' or 'admin'") + if username == current_user.username: - raise HTTPException( - status_code=400, - detail="Cannot change your own role" - ) - + raise HTTPException(status_code=400, detail="Cannot change your own role") + success, message = user_manager.update_user_role(username, request.role) if not success: raise HTTPException(status_code=404, detail=message) - + return MessageResponse(message=message) @router.post("/users/create", response_model=MessageResponse) -async def create_user_admin(request: CreateUserRequest, current_user: User = Depends(require_admin)): +async def create_user_admin( + request: CreateUserRequest, current_user: User = Depends(require_admin) +): """Create a new user (admin only) - for use when registration is disabled""" if not AUTH_ENABLED: - raise HTTPException( - status_code=400, - detail="Authentication is disabled" - ) - + raise HTTPException(status_code=400, detail="Authentication is disabled") + # Validate role if request.role not in ["user", "admin"]: - raise HTTPException( - status_code=400, - detail="Role must be 'user' or 'admin'" - ) - + raise HTTPException(status_code=400, detail="Role must be 'user' or 'admin'") + success, message = user_manager.create_user( username=request.username, password=request.password, email=request.email, - role=request.role + role=request.role, ) - + if not success: raise HTTPException(status_code=400, detail=message) - + return MessageResponse(message=message) @@ -309,22 +297,18 @@ async def get_profile(current_user: User = Depends(require_auth)): @router.put("/profile/password", response_model=MessageResponse) async def change_password( - request: PasswordChangeRequest, - current_user: User = Depends(require_auth) + request: PasswordChangeRequest, current_user: User = Depends(require_auth) ): """Change current user's password""" if not AUTH_ENABLED: - raise HTTPException( - status_code=400, - detail="Authentication is disabled" - ) - + raise HTTPException(status_code=400, detail="Authentication is disabled") + success, message = user_manager.change_password( username=current_user.username, current_password=request.current_password, - new_password=request.new_password + new_password=request.new_password, ) - + if not success: # Determine appropriate HTTP status code based on error message if "Current password is incorrect" in message: @@ -333,9 +317,9 @@ async def change_password( status_code = 404 else: status_code = 400 - + raise HTTPException(status_code=status_code, detail=message) - + return MessageResponse(message=message) @@ -343,30 +327,26 @@ async def change_password( async def admin_reset_password( username: str, request: AdminPasswordResetRequest, - current_user: User = Depends(require_admin) + current_user: User = Depends(require_admin), ): """Admin reset user password (admin only)""" if not AUTH_ENABLED: - raise HTTPException( - status_code=400, - detail="Authentication is disabled" - ) - + raise HTTPException(status_code=400, detail="Authentication is disabled") + success, message = user_manager.admin_reset_password( - username=username, - new_password=request.new_password + username=username, new_password=request.new_password ) - + if not success: # Determine appropriate HTTP status code based on error message if "User not found" in message: status_code = 404 else: status_code = 400 - + raise HTTPException(status_code=status_code, detail=message) - + return MessageResponse(message=message) -# Note: SSO routes are included in the main app, not here to avoid circular imports \ No newline at end of file +# Note: SSO routes are included in the main app, not here to avoid circular imports diff --git a/routes/auth/sso.py b/routes/auth/sso.py index f7ae7e5..f5ad728 100644 --- a/routes/auth/sso.py +++ b/routes/auth/sso.py @@ -1,17 +1,19 @@ """ SSO (Single Sign-On) implementation for Google and GitHub authentication """ + import os import logging from typing import Optional, Dict, Any from datetime import datetime, timedelta -from fastapi import APIRouter, Request, HTTPException, Depends +from fastapi import APIRouter, Request, HTTPException from fastapi.responses import RedirectResponse from fastapi_sso.sso.google import GoogleSSO from fastapi_sso.sso.github import GithubSSO from fastapi_sso.sso.base import OpenID from pydantic import BaseModel +from fastapi_sso.sso.generic import create_provider from . import user_manager, token_manager, User, AUTH_ENABLED, DISABLE_REGISTRATION @@ -25,11 +27,14 @@ GOOGLE_CLIENT_ID = os.getenv("GOOGLE_CLIENT_ID") GOOGLE_CLIENT_SECRET = os.getenv("GOOGLE_CLIENT_SECRET") GITHUB_CLIENT_ID = os.getenv("GITHUB_CLIENT_ID") GITHUB_CLIENT_SECRET = os.getenv("GITHUB_CLIENT_SECRET") -SSO_BASE_REDIRECT_URI = os.getenv("SSO_BASE_REDIRECT_URI", "http://localhost:7171/api/auth/sso/callback") +SSO_BASE_REDIRECT_URI = os.getenv( + "SSO_BASE_REDIRECT_URI", "http://localhost:7171/api/auth/sso/callback" +) # Initialize SSO providers google_sso = None github_sso = None +custom_sso = None if GOOGLE_CLIENT_ID and GOOGLE_CLIENT_SECRET: google_sso = GoogleSSO( @@ -47,6 +52,154 @@ if GITHUB_CLIENT_ID and GITHUB_CLIENT_SECRET: allow_insecure_http=True, # Set to False in production with HTTPS ) +# Custom/Generic OAuth provider configuration +CUSTOM_SSO_CLIENT_ID = os.getenv("CUSTOM_SSO_CLIENT_ID") +CUSTOM_SSO_CLIENT_SECRET = os.getenv("CUSTOM_SSO_CLIENT_SECRET") +CUSTOM_SSO_AUTHORIZATION_ENDPOINT = os.getenv("CUSTOM_SSO_AUTHORIZATION_ENDPOINT") +CUSTOM_SSO_TOKEN_ENDPOINT = os.getenv("CUSTOM_SSO_TOKEN_ENDPOINT") +CUSTOM_SSO_USERINFO_ENDPOINT = os.getenv("CUSTOM_SSO_USERINFO_ENDPOINT") +CUSTOM_SSO_SCOPE = os.getenv("CUSTOM_SSO_SCOPE") # comma-separated list +CUSTOM_SSO_NAME = os.getenv("CUSTOM_SSO_NAME", "custom") +CUSTOM_SSO_DISPLAY_NAME = os.getenv("CUSTOM_SSO_DISPLAY_NAME", "Custom") + + +def _default_custom_response_convertor( + userinfo: Dict[str, Any], _client=None +) -> OpenID: + """Best-effort convertor from generic userinfo to OpenID.""" + user_id = ( + userinfo.get("sub") + or userinfo.get("id") + or userinfo.get("user_id") + or userinfo.get("uid") + or userinfo.get("uuid") + ) + email = userinfo.get("email") + display_name = ( + userinfo.get("name") + or userinfo.get("preferred_username") + or userinfo.get("login") + or email + or (str(user_id) if user_id is not None else None) + ) + picture = userinfo.get("picture") or userinfo.get("avatar_url") + if not user_id and email: + user_id = email + return OpenID( + id=str(user_id) if user_id is not None else "", + email=email, + display_name=display_name, + picture=picture, + provider=CUSTOM_SSO_NAME, + ) + + +if all( + [ + CUSTOM_SSO_CLIENT_ID, + CUSTOM_SSO_CLIENT_SECRET, + CUSTOM_SSO_AUTHORIZATION_ENDPOINT, + CUSTOM_SSO_TOKEN_ENDPOINT, + CUSTOM_SSO_USERINFO_ENDPOINT, + ] +): + discovery = { + "authorization_endpoint": CUSTOM_SSO_AUTHORIZATION_ENDPOINT, + "token_endpoint": CUSTOM_SSO_TOKEN_ENDPOINT, + "userinfo_endpoint": CUSTOM_SSO_USERINFO_ENDPOINT, + } + default_scope = ( + [s.strip() for s in CUSTOM_SSO_SCOPE.split(",") if s.strip()] + if CUSTOM_SSO_SCOPE + else None + ) + CustomProvider = create_provider( + name=CUSTOM_SSO_NAME, + discovery_document=discovery, + response_convertor=_default_custom_response_convertor, + default_scope=default_scope, + ) + custom_sso = CustomProvider( + client_id=CUSTOM_SSO_CLIENT_ID, + client_secret=CUSTOM_SSO_CLIENT_SECRET, + redirect_uri=f"{SSO_BASE_REDIRECT_URI}/custom", + allow_insecure_http=True, # Set to False in production with HTTPS + ) + +# Support multiple indexed custom providers (CUSTOM_*_i), up to 10 +custom_sso_providers: Dict[int, Dict[str, Any]] = {} + + +def _make_response_convertor(provider_name: str): + def _convert(userinfo: Dict[str, Any], _client=None) -> OpenID: + user_id = ( + userinfo.get("sub") + or userinfo.get("id") + or userinfo.get("user_id") + or userinfo.get("uid") + or userinfo.get("uuid") + ) + email = userinfo.get("email") + display_name = ( + userinfo.get("name") + or userinfo.get("preferred_username") + or userinfo.get("login") + or email + or (str(user_id) if user_id is not None else None) + ) + picture = userinfo.get("picture") or userinfo.get("avatar_url") + if not user_id and email: + user_id = email + return OpenID( + id=str(user_id) if user_id is not None else "", + email=email, + display_name=display_name, + picture=picture, + provider=provider_name, + ) + + return _convert + + +for i in range(1, 11): + cid = os.getenv(f"CUSTOM_SSO_CLIENT_ID_{i}") + csecret = os.getenv(f"CUSTOM_SSO_CLIENT_SECRET_{i}") + auth_ep = os.getenv(f"CUSTOM_SSO_AUTHORIZATION_ENDPOINT_{i}") + token_ep = os.getenv(f"CUSTOM_SSO_TOKEN_ENDPOINT_{i}") + userinfo_ep = os.getenv(f"CUSTOM_SSO_USERINFO_ENDPOINT_{i}") + scope_raw = os.getenv(f"CUSTOM_SSO_SCOPE_{i}") + name_i = os.getenv(f"CUSTOM_SSO_NAME_{i}", f"custom{i}") + display_name_i = os.getenv(f"CUSTOM_SSO_DISPLAY_NAME_{i}", f"Custom {i}") + + if all([cid, csecret, auth_ep, token_ep, userinfo_ep]): + discovery_i = { + "authorization_endpoint": auth_ep, + "token_endpoint": token_ep, + "userinfo_endpoint": userinfo_ep, + } + default_scope_i = ( + [s.strip() for s in scope_raw.split(",") if s.strip()] + if scope_raw + else None + ) + ProviderClass = create_provider( + name=name_i, + discovery_document=discovery_i, + response_convertor=_make_response_convertor(name_i), + default_scope=default_scope_i, + ) + provider_instance = ProviderClass( + client_id=cid, + client_secret=csecret, + redirect_uri=f"{SSO_BASE_REDIRECT_URI}/custom/{i}", + allow_insecure_http=True, # Set to False in production with HTTPS + ) + custom_sso_providers[i] = { + "sso": provider_instance, + "name": name_i, + "display_name": display_name_i, + } + class MessageResponse(BaseModel): message: str @@ -70,21 +223,25 @@ def create_or_update_sso_user(openid: OpenID, provider: str) -> User: # Generate username from email or use provider ID email = openid.email if not email: - raise HTTPException(status_code=400, detail="Email is required for SSO authentication") - + raise HTTPException( + status_code=400, detail="Email is required for SSO authentication" + ) + # Use email prefix as username, fallback to provider + id username = email.split("@")[0] if not username: username = f"{provider}_{openid.id}" - + # Check if user already exists by email existing_user = None users = user_manager.load_users() for user_data in users.values(): if user_data.get("email") == email: - existing_user = User(**{k: v for k, v in user_data.items() if k != "password_hash"}) + existing_user = User( + **{k: v for k, v in user_data.items() if k != "password_hash"} + ) break - + if existing_user: # Update last login for existing user (always allowed) users[existing_user.username]["last_login"] = datetime.utcnow().isoformat() @@ -96,10 +253,10 @@ def create_or_update_sso_user(openid: OpenID, provider: str) -> User: # Check if registration is disabled before creating new user if DISABLE_REGISTRATION: raise HTTPException( - status_code=403, - detail="Registration is disabled. Contact an administrator to create an account." + status_code=403, + detail="Registration is disabled. Contact an administrator to create an account.", ) - + # Create new user # Ensure username is unique counter = 1 @@ -107,20 +264,20 @@ def create_or_update_sso_user(openid: OpenID, provider: str) -> User: while username in users: username = f"{original_username}{counter}" counter += 1 - + user = User( username=username, email=email, - role="user" # Default role for SSO users + role="user", # Default role for SSO users ) - + users[username] = { **user.to_dict(), "sso_provider": provider, "sso_id": openid.id, - "password_hash": None # SSO users don't have passwords + "password_hash": None, # SSO users don't have passwords } - + user_manager.save_users(users) logger.info(f"Created SSO user: {username} via {provider}") return user @@ -130,27 +287,51 @@ def create_or_update_sso_user(openid: OpenID, provider: str) -> User: async def sso_status(): """Get SSO status and available providers""" providers = [] - + if google_sso: - providers.append(SSOProvider( - name="google", - display_name="Google", - enabled=True, - login_url="/api/auth/sso/login/google" - )) - + providers.append( + SSOProvider( + name="google", + display_name="Google", + enabled=True, + login_url="/api/auth/sso/login/google", + ) + ) + if github_sso: - providers.append(SSOProvider( - name="github", - display_name="GitHub", - enabled=True, - login_url="/api/auth/sso/login/github" - )) - + providers.append( + SSOProvider( + name="github", + display_name="GitHub", + enabled=True, + login_url="/api/auth/sso/login/github", + ) + ) + + if custom_sso: + providers.append( + SSOProvider( + name="custom", + display_name=CUSTOM_SSO_DISPLAY_NAME, + enabled=True, + login_url="/api/auth/sso/login/custom", + ) + ) + + for idx, cfg in custom_sso_providers.items(): + providers.append( + SSOProvider( + name=cfg["name"], + display_name=cfg.get("display_name", cfg["name"]), + enabled=True, + login_url=f"/api/auth/sso/login/custom/{idx}", + ) + ) + return SSOStatusResponse( sso_enabled=SSO_ENABLED and AUTH_ENABLED, providers=providers, - registration_enabled=not DISABLE_REGISTRATION + registration_enabled=not DISABLE_REGISTRATION, ) @@ -159,12 +340,14 @@ async def google_login(): """Initiate Google SSO login""" if not SSO_ENABLED or not AUTH_ENABLED: raise HTTPException(status_code=400, detail="SSO is disabled") - + if not google_sso: raise HTTPException(status_code=400, detail="Google SSO is not configured") - + async with google_sso: - return await google_sso.get_login_redirect(params={"prompt": "consent", "access_type": "offline"}) + return await google_sso.get_login_redirect( + params={"prompt": "consent", "access_type": "offline"} + ) @router.get("/sso/login/github") @@ -172,37 +355,66 @@ async def github_login(): """Initiate GitHub SSO login""" if not SSO_ENABLED or not AUTH_ENABLED: raise HTTPException(status_code=400, detail="SSO is disabled") - + if not github_sso: raise HTTPException(status_code=400, detail="GitHub SSO is not configured") - + async with github_sso: return await github_sso.get_login_redirect() +@router.get("/sso/login/custom") +async def custom_login(): + """Initiate Custom SSO login""" + if not SSO_ENABLED or not AUTH_ENABLED: + raise HTTPException(status_code=400, detail="SSO is disabled") + + if not custom_sso: + raise HTTPException(status_code=400, detail="Custom SSO is not configured") + + async with custom_sso: + return await custom_sso.get_login_redirect() + + +@router.get("/sso/login/custom/{index}") +async def custom_login_indexed(index: int): + """Initiate indexed Custom SSO login""" + if not SSO_ENABLED or not AUTH_ENABLED: + raise HTTPException(status_code=400, detail="SSO is disabled") + + cfg = custom_sso_providers.get(index) + if not cfg: + raise HTTPException( + status_code=400, detail="Custom SSO provider not configured" + ) + + async with cfg["sso"]: + return await cfg["sso"].get_login_redirect() + + @router.get("/sso/callback/google") async def google_callback(request: Request): """Handle Google SSO callback""" if not SSO_ENABLED or not AUTH_ENABLED: raise HTTPException(status_code=400, detail="SSO is disabled") - + if not google_sso: raise HTTPException(status_code=400, detail="Google SSO is not configured") - + try: async with google_sso: openid = await google_sso.verify_and_process(request) - + # Create or update user user = create_or_update_sso_user(openid, "google") - + # Create JWT token access_token = token_manager.create_token(user) - + # Redirect to frontend with token (you might want to customize this) frontend_url = os.getenv("FRONTEND_URL", "http://localhost:3000") response = RedirectResponse(url=f"{frontend_url}?token={access_token}") - + # Also set as HTTP-only cookie response.set_cookie( key="access_token", @@ -210,18 +422,18 @@ async def google_callback(request: Request): httponly=True, secure=False, # Set to True in production with HTTPS samesite="lax", - max_age=timedelta(hours=24).total_seconds() + max_age=timedelta(hours=24).total_seconds(), ) - + return response - + except HTTPException as e: # Handle specific HTTP exceptions (like registration disabled) frontend_url = os.getenv("FRONTEND_URL", "http://localhost:3000") - error_msg = e.detail if hasattr(e, 'detail') else "Authentication failed" + error_msg = e.detail if hasattr(e, "detail") else "Authentication failed" logger.warning(f"Google SSO callback error: {error_msg}") return RedirectResponse(url=f"{frontend_url}?error={error_msg}") - + except Exception as e: logger.error(f"Google SSO callback error: {e}") frontend_url = os.getenv("FRONTEND_URL", "http://localhost:3000") @@ -233,24 +445,24 @@ async def github_callback(request: Request): """Handle GitHub SSO callback""" if not SSO_ENABLED or not AUTH_ENABLED: raise HTTPException(status_code=400, detail="SSO is disabled") - + if not github_sso: raise HTTPException(status_code=400, detail="GitHub SSO is not configured") - + try: async with github_sso: openid = await github_sso.verify_and_process(request) - + # Create or update user user = create_or_update_sso_user(openid, "github") - + # Create JWT token access_token = token_manager.create_token(user) - + # Redirect to frontend with token (you might want to customize this) frontend_url = os.getenv("FRONTEND_URL", "http://localhost:3000") response = RedirectResponse(url=f"{frontend_url}?token={access_token}") - + # Also set as HTTP-only cookie response.set_cookie( key="access_token", @@ -258,24 +470,123 @@ async def github_callback(request: Request): httponly=True, secure=False, # Set to True in production with HTTPS samesite="lax", - max_age=timedelta(hours=24).total_seconds() + max_age=timedelta(hours=24).total_seconds(), ) - + return response - + except HTTPException as e: # Handle specific HTTP exceptions (like registration disabled) frontend_url = os.getenv("FRONTEND_URL", "http://localhost:3000") - error_msg = e.detail if hasattr(e, 'detail') else "Authentication failed" + error_msg = e.detail if hasattr(e, "detail") else "Authentication failed" logger.warning(f"GitHub SSO callback error: {error_msg}") return RedirectResponse(url=f"{frontend_url}?error={error_msg}") - + except Exception as e: logger.error(f"GitHub SSO callback error: {e}") frontend_url = os.getenv("FRONTEND_URL", "http://localhost:3000") return RedirectResponse(url=f"{frontend_url}?error=Authentication failed") +@router.get("/sso/callback/custom") +async def custom_callback(request: Request): + """Handle Custom SSO callback""" + if not SSO_ENABLED or not AUTH_ENABLED: + raise HTTPException(status_code=400, detail="SSO is disabled") + + if not custom_sso: + raise HTTPException(status_code=400, detail="Custom SSO is not configured") + + try: + async with custom_sso: + openid = await custom_sso.verify_and_process(request) + + # Create or update user + user = create_or_update_sso_user(openid, "custom") + + # Create JWT token + access_token = token_manager.create_token(user) + + # Redirect to frontend with token (you might want to customize this) + frontend_url = os.getenv("FRONTEND_URL", "http://localhost:3000") + response = RedirectResponse(url=f"{frontend_url}?token={access_token}") + + # Also set as HTTP-only cookie + response.set_cookie( + key="access_token", + value=access_token, + httponly=True, + secure=False, # Set to True in production with HTTPS + samesite="lax", + max_age=timedelta(hours=24).total_seconds(), + ) + + return response + + except HTTPException as e: + # Handle specific HTTP exceptions (like registration disabled) + frontend_url = os.getenv("FRONTEND_URL", "http://localhost:3000") + error_msg = e.detail if hasattr(e, "detail") else "Authentication failed" + logger.warning(f"Custom SSO callback error: {error_msg}") + return RedirectResponse(url=f"{frontend_url}?error={error_msg}") + + except Exception as e: + logger.error(f"Custom SSO callback error: {e}") + frontend_url = os.getenv("FRONTEND_URL", "http://localhost:3000") + return RedirectResponse(url=f"{frontend_url}?error=Authentication failed") + + +@router.get("/sso/callback/custom/{index}") +async def custom_callback_indexed(request: Request, index: int): + """Handle indexed Custom SSO callback""" + if not SSO_ENABLED or not AUTH_ENABLED: + raise HTTPException(status_code=400, detail="SSO is disabled") + + cfg = custom_sso_providers.get(index) + if not cfg: + raise HTTPException( + status_code=400, detail="Custom SSO provider not configured" + ) + + try: + async with cfg["sso"]: + openid = await cfg["sso"].verify_and_process(request) + + # Create or update user + user = create_or_update_sso_user(openid, cfg["name"]) + + # Create JWT token + access_token = token_manager.create_token(user) + + # Redirect to frontend with token (you might want to customize this) + frontend_url = os.getenv("FRONTEND_URL", "http://localhost:3000") + response = RedirectResponse(url=f"{frontend_url}?token={access_token}") + + # Also set as HTTP-only cookie + response.set_cookie( + key="access_token", + value=access_token, + httponly=True, + secure=False, # Set to True in production with HTTPS + samesite="lax", + max_age=timedelta(hours=24).total_seconds(), + ) + + return response + + except HTTPException as e: + # Handle specific HTTP exceptions (like registration disabled) + frontend_url = os.getenv("FRONTEND_URL", "http://localhost:3000") + error_msg = e.detail if hasattr(e, "detail") else "Authentication failed" + logger.warning(f"Custom[{index}] SSO callback error: {error_msg}") + return RedirectResponse(url=f"{frontend_url}?error={error_msg}") + + except Exception as e: + logger.error(f"Custom[{index}] SSO callback error: {e}") + frontend_url = os.getenv("FRONTEND_URL", "http://localhost:3000") + return RedirectResponse(url=f"{frontend_url}?error=Authentication failed") + + @router.post("/sso/unlink/{provider}", response_model=MessageResponse) async def unlink_sso_provider( provider: str, @@ -284,27 +595,42 @@ async def unlink_sso_provider( """Unlink SSO provider from user account""" if not SSO_ENABLED or not AUTH_ENABLED: raise HTTPException(status_code=400, detail="SSO is disabled") - - if provider not in ["google", "github"]: + + available = [] + if google_sso: + available.append("google") + if github_sso: + available.append("github") + if custom_sso: + available.append("custom") + + for cfg in custom_sso_providers.values(): + available.append(cfg["name"]) + + if provider not in available: raise HTTPException(status_code=400, detail="Invalid SSO provider") - + # Get current user from request (avoiding circular imports) from .middleware import require_auth_from_state - + current_user = await require_auth_from_state(request) - + if not current_user.sso_provider: - raise HTTPException(status_code=400, detail="User is not linked to any SSO provider") - + raise HTTPException( + status_code=400, detail="User is not linked to any SSO provider" + ) + if current_user.sso_provider != provider: raise HTTPException(status_code=400, detail=f"User is not linked to {provider}") - + # Update user to remove SSO linkage users = user_manager.load_users() if current_user.username in users: users[current_user.username]["sso_provider"] = None users[current_user.username]["sso_id"] = None user_manager.save_users(users) - logger.info(f"Unlinked SSO provider {provider} from user {current_user.username}") - - return MessageResponse(message=f"SSO provider {provider} unlinked successfully") \ No newline at end of file + logger.info( + f"Unlinked SSO provider {provider} from user {current_user.username}" + ) + + return MessageResponse(message=f"SSO provider {provider} unlinked successfully")