Files
zsglpt/routes/api_social.py
2026-05-27 20:39:46 +08:00

334 lines
12 KiB
Python

#!/usr/bin/env python3
# -*- coding: utf-8 -*-
from __future__ import annotations
from datetime import timedelta
import database
from app_logger import get_logger
from db.utils import get_cst_now, get_cst_now_str
from flask import Blueprint, jsonify, request
from flask_login import current_user, login_required, login_user
from services.accounts_service import load_user_accounts
from services.models import User
from services.social_login import (
BIND_TOKEN_TTL_SECONDS,
PROVIDER_LABELS,
SocialLoginError,
SpaceProfile,
admin_social_config_out,
encrypt_social_appkey,
fetch_social_login_url,
fetch_space_profile,
new_bind_token,
normalize_social_endpoint,
parse_providers,
poll_social_scan,
provider_label,
public_social_config,
)
logger = get_logger("app")
api_social_bp = Blueprint("api_social", __name__)
def _get_json_payload() -> dict:
data = request.get_json(silent=True)
return data if isinstance(data, dict) else {}
def _social_error(error: SocialLoginError):
return jsonify({"error": error.message}), error.status_code
def _allowed_redirect_hosts() -> set[str]:
hosts: set[str] = set()
host = request.headers.get("X-Forwarded-Host") or request.headers.get("Host") or ""
if host:
hosts.add(host.lower())
hostname = host.split(":", 1)[0].lower()
if hostname:
hosts.add(hostname)
return hosts
def _create_pending_from_profile(profile: SpaceProfile) -> dict:
expires_at = (get_cst_now() + timedelta(seconds=BIND_TOKEN_TTL_SECONDS)).strftime("%Y-%m-%d %H:%M:%S")
return database.create_social_pending_bind(
token=new_bind_token(),
provider=profile.provider,
social_uid=profile.social_uid,
nickname=profile.nickname,
avatar_url=profile.avatar_url,
expires_at=expires_at,
)
def _login_user_id(user_id: int) -> None:
user_obj = User(user_id)
login_user(user_obj)
load_user_accounts(user_id)
def _binding_row(provider: str, binding: dict | None) -> dict:
return {
"provider": provider,
"provider_label": provider_label(provider),
"bound": bool(binding),
"nickname": (binding or {}).get("nickname") or "",
"avatar_url": (binding or {}).get("avatar_url") or "",
"last_login_at": (binding or {}).get("last_login_at"),
"created_at": (binding or {}).get("created_at"),
}
@api_social_bp.route("/api/auth/social/config", methods=["GET"])
def social_public_config():
return jsonify(public_social_config(database.get_system_config()))
@api_social_bp.route("/api/auth/social/login-url", methods=["POST"])
def social_login_url():
data = _get_json_payload()
provider = str(data.get("provider") or "").strip().lower()
mode = "bind" if str(data.get("mode") or "").strip().lower() == "bind" else "login"
redirect_uri = str(data.get("redirect_uri") or "").strip()
try:
result = fetch_social_login_url(
database.get_system_config(),
provider=provider,
mode=mode,
redirect_uri=redirect_uri,
allowed_hosts=_allowed_redirect_hosts(),
)
except SocialLoginError as error:
logger.warning(f"[social/login-url] provider={provider or '-'} mode={mode} failed: {error.message}")
return _social_error(error)
return jsonify(result)
@api_social_bp.route("/api/auth/social/poll", methods=["POST"])
def social_poll():
data = _get_json_payload()
provider = str(data.get("provider") or "").strip().lower()
state = str(data.get("state") or "").strip()
try:
result = poll_social_scan(database.get_system_config(), provider=provider, state=state)
except SocialLoginError as error:
logger.warning(f"[social/poll] provider={provider or '-'} failed: {error.message}")
return _social_error(error)
return jsonify(result)
@api_social_bp.route("/api/auth/social/callback", methods=["POST"])
def social_callback():
data = _get_json_payload()
provider = str(data.get("provider") or data.get("type") or "").strip().lower()
code = str(data.get("code") or "").strip()
mode = "bind" if str(data.get("mode") or "").strip().lower() == "bind" else "login"
try:
profile = fetch_space_profile(database.get_system_config(), provider=provider, code=code)
except SocialLoginError as error:
logger.warning(f"[social/callback] provider={provider or '-'} mode={mode} failed: {error.message}")
return _social_error(error)
binding = database.find_social_login_binding(profile.provider, profile.social_uid)
if binding:
if mode == "bind":
current_id = int(getattr(current_user, "id", 0) or 0)
if not current_id or int(binding.get("user_id") or 0) != current_id:
return jsonify({"error": "该第三方账号已绑定其他用户"}), 409
user = database.get_user_by_id(int(binding["user_id"]))
if not user or user.get("status") != "approved":
return jsonify({"error": "绑定账号不可用"}), 401
database.update_social_login_binding_profile(
int(binding["id"]),
nickname=profile.nickname,
avatar_url=profile.avatar_url,
)
_login_user_id(int(user["id"]))
return jsonify(
{
"success": True,
"mode": mode,
"provider": profile.provider,
"provider_label": provider_label(profile.provider),
"bound": True,
"username": user.get("username") or "",
}
)
pending = _create_pending_from_profile(profile)
return jsonify(
{
"success": True,
"mode": mode,
"provider": profile.provider,
"provider_label": provider_label(profile.provider),
"requires_register": mode == "login",
"requires_existing_login": mode == "bind",
"bind_token": pending.get("token"),
"expires_in": BIND_TOKEN_TTL_SECONDS,
"nickname": pending.get("nickname") or "",
"avatar_url": pending.get("avatar_url") or "",
}
)
@api_social_bp.route("/api/user/social-bindings", methods=["GET"])
@login_required
def list_social_bindings():
cfg = database.get_system_config()
providers = parse_providers(cfg.get("social_login_providers")) or list(PROVIDER_LABELS.keys())
existing = {
item["provider"]: item
for item in database.list_social_login_bindings(int(current_user.id))
}
return jsonify({"items": [_binding_row(provider, existing.get(provider)) for provider in providers]})
@api_social_bp.route("/api/user/social-bindings", methods=["POST"])
@login_required
def bind_social_account():
data = _get_json_payload()
token = str(data.get("bind_token") or "").strip()
pending = database.get_social_pending_bind(token)
if not pending:
return jsonify({"error": "绑定凭证已过期,请重新授权"}), 404
provider = str(pending.get("provider") or "").strip().lower()
social_uid = str(pending.get("social_uid") or "").strip()
existing_identity = database.find_social_login_binding(provider, social_uid)
if existing_identity and int(existing_identity.get("user_id") or 0) != int(current_user.id):
return jsonify({"error": "该第三方账号已绑定其他用户"}), 409
existing_provider = database.find_user_social_login_binding(int(current_user.id), provider)
if existing_provider and str(existing_provider.get("social_uid") or "") != social_uid:
return jsonify({"error": f"当前账号已绑定{provider_label(provider)}"}), 409
binding = database.upsert_social_login_binding(
user_id=int(current_user.id),
provider=provider,
social_uid=social_uid,
nickname=pending.get("nickname") or "",
avatar_url=pending.get("avatar_url") or "",
)
if not binding:
return jsonify({"error": "该第三方账号已绑定其他用户"}), 409
database.delete_social_pending_bind(token)
return jsonify({"success": True, "item": _binding_row(provider, binding)})
@api_social_bp.route("/api/user/social-bindings/<provider>", methods=["DELETE"])
@login_required
def unbind_social_account(provider):
provider = str(provider or "").strip().lower()
if provider not in PROVIDER_LABELS:
return jsonify({"error": "不支持的登录方式"}), 400
if not database.delete_social_login_binding(int(current_user.id), provider):
return jsonify({"error": "绑定记录不存在"}), 404
return jsonify({"success": True})
@api_social_bp.route("/yuyx/api/social-login/config", methods=["GET"])
def admin_social_config():
from routes.decorators import admin_required
protected = admin_required(lambda: jsonify(admin_social_config_out(database.get_system_config())))
return protected()
@api_social_bp.route("/yuyx/api/social-login/test", methods=["POST"])
def test_admin_social_config():
from routes.decorators import admin_required
@admin_required
def _inner():
data = _get_json_payload()
provider = str(data.get("provider") or "wx").strip().lower()
appkey = str(data.get("social_login_appkey") or "").strip()
old_config = database.get_system_config() or {}
try:
temp_config = {
**old_config,
"social_login_enabled": 1,
"social_login_endpoint": normalize_social_endpoint(str(data.get("social_login_endpoint") or "")),
"social_login_appid": str(data.get("social_login_appid") or "").strip(),
"social_login_appkey": encrypt_social_appkey(appkey) if appkey else str(old_config.get("social_login_appkey") or ""),
"social_login_providers": ",".join(parse_providers(data.get("social_login_providers"))),
}
redirect_uri = str(data.get("redirect_uri") or "").strip()
result = fetch_social_login_url(
temp_config,
provider=provider,
mode="login",
redirect_uri=redirect_uri,
allowed_hosts=_allowed_redirect_hosts(),
)
except ValueError as exc:
return jsonify({"error": str(exc)}), 400
except SocialLoginError as error:
return _social_error(error)
return jsonify(
{
"success": True,
"provider": result.get("provider"),
"has_url": bool(result.get("url")),
"has_scan_url": bool(result.get("scan_url")),
}
)
return _inner()
@api_social_bp.route("/yuyx/api/social-login/config", methods=["POST"])
def update_admin_social_config():
from routes.decorators import admin_required
@admin_required
def _inner():
data = _get_json_payload()
enabled = 1 if data.get("social_login_enabled") in (1, True, "1", "true", "on") else 0
endpoint_raw = str(data.get("social_login_endpoint") or "").strip()
appid = str(data.get("social_login_appid") or "").strip()
appkey = str(data.get("social_login_appkey") or "").strip()
providers = parse_providers(data.get("social_login_providers"))
try:
endpoint = normalize_social_endpoint(endpoint_raw)
except ValueError as exc:
return jsonify({"error": str(exc)}), 400
old_config = database.get_system_config() or {}
old_key = str(old_config.get("social_login_appkey") or "")
if enabled:
if not providers:
return jsonify({"error": "启用聚合登录时至少选择一种登录方式"}), 400
if not appid:
return jsonify({"error": "启用聚合登录时必须填写 APPID"}), 400
if not appkey and not old_key:
return jsonify({"error": "启用聚合登录时必须填写 APPKEY"}), 400
encrypted_key = encrypt_social_appkey(appkey) if appkey else old_key
ok = database.update_system_config(
social_login_enabled=enabled,
social_login_endpoint=endpoint,
social_login_appid=appid,
social_login_appkey=encrypted_key,
social_login_providers=",".join(providers),
)
if not ok:
return jsonify({"error": "更新失败"}), 400
logger.info(f"[social/config] updated enabled={enabled} providers={','.join(providers)} at={get_cst_now_str()}")
return jsonify({"message": "聚合登录配置已更新", "config": admin_social_config_out(database.get_system_config())})
return _inner()