#!/usr/bin/env python3 # -*- coding: utf-8 -*- from __future__ import annotations from datetime import timedelta import database from app_logger import get_logger from db.utils import get_cst_now, get_cst_now_str from flask import Blueprint, jsonify, request from flask_login import current_user, login_required, login_user from services.accounts_service import load_user_accounts from services.models import User from services.social_login import ( BIND_TOKEN_TTL_SECONDS, PROVIDER_LABELS, SocialLoginError, SpaceProfile, admin_social_config_out, encrypt_social_appkey, fetch_social_login_url, fetch_space_profile, new_bind_token, normalize_social_endpoint, parse_providers, poll_social_scan, provider_label, public_social_config, ) logger = get_logger("app") api_social_bp = Blueprint("api_social", __name__) def _get_json_payload() -> dict: data = request.get_json(silent=True) return data if isinstance(data, dict) else {} def _social_error(error: SocialLoginError): return jsonify({"error": error.message}), error.status_code def _allowed_redirect_hosts() -> set[str]: hosts: set[str] = set() host = request.headers.get("X-Forwarded-Host") or request.headers.get("Host") or "" if host: hosts.add(host.lower()) hostname = host.split(":", 1)[0].lower() if hostname: hosts.add(hostname) return hosts def _create_pending_from_profile(profile: SpaceProfile) -> dict: expires_at = (get_cst_now() + timedelta(seconds=BIND_TOKEN_TTL_SECONDS)).strftime("%Y-%m-%d %H:%M:%S") return database.create_social_pending_bind( token=new_bind_token(), provider=profile.provider, social_uid=profile.social_uid, nickname=profile.nickname, avatar_url=profile.avatar_url, expires_at=expires_at, ) def _login_user_id(user_id: int) -> None: user_obj = User(user_id) login_user(user_obj) load_user_accounts(user_id) def _binding_row(provider: str, binding: dict | None) -> dict: return { "provider": provider, "provider_label": provider_label(provider), "bound": bool(binding), "nickname": (binding or {}).get("nickname") or "", "avatar_url": (binding or {}).get("avatar_url") or "", "last_login_at": (binding or {}).get("last_login_at"), "created_at": (binding or {}).get("created_at"), } @api_social_bp.route("/api/auth/social/config", methods=["GET"]) def social_public_config(): return jsonify(public_social_config(database.get_system_config())) @api_social_bp.route("/api/auth/social/login-url", methods=["POST"]) def social_login_url(): data = _get_json_payload() provider = str(data.get("provider") or "").strip().lower() mode = "bind" if str(data.get("mode") or "").strip().lower() == "bind" else "login" redirect_uri = str(data.get("redirect_uri") or "").strip() try: result = fetch_social_login_url( database.get_system_config(), provider=provider, mode=mode, redirect_uri=redirect_uri, allowed_hosts=_allowed_redirect_hosts(), ) except SocialLoginError as error: logger.warning(f"[social/login-url] provider={provider or '-'} mode={mode} failed: {error.message}") return _social_error(error) return jsonify(result) @api_social_bp.route("/api/auth/social/poll", methods=["POST"]) def social_poll(): data = _get_json_payload() provider = str(data.get("provider") or "").strip().lower() state = str(data.get("state") or "").strip() try: result = poll_social_scan(database.get_system_config(), provider=provider, state=state) except SocialLoginError as error: logger.warning(f"[social/poll] provider={provider or '-'} failed: {error.message}") return _social_error(error) return jsonify(result) @api_social_bp.route("/api/auth/social/callback", methods=["POST"]) def social_callback(): data = _get_json_payload() provider = str(data.get("provider") or data.get("type") or "").strip().lower() code = str(data.get("code") or "").strip() mode = "bind" if str(data.get("mode") or "").strip().lower() == "bind" else "login" try: profile = fetch_space_profile(database.get_system_config(), provider=provider, code=code) except SocialLoginError as error: logger.warning(f"[social/callback] provider={provider or '-'} mode={mode} failed: {error.message}") return _social_error(error) binding = database.find_social_login_binding(profile.provider, profile.social_uid) if binding: if mode == "bind": current_id = int(getattr(current_user, "id", 0) or 0) if not current_id or int(binding.get("user_id") or 0) != current_id: return jsonify({"error": "该第三方账号已绑定其他用户"}), 409 user = database.get_user_by_id(int(binding["user_id"])) if not user or user.get("status") != "approved": return jsonify({"error": "绑定账号不可用"}), 401 database.update_social_login_binding_profile( int(binding["id"]), nickname=profile.nickname, avatar_url=profile.avatar_url, ) _login_user_id(int(user["id"])) return jsonify( { "success": True, "mode": mode, "provider": profile.provider, "provider_label": provider_label(profile.provider), "bound": True, "username": user.get("username") or "", } ) pending = _create_pending_from_profile(profile) return jsonify( { "success": True, "mode": mode, "provider": profile.provider, "provider_label": provider_label(profile.provider), "requires_register": mode == "login", "requires_existing_login": mode == "bind", "bind_token": pending.get("token"), "expires_in": BIND_TOKEN_TTL_SECONDS, "nickname": pending.get("nickname") or "", "avatar_url": pending.get("avatar_url") or "", } ) @api_social_bp.route("/api/user/social-bindings", methods=["GET"]) @login_required def list_social_bindings(): cfg = database.get_system_config() providers = parse_providers(cfg.get("social_login_providers")) or list(PROVIDER_LABELS.keys()) existing = { item["provider"]: item for item in database.list_social_login_bindings(int(current_user.id)) } return jsonify({"items": [_binding_row(provider, existing.get(provider)) for provider in providers]}) @api_social_bp.route("/api/user/social-bindings", methods=["POST"]) @login_required def bind_social_account(): data = _get_json_payload() token = str(data.get("bind_token") or "").strip() pending = database.get_social_pending_bind(token) if not pending: return jsonify({"error": "绑定凭证已过期,请重新授权"}), 404 provider = str(pending.get("provider") or "").strip().lower() social_uid = str(pending.get("social_uid") or "").strip() existing_identity = database.find_social_login_binding(provider, social_uid) if existing_identity and int(existing_identity.get("user_id") or 0) != int(current_user.id): return jsonify({"error": "该第三方账号已绑定其他用户"}), 409 existing_provider = database.find_user_social_login_binding(int(current_user.id), provider) if existing_provider and str(existing_provider.get("social_uid") or "") != social_uid: return jsonify({"error": f"当前账号已绑定{provider_label(provider)}"}), 409 binding = database.upsert_social_login_binding( user_id=int(current_user.id), provider=provider, social_uid=social_uid, nickname=pending.get("nickname") or "", avatar_url=pending.get("avatar_url") or "", ) if not binding: return jsonify({"error": "该第三方账号已绑定其他用户"}), 409 database.delete_social_pending_bind(token) return jsonify({"success": True, "item": _binding_row(provider, binding)}) @api_social_bp.route("/api/user/social-bindings/", methods=["DELETE"]) @login_required def unbind_social_account(provider): provider = str(provider or "").strip().lower() if provider not in PROVIDER_LABELS: return jsonify({"error": "不支持的登录方式"}), 400 if not database.delete_social_login_binding(int(current_user.id), provider): return jsonify({"error": "绑定记录不存在"}), 404 return jsonify({"success": True}) @api_social_bp.route("/yuyx/api/social-login/config", methods=["GET"]) def admin_social_config(): from routes.decorators import admin_required protected = admin_required(lambda: jsonify(admin_social_config_out(database.get_system_config()))) return protected() @api_social_bp.route("/yuyx/api/social-login/test", methods=["POST"]) def test_admin_social_config(): from routes.decorators import admin_required @admin_required def _inner(): data = _get_json_payload() provider = str(data.get("provider") or "wx").strip().lower() appkey = str(data.get("social_login_appkey") or "").strip() old_config = database.get_system_config() or {} try: temp_config = { **old_config, "social_login_enabled": 1, "social_login_endpoint": normalize_social_endpoint(str(data.get("social_login_endpoint") or "")), "social_login_appid": str(data.get("social_login_appid") or "").strip(), "social_login_appkey": encrypt_social_appkey(appkey) if appkey else str(old_config.get("social_login_appkey") or ""), "social_login_providers": ",".join(parse_providers(data.get("social_login_providers"))), } redirect_uri = str(data.get("redirect_uri") or "").strip() result = fetch_social_login_url( temp_config, provider=provider, mode="login", redirect_uri=redirect_uri, allowed_hosts=_allowed_redirect_hosts(), ) except ValueError as exc: return jsonify({"error": str(exc)}), 400 except SocialLoginError as error: return _social_error(error) return jsonify( { "success": True, "provider": result.get("provider"), "has_url": bool(result.get("url")), "has_scan_url": bool(result.get("scan_url")), } ) return _inner() @api_social_bp.route("/yuyx/api/social-login/config", methods=["POST"]) def update_admin_social_config(): from routes.decorators import admin_required @admin_required def _inner(): data = _get_json_payload() enabled = 1 if data.get("social_login_enabled") in (1, True, "1", "true", "on") else 0 endpoint_raw = str(data.get("social_login_endpoint") or "").strip() appid = str(data.get("social_login_appid") or "").strip() appkey = str(data.get("social_login_appkey") or "").strip() providers = parse_providers(data.get("social_login_providers")) try: endpoint = normalize_social_endpoint(endpoint_raw) except ValueError as exc: return jsonify({"error": str(exc)}), 400 old_config = database.get_system_config() or {} old_key = str(old_config.get("social_login_appkey") or "") if enabled: if not providers: return jsonify({"error": "启用聚合登录时至少选择一种登录方式"}), 400 if not appid: return jsonify({"error": "启用聚合登录时必须填写 APPID"}), 400 if not appkey and not old_key: return jsonify({"error": "启用聚合登录时必须填写 APPKEY"}), 400 encrypted_key = encrypt_social_appkey(appkey) if appkey else old_key ok = database.update_system_config( social_login_enabled=enabled, social_login_endpoint=endpoint, social_login_appid=appid, social_login_appkey=encrypted_key, social_login_providers=",".join(providers), ) if not ok: return jsonify({"error": "更新失败"}), 400 logger.info(f"[social/config] updated enabled={enabled} providers={','.join(providers)} at={get_cst_now_str()}") return jsonify({"message": "聚合登录配置已更新", "config": admin_social_config_out(database.get_system_config())}) return _inner()