feat: implement generic oauth provider

This commit is contained in:
Xoconoch
2025-08-25 08:03:59 -06:00
parent dc4a4f506f
commit c54a441228
3 changed files with 489 additions and 165 deletions

View File

@@ -30,6 +30,24 @@ Location: project `.env`. Minimal reference for server admins.
- FRONTEND_URL: Public UI base (e.g. `http://127.0.0.1:7171`) - FRONTEND_URL: Public UI base (e.g. `http://127.0.0.1:7171`)
- GOOGLE_CLIENT_ID / GOOGLE_CLIENT_SECRET - GOOGLE_CLIENT_ID / GOOGLE_CLIENT_SECRET
- GITHUB_CLIENT_ID / GITHUB_CLIENT_SECRET - GITHUB_CLIENT_ID / GITHUB_CLIENT_SECRET
- Custom/Generic OAuth (set all to enable a custom provider):
- CUSTOM_SSO_CLIENT_ID / CUSTOM_SSO_CLIENT_SECRET
- CUSTOM_SSO_AUTHORIZATION_ENDPOINT
- CUSTOM_SSO_TOKEN_ENDPOINT
- CUSTOM_SSO_USERINFO_ENDPOINT
- CUSTOM_SSO_SCOPE: Comma-separated scopes (optional)
- CUSTOM_SSO_NAME: Internal provider name (optional, default `custom`)
- CUSTOM_SSO_DISPLAY_NAME: UI name (optional, default `Custom`)
- Multiple Custom/Generic OAuth providers (up to 10):
- For provider index `i` (1..10), set:
- CUSTOM_SSO_CLIENT_ID_i / CUSTOM_SSO_CLIENT_SECRET_i
- CUSTOM_SSO_AUTHORIZATION_ENDPOINT_i
- CUSTOM_SSO_TOKEN_ENDPOINT_i
- CUSTOM_SSO_USERINFO_ENDPOINT_i
- CUSTOM_SSO_SCOPE_i (optional)
- CUSTOM_SSO_NAME_i (optional, default `custom{i}`)
- CUSTOM_SSO_DISPLAY_NAME_i (optional, default `Custom {i}`)
- Login URLs will be `/api/auth/sso/login/custom/i` and callback `/api/auth/sso/callback/custom/i`.
### Tips ### Tips
- If running behind a reverse proxy, set `FRONTEND_URL` and `SSO_BASE_REDIRECT_URI` to public URLs. - If running behind a reverse proxy, set `FRONTEND_URL` and `SSO_BASE_REDIRECT_URI` to public URLs.

View File

