feat: implement generic oauth provider

This commit is contained in:
Xoconoch
2025-08-25 08:03:59 -06:00
parent dc4a4f506f
commit c54a441228
3 changed files with 489 additions and 165 deletions

View File

@@ -30,6 +30,24 @@ Location: project `.env`. Minimal reference for server admins.
- FRONTEND_URL: Public UI base (e.g. `http://127.0.0.1:7171`)
- GOOGLE_CLIENT_ID / GOOGLE_CLIENT_SECRET
- GITHUB_CLIENT_ID / GITHUB_CLIENT_SECRET
- Custom/Generic OAuth (set all to enable a custom provider):
- CUSTOM_SSO_CLIENT_ID / CUSTOM_SSO_CLIENT_SECRET
- CUSTOM_SSO_AUTHORIZATION_ENDPOINT
- CUSTOM_SSO_TOKEN_ENDPOINT
- CUSTOM_SSO_USERINFO_ENDPOINT
- CUSTOM_SSO_SCOPE: Comma-separated scopes (optional)
- CUSTOM_SSO_NAME: Internal provider name (optional, default `custom`)
- CUSTOM_SSO_DISPLAY_NAME: UI name (optional, default `Custom`)
- Multiple Custom/Generic OAuth providers (up to 10):
- For provider index `i` (1..10), set:
- CUSTOM_SSO_CLIENT_ID_i / CUSTOM_SSO_CLIENT_SECRET_i
- CUSTOM_SSO_AUTHORIZATION_ENDPOINT_i
- CUSTOM_SSO_TOKEN_ENDPOINT_i
- CUSTOM_SSO_USERINFO_ENDPOINT_i
- CUSTOM_SSO_SCOPE_i (optional)
- CUSTOM_SSO_NAME_i (optional, default `custom{i}`)
- CUSTOM_SSO_DISPLAY_NAME_i (optional, default `Custom {i}`)
- Login URLs will be `/api/auth/sso/login/custom/i` and callback `/api/auth/sso/callback/custom/i`.
### Tips
- If running behind a reverse proxy, set `FRONTEND_URL` and `SSO_BASE_REDIRECT_URI` to public URLs.

View File

