feat: implement generic oauth provider

This commit is contained in:
Xoconoch
2025-08-25 08:03:59 -06:00
parent dc4a4f506f
commit c54a441228
3 changed files with 489 additions and 165 deletions

View File

@@ -30,6 +30,24 @@ Location: project `.env`. Minimal reference for server admins.
- FRONTEND_URL: Public UI base (e.g. `http://127.0.0.1:7171`) - FRONTEND_URL: Public UI base (e.g. `http://127.0.0.1:7171`)
- GOOGLE_CLIENT_ID / GOOGLE_CLIENT_SECRET - GOOGLE_CLIENT_ID / GOOGLE_CLIENT_SECRET
- GITHUB_CLIENT_ID / GITHUB_CLIENT_SECRET - GITHUB_CLIENT_ID / GITHUB_CLIENT_SECRET
- Custom/Generic OAuth (set all to enable a custom provider):
- CUSTOM_SSO_CLIENT_ID / CUSTOM_SSO_CLIENT_SECRET
- CUSTOM_SSO_AUTHORIZATION_ENDPOINT
- CUSTOM_SSO_TOKEN_ENDPOINT
- CUSTOM_SSO_USERINFO_ENDPOINT
- CUSTOM_SSO_SCOPE: Comma-separated scopes (optional)
- CUSTOM_SSO_NAME: Internal provider name (optional, default `custom`)
- CUSTOM_SSO_DISPLAY_NAME: UI name (optional, default `Custom`)
- Multiple Custom/Generic OAuth providers (up to 10):
- For provider index `i` (1..10), set:
- CUSTOM_SSO_CLIENT_ID_i / CUSTOM_SSO_CLIENT_SECRET_i
- CUSTOM_SSO_AUTHORIZATION_ENDPOINT_i
- CUSTOM_SSO_TOKEN_ENDPOINT_i
- CUSTOM_SSO_USERINFO_ENDPOINT_i
- CUSTOM_SSO_SCOPE_i (optional)
- CUSTOM_SSO_NAME_i (optional, default `custom{i}`)
- CUSTOM_SSO_DISPLAY_NAME_i (optional, default `Custom {i}`)
- Login URLs will be `/api/auth/sso/login/custom/i` and callback `/api/auth/sso/callback/custom/i`.
### Tips ### Tips
- If running behind a reverse proxy, set `FRONTEND_URL` and `SSO_BASE_REDIRECT_URI` to public URLs. - If running behind a reverse proxy, set `FRONTEND_URL` and `SSO_BASE_REDIRECT_URI` to public URLs.

View File