@@ -1,4 +1,4 @@
from fastapi import APIRouter, HTTPException, Depends, Request from fastapi import APIRouter, HTTPException, Depends
from fastapi.security import HTTPBearer, HTTPAuthorizationCredentials from fastapi.security import HTTPBearer, HTTPAuthorizationCredentials
from pydantic import BaseModel from pydantic import BaseModel
from typing import Optional, List from typing import Optional, List
@@ -14,6 +14,7 @@ security = HTTPBearer(auto_error=False)
# Include SSO sub-router # Include SSO sub-router
try: try:
from .sso import router as sso_router from .sso import router as sso_router
router.include_router(sso_router, tags=["sso"]) router.include_router(sso_router, tags=["sso"])
logging.info("SSO sub-router included in auth router") logging.info("SSO sub-router included in auth router")
except ImportError as e: except ImportError as e:
@@ -34,6 +35,7 @@ class RegisterRequest(BaseModel):
class CreateUserRequest(BaseModel): class CreateUserRequest(BaseModel):
"""Admin-only request to create users when registration is disabled""" """Admin-only request to create users when registration is disabled"""
username: str username: str
password: str password: str
email: Optional[str] = None email: Optional[str] = None
@@ -42,17 +44,20 @@ class CreateUserRequest(BaseModel):
class RoleUpdateRequest(BaseModel): class RoleUpdateRequest(BaseModel):
"""Request to update user role""" """Request to update user role"""
role: str role: str
class PasswordChangeRequest(BaseModel): class PasswordChangeRequest(BaseModel):
"""Request to change user password""" """Request to change user password"""
current_password: str current_password: str
new_password: str new_password: str
class AdminPasswordResetRequest(BaseModel): class AdminPasswordResetRequest(BaseModel):
"""Request for admin to reset user password""" """Request for admin to reset user password"""
new_password: str new_password: str
@@ -87,20 +92,20 @@ class AuthStatusResponse(BaseModel):
# Dependency to get current user # Dependency to get current user
async def get_current_user( async def get_current_user(
credentials: HTTPAuthorizationCredentials = Depends(security) credentials: HTTPAuthorizationCredentials = Depends(security),
) -> Optional[User]: ) -> Optional[User]:
"""Get current user from JWT token""" """Get current user from JWT token"""
if not AUTH_ENABLED: if not AUTH_ENABLED:
# When auth is disabled, return a mock admin user # When auth is disabled, return a mock admin user
return User(username="system", role="admin") return User(username="system", role="admin")
if not credentials: if not credentials:
return None return None
payload = token_manager.verify_token(credentials.credentials) payload = token_manager.verify_token(credentials.credentials)
if not payload: if not payload:
return None return None
user = user_manager.get_user(payload["username"]) user = user_manager.get_user(payload["username"])
return user return user
@@ -109,25 +114,22 @@ async def require_auth(current_user: User = Depends(get_current_user)) -> User:
"""Require authentication - raises HTTPException if not authenticated""" """Require authentication - raises HTTPException if not authenticated"""
if not AUTH_ENABLED: if not AUTH_ENABLED:
return User(username="system", role="admin") return User(username="system", role="admin")
if not current_user: if not current_user:
raise HTTPException( raise HTTPException(
status_code=401, status_code=401,
detail="Authentication required", detail="Authentication required",
headers={"WWW-Authenticate": "Bearer"}, headers={"WWW-Authenticate": "Bearer"},
) )
return current_user return current_user
async def require_admin(current_user: User = Depends(require_auth)) -> User: async def require_admin(current_user: User = Depends(require_auth)) -> User:
"""Require admin role - raises HTTPException if not admin""" """Require admin role - raises HTTPException if not admin"""
if current_user.role != "admin": if current_user.role != "admin":
raise HTTPException( raise HTTPException(status_code=403, detail="Admin access required")
status_code=403,
detail="Admin access required"
)
return current_user return current_user
@@ -138,24 +140,33 @@ async def auth_status(current_user: Optional[User] = Depends(get_current_user)):
# Check if SSO is enabled and get available providers # Check if SSO is enabled and get available providers
sso_enabled = False sso_enabled = False
sso_providers = [] sso_providers = []
try: try:
from . import sso from . import sso
sso_enabled = sso.SSO_ENABLED and AUTH_ENABLED sso_enabled = sso.SSO_ENABLED and AUTH_ENABLED
if sso.google_sso: if sso.google_sso:
sso_providers.append("google") sso_providers.append("google")
if sso.github_sso: if sso.github_sso:
sso_providers.append("github") sso_providers.append("github")
if getattr(sso, "custom_sso", None):
sso_providers.append("custom")
if getattr(sso, "custom_sso_providers", None):
if (
len(getattr(sso, "custom_sso_providers", {})) > 0
and "custom" not in sso_providers
):
sso_providers.append("custom")
except ImportError: except ImportError:
pass # SSO module not available pass # SSO module not available
return AuthStatusResponse( return AuthStatusResponse(
auth_enabled=AUTH_ENABLED, auth_enabled=AUTH_ENABLED,
authenticated=current_user is not None, authenticated=current_user is not None,
user=UserResponse(**current_user.to_public_dict()) if current_user else None, user=UserResponse(**current_user.to_public_dict()) if current_user else None,
registration_enabled=AUTH_ENABLED and not DISABLE_REGISTRATION, registration_enabled=AUTH_ENABLED and not DISABLE_REGISTRATION,
sso_enabled=sso_enabled, sso_enabled=sso_enabled,
sso_providers=sso_providers sso_providers=sso_providers,
) )
@@ -163,23 +174,16 @@ async def auth_status(current_user: Optional[User] = Depends(get_current_user)):
async def login(request: LoginRequest): async def login(request: LoginRequest):
"""Authenticate user and return access token""" """Authenticate user and return access token"""
if not AUTH_ENABLED: if not AUTH_ENABLED:
raise HTTPException( raise HTTPException(status_code=400, detail="Authentication is disabled")
status_code=400,
detail="Authentication is disabled"
)
user = user_manager.authenticate_user(request.username, request.password) user = user_manager.authenticate_user(request.username, request.password)
if not user: if not user:
raise HTTPException( raise HTTPException(status_code=401, detail="Invalid username or password")
status_code=401,
detail="Invalid username or password"
)
access_token = token_manager.create_token(user) access_token = token_manager.create_token(user)
return LoginResponse( return LoginResponse(
access_token=access_token, access_token=access_token, user=UserResponse(**user.to_public_dict())
user=UserResponse(**user.to_public_dict())
) )
@@ -187,31 +191,28 @@ async def login(request: LoginRequest):
async def register(request: RegisterRequest): async def register(request: RegisterRequest):
"""Register a new user""" """Register a new user"""
if not AUTH_ENABLED: if not AUTH_ENABLED:
raise HTTPException( raise HTTPException(status_code=400, detail="Authentication is disabled")
status_code=400,
detail="Authentication is disabled"
)
if DISABLE_REGISTRATION: if DISABLE_REGISTRATION:
raise HTTPException( raise HTTPException(
status_code=403, status_code=403,
detail="Public registration is disabled. Contact an administrator to create an account." detail="Public registration is disabled. Contact an administrator to create an account.",
) )
# Check if this is the first user (should be admin) # Check if this is the first user (should be admin)
existing_users = user_manager.list_users() existing_users = user_manager.list_users()
role = "admin" if len(existing_users) == 0 else "user" role = "admin" if len(existing_users) == 0 else "user"
success, message = user_manager.create_user( success, message = user_manager.create_user(
username=request.username, username=request.username,
password=request.password, password=request.password,
email=request.email, email=request.email,
role=role role=role,
) )
if not success: if not success:
raise HTTPException(status_code=400, detail=message) raise HTTPException(status_code=400, detail=message)
return MessageResponse(message=message) return MessageResponse(message=message)
@@ -233,70 +234,57 @@ async def list_users(current_user: User = Depends(require_admin)):
async def delete_user(username: str, current_user: User = Depends(require_admin)): async def delete_user(username: str, current_user: User = Depends(require_admin)):
"""Delete a user (admin only)""" """Delete a user (admin only)"""
if username == current_user.username: if username == current_user.username:
raise HTTPException( raise HTTPException(status_code=400, detail="Cannot delete your own account")
status_code=400,
detail="Cannot delete your own account"
)
success, message = user_manager.delete_user(username) success, message = user_manager.delete_user(username)
if not success: if not success:
raise HTTPException(status_code=404, detail=message) raise HTTPException(status_code=404, detail=message)
return MessageResponse(message=message) return MessageResponse(message=message)
@router.put("/users/{username}/role", response_model=MessageResponse) @router.put("/users/{username}/role", response_model=MessageResponse)
async def update_user_role( async def update_user_role(
username: str, username: str,
request: RoleUpdateRequest, request: RoleUpdateRequest,
current_user: User = Depends(require_admin) current_user: User = Depends(require_admin),
): ):
"""Update user role (admin only)""" """Update user role (admin only)"""
if request.role not in ["user", "admin"]: if request.role not in ["user", "admin"]:
raise HTTPException( raise HTTPException(status_code=400, detail="Role must be 'user' or 'admin'")
status_code=400,
detail="Role must be 'user' or 'admin'"
)
if username == current_user.username: if username == current_user.username:
raise HTTPException( raise HTTPException(status_code=400, detail="Cannot change your own role")
status_code=400,
detail="Cannot change your own role"
)
success, message = user_manager.update_user_role(username, request.role) success, message = user_manager.update_user_role(username, request.role)
if not success: if not success:
raise HTTPException(status_code=404, detail=message) raise HTTPException(status_code=404, detail=message)
return MessageResponse(message=message) return MessageResponse(message=message)
@router.post("/users/create", response_model=MessageResponse) @router.post("/users/create", response_model=MessageResponse)
async def create_user_admin(request: CreateUserRequest, current_user: User = Depends(require_admin)): async def create_user_admin(
request: CreateUserRequest, current_user: User = Depends(require_admin)
):
"""Create a new user (admin only) - for use when registration is disabled""" """Create a new user (admin only) - for use when registration is disabled"""
if not AUTH_ENABLED: if not AUTH_ENABLED:
raise HTTPException( raise HTTPException(status_code=400, detail="Authentication is disabled")
status_code=400,
detail="Authentication is disabled"
)
# Validate role # Validate role
if request.role not in ["user", "admin"]: if request.role not in ["user", "admin"]:
raise HTTPException( raise HTTPException(status_code=400, detail="Role must be 'user' or 'admin'")
status_code=400,
detail="Role must be 'user' or 'admin'"
)
success, message = user_manager.create_user( success, message = user_manager.create_user(
username=request.username, username=request.username,
password=request.password, password=request.password,
email=request.email, email=request.email,
role=request.role role=request.role,
) )
if not success: if not success:
raise HTTPException(status_code=400, detail=message) raise HTTPException(status_code=400, detail=message)
return MessageResponse(message=message) return MessageResponse(message=message)
@@ -309,22 +297,18 @@ async def get_profile(current_user: User = Depends(require_auth)):
@router.put("/profile/password", response_model=MessageResponse) @router.put("/profile/password", response_model=MessageResponse)
async def change_password( async def change_password(
request: PasswordChangeRequest, request: PasswordChangeRequest, current_user: User = Depends(require_auth)
current_user: User = Depends(require_auth)
): ):
"""Change current user's password""" """Change current user's password"""
if not AUTH_ENABLED: if not AUTH_ENABLED:
raise HTTPException( raise HTTPException(status_code=400, detail="Authentication is disabled")
status_code=400,
detail="Authentication is disabled"
)
success, message = user_manager.change_password( success, message = user_manager.change_password(
username=current_user.username, username=current_user.username,
current_password=request.current_password, current_password=request.current_password,
new_password=request.new_password new_password=request.new_password,
) )
if not success: if not success:
# Determine appropriate HTTP status code based on error message # Determine appropriate HTTP status code based on error message
if "Current password is incorrect" in message: if "Current password is incorrect" in message:
@@ -333,9 +317,9 @@ async def change_password(
status_code = 404 status_code = 404
else: else:
status_code = 400 status_code = 400
raise HTTPException(status_code=status_code, detail=message) raise HTTPException(status_code=status_code, detail=message)
return MessageResponse(message=message) return MessageResponse(message=message)
@@ -343,30 +327,26 @@ async def change_password(
async def admin_reset_password( async def admin_reset_password(
username: str, username: str,
request: AdminPasswordResetRequest, request: AdminPasswordResetRequest,
current_user: User = Depends(require_admin) current_user: User = Depends(require_admin),
): ):
"""Admin reset user password (admin only)""" """Admin reset user password (admin only)"""
if not AUTH_ENABLED: if not AUTH_ENABLED:
raise HTTPException( raise HTTPException(status_code=400, detail="Authentication is disabled")
status_code=400,
detail="Authentication is disabled"
)
success, message = user_manager.admin_reset_password( success, message = user_manager.admin_reset_password(
username=username, username=username, new_password=request.new_password
new_password=request.new_password
) )
if not success: if not success:
# Determine appropriate HTTP status code based on error message # Determine appropriate HTTP status code based on error message
if "User not found" in message: if "User not found" in message:
status_code = 404 status_code = 404
else: else:
status_code = 400 status_code = 400
raise HTTPException(status_code=status_code, detail=message) raise HTTPException(status_code=status_code, detail=message)
return MessageResponse(message=message) return MessageResponse(message=message)
# Note: SSO routes are included in the main app, not here to avoid circular imports # Note: SSO routes are included in the main app, not here to avoid circular imports

View File

@@ -1,17 +1,19 @@
""" """
SSO (Single Sign-On) implementation for Google and GitHub authentication SSO (Single Sign-On) implementation for Google and GitHub authentication
""" """
import os import os
import logging import logging
from typing import Optional, Dict, Any from typing import Optional, Dict, Any
from datetime import datetime, timedelta from datetime import datetime, timedelta
from fastapi import APIRouter, Request, HTTPException, Depends from fastapi import APIRouter, Request, HTTPException
from fastapi.responses import RedirectResponse from fastapi.responses import RedirectResponse
from fastapi_sso.sso.google import GoogleSSO from fastapi_sso.sso.google import GoogleSSO
from fastapi_sso.sso.github import GithubSSO from fastapi_sso.sso.github import GithubSSO
from fastapi_sso.sso.base import OpenID from fastapi_sso.sso.base import OpenID
from pydantic import BaseModel from pydantic import BaseModel
from fastapi_sso.sso.generic import create_provider
from . import user_manager, token_manager, User, AUTH_ENABLED, DISABLE_REGISTRATION from . import user_manager, token_manager, User, AUTH_ENABLED, DISABLE_REGISTRATION
@@ -25,11 +27,14 @@ GOOGLE_CLIENT_ID = os.getenv("GOOGLE_CLIENT_ID")
GOOGLE_CLIENT_SECRET = os.getenv("GOOGLE_CLIENT_SECRET") GOOGLE_CLIENT_SECRET = os.getenv("GOOGLE_CLIENT_SECRET")
GITHUB_CLIENT_ID = os.getenv("GITHUB_CLIENT_ID") GITHUB_CLIENT_ID = os.getenv("GITHUB_CLIENT_ID")
GITHUB_CLIENT_SECRET = os.getenv("GITHUB_CLIENT_SECRET") GITHUB_CLIENT_SECRET = os.getenv("GITHUB_CLIENT_SECRET")
SSO_BASE_REDIRECT_URI = os.getenv("SSO_BASE_REDIRECT_URI", "http://localhost:7171/api/auth/sso/callback") SSO_BASE_REDIRECT_URI = os.getenv(
"SSO_BASE_REDIRECT_URI", "http://localhost:7171/api/auth/sso/callback"
)
# Initialize SSO providers # Initialize SSO providers
google_sso = None google_sso = None
github_sso = None github_sso = None
custom_sso = None
if GOOGLE_CLIENT_ID and GOOGLE_CLIENT_SECRET: if GOOGLE_CLIENT_ID and GOOGLE_CLIENT_SECRET:
google_sso = GoogleSSO( google_sso = GoogleSSO(
@@ -47,6 +52,154 @@ if GITHUB_CLIENT_ID and GITHUB_CLIENT_SECRET:
allow_insecure_http=True, # Set to False in production with HTTPS allow_insecure_http=True, # Set to False in production with HTTPS
) )
# Custom/Generic OAuth provider configuration
CUSTOM_SSO_CLIENT_ID = os.getenv("CUSTOM_SSO_CLIENT_ID")
CUSTOM_SSO_CLIENT_SECRET = os.getenv("CUSTOM_SSO_CLIENT_SECRET")
CUSTOM_SSO_AUTHORIZATION_ENDPOINT = os.getenv("CUSTOM_SSO_AUTHORIZATION_ENDPOINT")
CUSTOM_SSO_TOKEN_ENDPOINT = os.getenv("CUSTOM_SSO_TOKEN_ENDPOINT")
CUSTOM_SSO_USERINFO_ENDPOINT = os.getenv("CUSTOM_SSO_USERINFO_ENDPOINT")
CUSTOM_SSO_SCOPE = os.getenv("CUSTOM_SSO_SCOPE") # comma-separated list
CUSTOM_SSO_NAME = os.getenv("CUSTOM_SSO_NAME", "custom")
CUSTOM_SSO_DISPLAY_NAME = os.getenv("CUSTOM_SSO_DISPLAY_NAME", "Custom")
def _default_custom_response_convertor(
userinfo: Dict[str, Any], _client=None
) -> OpenID:
"""Best-effort convertor from generic userinfo to OpenID."""
user_id = (
userinfo.get("sub")
or userinfo.get("id")
or userinfo.get("user_id")
or userinfo.get("uid")
or userinfo.get("uuid")
)
email = userinfo.get("email")
display_name = (
userinfo.get("name")
or userinfo.get("preferred_username")
or userinfo.get("login")
or email
or (str(user_id) if user_id is not None else None)
)
picture = userinfo.get("picture") or userinfo.get("avatar_url")
if not user_id and email:
user_id = email
return OpenID(
id=str(user_id) if user_id is not None else "",
email=email,
display_name=display_name,
picture=picture,
provider=CUSTOM_SSO_NAME,
)
if all(
[
CUSTOM_SSO_CLIENT_ID,
CUSTOM_SSO_CLIENT_SECRET,
CUSTOM_SSO_AUTHORIZATION_ENDPOINT,
CUSTOM_SSO_TOKEN_ENDPOINT,
CUSTOM_SSO_USERINFO_ENDPOINT,
]
):
discovery = {
"authorization_endpoint": CUSTOM_SSO_AUTHORIZATION_ENDPOINT,
"token_endpoint": CUSTOM_SSO_TOKEN_ENDPOINT,
"userinfo_endpoint": CUSTOM_SSO_USERINFO_ENDPOINT,
}
default_scope = (
[s.strip() for s in CUSTOM_SSO_SCOPE.split(",") if s.strip()]
if CUSTOM_SSO_SCOPE
else None
)
CustomProvider = create_provider(
name=CUSTOM_SSO_NAME,
discovery_document=discovery,
response_convertor=_default_custom_response_convertor,
default_scope=default_scope,
)
custom_sso = CustomProvider(
client_id=CUSTOM_SSO_CLIENT_ID,
client_secret=CUSTOM_SSO_CLIENT_SECRET,
redirect_uri=f"{SSO_BASE_REDIRECT_URI}/custom",
allow_insecure_http=True, # Set to False in production with HTTPS
)
# Support multiple indexed custom providers (CUSTOM_*_i), up to 10
custom_sso_providers: Dict[int, Dict[str, Any]] = {}
def _make_response_convertor(provider_name: str):
def _convert(userinfo: Dict[str, Any], _client=None) -> OpenID:
user_id = (
userinfo.get("sub")
or userinfo.get("id")
or userinfo.get("user_id")
or userinfo.get("uid")
or userinfo.get("uuid")
)
email = userinfo.get("email")
display_name = (
userinfo.get("name")
or userinfo.get("preferred_username")
or userinfo.get("login")
or email
or (str(user_id) if user_id is not None else None)
)
picture = userinfo.get("picture") or userinfo.get("avatar_url")
if not user_id and email:
user_id = email
return OpenID(
id=str(user_id) if user_id is not None else "",
email=email,
display_name=display_name,
picture=picture,
provider=provider_name,
)
return _convert
for i in range(1, 11):
cid = os.getenv(f"CUSTOM_SSO_CLIENT_ID_{i}")
csecret = os.getenv(f"CUSTOM_SSO_CLIENT_SECRET_{i}")
auth_ep = os.getenv(f"CUSTOM_SSO_AUTHORIZATION_ENDPOINT_{i}")
token_ep = os.getenv(f"CUSTOM_SSO_TOKEN_ENDPOINT_{i}")
userinfo_ep = os.getenv(f"CUSTOM_SSO_USERINFO_ENDPOINT_{i}")
scope_raw = os.getenv(f"CUSTOM_SSO_SCOPE_{i}")
name_i = os.getenv(f"CUSTOM_SSO_NAME_{i}", f"custom{i}")
display_name_i = os.getenv(f"CUSTOM_SSO_DISPLAY_NAME_{i}", f"Custom {i}")
if all([cid, csecret, auth_ep, token_ep, userinfo_ep]):
discovery_i = {
"authorization_endpoint": auth_ep,
"token_endpoint": token_ep,
"userinfo_endpoint": userinfo_ep,
}
default_scope_i = (
[s.strip() for s in scope_raw.split(",") if s.strip()]
if scope_raw
else None
)
ProviderClass = create_provider(
name=name_i,
discovery_document=discovery_i,
response_convertor=_make_response_convertor(name_i),
default_scope=default_scope_i,
)
provider_instance = ProviderClass(
client_id=cid,
client_secret=csecret,
redirect_uri=f"{SSO_BASE_REDIRECT_URI}/custom/{i}",
allow_insecure_http=True, # Set to False in production with HTTPS
)
custom_sso_providers[i] = {
"sso": provider_instance,
"name": name_i,
"display_name": display_name_i,
}
class MessageResponse(BaseModel): class MessageResponse(BaseModel):
message: str message: str
@@ -70,21 +223,25 @@ def create_or_update_sso_user(openid: OpenID, provider: str) -> User:
# Generate username from email or use provider ID # Generate username from email or use provider ID
email = openid.email email = openid.email
if not email: if not email:
raise HTTPException(status_code=400, detail="Email is required for SSO authentication") raise HTTPException(
status_code=400, detail="Email is required for SSO authentication"
)
# Use email prefix as username, fallback to provider + id # Use email prefix as username, fallback to provider + id
username = email.split("@")[0] username = email.split("@")[0]
if not username: if not username:
username = f"{provider}_{openid.id}" username = f"{provider}_{openid.id}"
# Check if user already exists by email # Check if user already exists by email
existing_user = None existing_user = None
users = user_manager.load_users() users = user_manager.load_users()
for user_data in users.values(): for user_data in users.values():
if user_data.get("email") == email: if user_data.get("email") == email:
existing_user = User(**{k: v for k, v in user_data.items() if k != "password_hash"}) existing_user = User(
**{k: v for k, v in user_data.items() if k != "password_hash"}
)
break break
if existing_user: if existing_user:
# Update last login for existing user (always allowed) # Update last login for existing user (always allowed)
users[existing_user.username]["last_login"] = datetime.utcnow().isoformat() users[existing_user.username]["last_login"] = datetime.utcnow().isoformat()
@@ -96,10 +253,10 @@ def create_or_update_sso_user(openid: OpenID, provider: str) -> User:
# Check if registration is disabled before creating new user # Check if registration is disabled before creating new user
if DISABLE_REGISTRATION: if DISABLE_REGISTRATION:
raise HTTPException( raise HTTPException(
status_code=403, status_code=403,
detail="Registration is disabled. Contact an administrator to create an account." detail="Registration is disabled. Contact an administrator to create an account.",
) )
# Create new user # Create new user
# Ensure username is unique # Ensure username is unique
counter = 1 counter = 1
@@ -107,20 +264,20 @@ def create_or_update_sso_user(openid: OpenID, provider: str) -> User:
while username in users: while username in users:
username = f"{original_username}{counter}" username = f"{original_username}{counter}"
counter += 1 counter += 1
user = User( user = User(
username=username, username=username,
email=email, email=email,
role="user" # Default role for SSO users role="user", # Default role for SSO users
) )
users[username] = { users[username] = {
**user.to_dict(), **user.to_dict(),
"sso_provider": provider, "sso_provider": provider,
"sso_id": openid.id, "sso_id": openid.id,
"password_hash": None # SSO users don't have passwords "password_hash": None, # SSO users don't have passwords
} }
user_manager.save_users(users) user_manager.save_users(users)
logger.info(f"Created SSO user: {username} via {provider}") logger.info(f"Created SSO user: {username} via {provider}")
return user return user
@@ -130,27 +287,51 @@ def create_or_update_sso_user(openid: OpenID, provider: str) -> User:
async def sso_status(): async def sso_status():
"""Get SSO status and available providers""" """Get SSO status and available providers"""
providers = [] providers = []
if google_sso: if google_sso:
providers.append(SSOProvider( providers.append(
name="google", SSOProvider(
display_name="Google", name="google",
enabled=True, display_name="Google",
login_url="/api/auth/sso/login/google" enabled=True,
)) login_url="/api/auth/sso/login/google",
)
)
if github_sso: if github_sso:
providers.append(SSOProvider( providers.append(
name="github", SSOProvider(
display_name="GitHub", name="github",
enabled=True, display_name="GitHub",
login_url="/api/auth/sso/login/github" enabled=True,
)) login_url="/api/auth/sso/login/github",
)
)
if custom_sso:
providers.append(
SSOProvider(
name="custom",
display_name=CUSTOM_SSO_DISPLAY_NAME,
enabled=True,
login_url="/api/auth/sso/login/custom",
)
)
for idx, cfg in custom_sso_providers.items():
providers.append(
SSOProvider(
name=cfg["name"],
display_name=cfg.get("display_name", cfg["name"]),
enabled=True,
login_url=f"/api/auth/sso/login/custom/{idx}",
)
)
return SSOStatusResponse( return SSOStatusResponse(
sso_enabled=SSO_ENABLED and AUTH_ENABLED, sso_enabled=SSO_ENABLED and AUTH_ENABLED,
providers=providers, providers=providers,
registration_enabled=not DISABLE_REGISTRATION registration_enabled=not DISABLE_REGISTRATION,
) )
@@ -159,12 +340,14 @@ async def google_login():
"""Initiate Google SSO login""" """Initiate Google SSO login"""
if not SSO_ENABLED or not AUTH_ENABLED: if not SSO_ENABLED or not AUTH_ENABLED:
raise HTTPException(status_code=400, detail="SSO is disabled") raise HTTPException(status_code=400, detail="SSO is disabled")
if not google_sso: if not google_sso:
raise HTTPException(status_code=400, detail="Google SSO is not configured") raise HTTPException(status_code=400, detail="Google SSO is not configured")
async with google_sso: async with google_sso:
return await google_sso.get_login_redirect(params={"prompt": "consent", "access_type": "offline"}) return await google_sso.get_login_redirect(
params={"prompt": "consent", "access_type": "offline"}
)
@router.get("/sso/login/github") @router.get("/sso/login/github")
@@ -172,37 +355,66 @@ async def github_login():
"""Initiate GitHub SSO login""" """Initiate GitHub SSO login"""
if not SSO_ENABLED or not AUTH_ENABLED: if not SSO_ENABLED or not AUTH_ENABLED:
raise HTTPException(status_code=400, detail="SSO is disabled") raise HTTPException(status_code=400, detail="SSO is disabled")
if not github_sso: if not github_sso:
raise HTTPException(status_code=400, detail="GitHub SSO is not configured") raise HTTPException(status_code=400, detail="GitHub SSO is not configured")
async with github_sso: async with github_sso:
return await github_sso.get_login_redirect() return await github_sso.get_login_redirect()
@router.get("/sso/login/custom")
async def custom_login():
"""Initiate Custom SSO login"""
if not SSO_ENABLED or not AUTH_ENABLED:
raise HTTPException(status_code=400, detail="SSO is disabled")
if not custom_sso:
raise HTTPException(status_code=400, detail="Custom SSO is not configured")
async with custom_sso:
return await custom_sso.get_login_redirect()
@router.get("/sso/login/custom/{index}")
async def custom_login_indexed(index: int):
"""Initiate indexed Custom SSO login"""
if not SSO_ENABLED or not AUTH_ENABLED:
raise HTTPException(status_code=400, detail="SSO is disabled")
cfg = custom_sso_providers.get(index)
if not cfg:
raise HTTPException(
status_code=400, detail="Custom SSO provider not configured"
)
async with cfg["sso"]:
return await cfg["sso"].get_login_redirect()
@router.get("/sso/callback/google") @router.get("/sso/callback/google")
async def google_callback(request: Request): async def google_callback(request: Request):
"""Handle Google SSO callback""" """Handle Google SSO callback"""
if not SSO_ENABLED or not AUTH_ENABLED: if not SSO_ENABLED or not AUTH_ENABLED:
raise HTTPException(status_code=400, detail="SSO is disabled") raise HTTPException(status_code=400, detail="SSO is disabled")
if not google_sso: if not google_sso:
raise HTTPException(status_code=400, detail="Google SSO is not configured") raise HTTPException(status_code=400, detail="Google SSO is not configured")
try: try:
async with google_sso: async with google_sso:
openid = await google_sso.verify_and_process(request) openid = await google_sso.verify_and_process(request)
# Create or update user # Create or update user
user = create_or_update_sso_user(openid, "google") user = create_or_update_sso_user(openid, "google")
# Create JWT token # Create JWT token
access_token = token_manager.create_token(user) access_token = token_manager.create_token(user)
# Redirect to frontend with token (you might want to customize this) # Redirect to frontend with token (you might want to customize this)
frontend_url = os.getenv("FRONTEND_URL", "http://localhost:3000") frontend_url = os.getenv("FRONTEND_URL", "http://localhost:3000")
response = RedirectResponse(url=f"{frontend_url}?token={access_token}") response = RedirectResponse(url=f"{frontend_url}?token={access_token}")
# Also set as HTTP-only cookie # Also set as HTTP-only cookie
response.set_cookie( response.set_cookie(
key="access_token", key="access_token",
@@ -210,18 +422,18 @@ async def google_callback(request: Request):
httponly=True, httponly=True,
secure=False, # Set to True in production with HTTPS secure=False, # Set to True in production with HTTPS
samesite="lax", samesite="lax",
max_age=timedelta(hours=24).total_seconds() max_age=timedelta(hours=24).total_seconds(),
) )
return response return response
except HTTPException as e: except HTTPException as e:
# Handle specific HTTP exceptions (like registration disabled) # Handle specific HTTP exceptions (like registration disabled)
frontend_url = os.getenv("FRONTEND_URL", "http://localhost:3000") frontend_url = os.getenv("FRONTEND_URL", "http://localhost:3000")
error_msg = e.detail if hasattr(e, 'detail') else "Authentication failed" error_msg = e.detail if hasattr(e, "detail") else "Authentication failed"
logger.warning(f"Google SSO callback error: {error_msg}") logger.warning(f"Google SSO callback error: {error_msg}")
return RedirectResponse(url=f"{frontend_url}?error={error_msg}") return RedirectResponse(url=f"{frontend_url}?error={error_msg}")
except Exception as e: except Exception as e:
logger.error(f"Google SSO callback error: {e}") logger.error(f"Google SSO callback error: {e}")
frontend_url = os.getenv("FRONTEND_URL", "http://localhost:3000") frontend_url = os.getenv("FRONTEND_URL", "http://localhost:3000")
@@ -233,24 +445,24 @@ async def github_callback(request: Request):
"""Handle GitHub SSO callback""" """Handle GitHub SSO callback"""
if not SSO_ENABLED or not AUTH_ENABLED: if not SSO_ENABLED or not AUTH_ENABLED:
raise HTTPException(status_code=400, detail="SSO is disabled") raise HTTPException(status_code=400, detail="SSO is disabled")
if not github_sso: if not github_sso:
raise HTTPException(status_code=400, detail="GitHub SSO is not configured") raise HTTPException(status_code=400, detail="GitHub SSO is not configured")
try: try:
async with github_sso: async with github_sso:
openid = await github_sso.verify_and_process(request) openid = await github_sso.verify_and_process(request)
# Create or update user # Create or update user
user = create_or_update_sso_user(openid, "github") user = create_or_update_sso_user(openid, "github")
# Create JWT token # Create JWT token
access_token = token_manager.create_token(user) access_token = token_manager.create_token(user)
# Redirect to frontend with token (you might want to customize this) # Redirect to frontend with token (you might want to customize this)
frontend_url = os.getenv("FRONTEND_URL", "http://localhost:3000") frontend_url = os.getenv("FRONTEND_URL", "http://localhost:3000")
response = RedirectResponse(url=f"{frontend_url}?token={access_token}") response = RedirectResponse(url=f"{frontend_url}?token={access_token}")
# Also set as HTTP-only cookie # Also set as HTTP-only cookie
response.set_cookie( response.set_cookie(
key="access_token", key="access_token",
@@ -258,24 +470,123 @@ async def github_callback(request: Request):
httponly=True, httponly=True,
secure=False, # Set to True in production with HTTPS secure=False, # Set to True in production with HTTPS
samesite="lax", samesite="lax",
max_age=timedelta(hours=24).total_seconds() max_age=timedelta(hours=24).total_seconds(),
) )
return response return response
except HTTPException as e: except HTTPException as e:
# Handle specific HTTP exceptions (like registration disabled) # Handle specific HTTP exceptions (like registration disabled)
frontend_url = os.getenv("FRONTEND_URL", "http://localhost:3000") frontend_url = os.getenv("FRONTEND_URL", "http://localhost:3000")
error_msg = e.detail if hasattr(e, 'detail') else "Authentication failed" error_msg = e.detail if hasattr(e, "detail") else "Authentication failed"
logger.warning(f"GitHub SSO callback error: {error_msg}") logger.warning(f"GitHub SSO callback error: {error_msg}")
return RedirectResponse(url=f"{frontend_url}?error={error_msg}") return RedirectResponse(url=f"{frontend_url}?error={error_msg}")
except Exception as e: except Exception as e:
logger.error(f"GitHub SSO callback error: {e}") logger.error(f"GitHub SSO callback error: {e}")
frontend_url = os.getenv("FRONTEND_URL", "http://localhost:3000") frontend_url = os.getenv("FRONTEND_URL", "http://localhost:3000")
return RedirectResponse(url=f"{frontend_url}?error=Authentication failed") return RedirectResponse(url=f"{frontend_url}?error=Authentication failed")
@router.get("/sso/callback/custom")
async def custom_callback(request: Request):
"""Handle Custom SSO callback"""
if not SSO_ENABLED or not AUTH_ENABLED:
raise HTTPException(status_code=400, detail="SSO is disabled")
if not custom_sso:
raise HTTPException(status_code=400, detail="Custom SSO is not configured")
try:
async with custom_sso:
openid = await custom_sso.verify_and_process(request)
# Create or update user
user = create_or_update_sso_user(openid, "custom")
# Create JWT token
access_token = token_manager.create_token(user)
# Redirect to frontend with token (you might want to customize this)
frontend_url = os.getenv("FRONTEND_URL", "http://localhost:3000")
response = RedirectResponse(url=f"{frontend_url}?token={access_token}")
# Also set as HTTP-only cookie
response.set_cookie(
key="access_token",
value=access_token,
httponly=True,
secure=False, # Set to True in production with HTTPS
samesite="lax",
max_age=timedelta(hours=24).total_seconds(),
)
return response
except HTTPException as e:
# Handle specific HTTP exceptions (like registration disabled)
frontend_url = os.getenv("FRONTEND_URL", "http://localhost:3000")
error_msg = e.detail if hasattr(e, "detail") else "Authentication failed"
logger.warning(f"Custom SSO callback error: {error_msg}")
return RedirectResponse(url=f"{frontend_url}?error={error_msg}")
except Exception as e:
logger.error(f"Custom SSO callback error: {e}")
frontend_url = os.getenv("FRONTEND_URL", "http://localhost:3000")
return RedirectResponse(url=f"{frontend_url}?error=Authentication failed")
@router.get("/sso/callback/custom/{index}")
async def custom_callback_indexed(request: Request, index: int):
"""Handle indexed Custom SSO callback"""
if not SSO_ENABLED or not AUTH_ENABLED:
raise HTTPException(status_code=400, detail="SSO is disabled")
cfg = custom_sso_providers.get(index)
if not cfg:
raise HTTPException(
status_code=400, detail="Custom SSO provider not configured"
)
try:
async with cfg["sso"]:
openid = await cfg["sso"].verify_and_process(request)
# Create or update user
user = create_or_update_sso_user(openid, cfg["name"])
# Create JWT token
access_token = token_manager.create_token(user)
# Redirect to frontend with token (you might want to customize this)
frontend_url = os.getenv("FRONTEND_URL", "http://localhost:3000")
response = RedirectResponse(url=f"{frontend_url}?token={access_token}")
# Also set as HTTP-only cookie
response.set_cookie(
key="access_token",
value=access_token,
httponly=True,
secure=False, # Set to True in production with HTTPS
samesite="lax",
max_age=timedelta(hours=24).total_seconds(),
)
return response
except HTTPException as e:
# Handle specific HTTP exceptions (like registration disabled)
frontend_url = os.getenv("FRONTEND_URL", "http://localhost:3000")
error_msg = e.detail if hasattr(e, "detail") else "Authentication failed"
logger.warning(f"Custom[{index}] SSO callback error: {error_msg}")
return RedirectResponse(url=f"{frontend_url}?error={error_msg}")
except Exception as e:
logger.error(f"Custom[{index}] SSO callback error: {e}")
frontend_url = os.getenv("FRONTEND_URL", "http://localhost:3000")
return RedirectResponse(url=f"{frontend_url}?error=Authentication failed")
@router.post("/sso/unlink/{provider}", response_model=MessageResponse) @router.post("/sso/unlink/{provider}", response_model=MessageResponse)
async def unlink_sso_provider( async def unlink_sso_provider(
provider: str, provider: str,
@@ -284,27 +595,42 @@ async def unlink_sso_provider(
"""Unlink SSO provider from user account""" """Unlink SSO provider from user account"""
if not SSO_ENABLED or not AUTH_ENABLED: if not SSO_ENABLED or not AUTH_ENABLED:
raise HTTPException(status_code=400, detail="SSO is disabled") raise HTTPException(status_code=400, detail="SSO is disabled")
if provider not in ["google", "github"]: available = []
if google_sso:
available.append("google")
if github_sso:
available.append("github")
if custom_sso:
available.append("custom")
for cfg in custom_sso_providers.values():
available.append(cfg["name"])
if provider not in available:
raise HTTPException(status_code=400, detail="Invalid SSO provider") raise HTTPException(status_code=400, detail="Invalid SSO provider")
# Get current user from request (avoiding circular imports) # Get current user from request (avoiding circular imports)
from .middleware import require_auth_from_state from .middleware import require_auth_from_state
current_user = await require_auth_from_state(request) current_user = await require_auth_from_state(request)
if not current_user.sso_provider: if not current_user.sso_provider:
raise HTTPException(status_code=400, detail="User is not linked to any SSO provider") raise HTTPException(
status_code=400, detail="User is not linked to any SSO provider"
)
if current_user.sso_provider != provider: if current_user.sso_provider != provider:
raise HTTPException(status_code=400, detail=f"User is not linked to {provider}") raise HTTPException(status_code=400, detail=f"User is not linked to {provider}")
# Update user to remove SSO linkage # Update user to remove SSO linkage
users = user_manager.load_users() users = user_manager.load_users()
if current_user.username in users: if current_user.username in users:
users[current_user.username]["sso_provider"] = None users[current_user.username]["sso_provider"] = None
users[current_user.username]["sso_id"] = None users[current_user.username]["sso_id"] = None
user_manager.save_users(users) user_manager.save_users(users)
logger.info(f"Unlinked SSO provider {provider} from user {current_user.username}") logger.info(
f"Unlinked SSO provider {provider} from user {current_user.username}"
return MessageResponse(message=f"SSO provider {provider} unlinked successfully") )
return MessageResponse(message=f"SSO provider {provider} unlinked successfully")