feat: implement generic oauth provider
This commit is contained in:
@@ -30,6 +30,24 @@ Location: project `.env`. Minimal reference for server admins.
|
||||
- FRONTEND_URL: Public UI base (e.g. `http://127.0.0.1:7171`)
|
||||
- GOOGLE_CLIENT_ID / GOOGLE_CLIENT_SECRET
|
||||
- GITHUB_CLIENT_ID / GITHUB_CLIENT_SECRET
|
||||
- Custom/Generic OAuth (set all to enable a custom provider):
|
||||
- CUSTOM_SSO_CLIENT_ID / CUSTOM_SSO_CLIENT_SECRET
|
||||
- CUSTOM_SSO_AUTHORIZATION_ENDPOINT
|
||||
- CUSTOM_SSO_TOKEN_ENDPOINT
|
||||
- CUSTOM_SSO_USERINFO_ENDPOINT
|
||||
- CUSTOM_SSO_SCOPE: Comma-separated scopes (optional)
|
||||
- CUSTOM_SSO_NAME: Internal provider name (optional, default `custom`)
|
||||
- CUSTOM_SSO_DISPLAY_NAME: UI name (optional, default `Custom`)
|
||||
- Multiple Custom/Generic OAuth providers (up to 10):
|
||||
- For provider index `i` (1..10), set:
|
||||
- CUSTOM_SSO_CLIENT_ID_i / CUSTOM_SSO_CLIENT_SECRET_i
|
||||
- CUSTOM_SSO_AUTHORIZATION_ENDPOINT_i
|
||||
- CUSTOM_SSO_TOKEN_ENDPOINT_i
|
||||
- CUSTOM_SSO_USERINFO_ENDPOINT_i
|
||||
- CUSTOM_SSO_SCOPE_i (optional)
|
||||
- CUSTOM_SSO_NAME_i (optional, default `custom{i}`)
|
||||
- CUSTOM_SSO_DISPLAY_NAME_i (optional, default `Custom {i}`)
|
||||
- Login URLs will be `/api/auth/sso/login/custom/i` and callback `/api/auth/sso/callback/custom/i`.
|
||||
|
||||
### Tips
|
||||
- If running behind a reverse proxy, set `FRONTEND_URL` and `SSO_BASE_REDIRECT_URI` to public URLs.
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
from fastapi import APIRouter, HTTPException, Depends, Request
|
||||
from fastapi import APIRouter, HTTPException, Depends
|
||||
from fastapi.security import HTTPBearer, HTTPAuthorizationCredentials
|
||||
from pydantic import BaseModel
|
||||
from typing import Optional, List
|
||||
@@ -14,6 +14,7 @@ security = HTTPBearer(auto_error=False)
|
||||
# Include SSO sub-router
|
||||
try:
|
||||
from .sso import router as sso_router
|
||||
|
||||
router.include_router(sso_router, tags=["sso"])
|
||||
logging.info("SSO sub-router included in auth router")
|
||||
except ImportError as e:
|
||||
@@ -34,6 +35,7 @@ class RegisterRequest(BaseModel):
|
||||
|
||||
class CreateUserRequest(BaseModel):
|
||||
"""Admin-only request to create users when registration is disabled"""
|
||||
|
||||
username: str
|
||||
password: str
|
||||
email: Optional[str] = None
|
||||
@@ -42,17 +44,20 @@ class CreateUserRequest(BaseModel):
|
||||
|
||||
class RoleUpdateRequest(BaseModel):
|
||||
"""Request to update user role"""
|
||||
|
||||
role: str
|
||||
|
||||
|
||||
class PasswordChangeRequest(BaseModel):
|
||||
"""Request to change user password"""
|
||||
|
||||
current_password: str
|
||||
new_password: str
|
||||
|
||||
|
||||
class AdminPasswordResetRequest(BaseModel):
|
||||
"""Request for admin to reset user password"""
|
||||
|
||||
new_password: str
|
||||
|
||||
|
||||
@@ -87,7 +92,7 @@ class AuthStatusResponse(BaseModel):
|
||||
|
||||
# Dependency to get current user
|
||||
async def get_current_user(
|
||||
credentials: HTTPAuthorizationCredentials = Depends(security)
|
||||
credentials: HTTPAuthorizationCredentials = Depends(security),
|
||||
) -> Optional[User]:
|
||||
"""Get current user from JWT token"""
|
||||
if not AUTH_ENABLED:
|
||||
@@ -123,10 +128,7 @@ async def require_auth(current_user: User = Depends(get_current_user)) -> User:
|
||||
async def require_admin(current_user: User = Depends(require_auth)) -> User:
|
||||
"""Require admin role - raises HTTPException if not admin"""
|
||||
if current_user.role != "admin":
|
||||
raise HTTPException(
|
||||
status_code=403,
|
||||
detail="Admin access required"
|
||||
)
|
||||
raise HTTPException(status_code=403, detail="Admin access required")
|
||||
|
||||
return current_user
|
||||
|
||||
@@ -141,11 +143,20 @@ async def auth_status(current_user: Optional[User] = Depends(get_current_user)):
|
||||
|
||||
try:
|
||||
from . import sso
|
||||
|
||||
sso_enabled = sso.SSO_ENABLED and AUTH_ENABLED
|
||||
if sso.google_sso:
|
||||
sso_providers.append("google")
|
||||
if sso.github_sso:
|
||||
sso_providers.append("github")
|
||||
if getattr(sso, "custom_sso", None):
|
||||
sso_providers.append("custom")
|
||||
if getattr(sso, "custom_sso_providers", None):
|
||||
if (
|
||||
len(getattr(sso, "custom_sso_providers", {})) > 0
|
||||
and "custom" not in sso_providers
|
||||
):
|
||||
sso_providers.append("custom")
|
||||
except ImportError:
|
||||
pass # SSO module not available
|
||||
|
||||
@@ -155,7 +166,7 @@ async def auth_status(current_user: Optional[User] = Depends(get_current_user)):
|
||||
user=UserResponse(**current_user.to_public_dict()) if current_user else None,
|
||||
registration_enabled=AUTH_ENABLED and not DISABLE_REGISTRATION,
|
||||
sso_enabled=sso_enabled,
|
||||
sso_providers=sso_providers
|
||||
sso_providers=sso_providers,
|
||||
)
|
||||
|
||||
|
||||
@@ -163,23 +174,16 @@ async def auth_status(current_user: Optional[User] = Depends(get_current_user)):
|
||||
async def login(request: LoginRequest):
|
||||
"""Authenticate user and return access token"""
|
||||
if not AUTH_ENABLED:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail="Authentication is disabled"
|
||||
)
|
||||
raise HTTPException(status_code=400, detail="Authentication is disabled")
|
||||
|
||||
user = user_manager.authenticate_user(request.username, request.password)
|
||||
if not user:
|
||||
raise HTTPException(
|
||||
status_code=401,
|
||||
detail="Invalid username or password"
|
||||
)
|
||||
raise HTTPException(status_code=401, detail="Invalid username or password")
|
||||
|
||||
access_token = token_manager.create_token(user)
|
||||
|
||||
return LoginResponse(
|
||||
access_token=access_token,
|
||||
user=UserResponse(**user.to_public_dict())
|
||||
access_token=access_token, user=UserResponse(**user.to_public_dict())
|
||||
)
|
||||
|
||||
|
||||
@@ -187,15 +191,12 @@ async def login(request: LoginRequest):
|
||||
async def register(request: RegisterRequest):
|
||||
"""Register a new user"""
|
||||
if not AUTH_ENABLED:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail="Authentication is disabled"
|
||||
)
|
||||
raise HTTPException(status_code=400, detail="Authentication is disabled")
|
||||
|
||||
if DISABLE_REGISTRATION:
|
||||
raise HTTPException(
|
||||
status_code=403,
|
||||
detail="Public registration is disabled. Contact an administrator to create an account."
|
||||
detail="Public registration is disabled. Contact an administrator to create an account.",
|
||||
)
|
||||
|
||||
# Check if this is the first user (should be admin)
|
||||
@@ -206,7 +207,7 @@ async def register(request: RegisterRequest):
|
||||
username=request.username,
|
||||
password=request.password,
|
||||
email=request.email,
|
||||
role=role
|
||||
role=role,
|
||||
)
|
||||
|
||||
if not success:
|
||||
@@ -233,10 +234,7 @@ async def list_users(current_user: User = Depends(require_admin)):
|
||||
async def delete_user(username: str, current_user: User = Depends(require_admin)):
|
||||
"""Delete a user (admin only)"""
|
||||
if username == current_user.username:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail="Cannot delete your own account"
|
||||
)
|
||||
raise HTTPException(status_code=400, detail="Cannot delete your own account")
|
||||
|
||||
success, message = user_manager.delete_user(username)
|
||||
if not success:
|
||||
@@ -249,20 +247,14 @@ async def delete_user(username: str, current_user: User = Depends(require_admin)
|
||||
async def update_user_role(
|
||||
username: str,
|
||||
request: RoleUpdateRequest,
|
||||
current_user: User = Depends(require_admin)
|
||||
current_user: User = Depends(require_admin),
|
||||
):
|
||||
"""Update user role (admin only)"""
|
||||
if request.role not in ["user", "admin"]:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail="Role must be 'user' or 'admin'"
|
||||
)
|
||||
raise HTTPException(status_code=400, detail="Role must be 'user' or 'admin'")
|
||||
|
||||
if username == current_user.username:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail="Cannot change your own role"
|
||||
)
|
||||
raise HTTPException(status_code=400, detail="Cannot change your own role")
|
||||
|
||||
success, message = user_manager.update_user_role(username, request.role)
|
||||
if not success:
|
||||
@@ -272,26 +264,22 @@ async def update_user_role(
|
||||
|
||||
|
||||
@router.post("/users/create", response_model=MessageResponse)
|
||||
async def create_user_admin(request: CreateUserRequest, current_user: User = Depends(require_admin)):
|
||||
async def create_user_admin(
|
||||
request: CreateUserRequest, current_user: User = Depends(require_admin)
|
||||
):
|
||||
"""Create a new user (admin only) - for use when registration is disabled"""
|
||||
if not AUTH_ENABLED:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail="Authentication is disabled"
|
||||
)
|
||||
raise HTTPException(status_code=400, detail="Authentication is disabled")
|
||||
|
||||
# Validate role
|
||||
if request.role not in ["user", "admin"]:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail="Role must be 'user' or 'admin'"
|
||||
)
|
||||
raise HTTPException(status_code=400, detail="Role must be 'user' or 'admin'")
|
||||
|
||||
success, message = user_manager.create_user(
|
||||
username=request.username,
|
||||
password=request.password,
|
||||
email=request.email,
|
||||
role=request.role
|
||||
role=request.role,
|
||||
)
|
||||
|
||||
if not success:
|
||||
@@ -309,20 +297,16 @@ async def get_profile(current_user: User = Depends(require_auth)):
|
||||
|
||||
@router.put("/profile/password", response_model=MessageResponse)
|
||||
async def change_password(
|
||||
request: PasswordChangeRequest,
|
||||
current_user: User = Depends(require_auth)
|
||||
request: PasswordChangeRequest, current_user: User = Depends(require_auth)
|
||||
):
|
||||
"""Change current user's password"""
|
||||
if not AUTH_ENABLED:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail="Authentication is disabled"
|
||||
)
|
||||
raise HTTPException(status_code=400, detail="Authentication is disabled")
|
||||
|
||||
success, message = user_manager.change_password(
|
||||
username=current_user.username,
|
||||
current_password=request.current_password,
|
||||
new_password=request.new_password
|
||||
new_password=request.new_password,
|
||||
)
|
||||
|
||||
if not success:
|
||||
@@ -343,18 +327,14 @@ async def change_password(
|
||||
async def admin_reset_password(
|
||||
username: str,
|
||||
request: AdminPasswordResetRequest,
|
||||
current_user: User = Depends(require_admin)
|
||||
current_user: User = Depends(require_admin),
|
||||
):
|
||||
"""Admin reset user password (admin only)"""
|
||||
if not AUTH_ENABLED:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail="Authentication is disabled"
|
||||
)
|
||||
raise HTTPException(status_code=400, detail="Authentication is disabled")
|
||||
|
||||
success, message = user_manager.admin_reset_password(
|
||||
username=username,
|
||||
new_password=request.new_password
|
||||
username=username, new_password=request.new_password
|
||||
)
|
||||
|
||||
if not success:
|
||||
|
||||
@@ -1,17 +1,19 @@
|
||||
"""
|
||||
SSO (Single Sign-On) implementation for Google and GitHub authentication
|
||||
"""
|
||||
|
||||
import os
|
||||
import logging
|
||||
from typing import Optional, Dict, Any
|
||||
from datetime import datetime, timedelta
|
||||
|
||||
from fastapi import APIRouter, Request, HTTPException, Depends
|
||||
from fastapi import APIRouter, Request, HTTPException
|
||||
from fastapi.responses import RedirectResponse
|
||||
from fastapi_sso.sso.google import GoogleSSO
|
||||
from fastapi_sso.sso.github import GithubSSO
|
||||
from fastapi_sso.sso.base import OpenID
|
||||
from pydantic import BaseModel
|
||||
from fastapi_sso.sso.generic import create_provider
|
||||
|
||||
from . import user_manager, token_manager, User, AUTH_ENABLED, DISABLE_REGISTRATION
|
||||
|
||||
@@ -25,11 +27,14 @@ GOOGLE_CLIENT_ID = os.getenv("GOOGLE_CLIENT_ID")
|
||||
GOOGLE_CLIENT_SECRET = os.getenv("GOOGLE_CLIENT_SECRET")
|
||||
GITHUB_CLIENT_ID = os.getenv("GITHUB_CLIENT_ID")
|
||||
GITHUB_CLIENT_SECRET = os.getenv("GITHUB_CLIENT_SECRET")
|
||||
SSO_BASE_REDIRECT_URI = os.getenv("SSO_BASE_REDIRECT_URI", "http://localhost:7171/api/auth/sso/callback")
|
||||
SSO_BASE_REDIRECT_URI = os.getenv(
|
||||
"SSO_BASE_REDIRECT_URI", "http://localhost:7171/api/auth/sso/callback"
|
||||
)
|
||||
|
||||
# Initialize SSO providers
|
||||
google_sso = None
|
||||
github_sso = None
|
||||
custom_sso = None
|
||||
|
||||
if GOOGLE_CLIENT_ID and GOOGLE_CLIENT_SECRET:
|
||||
google_sso = GoogleSSO(
|
||||
@@ -47,6 +52,154 @@ if GITHUB_CLIENT_ID and GITHUB_CLIENT_SECRET:
|
||||
allow_insecure_http=True, # Set to False in production with HTTPS
|
||||
)
|
||||
|
||||
# Custom/Generic OAuth provider configuration
|
||||
CUSTOM_SSO_CLIENT_ID = os.getenv("CUSTOM_SSO_CLIENT_ID")
|
||||
CUSTOM_SSO_CLIENT_SECRET = os.getenv("CUSTOM_SSO_CLIENT_SECRET")
|
||||
CUSTOM_SSO_AUTHORIZATION_ENDPOINT = os.getenv("CUSTOM_SSO_AUTHORIZATION_ENDPOINT")
|
||||
CUSTOM_SSO_TOKEN_ENDPOINT = os.getenv("CUSTOM_SSO_TOKEN_ENDPOINT")
|
||||
CUSTOM_SSO_USERINFO_ENDPOINT = os.getenv("CUSTOM_SSO_USERINFO_ENDPOINT")
|
||||
CUSTOM_SSO_SCOPE = os.getenv("CUSTOM_SSO_SCOPE") # comma-separated list
|
||||
CUSTOM_SSO_NAME = os.getenv("CUSTOM_SSO_NAME", "custom")
|
||||
CUSTOM_SSO_DISPLAY_NAME = os.getenv("CUSTOM_SSO_DISPLAY_NAME", "Custom")
|
||||
|
||||
|
||||
def _default_custom_response_convertor(
|
||||
userinfo: Dict[str, Any], _client=None
|
||||
) -> OpenID:
|
||||
"""Best-effort convertor from generic userinfo to OpenID."""
|
||||
user_id = (
|
||||
userinfo.get("sub")
|
||||
or userinfo.get("id")
|
||||
or userinfo.get("user_id")
|
||||
or userinfo.get("uid")
|
||||
or userinfo.get("uuid")
|
||||
)
|
||||
email = userinfo.get("email")
|
||||
display_name = (
|
||||
userinfo.get("name")
|
||||
or userinfo.get("preferred_username")
|
||||
or userinfo.get("login")
|
||||
or email
|
||||
or (str(user_id) if user_id is not None else None)
|
||||
)
|
||||
picture = userinfo.get("picture") or userinfo.get("avatar_url")
|
||||
if not user_id and email:
|
||||
user_id = email
|
||||
return OpenID(
|
||||
id=str(user_id) if user_id is not None else "",
|
||||
email=email,
|
||||
display_name=display_name,
|
||||
picture=picture,
|
||||
provider=CUSTOM_SSO_NAME,
|
||||
)
|
||||
|
||||
|
||||
if all(
|
||||
[
|
||||
CUSTOM_SSO_CLIENT_ID,
|
||||
CUSTOM_SSO_CLIENT_SECRET,
|
||||
CUSTOM_SSO_AUTHORIZATION_ENDPOINT,
|
||||
CUSTOM_SSO_TOKEN_ENDPOINT,
|
||||
CUSTOM_SSO_USERINFO_ENDPOINT,
|
||||
]
|
||||
):
|
||||
discovery = {
|
||||
"authorization_endpoint": CUSTOM_SSO_AUTHORIZATION_ENDPOINT,
|
||||
"token_endpoint": CUSTOM_SSO_TOKEN_ENDPOINT,
|
||||
"userinfo_endpoint": CUSTOM_SSO_USERINFO_ENDPOINT,
|
||||
}
|
||||
default_scope = (
|
||||
[s.strip() for s in CUSTOM_SSO_SCOPE.split(",") if s.strip()]
|
||||
if CUSTOM_SSO_SCOPE
|
||||
else None
|
||||
)
|
||||
CustomProvider = create_provider(
|
||||
name=CUSTOM_SSO_NAME,
|
||||
discovery_document=discovery,
|
||||
response_convertor=_default_custom_response_convertor,
|
||||
default_scope=default_scope,
|
||||
)
|
||||
custom_sso = CustomProvider(
|
||||
client_id=CUSTOM_SSO_CLIENT_ID,
|
||||
client_secret=CUSTOM_SSO_CLIENT_SECRET,
|
||||
redirect_uri=f"{SSO_BASE_REDIRECT_URI}/custom",
|
||||
allow_insecure_http=True, # Set to False in production with HTTPS
|
||||
)
|
||||
|
||||
# Support multiple indexed custom providers (CUSTOM_*_i), up to 10
|
||||
custom_sso_providers: Dict[int, Dict[str, Any]] = {}
|
||||
|
||||
|
||||
def _make_response_convertor(provider_name: str):
|
||||
def _convert(userinfo: Dict[str, Any], _client=None) -> OpenID:
|
||||
user_id = (
|
||||
userinfo.get("sub")
|
||||
or userinfo.get("id")
|
||||
or userinfo.get("user_id")
|
||||
or userinfo.get("uid")
|
||||
or userinfo.get("uuid")
|
||||
)
|
||||
email = userinfo.get("email")
|
||||
display_name = (
|
||||
userinfo.get("name")
|
||||
or userinfo.get("preferred_username")
|
||||
or userinfo.get("login")
|
||||
or email
|
||||
or (str(user_id) if user_id is not None else None)
|
||||
)
|
||||
picture = userinfo.get("picture") or userinfo.get("avatar_url")
|
||||
if not user_id and email:
|
||||
user_id = email
|
||||
return OpenID(
|
||||
id=str(user_id) if user_id is not None else "",
|
||||
email=email,
|
||||
display_name=display_name,
|
||||
picture=picture,
|
||||
provider=provider_name,
|
||||
)
|
||||
|
||||
return _convert
|
||||
|
||||
|
||||
for i in range(1, 11):
|
||||
cid = os.getenv(f"CUSTOM_SSO_CLIENT_ID_{i}")
|
||||
csecret = os.getenv(f"CUSTOM_SSO_CLIENT_SECRET_{i}")
|
||||
auth_ep = os.getenv(f"CUSTOM_SSO_AUTHORIZATION_ENDPOINT_{i}")
|
||||
token_ep = os.getenv(f"CUSTOM_SSO_TOKEN_ENDPOINT_{i}")
|
||||
userinfo_ep = os.getenv(f"CUSTOM_SSO_USERINFO_ENDPOINT_{i}")
|
||||
scope_raw = os.getenv(f"CUSTOM_SSO_SCOPE_{i}")
|
||||
name_i = os.getenv(f"CUSTOM_SSO_NAME_{i}", f"custom{i}")
|
||||
display_name_i = os.getenv(f"CUSTOM_SSO_DISPLAY_NAME_{i}", f"Custom {i}")
|
||||
|
||||
if all([cid, csecret, auth_ep, token_ep, userinfo_ep]):
|
||||
discovery_i = {
|
||||
"authorization_endpoint": auth_ep,
|
||||
"token_endpoint": token_ep,
|
||||
"userinfo_endpoint": userinfo_ep,
|
||||
}
|
||||
default_scope_i = (
|
||||
[s.strip() for s in scope_raw.split(",") if s.strip()]
|
||||
if scope_raw
|
||||
else None
|
||||
)
|
||||
ProviderClass = create_provider(
|
||||
name=name_i,
|
||||
discovery_document=discovery_i,
|
||||
response_convertor=_make_response_convertor(name_i),
|
||||
default_scope=default_scope_i,
|
||||
)
|
||||
provider_instance = ProviderClass(
|
||||
client_id=cid,
|
||||
client_secret=csecret,
|
||||
redirect_uri=f"{SSO_BASE_REDIRECT_URI}/custom/{i}",
|
||||
allow_insecure_http=True, # Set to False in production with HTTPS
|
||||
)
|
||||
custom_sso_providers[i] = {
|
||||
"sso": provider_instance,
|
||||
"name": name_i,
|
||||
"display_name": display_name_i,
|
||||
}
|
||||
|
||||
|
||||
class MessageResponse(BaseModel):
|
||||
message: str
|
||||
@@ -70,7 +223,9 @@ def create_or_update_sso_user(openid: OpenID, provider: str) -> User:
|
||||
# Generate username from email or use provider ID
|
||||
email = openid.email
|
||||
if not email:
|
||||
raise HTTPException(status_code=400, detail="Email is required for SSO authentication")
|
||||
raise HTTPException(
|
||||
status_code=400, detail="Email is required for SSO authentication"
|
||||
)
|
||||
|
||||
# Use email prefix as username, fallback to provider + id
|
||||
username = email.split("@")[0]
|
||||
@@ -82,7 +237,9 @@ def create_or_update_sso_user(openid: OpenID, provider: str) -> User:
|
||||
users = user_manager.load_users()
|
||||
for user_data in users.values():
|
||||
if user_data.get("email") == email:
|
||||
existing_user = User(**{k: v for k, v in user_data.items() if k != "password_hash"})
|
||||
existing_user = User(
|
||||
**{k: v for k, v in user_data.items() if k != "password_hash"}
|
||||
)
|
||||
break
|
||||
|
||||
if existing_user:
|
||||
@@ -97,7 +254,7 @@ def create_or_update_sso_user(openid: OpenID, provider: str) -> User:
|
||||
if DISABLE_REGISTRATION:
|
||||
raise HTTPException(
|
||||
status_code=403,
|
||||
detail="Registration is disabled. Contact an administrator to create an account."
|
||||
detail="Registration is disabled. Contact an administrator to create an account.",
|
||||
)
|
||||
|
||||
# Create new user
|
||||
@@ -111,14 +268,14 @@ def create_or_update_sso_user(openid: OpenID, provider: str) -> User:
|
||||
user = User(
|
||||
username=username,
|
||||
email=email,
|
||||
role="user" # Default role for SSO users
|
||||
role="user", # Default role for SSO users
|
||||
)
|
||||
|
||||
users[username] = {
|
||||
**user.to_dict(),
|
||||
"sso_provider": provider,
|
||||
"sso_id": openid.id,
|
||||
"password_hash": None # SSO users don't have passwords
|
||||
"password_hash": None, # SSO users don't have passwords
|
||||
}
|
||||
|
||||
user_manager.save_users(users)
|
||||
@@ -132,25 +289,49 @@ async def sso_status():
|
||||
providers = []
|
||||
|
||||
if google_sso:
|
||||
providers.append(SSOProvider(
|
||||
name="google",
|
||||
display_name="Google",
|
||||
enabled=True,
|
||||
login_url="/api/auth/sso/login/google"
|
||||
))
|
||||
providers.append(
|
||||
SSOProvider(
|
||||
name="google",
|
||||
display_name="Google",
|
||||
enabled=True,
|
||||
login_url="/api/auth/sso/login/google",
|
||||
)
|
||||
)
|
||||
|
||||
if github_sso:
|
||||
providers.append(SSOProvider(
|
||||
name="github",
|
||||
display_name="GitHub",
|
||||
enabled=True,
|
||||
login_url="/api/auth/sso/login/github"
|
||||
))
|
||||
providers.append(
|
||||
SSOProvider(
|
||||
name="github",
|
||||
display_name="GitHub",
|
||||
enabled=True,
|
||||
login_url="/api/auth/sso/login/github",
|
||||
)
|
||||
)
|
||||
|
||||
if custom_sso:
|
||||
providers.append(
|
||||
SSOProvider(
|
||||
name="custom",
|
||||
display_name=CUSTOM_SSO_DISPLAY_NAME,
|
||||
enabled=True,
|
||||
login_url="/api/auth/sso/login/custom",
|
||||
)
|
||||
)
|
||||
|
||||
for idx, cfg in custom_sso_providers.items():
|
||||
providers.append(
|
||||
SSOProvider(
|
||||
name=cfg["name"],
|
||||
display_name=cfg.get("display_name", cfg["name"]),
|
||||
enabled=True,
|
||||
login_url=f"/api/auth/sso/login/custom/{idx}",
|
||||
)
|
||||
)
|
||||
|
||||
return SSOStatusResponse(
|
||||
sso_enabled=SSO_ENABLED and AUTH_ENABLED,
|
||||
providers=providers,
|
||||
registration_enabled=not DISABLE_REGISTRATION
|
||||
registration_enabled=not DISABLE_REGISTRATION,
|
||||
)
|
||||
|
||||
|
||||
@@ -164,7 +345,9 @@ async def google_login():
|
||||
raise HTTPException(status_code=400, detail="Google SSO is not configured")
|
||||
|
||||
async with google_sso:
|
||||
return await google_sso.get_login_redirect(params={"prompt": "consent", "access_type": "offline"})
|
||||
return await google_sso.get_login_redirect(
|
||||
params={"prompt": "consent", "access_type": "offline"}
|
||||
)
|
||||
|
||||
|
||||
@router.get("/sso/login/github")
|
||||
@@ -180,6 +363,35 @@ async def github_login():
|
||||
return await github_sso.get_login_redirect()
|
||||
|
||||
|
||||
@router.get("/sso/login/custom")
|
||||
async def custom_login():
|
||||
"""Initiate Custom SSO login"""
|
||||
if not SSO_ENABLED or not AUTH_ENABLED:
|
||||
raise HTTPException(status_code=400, detail="SSO is disabled")
|
||||
|
||||
if not custom_sso:
|
||||
raise HTTPException(status_code=400, detail="Custom SSO is not configured")
|
||||
|
||||
async with custom_sso:
|
||||
return await custom_sso.get_login_redirect()
|
||||
|
||||
|
||||
@router.get("/sso/login/custom/{index}")
|
||||
async def custom_login_indexed(index: int):
|
||||
"""Initiate indexed Custom SSO login"""
|
||||
if not SSO_ENABLED or not AUTH_ENABLED:
|
||||
raise HTTPException(status_code=400, detail="SSO is disabled")
|
||||
|
||||
cfg = custom_sso_providers.get(index)
|
||||
if not cfg:
|
||||
raise HTTPException(
|
||||
status_code=400, detail="Custom SSO provider not configured"
|
||||
)
|
||||
|
||||
async with cfg["sso"]:
|
||||
return await cfg["sso"].get_login_redirect()
|
||||
|
||||
|
||||
@router.get("/sso/callback/google")
|
||||
async def google_callback(request: Request):
|
||||
"""Handle Google SSO callback"""
|
||||
@@ -210,7 +422,7 @@ async def google_callback(request: Request):
|
||||
httponly=True,
|
||||
secure=False, # Set to True in production with HTTPS
|
||||
samesite="lax",
|
||||
max_age=timedelta(hours=24).total_seconds()
|
||||
max_age=timedelta(hours=24).total_seconds(),
|
||||
)
|
||||
|
||||
return response
|
||||
@@ -218,7 +430,7 @@ async def google_callback(request: Request):
|
||||
except HTTPException as e:
|
||||
# Handle specific HTTP exceptions (like registration disabled)
|
||||
frontend_url = os.getenv("FRONTEND_URL", "http://localhost:3000")
|
||||
error_msg = e.detail if hasattr(e, 'detail') else "Authentication failed"
|
||||
error_msg = e.detail if hasattr(e, "detail") else "Authentication failed"
|
||||
logger.warning(f"Google SSO callback error: {error_msg}")
|
||||
return RedirectResponse(url=f"{frontend_url}?error={error_msg}")
|
||||
|
||||
@@ -258,7 +470,7 @@ async def github_callback(request: Request):
|
||||
httponly=True,
|
||||
secure=False, # Set to True in production with HTTPS
|
||||
samesite="lax",
|
||||
max_age=timedelta(hours=24).total_seconds()
|
||||
max_age=timedelta(hours=24).total_seconds(),
|
||||
)
|
||||
|
||||
return response
|
||||
@@ -266,7 +478,7 @@ async def github_callback(request: Request):
|
||||
except HTTPException as e:
|
||||
# Handle specific HTTP exceptions (like registration disabled)
|
||||
frontend_url = os.getenv("FRONTEND_URL", "http://localhost:3000")
|
||||
error_msg = e.detail if hasattr(e, 'detail') else "Authentication failed"
|
||||
error_msg = e.detail if hasattr(e, "detail") else "Authentication failed"
|
||||
logger.warning(f"GitHub SSO callback error: {error_msg}")
|
||||
return RedirectResponse(url=f"{frontend_url}?error={error_msg}")
|
||||
|
||||
@@ -276,6 +488,105 @@ async def github_callback(request: Request):
|
||||
return RedirectResponse(url=f"{frontend_url}?error=Authentication failed")
|
||||
|
||||
|
||||
@router.get("/sso/callback/custom")
|
||||
async def custom_callback(request: Request):
|
||||
"""Handle Custom SSO callback"""
|
||||
if not SSO_ENABLED or not AUTH_ENABLED:
|
||||
raise HTTPException(status_code=400, detail="SSO is disabled")
|
||||
|
||||
if not custom_sso:
|
||||
raise HTTPException(status_code=400, detail="Custom SSO is not configured")
|
||||
|
||||
try:
|
||||
async with custom_sso:
|
||||
openid = await custom_sso.verify_and_process(request)
|
||||
|
||||
# Create or update user
|
||||
user = create_or_update_sso_user(openid, "custom")
|
||||
|
||||
# Create JWT token
|
||||
access_token = token_manager.create_token(user)
|
||||
|
||||
# Redirect to frontend with token (you might want to customize this)
|
||||
frontend_url = os.getenv("FRONTEND_URL", "http://localhost:3000")
|
||||
response = RedirectResponse(url=f"{frontend_url}?token={access_token}")
|
||||
|
||||
# Also set as HTTP-only cookie
|
||||
response.set_cookie(
|
||||
key="access_token",
|
||||
value=access_token,
|
||||
httponly=True,
|
||||
secure=False, # Set to True in production with HTTPS
|
||||
samesite="lax",
|
||||
max_age=timedelta(hours=24).total_seconds(),
|
||||
)
|
||||
|
||||
return response
|
||||
|
||||
except HTTPException as e:
|
||||
# Handle specific HTTP exceptions (like registration disabled)
|
||||
frontend_url = os.getenv("FRONTEND_URL", "http://localhost:3000")
|
||||
error_msg = e.detail if hasattr(e, "detail") else "Authentication failed"
|
||||
logger.warning(f"Custom SSO callback error: {error_msg}")
|
||||
return RedirectResponse(url=f"{frontend_url}?error={error_msg}")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Custom SSO callback error: {e}")
|
||||
frontend_url = os.getenv("FRONTEND_URL", "http://localhost:3000")
|
||||
return RedirectResponse(url=f"{frontend_url}?error=Authentication failed")
|
||||
|
||||
|
||||
@router.get("/sso/callback/custom/{index}")
|
||||
async def custom_callback_indexed(request: Request, index: int):
|
||||
"""Handle indexed Custom SSO callback"""
|
||||
if not SSO_ENABLED or not AUTH_ENABLED:
|
||||
raise HTTPException(status_code=400, detail="SSO is disabled")
|
||||
|
||||
cfg = custom_sso_providers.get(index)
|
||||
if not cfg:
|
||||
raise HTTPException(
|
||||
status_code=400, detail="Custom SSO provider not configured"
|
||||
)
|
||||
|
||||
try:
|
||||
async with cfg["sso"]:
|
||||
openid = await cfg["sso"].verify_and_process(request)
|
||||
|
||||
# Create or update user
|
||||
user = create_or_update_sso_user(openid, cfg["name"])
|
||||
|
||||
# Create JWT token
|
||||
access_token = token_manager.create_token(user)
|
||||
|
||||
# Redirect to frontend with token (you might want to customize this)
|
||||
frontend_url = os.getenv("FRONTEND_URL", "http://localhost:3000")
|
||||
response = RedirectResponse(url=f"{frontend_url}?token={access_token}")
|
||||
|
||||
# Also set as HTTP-only cookie
|
||||
response.set_cookie(
|
||||
key="access_token",
|
||||
value=access_token,
|
||||
httponly=True,
|
||||
secure=False, # Set to True in production with HTTPS
|
||||
samesite="lax",
|
||||
max_age=timedelta(hours=24).total_seconds(),
|
||||
)
|
||||
|
||||
return response
|
||||
|
||||
except HTTPException as e:
|
||||
# Handle specific HTTP exceptions (like registration disabled)
|
||||
frontend_url = os.getenv("FRONTEND_URL", "http://localhost:3000")
|
||||
error_msg = e.detail if hasattr(e, "detail") else "Authentication failed"
|
||||
logger.warning(f"Custom[{index}] SSO callback error: {error_msg}")
|
||||
return RedirectResponse(url=f"{frontend_url}?error={error_msg}")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Custom[{index}] SSO callback error: {e}")
|
||||
frontend_url = os.getenv("FRONTEND_URL", "http://localhost:3000")
|
||||
return RedirectResponse(url=f"{frontend_url}?error=Authentication failed")
|
||||
|
||||
|
||||
@router.post("/sso/unlink/{provider}", response_model=MessageResponse)
|
||||
async def unlink_sso_provider(
|
||||
provider: str,
|
||||
@@ -285,7 +596,18 @@ async def unlink_sso_provider(
|
||||
if not SSO_ENABLED or not AUTH_ENABLED:
|
||||
raise HTTPException(status_code=400, detail="SSO is disabled")
|
||||
|
||||
if provider not in ["google", "github"]:
|
||||
available = []
|
||||
if google_sso:
|
||||
available.append("google")
|
||||
if github_sso:
|
||||
available.append("github")
|
||||
if custom_sso:
|
||||
available.append("custom")
|
||||
|
||||
for cfg in custom_sso_providers.values():
|
||||
available.append(cfg["name"])
|
||||
|
||||
if provider not in available:
|
||||
raise HTTPException(status_code=400, detail="Invalid SSO provider")
|
||||
|
||||
# Get current user from request (avoiding circular imports)
|
||||
@@ -294,7 +616,9 @@ async def unlink_sso_provider(
|
||||
current_user = await require_auth_from_state(request)
|
||||
|
||||
if not current_user.sso_provider:
|
||||
raise HTTPException(status_code=400, detail="User is not linked to any SSO provider")
|
||||
raise HTTPException(
|
||||
status_code=400, detail="User is not linked to any SSO provider"
|
||||
)
|
||||
|
||||
if current_user.sso_provider != provider:
|
||||
raise HTTPException(status_code=400, detail=f"User is not linked to {provider}")
|
||||
@@ -305,6 +629,8 @@ async def unlink_sso_provider(
|
||||
users[current_user.username]["sso_provider"] = None
|
||||
users[current_user.username]["sso_id"] = None
|
||||
user_manager.save_users(users)
|
||||
logger.info(f"Unlinked SSO provider {provider} from user {current_user.username}")
|
||||
logger.info(
|
||||
f"Unlinked SSO provider {provider} from user {current_user.username}"
|
||||
)
|
||||
|
||||
return MessageResponse(message=f"SSO provider {provider} unlinked successfully")
|
||||
Reference in New Issue
Block a user