334 lines
12 KiB
Python
334 lines
12 KiB
Python
#!/usr/bin/env python3
|
|
# -*- coding: utf-8 -*-
|
|
from __future__ import annotations
|
|
|
|
from datetime import timedelta
|
|
import database
|
|
from app_logger import get_logger
|
|
from db.utils import get_cst_now, get_cst_now_str
|
|
from flask import Blueprint, jsonify, request
|
|
from flask_login import current_user, login_required, login_user
|
|
from services.accounts_service import load_user_accounts
|
|
from services.models import User
|
|
from services.social_login import (
|
|
BIND_TOKEN_TTL_SECONDS,
|
|
PROVIDER_LABELS,
|
|
SocialLoginError,
|
|
SpaceProfile,
|
|
admin_social_config_out,
|
|
encrypt_social_appkey,
|
|
fetch_social_login_url,
|
|
fetch_space_profile,
|
|
new_bind_token,
|
|
normalize_social_endpoint,
|
|
parse_providers,
|
|
poll_social_scan,
|
|
provider_label,
|
|
public_social_config,
|
|
)
|
|
|
|
logger = get_logger("app")
|
|
|
|
api_social_bp = Blueprint("api_social", __name__)
|
|
|
|
|
|
def _get_json_payload() -> dict:
|
|
data = request.get_json(silent=True)
|
|
return data if isinstance(data, dict) else {}
|
|
|
|
|
|
def _social_error(error: SocialLoginError):
|
|
return jsonify({"error": error.message}), error.status_code
|
|
|
|
|
|
def _allowed_redirect_hosts() -> set[str]:
|
|
hosts: set[str] = set()
|
|
host = request.headers.get("X-Forwarded-Host") or request.headers.get("Host") or ""
|
|
if host:
|
|
hosts.add(host.lower())
|
|
hostname = host.split(":", 1)[0].lower()
|
|
if hostname:
|
|
hosts.add(hostname)
|
|
return hosts
|
|
|
|
|
|
def _create_pending_from_profile(profile: SpaceProfile) -> dict:
|
|
expires_at = (get_cst_now() + timedelta(seconds=BIND_TOKEN_TTL_SECONDS)).strftime("%Y-%m-%d %H:%M:%S")
|
|
return database.create_social_pending_bind(
|
|
token=new_bind_token(),
|
|
provider=profile.provider,
|
|
social_uid=profile.social_uid,
|
|
nickname=profile.nickname,
|
|
avatar_url=profile.avatar_url,
|
|
expires_at=expires_at,
|
|
)
|
|
|
|
|
|
def _login_user_id(user_id: int) -> None:
|
|
user_obj = User(user_id)
|
|
login_user(user_obj)
|
|
load_user_accounts(user_id)
|
|
|
|
|
|
def _binding_row(provider: str, binding: dict | None) -> dict:
|
|
return {
|
|
"provider": provider,
|
|
"provider_label": provider_label(provider),
|
|
"bound": bool(binding),
|
|
"nickname": (binding or {}).get("nickname") or "",
|
|
"avatar_url": (binding or {}).get("avatar_url") or "",
|
|
"last_login_at": (binding or {}).get("last_login_at"),
|
|
"created_at": (binding or {}).get("created_at"),
|
|
}
|
|
|
|
|
|
@api_social_bp.route("/api/auth/social/config", methods=["GET"])
|
|
def social_public_config():
|
|
return jsonify(public_social_config(database.get_system_config()))
|
|
|
|
|
|
@api_social_bp.route("/api/auth/social/login-url", methods=["POST"])
|
|
def social_login_url():
|
|
data = _get_json_payload()
|
|
provider = str(data.get("provider") or "").strip().lower()
|
|
mode = "bind" if str(data.get("mode") or "").strip().lower() == "bind" else "login"
|
|
redirect_uri = str(data.get("redirect_uri") or "").strip()
|
|
|
|
try:
|
|
result = fetch_social_login_url(
|
|
database.get_system_config(),
|
|
provider=provider,
|
|
mode=mode,
|
|
redirect_uri=redirect_uri,
|
|
allowed_hosts=_allowed_redirect_hosts(),
|
|
)
|
|
except SocialLoginError as error:
|
|
logger.warning(f"[social/login-url] provider={provider or '-'} mode={mode} failed: {error.message}")
|
|
return _social_error(error)
|
|
return jsonify(result)
|
|
|
|
|
|
@api_social_bp.route("/api/auth/social/poll", methods=["POST"])
|
|
def social_poll():
|
|
data = _get_json_payload()
|
|
provider = str(data.get("provider") or "").strip().lower()
|
|
state = str(data.get("state") or "").strip()
|
|
try:
|
|
result = poll_social_scan(database.get_system_config(), provider=provider, state=state)
|
|
except SocialLoginError as error:
|
|
logger.warning(f"[social/poll] provider={provider or '-'} failed: {error.message}")
|
|
return _social_error(error)
|
|
return jsonify(result)
|
|
|
|
|
|
@api_social_bp.route("/api/auth/social/callback", methods=["POST"])
|
|
def social_callback():
|
|
data = _get_json_payload()
|
|
provider = str(data.get("provider") or data.get("type") or "").strip().lower()
|
|
code = str(data.get("code") or "").strip()
|
|
mode = "bind" if str(data.get("mode") or "").strip().lower() == "bind" else "login"
|
|
|
|
try:
|
|
profile = fetch_space_profile(database.get_system_config(), provider=provider, code=code)
|
|
except SocialLoginError as error:
|
|
logger.warning(f"[social/callback] provider={provider or '-'} mode={mode} failed: {error.message}")
|
|
return _social_error(error)
|
|
|
|
binding = database.find_social_login_binding(profile.provider, profile.social_uid)
|
|
if binding:
|
|
if mode == "bind":
|
|
current_id = int(getattr(current_user, "id", 0) or 0)
|
|
if not current_id or int(binding.get("user_id") or 0) != current_id:
|
|
return jsonify({"error": "该第三方账号已绑定其他用户"}), 409
|
|
|
|
user = database.get_user_by_id(int(binding["user_id"]))
|
|
if not user or user.get("status") != "approved":
|
|
return jsonify({"error": "绑定账号不可用"}), 401
|
|
database.update_social_login_binding_profile(
|
|
int(binding["id"]),
|
|
nickname=profile.nickname,
|
|
avatar_url=profile.avatar_url,
|
|
)
|
|
_login_user_id(int(user["id"]))
|
|
return jsonify(
|
|
{
|
|
"success": True,
|
|
"mode": mode,
|
|
"provider": profile.provider,
|
|
"provider_label": provider_label(profile.provider),
|
|
"bound": True,
|
|
"username": user.get("username") or "",
|
|
}
|
|
)
|
|
|
|
pending = _create_pending_from_profile(profile)
|
|
return jsonify(
|
|
{
|
|
"success": True,
|
|
"mode": mode,
|
|
"provider": profile.provider,
|
|
"provider_label": provider_label(profile.provider),
|
|
"requires_register": mode == "login",
|
|
"requires_existing_login": mode == "bind",
|
|
"bind_token": pending.get("token"),
|
|
"expires_in": BIND_TOKEN_TTL_SECONDS,
|
|
"nickname": pending.get("nickname") or "",
|
|
"avatar_url": pending.get("avatar_url") or "",
|
|
}
|
|
)
|
|
|
|
|
|
@api_social_bp.route("/api/user/social-bindings", methods=["GET"])
|
|
@login_required
|
|
def list_social_bindings():
|
|
cfg = database.get_system_config()
|
|
providers = parse_providers(cfg.get("social_login_providers")) or list(PROVIDER_LABELS.keys())
|
|
existing = {
|
|
item["provider"]: item
|
|
for item in database.list_social_login_bindings(int(current_user.id))
|
|
}
|
|
return jsonify({"items": [_binding_row(provider, existing.get(provider)) for provider in providers]})
|
|
|
|
|
|
@api_social_bp.route("/api/user/social-bindings", methods=["POST"])
|
|
@login_required
|
|
def bind_social_account():
|
|
data = _get_json_payload()
|
|
token = str(data.get("bind_token") or "").strip()
|
|
pending = database.get_social_pending_bind(token)
|
|
if not pending:
|
|
return jsonify({"error": "绑定凭证已过期,请重新授权"}), 404
|
|
|
|
provider = str(pending.get("provider") or "").strip().lower()
|
|
social_uid = str(pending.get("social_uid") or "").strip()
|
|
existing_identity = database.find_social_login_binding(provider, social_uid)
|
|
if existing_identity and int(existing_identity.get("user_id") or 0) != int(current_user.id):
|
|
return jsonify({"error": "该第三方账号已绑定其他用户"}), 409
|
|
|
|
existing_provider = database.find_user_social_login_binding(int(current_user.id), provider)
|
|
if existing_provider and str(existing_provider.get("social_uid") or "") != social_uid:
|
|
return jsonify({"error": f"当前账号已绑定{provider_label(provider)}"}), 409
|
|
|
|
binding = database.upsert_social_login_binding(
|
|
user_id=int(current_user.id),
|
|
provider=provider,
|
|
social_uid=social_uid,
|
|
nickname=pending.get("nickname") or "",
|
|
avatar_url=pending.get("avatar_url") or "",
|
|
)
|
|
if not binding:
|
|
return jsonify({"error": "该第三方账号已绑定其他用户"}), 409
|
|
|
|
database.delete_social_pending_bind(token)
|
|
return jsonify({"success": True, "item": _binding_row(provider, binding)})
|
|
|
|
|
|
@api_social_bp.route("/api/user/social-bindings/<provider>", methods=["DELETE"])
|
|
@login_required
|
|
def unbind_social_account(provider):
|
|
provider = str(provider or "").strip().lower()
|
|
if provider not in PROVIDER_LABELS:
|
|
return jsonify({"error": "不支持的登录方式"}), 400
|
|
if not database.delete_social_login_binding(int(current_user.id), provider):
|
|
return jsonify({"error": "绑定记录不存在"}), 404
|
|
return jsonify({"success": True})
|
|
|
|
|
|
@api_social_bp.route("/yuyx/api/social-login/config", methods=["GET"])
|
|
def admin_social_config():
|
|
from routes.decorators import admin_required
|
|
|
|
protected = admin_required(lambda: jsonify(admin_social_config_out(database.get_system_config())))
|
|
return protected()
|
|
|
|
|
|
@api_social_bp.route("/yuyx/api/social-login/test", methods=["POST"])
|
|
def test_admin_social_config():
|
|
from routes.decorators import admin_required
|
|
|
|
@admin_required
|
|
def _inner():
|
|
data = _get_json_payload()
|
|
provider = str(data.get("provider") or "wx").strip().lower()
|
|
appkey = str(data.get("social_login_appkey") or "").strip()
|
|
old_config = database.get_system_config() or {}
|
|
|
|
try:
|
|
temp_config = {
|
|
**old_config,
|
|
"social_login_enabled": 1,
|
|
"social_login_endpoint": normalize_social_endpoint(str(data.get("social_login_endpoint") or "")),
|
|
"social_login_appid": str(data.get("social_login_appid") or "").strip(),
|
|
"social_login_appkey": encrypt_social_appkey(appkey) if appkey else str(old_config.get("social_login_appkey") or ""),
|
|
"social_login_providers": ",".join(parse_providers(data.get("social_login_providers"))),
|
|
}
|
|
redirect_uri = str(data.get("redirect_uri") or "").strip()
|
|
result = fetch_social_login_url(
|
|
temp_config,
|
|
provider=provider,
|
|
mode="login",
|
|
redirect_uri=redirect_uri,
|
|
allowed_hosts=_allowed_redirect_hosts(),
|
|
)
|
|
except ValueError as exc:
|
|
return jsonify({"error": str(exc)}), 400
|
|
except SocialLoginError as error:
|
|
return _social_error(error)
|
|
|
|
return jsonify(
|
|
{
|
|
"success": True,
|
|
"provider": result.get("provider"),
|
|
"has_url": bool(result.get("url")),
|
|
"has_scan_url": bool(result.get("scan_url")),
|
|
}
|
|
)
|
|
|
|
return _inner()
|
|
|
|
|
|
@api_social_bp.route("/yuyx/api/social-login/config", methods=["POST"])
|
|
def update_admin_social_config():
|
|
from routes.decorators import admin_required
|
|
|
|
@admin_required
|
|
def _inner():
|
|
data = _get_json_payload()
|
|
enabled = 1 if data.get("social_login_enabled") in (1, True, "1", "true", "on") else 0
|
|
endpoint_raw = str(data.get("social_login_endpoint") or "").strip()
|
|
appid = str(data.get("social_login_appid") or "").strip()
|
|
appkey = str(data.get("social_login_appkey") or "").strip()
|
|
providers = parse_providers(data.get("social_login_providers"))
|
|
|
|
try:
|
|
endpoint = normalize_social_endpoint(endpoint_raw)
|
|
except ValueError as exc:
|
|
return jsonify({"error": str(exc)}), 400
|
|
|
|
old_config = database.get_system_config() or {}
|
|
old_key = str(old_config.get("social_login_appkey") or "")
|
|
|
|
if enabled:
|
|
if not providers:
|
|
return jsonify({"error": "启用聚合登录时至少选择一种登录方式"}), 400
|
|
if not appid:
|
|
return jsonify({"error": "启用聚合登录时必须填写 APPID"}), 400
|
|
if not appkey and not old_key:
|
|
return jsonify({"error": "启用聚合登录时必须填写 APPKEY"}), 400
|
|
|
|
encrypted_key = encrypt_social_appkey(appkey) if appkey else old_key
|
|
ok = database.update_system_config(
|
|
social_login_enabled=enabled,
|
|
social_login_endpoint=endpoint,
|
|
social_login_appid=appid,
|
|
social_login_appkey=encrypted_key,
|
|
social_login_providers=",".join(providers),
|
|
)
|
|
if not ok:
|
|
return jsonify({"error": "更新失败"}), 400
|
|
|
|
logger.info(f"[social/config] updated enabled={enabled} providers={','.join(providers)} at={get_cst_now_str()}")
|
|
return jsonify({"message": "聚合登录配置已更新", "config": admin_social_config_out(database.get_system_config())})
|
|
|
|
return _inner()
|