Files
zsglpt/services/social_login.py
2026-05-27 20:39:46 +08:00

307 lines
12 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
from __future__ import annotations
import re
import secrets
from dataclasses import dataclass
from typing import Any
from urllib.parse import parse_qs, urlencode, urljoin, urlparse
import requests
from crypto_utils import decrypt_password, encrypt_password
DEFAULT_SPACE_ENDPOINT = "https://www.spacezs.cn/connect.php"
PROVIDER_LABELS = {"qq": "QQ", "wx": "微信", "alipay": "支付宝"}
SUPPORTED_PROVIDERS = set(PROVIDER_LABELS)
BIND_TOKEN_TTL_SECONDS = 600
SPACE_BROWSER_HEADERS = {
"User-Agent": (
"Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 "
"(KHTML, like Gecko) Chrome/126.0.0.0 Safari/537.36"
),
"Accept": "text/html,application/xhtml+xml,application/xml;q=0.9,*/*;q=0.8",
}
class SocialLoginError(Exception):
def __init__(self, message: str, status_code: int = 400):
super().__init__(message)
self.message = message
self.status_code = int(status_code or 400)
@dataclass
class SpaceProfile:
provider: str
social_uid: str
nickname: str = ""
avatar_url: str = ""
def provider_label(provider: str) -> str:
return PROVIDER_LABELS.get(provider, provider)
def normalize_social_endpoint(value: str) -> str:
raw = (value or DEFAULT_SPACE_ENDPOINT).strip() or DEFAULT_SPACE_ENDPOINT
parsed = urlparse(raw)
if parsed.scheme not in {"http", "https"} or not parsed.netloc:
raise ValueError("聚合接口地址必须是 http 或 https 地址")
path = parsed.path.rstrip("/")
if path in {"", "/"}:
return parsed._replace(path="/connect.php", params="", query="", fragment="").geturl()
if path.endswith("/connect.php"):
return parsed._replace(path=path, params="", query="", fragment="").geturl()
raise ValueError("聚合接口地址必须指向 connect.php")
def parse_providers(value: str | list[str] | tuple[str, ...] | None) -> list[str]:
items = value.split(",") if isinstance(value, str) else list(value or [])
result: list[str] = []
for item in items:
provider = str(item or "").strip().lower()
if provider in SUPPORTED_PROVIDERS and provider not in result:
result.append(provider)
return result
def social_appkey(config: dict[str, Any]) -> str:
encrypted = str((config or {}).get("social_login_appkey") or "").strip()
if not encrypted:
return ""
return decrypt_password(encrypted)
def encrypt_social_appkey(value: str) -> str:
raw = str(value or "").strip()
return encrypt_password(raw) if raw else ""
def mask_secret(value: str) -> str:
if not value:
return ""
if len(value) <= 8:
return f"{value[:2]}***"
return f"{value[:4]}***{value[-4:]}"
def admin_social_config_out(config: dict[str, Any]) -> dict[str, Any]:
key = social_appkey(config)
return {
"social_login_enabled": int(config.get("social_login_enabled") or 0),
"social_login_endpoint": config.get("social_login_endpoint") or DEFAULT_SPACE_ENDPOINT,
"social_login_appid": config.get("social_login_appid") or "",
"social_login_appkey": "",
"social_login_appkey_configured": bool(key),
"social_login_appkey_masked": mask_secret(key),
"social_login_providers": parse_providers(config.get("social_login_providers")) or ["qq", "wx", "alipay"],
}
def public_social_config(config: dict[str, Any]) -> dict[str, Any]:
providers = parse_providers(config.get("social_login_providers"))
configured = bool(config.get("social_login_appid") and social_appkey(config))
enabled = bool(int(config.get("social_login_enabled") or 0) == 1 and configured and providers)
return {"enabled": enabled, "providers": providers if enabled else []}
def validate_social_ready(config: dict[str, Any], provider: str) -> None:
provider = str(provider or "").strip().lower()
if int(config.get("social_login_enabled") or 0) != 1:
raise SocialLoginError("聚合登录未启用", 409)
if provider not in parse_providers(config.get("social_login_providers")):
raise SocialLoginError("该登录方式未启用", 400)
if not config.get("social_login_appid") or not social_appkey(config):
raise SocialLoginError("聚合登录配置不完整", 409)
def validate_redirect_uri(redirect_uri: str, *, allowed_hosts: set[str] | None = None) -> str:
uri = str(redirect_uri or "").strip()
parsed = urlparse(uri)
if parsed.scheme not in {"http", "https"} or not parsed.netloc:
raise SocialLoginError("redirect_uri 必须是完整的 http/https 地址", 400)
if allowed_hosts:
host = parsed.netloc.lower()
hostname = (parsed.hostname or "").lower()
if host not in allowed_hosts and hostname not in allowed_hosts:
raise SocialLoginError("redirect_uri 域名不在允许范围内", 400)
return uri
def _space_base_url(endpoint: str) -> str:
parsed = urlparse(endpoint)
return f"{parsed.scheme}://{parsed.netloc}"
def _parse_space_json(payload: dict[str, Any], error_message: str) -> dict[str, Any]:
try:
code_value = int(payload.get("code"))
except (TypeError, ValueError):
code_value = -1
if code_value != 0:
raise SocialLoginError(str(payload.get("msg") or error_message), 409)
return payload
def parse_space_scan_page(html: str) -> tuple[str, str]:
qrcode_match = re.search(r'var\s+qrcode_url\s*=\s*["\']([^"\']+)["\']', html or "")
state_match = re.search(r'var\s+state\s*=\s*["\']([^"\']+)["\']', html or "")
qrcode_url = qrcode_match.group(1).strip() if qrcode_match else ""
state = state_match.group(1).strip() if state_match else ""
if not qrcode_url or not state:
raise SocialLoginError("聚合登录扫码页面解析失败", 409)
return qrcode_url, state
def fetch_social_login_url(
config: dict[str, Any],
*,
provider: str,
mode: str,
redirect_uri: str,
allowed_hosts: set[str] | None = None,
) -> dict[str, Any]:
provider = str(provider or "").strip().lower()
mode = "bind" if str(mode or "").strip().lower() == "bind" else "login"
validate_social_ready(config, provider)
redirect_uri = validate_redirect_uri(redirect_uri, allowed_hosts=allowed_hosts)
endpoint = normalize_social_endpoint(str(config.get("social_login_endpoint") or ""))
params = {
"act": "login",
"appid": str(config.get("social_login_appid") or "").strip(),
"appkey": social_appkey(config),
"type": provider,
"redirect_uri": redirect_uri,
}
try:
response = requests.get(endpoint, params=params, timeout=15)
payload = response.json()
except ValueError as exc:
raise SocialLoginError("聚合接口未返回 JSON请检查 endpoint 是否为 connect.php", 409) from exc
except requests.RequestException as exc:
raise SocialLoginError("聚合接口连接失败", 409) from exc
data = _parse_space_json(payload, "获取聚合登录地址失败")
url = str(data.get("url") or "")
if not url:
raise SocialLoginError("聚合接口未返回授权地址", 409)
result = {
"provider": provider,
"mode": mode,
"url": url,
"qrcode": str(data.get("qrcode") or ""),
"scan_url": "",
"scan_state": "",
"scan_poll_interval": 2,
}
if provider == "wx":
scan_page_url = result["qrcode"] or result["url"]
try:
scan_response = requests.get(scan_page_url, headers=SPACE_BROWSER_HEADERS, timeout=15)
except requests.RequestException as exc:
raise SocialLoginError("微信扫码页获取失败", 409) from exc
scan_url, state = parse_space_scan_page(scan_response.text)
result["scan_url"] = scan_url
result["scan_state"] = state
return result
def poll_social_scan(config: dict[str, Any], *, provider: str, state: str) -> dict[str, Any]:
provider = str(provider or "").strip().lower()
validate_social_ready(config, provider)
if provider != "wx":
raise SocialLoginError("该登录方式不需要轮询", 400)
state_value = str(state or "").strip()
if not state_value:
raise SocialLoginError("缺少扫码状态", 400)
endpoint = normalize_social_endpoint(str(config.get("social_login_endpoint") or ""))
base_url = _space_base_url(endpoint)
ajax_url = urljoin(base_url + "/", "ajax.php")
headers = {
**SPACE_BROWSER_HEADERS,
"Referer": f"{base_url}/jump.php?state={state_value}&client=1",
"X-Requested-With": "XMLHttpRequest",
"Accept": "application/json, text/javascript, */*; q=0.01",
}
try:
response = requests.get(ajax_url, params={"act": "login", "state": state_value}, headers=headers, timeout=10)
payload = response.json()
except ValueError as exc:
raise SocialLoginError("扫码轮询未返回 JSON", 409) from exc
except requests.RequestException as exc:
raise SocialLoginError("扫码轮询连接失败", 409) from exc
try:
code_value = int(payload.get("code"))
except (TypeError, ValueError):
code_value = -1
if code_value == 1:
return {"status": "pending"}
if code_value == 0:
url = str(payload.get("url") or "")
if not url:
raise SocialLoginError("扫码成功但未返回回调地址", 409)
return {"status": "authorized", "url": url}
raise SocialLoginError(str(payload.get("msg") or "扫码登录失败"), 409)
def _profile_uid(payload: dict[str, Any]) -> str:
for key in ("social_uid", "uid", "openid", "unionid", "id"):
value = str(payload.get(key) or "").strip()
if value:
return value
raise SocialLoginError("聚合回调未返回用户唯一标识", 409)
def fetch_space_profile(config: dict[str, Any], *, provider: str, code: str) -> SpaceProfile:
provider = str(provider or "").strip().lower()
code_value = str(code or "").strip()
validate_social_ready(config, provider)
if not code_value:
raise SocialLoginError("缺少授权 code", 400)
endpoint = normalize_social_endpoint(str(config.get("social_login_endpoint") or ""))
params = {
"act": "callback",
"appid": str(config.get("social_login_appid") or "").strip(),
"appkey": social_appkey(config),
"type": provider,
"code": code_value,
}
try:
response = requests.get(endpoint, params=params, timeout=15)
payload = response.json()
except ValueError as exc:
raise SocialLoginError("聚合回调接口未返回 JSON", 409) from exc
except requests.RequestException as exc:
raise SocialLoginError("聚合回调接口连接失败", 409) from exc
data = _parse_space_json(payload, "聚合登录回调失败")
return SpaceProfile(
provider=provider,
social_uid=_profile_uid(data),
nickname=str(data.get("nickname") or data.get("nick") or data.get("name") or ""),
avatar_url=str(data.get("faceimg") or data.get("avatar") or data.get("avatar_url") or ""),
)
def callback_mode_from_redirect_query(redirect_url: str, fallback: str = "login") -> str:
parsed = urlparse(redirect_url)
mode = parse_qs(parsed.query).get("mode", [fallback])[0]
return "bind" if mode == "bind" else "login"
def append_query(url: str, values: dict[str, str]) -> str:
parsed = urlparse(url)
query = parse_qs(parsed.query)
for key, value in values.items():
query[key] = [value]
return parsed._replace(query=urlencode(query, doseq=True)).geturl()
def new_bind_token() -> str:
return secrets.token_urlsafe(32)