@@ -1,4 +1,4 @@
from fastapi import APIRouter, HTTPException, Depends, Request
from fastapi import APIRouter, HTTPException, Depends
from fastapi.security import HTTPBearer, HTTPAuthorizationCredentials
from pydantic import BaseModel
from typing import Optional, List
@@ -14,6 +14,7 @@ security = HTTPBearer(auto_error=False)
# Include SSO sub-router
try:
from .sso import router as sso_router
router.include_router(sso_router, tags=["sso"])
logging.info("SSO sub-router included in auth router")
except ImportError as e:
@@ -34,6 +35,7 @@ class RegisterRequest(BaseModel):
class CreateUserRequest(BaseModel):
"""Admin-only request to create users when registration is disabled"""
username: str
password: str
email: Optional[str] = None
@@ -42,17 +44,20 @@ class CreateUserRequest(BaseModel):
class RoleUpdateRequest(BaseModel):
"""Request to update user role"""
role: str
class PasswordChangeRequest(BaseModel):
"""Request to change user password"""
current_password: str
new_password: str
class AdminPasswordResetRequest(BaseModel):
"""Request for admin to reset user password"""
new_password: str
@@ -87,7 +92,7 @@ class AuthStatusResponse(BaseModel):
# Dependency to get current user
async def get_current_user(
credentials: HTTPAuthorizationCredentials = Depends(security)
credentials: HTTPAuthorizationCredentials = Depends(security),
) -> Optional[User]:
"""Get current user from JWT token"""
if not AUTH_ENABLED:
@@ -123,10 +128,7 @@ async def require_auth(current_user: User = Depends(get_current_user)) -> User:
async def require_admin(current_user: User = Depends(require_auth)) -> User:
"""Require admin role - raises HTTPException if not admin"""
if current_user.role != "admin":
raise HTTPException(
status_code=403,
detail="Admin access required"
)
raise HTTPException(status_code=403, detail="Admin access required")
return current_user
@@ -141,11 +143,20 @@ async def auth_status(current_user: Optional[User] = Depends(get_current_user)):
try:
from . import sso
sso_enabled = sso.SSO_ENABLED and AUTH_ENABLED
if sso.google_sso:
sso_providers.append("google")
if sso.github_sso:
sso_providers.append("github")
if getattr(sso, "custom_sso", None):
sso_providers.append("custom")
if getattr(sso, "custom_sso_providers", None):
if (
len(getattr(sso, "custom_sso_providers", {})) > 0
and "custom" not in sso_providers
):
sso_providers.append("custom")
except ImportError:
pass # SSO module not available
@@ -155,7 +166,7 @@ async def auth_status(current_user: Optional[User] = Depends(get_current_user)):
user=UserResponse(**current_user.to_public_dict()) if current_user else None,
registration_enabled=AUTH_ENABLED and not DISABLE_REGISTRATION,
sso_enabled=sso_enabled,
sso_providers=sso_providers
sso_providers=sso_providers,
)
@@ -163,23 +174,16 @@ async def auth_status(current_user: Optional[User] = Depends(get_current_user)):
async def login(request: LoginRequest):
"""Authenticate user and return access token"""
if not AUTH_ENABLED:
raise HTTPException(
status_code=400,
detail="Authentication is disabled"
)
raise HTTPException(status_code=400, detail="Authentication is disabled")
user = user_manager.authenticate_user(request.username, request.password)
if not user:
raise HTTPException(
status_code=401,
detail="Invalid username or password"
)
raise HTTPException(status_code=401, detail="Invalid username or password")
access_token = token_manager.create_token(user)
return LoginResponse(
access_token=access_token,
user=UserResponse(**user.to_public_dict())
access_token=access_token, user=UserResponse(**user.to_public_dict())
)
@@ -187,15 +191,12 @@ async def login(request: LoginRequest):
async def register(request: RegisterRequest):
"""Register a new user"""
if not AUTH_ENABLED:
raise HTTPException(
status_code=400,
detail="Authentication is disabled"
)
raise HTTPException(status_code=400, detail="Authentication is disabled")
if DISABLE_REGISTRATION:
raise HTTPException(
status_code=403,
detail="Public registration is disabled. Contact an administrator to create an account."
detail="Public registration is disabled. Contact an administrator to create an account.",
)
# Check if this is the first user (should be admin)
@@ -206,7 +207,7 @@ async def register(request: RegisterRequest):
username=request.username,
password=request.password,
email=request.email,
role=role
role=role,
)
if not success:
@@ -233,10 +234,7 @@ async def list_users(current_user: User = Depends(require_admin)):
async def delete_user(username: str, current_user: User = Depends(require_admin)):
"""Delete a user (admin only)"""
if username == current_user.username:
raise HTTPException(
status_code=400,
detail="Cannot delete your own account"
)
raise HTTPException(status_code=400, detail="Cannot delete your own account")
success, message = user_manager.delete_user(username)
if not success:
@@ -249,20 +247,14 @@ async def delete_user(username: str, current_user: User = Depends(require_admin)
async def update_user_role(
username: str,
request: RoleUpdateRequest,
current_user: User = Depends(require_admin)
current_user: User = Depends(require_admin),
):
"""Update user role (admin only)"""
if request.role not in ["user", "admin"]:
raise HTTPException(
status_code=400,
detail="Role must be 'user' or 'admin'"
)
raise HTTPException(status_code=400, detail="Role must be 'user' or 'admin'")
if username == current_user.username:
raise HTTPException(
status_code=400,
detail="Cannot change your own role"
)
raise HTTPException(status_code=400, detail="Cannot change your own role")
success, message = user_manager.update_user_role(username, request.role)
if not success:
@@ -272,26 +264,22 @@ async def update_user_role(
@router.post("/users/create", response_model=MessageResponse)
async def create_user_admin(request: CreateUserRequest, current_user: User = Depends(require_admin)):
async def create_user_admin(
request: CreateUserRequest, current_user: User = Depends(require_admin)
):
"""Create a new user (admin only) - for use when registration is disabled"""
if not AUTH_ENABLED:
raise HTTPException(
status_code=400,
detail="Authentication is disabled"
)
raise HTTPException(status_code=400, detail="Authentication is disabled")
# Validate role
if request.role not in ["user", "admin"]:
raise HTTPException(
status_code=400,
detail="Role must be 'user' or 'admin'"
)
raise HTTPException(status_code=400, detail="Role must be 'user' or 'admin'")
success, message = user_manager.create_user(
username=request.username,
password=request.password,
email=request.email,
role=request.role
role=request.role,
)
if not success:
@@ -309,20 +297,16 @@ async def get_profile(current_user: User = Depends(require_auth)):
@router.put("/profile/password", response_model=MessageResponse)
async def change_password(
request: PasswordChangeRequest,
current_user: User = Depends(require_auth)
request: PasswordChangeRequest, current_user: User = Depends(require_auth)
):
"""Change current user's password"""
if not AUTH_ENABLED:
raise HTTPException(
status_code=400,
detail="Authentication is disabled"
)
raise HTTPException(status_code=400, detail="Authentication is disabled")
success, message = user_manager.change_password(
username=current_user.username,
current_password=request.current_password,
new_password=request.new_password
new_password=request.new_password,
)
if not success:
@@ -343,18 +327,14 @@ async def change_password(
async def admin_reset_password(
username: str,
request: AdminPasswordResetRequest,
current_user: User = Depends(require_admin)
current_user: User = Depends(require_admin),
):
"""Admin reset user password (admin only)"""
if not AUTH_ENABLED:
raise HTTPException(
status_code=400,
detail="Authentication is disabled"
)
raise HTTPException(status_code=400, detail="Authentication is disabled")
success, message = user_manager.admin_reset_password(
username=username,
new_password=request.new_password
username=username, new_password=request.new_password
)
if not success:

View File

@@ -1,17 +1,19 @@
"""
SSO (Single Sign-On) implementation for Google and GitHub authentication
"""
import os
import logging
from typing import Optional, Dict, Any
from datetime import datetime, timedelta
from fastapi import APIRouter, Request, HTTPException, Depends
from fastapi import APIRouter, Request, HTTPException
from fastapi.responses import RedirectResponse
from fastapi_sso.sso.google import GoogleSSO
from fastapi_sso.sso.github import GithubSSO
from fastapi_sso.sso.base import OpenID
from pydantic import BaseModel
from fastapi_sso.sso.generic import create_provider
from . import user_manager, token_manager, User, AUTH_ENABLED, DISABLE_REGISTRATION
@@ -25,11 +27,14 @@ GOOGLE_CLIENT_ID = os.getenv("GOOGLE_CLIENT_ID")
GOOGLE_CLIENT_SECRET = os.getenv("GOOGLE_CLIENT_SECRET")
GITHUB_CLIENT_ID = os.getenv("GITHUB_CLIENT_ID")
GITHUB_CLIENT_SECRET = os.getenv("GITHUB_CLIENT_SECRET")
SSO_BASE_REDIRECT_URI = os.getenv("SSO_BASE_REDIRECT_URI", "http://localhost:7171/api/auth/sso/callback")
SSO_BASE_REDIRECT_URI = os.getenv(
"SSO_BASE_REDIRECT_URI", "http://localhost:7171/api/auth/sso/callback"
)
# Initialize SSO providers
google_sso = None
github_sso = None
custom_sso = None
if GOOGLE_CLIENT_ID and GOOGLE_CLIENT_SECRET:
google_sso = GoogleSSO(
@@ -47,6 +52,154 @@ if GITHUB_CLIENT_ID and GITHUB_CLIENT_SECRET:
allow_insecure_http=True, # Set to False in production with HTTPS
)
# Custom/Generic OAuth provider configuration
CUSTOM_SSO_CLIENT_ID = os.getenv("CUSTOM_SSO_CLIENT_ID")
CUSTOM_SSO_CLIENT_SECRET = os.getenv("CUSTOM_SSO_CLIENT_SECRET")
CUSTOM_SSO_AUTHORIZATION_ENDPOINT = os.getenv("CUSTOM_SSO_AUTHORIZATION_ENDPOINT")
CUSTOM_SSO_TOKEN_ENDPOINT = os.getenv("CUSTOM_SSO_TOKEN_ENDPOINT")
CUSTOM_SSO_USERINFO_ENDPOINT = os.getenv("CUSTOM_SSO_USERINFO_ENDPOINT")
CUSTOM_SSO_SCOPE = os.getenv("CUSTOM_SSO_SCOPE") # comma-separated list
CUSTOM_SSO_NAME = os.getenv("CUSTOM_SSO_NAME", "custom")
CUSTOM_SSO_DISPLAY_NAME = os.getenv("CUSTOM_SSO_DISPLAY_NAME", "Custom")
def _default_custom_response_convertor(
userinfo: Dict[str, Any], _client=None
) -> OpenID:
"""Best-effort convertor from generic userinfo to OpenID."""
user_id = (
userinfo.get("sub")
or userinfo.get("id")
or userinfo.get("user_id")
or userinfo.get("uid")
or userinfo.get("uuid")
)
email = userinfo.get("email")
display_name = (
userinfo.get("name")
or userinfo.get("preferred_username")
or userinfo.get("login")
or email
or (str(user_id) if user_id is not None else None)
)
picture = userinfo.get("picture") or userinfo.get("avatar_url")
if not user_id and email:
user_id = email
return OpenID(
id=str(user_id) if user_id is not None else "",
email=email,
display_name=display_name,
picture=picture,
provider=CUSTOM_SSO_NAME,
)
if all(
[
CUSTOM_SSO_CLIENT_ID,
CUSTOM_SSO_CLIENT_SECRET,
CUSTOM_SSO_AUTHORIZATION_ENDPOINT,
CUSTOM_SSO_TOKEN_ENDPOINT,
CUSTOM_SSO_USERINFO_ENDPOINT,
]
):
discovery = {
"authorization_endpoint": CUSTOM_SSO_AUTHORIZATION_ENDPOINT,
"token_endpoint": CUSTOM_SSO_TOKEN_ENDPOINT,
"userinfo_endpoint": CUSTOM_SSO_USERINFO_ENDPOINT,
}
default_scope = (
[s.strip() for s in CUSTOM_SSO_SCOPE.split(",") if s.strip()]
if CUSTOM_SSO_SCOPE
else None
)
CustomProvider = create_provider(
name=CUSTOM_SSO_NAME,
discovery_document=discovery,
response_convertor=_default_custom_response_convertor,
default_scope=default_scope,
)
custom_sso = CustomProvider(
client_id=CUSTOM_SSO_CLIENT_ID,
client_secret=CUSTOM_SSO_CLIENT_SECRET,
redirect_uri=f"{SSO_BASE_REDIRECT_URI}/custom",
allow_insecure_http=True, # Set to False in production with HTTPS
)
# Support multiple indexed custom providers (CUSTOM_*_i), up to 10
custom_sso_providers: Dict[int, Dict[str, Any]] = {}
def _make_response_convertor(provider_name: str):
def _convert(userinfo: Dict[str, Any], _client=None) -> OpenID:
user_id = (
userinfo.get("sub")
or userinfo.get("id")
or userinfo.get("user_id")
or userinfo.get("uid")
or userinfo.get("uuid")
)
email = userinfo.get("email")
display_name = (
userinfo.get("name")
or userinfo.get("preferred_username")
or userinfo.get("login")
or email
or (str(user_id) if user_id is not None else None)
)
picture = userinfo.get("picture") or userinfo.get("avatar_url")
if not user_id and email:
user_id = email
return OpenID(
id=str(user_id) if user_id is not None else "",
email=email,
display_name=display_name,
picture=picture,
provider=provider_name,
)
return _convert
for i in range(1, 11):
cid = os.getenv(f"CUSTOM_SSO_CLIENT_ID_{i}")
csecret = os.getenv(f"CUSTOM_SSO_CLIENT_SECRET_{i}")
auth_ep = os.getenv(f"CUSTOM_SSO_AUTHORIZATION_ENDPOINT_{i}")
token_ep = os.getenv(f"CUSTOM_SSO_TOKEN_ENDPOINT_{i}")
userinfo_ep = os.getenv(f"CUSTOM_SSO_USERINFO_ENDPOINT_{i}")
scope_raw = os.getenv(f"CUSTOM_SSO_SCOPE_{i}")
name_i = os.getenv(f"CUSTOM_SSO_NAME_{i}", f"custom{i}")
display_name_i = os.getenv(f"CUSTOM_SSO_DISPLAY_NAME_{i}", f"Custom {i}")
if all([cid, csecret, auth_ep, token_ep, userinfo_ep]):
discovery_i = {
"authorization_endpoint": auth_ep,
"token_endpoint": token_ep,
"userinfo_endpoint": userinfo_ep,
}
default_scope_i = (
[s.strip() for s in scope_raw.split(",") if s.strip()]
if scope_raw
else None
)
ProviderClass = create_provider(
name=name_i,
discovery_document=discovery_i,
response_convertor=_make_response_convertor(name_i),
default_scope=default_scope_i,
)
provider_instance = ProviderClass(
client_id=cid,
client_secret=csecret,
redirect_uri=f"{SSO_BASE_REDIRECT_URI}/custom/{i}",
allow_insecure_http=True, # Set to False in production with HTTPS
)
custom_sso_providers[i] = {
"sso": provider_instance,
"name": name_i,
"display_name": display_name_i,
}
class MessageResponse(BaseModel):
message: str
@@ -70,7 +223,9 @@ def create_or_update_sso_user(openid: OpenID, provider: str) -> User:
# Generate username from email or use provider ID
email = openid.email
if not email:
raise HTTPException(status_code=400, detail="Email is required for SSO authentication")
raise HTTPException(
status_code=400, detail="Email is required for SSO authentication"
)
# Use email prefix as username, fallback to provider + id
username = email.split("@")[0]
@@ -82,7 +237,9 @@ def create_or_update_sso_user(openid: OpenID, provider: str) -> User:
users = user_manager.load_users()
for user_data in users.values():
if user_data.get("email") == email:
existing_user = User(**{k: v for k, v in user_data.items() if k != "password_hash"})
existing_user = User(
**{k: v for k, v in user_data.items() if k != "password_hash"}
)
break
if existing_user:
@@ -97,7 +254,7 @@ def create_or_update_sso_user(openid: OpenID, provider: str) -> User:
if DISABLE_REGISTRATION:
raise HTTPException(
status_code=403,
detail="Registration is disabled. Contact an administrator to create an account."
detail="Registration is disabled. Contact an administrator to create an account.",
)
# Create new user
@@ -111,14 +268,14 @@ def create_or_update_sso_user(openid: OpenID, provider: str) -> User:
user = User(
username=username,
email=email,
role="user" # Default role for SSO users
role="user", # Default role for SSO users
)
users[username] = {
**user.to_dict(),
"sso_provider": provider,
"sso_id": openid.id,
"password_hash": None # SSO users don't have passwords
"password_hash": None, # SSO users don't have passwords
}
user_manager.save_users(users)
@@ -132,25 +289,49 @@ async def sso_status():
providers = []
if google_sso:
providers.append(SSOProvider(
providers.append(
SSOProvider(
name="google",
display_name="Google",
enabled=True,
login_url="/api/auth/sso/login/google"
))
login_url="/api/auth/sso/login/google",
)
)
if github_sso:
providers.append(SSOProvider(
providers.append(
SSOProvider(
name="github",
display_name="GitHub",
enabled=True,
login_url="/api/auth/sso/login/github"
))
login_url="/api/auth/sso/login/github",
)
)
if custom_sso:
providers.append(
SSOProvider(
name="custom",
display_name=CUSTOM_SSO_DISPLAY_NAME,
enabled=True,
login_url="/api/auth/sso/login/custom",
)
)
for idx, cfg in custom_sso_providers.items():
providers.append(
SSOProvider(
name=cfg["name"],
display_name=cfg.get("display_name", cfg["name"]),
enabled=True,
login_url=f"/api/auth/sso/login/custom/{idx}",
)
)
return SSOStatusResponse(
sso_enabled=SSO_ENABLED and AUTH_ENABLED,
providers=providers,
registration_enabled=not DISABLE_REGISTRATION
registration_enabled=not DISABLE_REGISTRATION,
)
@@ -164,7 +345,9 @@ async def google_login():
raise HTTPException(status_code=400, detail="Google SSO is not configured")
async with google_sso:
return await google_sso.get_login_redirect(params={"prompt": "consent", "access_type": "offline"})
return await google_sso.get_login_redirect(
params={"prompt": "consent", "access_type": "offline"}
)
@router.get("/sso/login/github")
@@ -180,6 +363,35 @@ async def github_login():
return await github_sso.get_login_redirect()
@router.get("/sso/login/custom")
async def custom_login():
"""Initiate Custom SSO login"""
if not SSO_ENABLED or not AUTH_ENABLED:
raise HTTPException(status_code=400, detail="SSO is disabled")
if not custom_sso:
raise HTTPException(status_code=400, detail="Custom SSO is not configured")
async with custom_sso:
return await custom_sso.get_login_redirect()
@router.get("/sso/login/custom/{index}")
async def custom_login_indexed(index: int):
"""Initiate indexed Custom SSO login"""
if not SSO_ENABLED or not AUTH_ENABLED:
raise HTTPException(status_code=400, detail="SSO is disabled")
cfg = custom_sso_providers.get(index)
if not cfg:
raise HTTPException(
status_code=400, detail="Custom SSO provider not configured"
)
async with cfg["sso"]:
return await cfg["sso"].get_login_redirect()
@router.get("/sso/callback/google")
async def google_callback(request: Request):
"""Handle Google SSO callback"""
@@ -210,7 +422,7 @@ async def google_callback(request: Request):
httponly=True,
secure=False, # Set to True in production with HTTPS
samesite="lax",
max_age=timedelta(hours=24).total_seconds()
max_age=timedelta(hours=24).total_seconds(),
)
return response
@@ -218,7 +430,7 @@ async def google_callback(request: Request):
except HTTPException as e:
# Handle specific HTTP exceptions (like registration disabled)
frontend_url = os.getenv("FRONTEND_URL", "http://localhost:3000")
error_msg = e.detail if hasattr(e, 'detail') else "Authentication failed"
error_msg = e.detail if hasattr(e, "detail") else "Authentication failed"
logger.warning(f"Google SSO callback error: {error_msg}")
return RedirectResponse(url=f"{frontend_url}?error={error_msg}")
@@ -258,7 +470,7 @@ async def github_callback(request: Request):
httponly=True,
secure=False, # Set to True in production with HTTPS
samesite="lax",
max_age=timedelta(hours=24).total_seconds()
max_age=timedelta(hours=24).total_seconds(),
)
return response
@@ -266,7 +478,7 @@ async def github_callback(request: Request):
except HTTPException as e:
# Handle specific HTTP exceptions (like registration disabled)
frontend_url = os.getenv("FRONTEND_URL", "http://localhost:3000")
error_msg = e.detail if hasattr(e, 'detail') else "Authentication failed"
error_msg = e.detail if hasattr(e, "detail") else "Authentication failed"
logger.warning(f"GitHub SSO callback error: {error_msg}")
return RedirectResponse(url=f"{frontend_url}?error={error_msg}")
@@ -276,6 +488,105 @@ async def github_callback(request: Request):
return RedirectResponse(url=f"{frontend_url}?error=Authentication failed")
@router.get("/sso/callback/custom")
async def custom_callback(request: Request):
"""Handle Custom SSO callback"""
if not SSO_ENABLED or not AUTH_ENABLED:
raise HTTPException(status_code=400, detail="SSO is disabled")
if not custom_sso:
raise HTTPException(status_code=400, detail="Custom SSO is not configured")
try:
async with custom_sso:
openid = await custom_sso.verify_and_process(request)
# Create or update user
user = create_or_update_sso_user(openid, "custom")
# Create JWT token
access_token = token_manager.create_token(user)
# Redirect to frontend with token (you might want to customize this)
frontend_url = os.getenv("FRONTEND_URL", "http://localhost:3000")
response = RedirectResponse(url=f"{frontend_url}?token={access_token}")
# Also set as HTTP-only cookie
response.set_cookie(
key="access_token",
value=access_token,
httponly=True,
secure=False, # Set to True in production with HTTPS
samesite="lax",
max_age=timedelta(hours=24).total_seconds(),
)
return response
except HTTPException as e:
# Handle specific HTTP exceptions (like registration disabled)
frontend_url = os.getenv("FRONTEND_URL", "http://localhost:3000")
error_msg = e.detail if hasattr(e, "detail") else "Authentication failed"
logger.warning(f"Custom SSO callback error: {error_msg}")
return RedirectResponse(url=f"{frontend_url}?error={error_msg}")
except Exception as e:
logger.error(f"Custom SSO callback error: {e}")
frontend_url = os.getenv("FRONTEND_URL", "http://localhost:3000")
return RedirectResponse(url=f"{frontend_url}?error=Authentication failed")
@router.get("/sso/callback/custom/{index}")
async def custom_callback_indexed(request: Request, index: int):
"""Handle indexed Custom SSO callback"""
if not SSO_ENABLED or not AUTH_ENABLED:
raise HTTPException(status_code=400, detail="SSO is disabled")
cfg = custom_sso_providers.get(index)
if not cfg:
raise HTTPException(
status_code=400, detail="Custom SSO provider not configured"
)
try:
async with cfg["sso"]:
openid = await cfg["sso"].verify_and_process(request)
# Create or update user
user = create_or_update_sso_user(openid, cfg["name"])
# Create JWT token
access_token = token_manager.create_token(user)
# Redirect to frontend with token (you might want to customize this)
frontend_url = os.getenv("FRONTEND_URL", "http://localhost:3000")
response = RedirectResponse(url=f"{frontend_url}?token={access_token}")
# Also set as HTTP-only cookie
response.set_cookie(
key="access_token",
value=access_token,
httponly=True,
secure=False, # Set to True in production with HTTPS
samesite="lax",
max_age=timedelta(hours=24).total_seconds(),
)
return response
except HTTPException as e:
# Handle specific HTTP exceptions (like registration disabled)
frontend_url = os.getenv("FRONTEND_URL", "http://localhost:3000")
error_msg = e.detail if hasattr(e, "detail") else "Authentication failed"
logger.warning(f"Custom[{index}] SSO callback error: {error_msg}")
return RedirectResponse(url=f"{frontend_url}?error={error_msg}")
except Exception as e:
logger.error(f"Custom[{index}] SSO callback error: {e}")
frontend_url = os.getenv("FRONTEND_URL", "http://localhost:3000")
return RedirectResponse(url=f"{frontend_url}?error=Authentication failed")
@router.post("/sso/unlink/{provider}", response_model=MessageResponse)
async def unlink_sso_provider(
provider: str,
@@ -285,7 +596,18 @@ async def unlink_sso_provider(
if not SSO_ENABLED or not AUTH_ENABLED:
raise HTTPException(status_code=400, detail="SSO is disabled")
if provider not in ["google", "github"]:
available = []
if google_sso:
available.append("google")
if github_sso:
available.append("github")
if custom_sso:
available.append("custom")
for cfg in custom_sso_providers.values():
available.append(cfg["name"])
if provider not in available:
raise HTTPException(status_code=400, detail="Invalid SSO provider")
# Get current user from request (avoiding circular imports)
@@ -294,7 +616,9 @@ async def unlink_sso_provider(
current_user = await require_auth_from_state(request)
if not current_user.sso_provider:
raise HTTPException(status_code=400, detail="User is not linked to any SSO provider")
raise HTTPException(
status_code=400, detail="User is not linked to any SSO provider"
)
if current_user.sso_provider != provider:
raise HTTPException(status_code=400, detail=f"User is not linked to {provider}")
@@ -305,6 +629,8 @@ async def unlink_sso_provider(
users[current_user.username]["sso_provider"] = None
users[current_user.username]["sso_id"] = None
user_manager.save_users(users)
logger.info(f"Unlinked SSO provider {provider} from user {current_user.username}")
logger.info(
f"Unlinked SSO provider {provider} from user {current_user.username}"
)
return MessageResponse(message=f"SSO provider {provider} unlinked successfully")