@@ -1,4 +1,4 @@
from fastapi import APIRouter, HTTPException, Depends, Request from fastapi import APIRouter, HTTPException, Depends
from fastapi.security import HTTPBearer, HTTPAuthorizationCredentials from fastapi.security import HTTPBearer, HTTPAuthorizationCredentials
from pydantic import BaseModel from pydantic import BaseModel
from typing import Optional, List from typing import Optional, List
@@ -14,6 +14,7 @@ security = HTTPBearer(auto_error=False)
# Include SSO sub-router # Include SSO sub-router
try: try:
from .sso import router as sso_router from .sso import router as sso_router
router.include_router(sso_router, tags=["sso"]) router.include_router(sso_router, tags=["sso"])
logging.info("SSO sub-router included in auth router") logging.info("SSO sub-router included in auth router")
except ImportError as e: except ImportError as e:
@@ -34,6 +35,7 @@ class RegisterRequest(BaseModel):
class CreateUserRequest(BaseModel): class CreateUserRequest(BaseModel):
"""Admin-only request to create users when registration is disabled""" """Admin-only request to create users when registration is disabled"""
username: str username: str
password: str password: str
email: Optional[str] = None email: Optional[str] = None
@@ -42,17 +44,20 @@ class CreateUserRequest(BaseModel):
class RoleUpdateRequest(BaseModel): class RoleUpdateRequest(BaseModel):
"""Request to update user role""" """Request to update user role"""
role: str role: str
class PasswordChangeRequest(BaseModel): class PasswordChangeRequest(BaseModel):
"""Request to change user password""" """Request to change user password"""
current_password: str current_password: str
new_password: str new_password: str
class AdminPasswordResetRequest(BaseModel): class AdminPasswordResetRequest(BaseModel):
"""Request for admin to reset user password""" """Request for admin to reset user password"""
new_password: str new_password: str
@@ -87,7 +92,7 @@ class AuthStatusResponse(BaseModel):
# Dependency to get current user # Dependency to get current user
async def get_current_user( async def get_current_user(
credentials: HTTPAuthorizationCredentials = Depends(security) credentials: HTTPAuthorizationCredentials = Depends(security),
) -> Optional[User]: ) -> Optional[User]:
"""Get current user from JWT token""" """Get current user from JWT token"""
if not AUTH_ENABLED: if not AUTH_ENABLED:
@@ -123,10 +128,7 @@ async def require_auth(current_user: User = Depends(get_current_user)) -> User:
async def require_admin(current_user: User = Depends(require_auth)) -> User: async def require_admin(current_user: User = Depends(require_auth)) -> User:
"""Require admin role - raises HTTPException if not admin""" """Require admin role - raises HTTPException if not admin"""
if current_user.role != "admin": if current_user.role != "admin":
raise HTTPException( raise HTTPException(status_code=403, detail="Admin access required")
status_code=403,
detail="Admin access required"
)
return current_user return current_user
@@ -141,11 +143,20 @@ async def auth_status(current_user: Optional[User] = Depends(get_current_user)):
try: try:
from . import sso from . import sso
sso_enabled = sso.SSO_ENABLED and AUTH_ENABLED sso_enabled = sso.SSO_ENABLED and AUTH_ENABLED
if sso.google_sso: if sso.google_sso:
sso_providers.append("google") sso_providers.append("google")
if sso.github_sso: if sso.github_sso:
sso_providers.append("github") sso_providers.append("github")
if getattr(sso, "custom_sso", None):
sso_providers.append("custom")
if getattr(sso, "custom_sso_providers", None):
if (
len(getattr(sso, "custom_sso_providers", {})) > 0
and "custom" not in sso_providers
):
sso_providers.append("custom")
except ImportError: except ImportError:
pass # SSO module not available pass # SSO module not available
@@ -155,7 +166,7 @@ async def auth_status(current_user: Optional[User] = Depends(get_current_user)):
user=UserResponse(**current_user.to_public_dict()) if current_user else None, user=UserResponse(**current_user.to_public_dict()) if current_user else None,
registration_enabled=AUTH_ENABLED and not DISABLE_REGISTRATION, registration_enabled=AUTH_ENABLED and not DISABLE_REGISTRATION,
sso_enabled=sso_enabled, sso_enabled=sso_enabled,
sso_providers=sso_providers sso_providers=sso_providers,
) )
@@ -163,23 +174,16 @@ async def auth_status(current_user: Optional[User] = Depends(get_current_user)):
async def login(request: LoginRequest): async def login(request: LoginRequest):
"""Authenticate user and return access token""" """Authenticate user and return access token"""
if not AUTH_ENABLED: if not AUTH_ENABLED:
raise HTTPException( raise HTTPException(status_code=400, detail="Authentication is disabled")
status_code=400,
detail="Authentication is disabled"
)
user = user_manager.authenticate_user(request.username, request.password) user = user_manager.authenticate_user(request.username, request.password)
if not user: if not user:
raise HTTPException( raise HTTPException(status_code=401, detail="Invalid username or password")
status_code=401,
detail="Invalid username or password"
)
access_token = token_manager.create_token(user) access_token = token_manager.create_token(user)
return LoginResponse( return LoginResponse(
access_token=access_token, access_token=access_token, user=UserResponse(**user.to_public_dict())
user=UserResponse(**user.to_public_dict())
) )
@@ -187,15 +191,12 @@ async def login(request: LoginRequest):
async def register(request: RegisterRequest): async def register(request: RegisterRequest):
"""Register a new user""" """Register a new user"""
if not AUTH_ENABLED: if not AUTH_ENABLED:
raise HTTPException( raise HTTPException(status_code=400, detail="Authentication is disabled")
status_code=400,
detail="Authentication is disabled"
)
if DISABLE_REGISTRATION: if DISABLE_REGISTRATION:
raise HTTPException( raise HTTPException(
status_code=403, status_code=403,
detail="Public registration is disabled. Contact an administrator to create an account." detail="Public registration is disabled. Contact an administrator to create an account.",
) )
# Check if this is the first user (should be admin) # Check if this is the first user (should be admin)
@@ -206,7 +207,7 @@ async def register(request: RegisterRequest):
username=request.username, username=request.username,
password=request.password, password=request.password,
email=request.email, email=request.email,
role=role role=role,
) )
if not success: if not success:
@@ -233,10 +234,7 @@ async def list_users(current_user: User = Depends(require_admin)):
async def delete_user(username: str, current_user: User = Depends(require_admin)): async def delete_user(username: str, current_user: User = Depends(require_admin)):
"""Delete a user (admin only)""" """Delete a user (admin only)"""
if username == current_user.username: if username == current_user.username:
raise HTTPException( raise HTTPException(status_code=400, detail="Cannot delete your own account")
status_code=400,
detail="Cannot delete your own account"
)
success, message = user_manager.delete_user(username) success, message = user_manager.delete_user(username)
if not success: if not success:
@@ -249,20 +247,14 @@ async def delete_user(username: str, current_user: User = Depends(require_admin)
async def update_user_role( async def update_user_role(
username: str, username: str,
request: RoleUpdateRequest, request: RoleUpdateRequest,
current_user: User = Depends(require_admin) current_user: User = Depends(require_admin),
): ):
"""Update user role (admin only)""" """Update user role (admin only)"""
if request.role not in ["user", "admin"]: if request.role not in ["user", "admin"]:
raise HTTPException( raise HTTPException(status_code=400, detail="Role must be 'user' or 'admin'")
status_code=400,
detail="Role must be 'user' or 'admin'"
)
if username == current_user.username: if username == current_user.username:
raise HTTPException( raise HTTPException(status_code=400, detail="Cannot change your own role")
status_code=400,
detail="Cannot change your own role"
)
success, message = user_manager.update_user_role(username, request.role) success, message = user_manager.update_user_role(username, request.role)
if not success: if not success:
@@ -272,26 +264,22 @@ async def update_user_role(
@router.post("/users/create", response_model=MessageResponse) @router.post("/users/create", response_model=MessageResponse)
async def create_user_admin(request: CreateUserRequest, current_user: User = Depends(require_admin)): async def create_user_admin(
request: CreateUserRequest, current_user: User = Depends(require_admin)
):
"""Create a new user (admin only) - for use when registration is disabled""" """Create a new user (admin only) - for use when registration is disabled"""
if not AUTH_ENABLED: if not AUTH_ENABLED:
raise HTTPException( raise HTTPException(status_code=400, detail="Authentication is disabled")
status_code=400,
detail="Authentication is disabled"
)
# Validate role # Validate role
if request.role not in ["user", "admin"]: if request.role not in ["user", "admin"]:
raise HTTPException( raise HTTPException(status_code=400, detail="Role must be 'user' or 'admin'")
status_code=400,
detail="Role must be 'user' or 'admin'"
)
success, message = user_manager.create_user( success, message = user_manager.create_user(
username=request.username, username=request.username,
password=request.password, password=request.password,
email=request.email, email=request.email,
role=request.role role=request.role,
) )
if not success: if not success:
@@ -309,20 +297,16 @@ async def get_profile(current_user: User = Depends(require_auth)):
@router.put("/profile/password", response_model=MessageResponse) @router.put("/profile/password", response_model=MessageResponse)
async def change_password( async def change_password(
request: PasswordChangeRequest, request: PasswordChangeRequest, current_user: User = Depends(require_auth)
current_user: User = Depends(require_auth)
): ):
"""Change current user's password""" """Change current user's password"""
if not AUTH_ENABLED: if not AUTH_ENABLED:
raise HTTPException( raise HTTPException(status_code=400, detail="Authentication is disabled")
status_code=400,
detail="Authentication is disabled"
)
success, message = user_manager.change_password( success, message = user_manager.change_password(
username=current_user.username, username=current_user.username,
current_password=request.current_password, current_password=request.current_password,
new_password=request.new_password new_password=request.new_password,
) )
if not success: if not success:
@@ -343,18 +327,14 @@ async def change_password(
async def admin_reset_password( async def admin_reset_password(
username: str, username: str,
request: AdminPasswordResetRequest, request: AdminPasswordResetRequest,
current_user: User = Depends(require_admin) current_user: User = Depends(require_admin),
): ):
"""Admin reset user password (admin only)""" """Admin reset user password (admin only)"""
if not AUTH_ENABLED: if not AUTH_ENABLED:
raise HTTPException( raise HTTPException(status_code=400, detail="Authentication is disabled")
status_code=400,
detail="Authentication is disabled"
)
success, message = user_manager.admin_reset_password( success, message = user_manager.admin_reset_password(
username=username, username=username, new_password=request.new_password
new_password=request.new_password
) )
if not success: if not success:

View File

@@ -1,17 +1,19 @@
""" """
SSO (Single Sign-On) implementation for Google and GitHub authentication SSO (Single Sign-On) implementation for Google and GitHub authentication
""" """
import os import os
import logging import logging
from typing import Optional, Dict, Any from typing import Optional, Dict, Any
from datetime import datetime, timedelta from datetime import datetime, timedelta
from fastapi import APIRouter, Request, HTTPException, Depends from fastapi import APIRouter, Request, HTTPException
from fastapi.responses import RedirectResponse from fastapi.responses import RedirectResponse
from fastapi_sso.sso.google import GoogleSSO from fastapi_sso.sso.google import GoogleSSO
from fastapi_sso.sso.github import GithubSSO from fastapi_sso.sso.github import GithubSSO
from fastapi_sso.sso.base import OpenID from fastapi_sso.sso.base import OpenID
from pydantic import BaseModel from pydantic import BaseModel
from fastapi_sso.sso.generic import create_provider
from . import user_manager, token_manager, User, AUTH_ENABLED, DISABLE_REGISTRATION from . import user_manager, token_manager, User, AUTH_ENABLED, DISABLE_REGISTRATION
@@ -25,11 +27,14 @@ GOOGLE_CLIENT_ID = os.getenv("GOOGLE_CLIENT_ID")
GOOGLE_CLIENT_SECRET = os.getenv("GOOGLE_CLIENT_SECRET") GOOGLE_CLIENT_SECRET = os.getenv("GOOGLE_CLIENT_SECRET")
GITHUB_CLIENT_ID = os.getenv("GITHUB_CLIENT_ID") GITHUB_CLIENT_ID = os.getenv("GITHUB_CLIENT_ID")
GITHUB_CLIENT_SECRET = os.getenv("GITHUB_CLIENT_SECRET") GITHUB_CLIENT_SECRET = os.getenv("GITHUB_CLIENT_SECRET")
SSO_BASE_REDIRECT_URI = os.getenv("SSO_BASE_REDIRECT_URI", "http://localhost:7171/api/auth/sso/callback") SSO_BASE_REDIRECT_URI = os.getenv(
"SSO_BASE_REDIRECT_URI", "http://localhost:7171/api/auth/sso/callback"
)
# Initialize SSO providers # Initialize SSO providers
google_sso = None google_sso = None
github_sso = None github_sso = None
custom_sso = None
if GOOGLE_CLIENT_ID and GOOGLE_CLIENT_SECRET: if GOOGLE_CLIENT_ID and GOOGLE_CLIENT_SECRET:
google_sso = GoogleSSO( google_sso = GoogleSSO(
@@ -47,6 +52,154 @@ if GITHUB_CLIENT_ID and GITHUB_CLIENT_SECRET:
allow_insecure_http=True, # Set to False in production with HTTPS allow_insecure_http=True, # Set to False in production with HTTPS
) )
# Custom/Generic OAuth provider configuration
CUSTOM_SSO_CLIENT_ID = os.getenv("CUSTOM_SSO_CLIENT_ID")
CUSTOM_SSO_CLIENT_SECRET = os.getenv("CUSTOM_SSO_CLIENT_SECRET")
CUSTOM_SSO_AUTHORIZATION_ENDPOINT = os.getenv("CUSTOM_SSO_AUTHORIZATION_ENDPOINT")
CUSTOM_SSO_TOKEN_ENDPOINT = os.getenv("CUSTOM_SSO_TOKEN_ENDPOINT")
CUSTOM_SSO_USERINFO_ENDPOINT = os.getenv("CUSTOM_SSO_USERINFO_ENDPOINT")
CUSTOM_SSO_SCOPE = os.getenv("CUSTOM_SSO_SCOPE") # comma-separated list
CUSTOM_SSO_NAME = os.getenv("CUSTOM_SSO_NAME", "custom")
CUSTOM_SSO_DISPLAY_NAME = os.getenv("CUSTOM_SSO_DISPLAY_NAME", "Custom")
def _default_custom_response_convertor(
userinfo: Dict[str, Any], _client=None
) -> OpenID:
"""Best-effort convertor from generic userinfo to OpenID."""
user_id = (
userinfo.get("sub")
or userinfo.get("id")
or userinfo.get("user_id")
or userinfo.get("uid")
or userinfo.get("uuid")
)
email = userinfo.get("email")
display_name = (
userinfo.get("name")
or userinfo.get("preferred_username")
or userinfo.get("login")
or email
or (str(user_id) if user_id is not None else None)
)
picture = userinfo.get("picture") or userinfo.get("avatar_url")
if not user_id and email:
user_id = email
return OpenID(
id=str(user_id) if user_id is not None else "",
email=email,
display_name=display_name,
picture=picture,
provider=CUSTOM_SSO_NAME,
)
if all(
[
CUSTOM_SSO_CLIENT_ID,
CUSTOM_SSO_CLIENT_SECRET,
CUSTOM_SSO_AUTHORIZATION_ENDPOINT,
CUSTOM_SSO_TOKEN_ENDPOINT,
CUSTOM_SSO_USERINFO_ENDPOINT,
]
):
discovery = {
"authorization_endpoint": CUSTOM_SSO_AUTHORIZATION_ENDPOINT,
"token_endpoint": CUSTOM_SSO_TOKEN_ENDPOINT,
"userinfo_endpoint": CUSTOM_SSO_USERINFO_ENDPOINT,
}
default_scope = (
[s.strip() for s in CUSTOM_SSO_SCOPE.split(",") if s.strip()]
if CUSTOM_SSO_SCOPE
else None
)
CustomProvider = create_provider(
name=CUSTOM_SSO_NAME,
discovery_document=discovery,
response_convertor=_default_custom_response_convertor,
default_scope=default_scope,
)
custom_sso = CustomProvider(
client_id=CUSTOM_SSO_CLIENT_ID,
client_secret=CUSTOM_SSO_CLIENT_SECRET,
redirect_uri=f"{SSO_BASE_REDIRECT_URI}/custom",
allow_insecure_http=True, # Set to False in production with HTTPS
)
# Support multiple indexed custom providers (CUSTOM_*_i), up to 10
custom_sso_providers: Dict[int, Dict[str, Any]] = {}
def _make_response_convertor(provider_name: str):
def _convert(userinfo: Dict[str, Any], _client=None) -> OpenID:
user_id = (
userinfo.get("sub")
or userinfo.get("id")
or userinfo.get("user_id")
or userinfo.get("uid")
or userinfo.get("uuid")
)
email = userinfo.get("email")
display_name = (
userinfo.get("name")
or userinfo.get("preferred_username")
or userinfo.get("login")
or email
or (str(user_id) if user_id is not None else None)
)
picture = userinfo.get("picture") or userinfo.get("avatar_url")
if not user_id and email:
user_id = email
return OpenID(
id=str(user_id) if user_id is not None else "",
email=email,
display_name=display_name,
picture=picture,
provider=provider_name,
)
return _convert
for i in range(1, 11):
cid = os.getenv(f"CUSTOM_SSO_CLIENT_ID_{i}")
csecret = os.getenv(f"CUSTOM_SSO_CLIENT_SECRET_{i}")
auth_ep = os.getenv(f"CUSTOM_SSO_AUTHORIZATION_ENDPOINT_{i}")
token_ep = os.getenv(f"CUSTOM_SSO_TOKEN_ENDPOINT_{i}")
userinfo_ep = os.getenv(f"CUSTOM_SSO_USERINFO_ENDPOINT_{i}")
scope_raw = os.getenv(f"CUSTOM_SSO_SCOPE_{i}")
name_i = os.getenv(f"CUSTOM_SSO_NAME_{i}", f"custom{i}")
display_name_i = os.getenv(f"CUSTOM_SSO_DISPLAY_NAME_{i}", f"Custom {i}")
if all([cid, csecret, auth_ep, token_ep, userinfo_ep]):
discovery_i = {
"authorization_endpoint": auth_ep,
"token_endpoint": token_ep,
"userinfo_endpoint": userinfo_ep,
}
default_scope_i = (
[s.strip() for s in scope_raw.split(",") if s.strip()]
if scope_raw
else None
)
ProviderClass = create_provider(
name=name_i,
discovery_document=discovery_i,
response_convertor=_make_response_convertor(name_i),
default_scope=default_scope_i,
)
provider_instance = ProviderClass(
client_id=cid,
client_secret=csecret,
redirect_uri=f"{SSO_BASE_REDIRECT_URI}/custom/{i}",
allow_insecure_http=True, # Set to False in production with HTTPS
)
custom_sso_providers[i] = {
"sso": provider_instance,
"name": name_i,
"display_name": display_name_i,
}
class MessageResponse(BaseModel): class MessageResponse(BaseModel):
message: str message: str
@@ -70,7 +223,9 @@ def create_or_update_sso_user(openid: OpenID, provider: str) -> User:
# Generate username from email or use provider ID # Generate username from email or use provider ID
email = openid.email email = openid.email
if not email: if not email:
raise HTTPException(status_code=400, detail="Email is required for SSO authentication") raise HTTPException(
status_code=400, detail="Email is required for SSO authentication"
)
# Use email prefix as username, fallback to provider + id # Use email prefix as username, fallback to provider + id
username = email.split("@")[0] username = email.split("@")[0]
@@ -82,7 +237,9 @@ def create_or_update_sso_user(openid: OpenID, provider: str) -> User:
users = user_manager.load_users() users = user_manager.load_users()
for user_data in users.values(): for user_data in users.values():
if user_data.get("email") == email: if user_data.get("email") == email:
existing_user = User(**{k: v for k, v in user_data.items() if k != "password_hash"}) existing_user = User(
**{k: v for k, v in user_data.items() if k != "password_hash"}
)
break break
if existing_user: if existing_user:
@@ -97,7 +254,7 @@ def create_or_update_sso_user(openid: OpenID, provider: str) -> User:
if DISABLE_REGISTRATION: if DISABLE_REGISTRATION:
raise HTTPException( raise HTTPException(
status_code=403, status_code=403,
detail="Registration is disabled. Contact an administrator to create an account." detail="Registration is disabled. Contact an administrator to create an account.",
) )
# Create new user # Create new user
@@ -111,14 +268,14 @@ def create_or_update_sso_user(openid: OpenID, provider: str) -> User:
user = User( user = User(
username=username, username=username,
email=email, email=email,
role="user" # Default role for SSO users role="user", # Default role for SSO users
) )
users[username] = { users[username] = {
**user.to_dict(), **user.to_dict(),
"sso_provider": provider, "sso_provider": provider,
"sso_id": openid.id, "sso_id": openid.id,
"password_hash": None # SSO users don't have passwords "password_hash": None, # SSO users don't have passwords
} }
user_manager.save_users(users) user_manager.save_users(users)
@@ -132,25 +289,49 @@ async def sso_status():
providers = [] providers = []
if google_sso: if google_sso:
providers.append(SSOProvider( providers.append(
name="google", SSOProvider(
display_name="Google", name="google",
enabled=True, display_name="Google",
login_url="/api/auth/sso/login/google" enabled=True,
)) login_url="/api/auth/sso/login/google",
)
)
if github_sso: if github_sso:
providers.append(SSOProvider( providers.append(
name="github", SSOProvider(
display_name="GitHub", name="github",
enabled=True, display_name="GitHub",
login_url="/api/auth/sso/login/github" enabled=True,
)) login_url="/api/auth/sso/login/github",
)
)
if custom_sso:
providers.append(
SSOProvider(
name="custom",
display_name=CUSTOM_SSO_DISPLAY_NAME,
enabled=True,
login_url="/api/auth/sso/login/custom",
)
)
for idx, cfg in custom_sso_providers.items():
providers.append(
SSOProvider(
name=cfg["name"],
display_name=cfg.get("display_name", cfg["name"]),
enabled=True,
login_url=f"/api/auth/sso/login/custom/{idx}",
)
)
return SSOStatusResponse( return SSOStatusResponse(
sso_enabled=SSO_ENABLED and AUTH_ENABLED, sso_enabled=SSO_ENABLED and AUTH_ENABLED,
providers=providers, providers=providers,
registration_enabled=not DISABLE_REGISTRATION registration_enabled=not DISABLE_REGISTRATION,
) )
@@ -164,7 +345,9 @@ async def google_login():
raise HTTPException(status_code=400, detail="Google SSO is not configured") raise HTTPException(status_code=400, detail="Google SSO is not configured")
async with google_sso: async with google_sso:
return await google_sso.get_login_redirect(params={"prompt": "consent", "access_type": "offline"}) return await google_sso.get_login_redirect(
params={"prompt": "consent", "access_type": "offline"}
)
@router.get("/sso/login/github") @router.get("/sso/login/github")
@@ -180,6 +363,35 @@ async def github_login():
return await github_sso.get_login_redirect() return await github_sso.get_login_redirect()
@router.get("/sso/login/custom")
async def custom_login():
"""Initiate Custom SSO login"""
if not SSO_ENABLED or not AUTH_ENABLED:
raise HTTPException(status_code=400, detail="SSO is disabled")
if not custom_sso:
raise HTTPException(status_code=400, detail="Custom SSO is not configured")
async with custom_sso:
return await custom_sso.get_login_redirect()
@router.get("/sso/login/custom/{index}")
async def custom_login_indexed(index: int):
"""Initiate indexed Custom SSO login"""
if not SSO_ENABLED or not AUTH_ENABLED:
raise HTTPException(status_code=400, detail="SSO is disabled")
cfg = custom_sso_providers.get(index)
if not cfg:
raise HTTPException(
status_code=400, detail="Custom SSO provider not configured"
)
async with cfg["sso"]:
return await cfg["sso"].get_login_redirect()
@router.get("/sso/callback/google") @router.get("/sso/callback/google")
async def google_callback(request: Request): async def google_callback(request: Request):
"""Handle Google SSO callback""" """Handle Google SSO callback"""
@@ -210,7 +422,7 @@ async def google_callback(request: Request):
httponly=True, httponly=True,
secure=False, # Set to True in production with HTTPS secure=False, # Set to True in production with HTTPS
samesite="lax", samesite="lax",
max_age=timedelta(hours=24).total_seconds() max_age=timedelta(hours=24).total_seconds(),
) )
return response return response
@@ -218,7 +430,7 @@ async def google_callback(request: Request):
except HTTPException as e: except HTTPException as e:
# Handle specific HTTP exceptions (like registration disabled) # Handle specific HTTP exceptions (like registration disabled)
frontend_url = os.getenv("FRONTEND_URL", "http://localhost:3000") frontend_url = os.getenv("FRONTEND_URL", "http://localhost:3000")
error_msg = e.detail if hasattr(e, 'detail') else "Authentication failed" error_msg = e.detail if hasattr(e, "detail") else "Authentication failed"
logger.warning(f"Google SSO callback error: {error_msg}") logger.warning(f"Google SSO callback error: {error_msg}")
return RedirectResponse(url=f"{frontend_url}?error={error_msg}") return RedirectResponse(url=f"{frontend_url}?error={error_msg}")
@@ -258,7 +470,7 @@ async def github_callback(request: Request):
httponly=True, httponly=True,
secure=False, # Set to True in production with HTTPS secure=False, # Set to True in production with HTTPS
samesite="lax", samesite="lax",
max_age=timedelta(hours=24).total_seconds() max_age=timedelta(hours=24).total_seconds(),
) )
return response return response
@@ -266,7 +478,7 @@ async def github_callback(request: Request):
except HTTPException as e: except HTTPException as e:
# Handle specific HTTP exceptions (like registration disabled) # Handle specific HTTP exceptions (like registration disabled)
frontend_url = os.getenv("FRONTEND_URL", "http://localhost:3000") frontend_url = os.getenv("FRONTEND_URL", "http://localhost:3000")
error_msg = e.detail if hasattr(e, 'detail') else "Authentication failed" error_msg = e.detail if hasattr(e, "detail") else "Authentication failed"
logger.warning(f"GitHub SSO callback error: {error_msg}") logger.warning(f"GitHub SSO callback error: {error_msg}")
return RedirectResponse(url=f"{frontend_url}?error={error_msg}") return RedirectResponse(url=f"{frontend_url}?error={error_msg}")
@@ -276,6 +488,105 @@ async def github_callback(request: Request):
return RedirectResponse(url=f"{frontend_url}?error=Authentication failed") return RedirectResponse(url=f"{frontend_url}?error=Authentication failed")
@router.get("/sso/callback/custom")
async def custom_callback(request: Request):
"""Handle Custom SSO callback"""
if not SSO_ENABLED or not AUTH_ENABLED:
raise HTTPException(status_code=400, detail="SSO is disabled")
if not custom_sso:
raise HTTPException(status_code=400, detail="Custom SSO is not configured")
try:
async with custom_sso:
openid = await custom_sso.verify_and_process(request)
# Create or update user
user = create_or_update_sso_user(openid, "custom")
# Create JWT token
access_token = token_manager.create_token(user)
# Redirect to frontend with token (you might want to customize this)
frontend_url = os.getenv("FRONTEND_URL", "http://localhost:3000")
response = RedirectResponse(url=f"{frontend_url}?token={access_token}")
# Also set as HTTP-only cookie
response.set_cookie(
key="access_token",
value=access_token,
httponly=True,
secure=False, # Set to True in production with HTTPS
samesite="lax",
max_age=timedelta(hours=24).total_seconds(),
)
return response
except HTTPException as e:
# Handle specific HTTP exceptions (like registration disabled)
frontend_url = os.getenv("FRONTEND_URL", "http://localhost:3000")
error_msg = e.detail if hasattr(e, "detail") else "Authentication failed"
logger.warning(f"Custom SSO callback error: {error_msg}")
return RedirectResponse(url=f"{frontend_url}?error={error_msg}")
except Exception as e:
logger.error(f"Custom SSO callback error: {e}")
frontend_url = os.getenv("FRONTEND_URL", "http://localhost:3000")
return RedirectResponse(url=f"{frontend_url}?error=Authentication failed")
@router.get("/sso/callback/custom/{index}")
async def custom_callback_indexed(request: Request, index: int):
"""Handle indexed Custom SSO callback"""
if not SSO_ENABLED or not AUTH_ENABLED:
raise HTTPException(status_code=400, detail="SSO is disabled")
cfg = custom_sso_providers.get(index)
if not cfg:
raise HTTPException(
status_code=400, detail="Custom SSO provider not configured"
)
try:
async with cfg["sso"]:
openid = await cfg["sso"].verify_and_process(request)
# Create or update user
user = create_or_update_sso_user(openid, cfg["name"])
# Create JWT token
access_token = token_manager.create_token(user)
# Redirect to frontend with token (you might want to customize this)
frontend_url = os.getenv("FRONTEND_URL", "http://localhost:3000")
response = RedirectResponse(url=f"{frontend_url}?token={access_token}")
# Also set as HTTP-only cookie
response.set_cookie(
key="access_token",
value=access_token,
httponly=True,
secure=False, # Set to True in production with HTTPS
samesite="lax",
max_age=timedelta(hours=24).total_seconds(),
)
return response
except HTTPException as e:
# Handle specific HTTP exceptions (like registration disabled)
frontend_url = os.getenv("FRONTEND_URL", "http://localhost:3000")
error_msg = e.detail if hasattr(e, "detail") else "Authentication failed"
logger.warning(f"Custom[{index}] SSO callback error: {error_msg}")
return RedirectResponse(url=f"{frontend_url}?error={error_msg}")
except Exception as e:
logger.error(f"Custom[{index}] SSO callback error: {e}")
frontend_url = os.getenv("FRONTEND_URL", "http://localhost:3000")
return RedirectResponse(url=f"{frontend_url}?error=Authentication failed")
@router.post("/sso/unlink/{provider}", response_model=MessageResponse) @router.post("/sso/unlink/{provider}", response_model=MessageResponse)
async def unlink_sso_provider( async def unlink_sso_provider(
provider: str, provider: str,
@@ -285,7 +596,18 @@ async def unlink_sso_provider(
if not SSO_ENABLED or not AUTH_ENABLED: if not SSO_ENABLED or not AUTH_ENABLED:
raise HTTPException(status_code=400, detail="SSO is disabled") raise HTTPException(status_code=400, detail="SSO is disabled")
if provider not in ["google", "github"]: available = []
if google_sso:
available.append("google")
if github_sso:
available.append("github")
if custom_sso:
available.append("custom")
for cfg in custom_sso_providers.values():
available.append(cfg["name"])
if provider not in available:
raise HTTPException(status_code=400, detail="Invalid SSO provider") raise HTTPException(status_code=400, detail="Invalid SSO provider")
# Get current user from request (avoiding circular imports) # Get current user from request (avoiding circular imports)
@@ -294,7 +616,9 @@ async def unlink_sso_provider(
current_user = await require_auth_from_state(request) current_user = await require_auth_from_state(request)
if not current_user.sso_provider: if not current_user.sso_provider:
raise HTTPException(status_code=400, detail="User is not linked to any SSO provider") raise HTTPException(
status_code=400, detail="User is not linked to any SSO provider"
)
if current_user.sso_provider != provider: if current_user.sso_provider != provider:
raise HTTPException(status_code=400, detail=f"User is not linked to {provider}") raise HTTPException(status_code=400, detail=f"User is not linked to {provider}")
@@ -305,6 +629,8 @@ async def unlink_sso_provider(
users[current_user.username]["sso_provider"] = None users[current_user.username]["sso_provider"] = None
users[current_user.username]["sso_id"] = None users[current_user.username]["sso_id"] = None
user_manager.save_users(users) user_manager.save_users(users)
logger.info(f"Unlinked SSO provider {provider} from user {current_user.username}") logger.info(
f"Unlinked SSO provider {provider} from user {current_user.username}"
)
return MessageResponse(message=f"SSO provider {provider} unlinked successfully") return MessageResponse(message=f"SSO provider {provider} unlinked successfully")