307 lines
12 KiB
Python
307 lines
12 KiB
Python
#!/usr/bin/env python3
|
||
# -*- coding: utf-8 -*-
|
||
from __future__ import annotations
|
||
|
||
import re
|
||
import secrets
|
||
from dataclasses import dataclass
|
||
from typing import Any
|
||
from urllib.parse import parse_qs, urlencode, urljoin, urlparse
|
||
|
||
import requests
|
||
|
||
from crypto_utils import decrypt_password, encrypt_password
|
||
|
||
DEFAULT_SPACE_ENDPOINT = "https://www.spacezs.cn/connect.php"
|
||
PROVIDER_LABELS = {"qq": "QQ", "wx": "微信", "alipay": "支付宝"}
|
||
SUPPORTED_PROVIDERS = set(PROVIDER_LABELS)
|
||
BIND_TOKEN_TTL_SECONDS = 600
|
||
SPACE_BROWSER_HEADERS = {
|
||
"User-Agent": (
|
||
"Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 "
|
||
"(KHTML, like Gecko) Chrome/126.0.0.0 Safari/537.36"
|
||
),
|
||
"Accept": "text/html,application/xhtml+xml,application/xml;q=0.9,*/*;q=0.8",
|
||
}
|
||
|
||
|
||
class SocialLoginError(Exception):
|
||
def __init__(self, message: str, status_code: int = 400):
|
||
super().__init__(message)
|
||
self.message = message
|
||
self.status_code = int(status_code or 400)
|
||
|
||
|
||
@dataclass
|
||
class SpaceProfile:
|
||
provider: str
|
||
social_uid: str
|
||
nickname: str = ""
|
||
avatar_url: str = ""
|
||
|
||
|
||
def provider_label(provider: str) -> str:
|
||
return PROVIDER_LABELS.get(provider, provider)
|
||
|
||
|
||
def normalize_social_endpoint(value: str) -> str:
|
||
raw = (value or DEFAULT_SPACE_ENDPOINT).strip() or DEFAULT_SPACE_ENDPOINT
|
||
parsed = urlparse(raw)
|
||
if parsed.scheme not in {"http", "https"} or not parsed.netloc:
|
||
raise ValueError("聚合接口地址必须是 http 或 https 地址")
|
||
path = parsed.path.rstrip("/")
|
||
if path in {"", "/"}:
|
||
return parsed._replace(path="/connect.php", params="", query="", fragment="").geturl()
|
||
if path.endswith("/connect.php"):
|
||
return parsed._replace(path=path, params="", query="", fragment="").geturl()
|
||
raise ValueError("聚合接口地址必须指向 connect.php")
|
||
|
||
|
||
def parse_providers(value: str | list[str] | tuple[str, ...] | None) -> list[str]:
|
||
items = value.split(",") if isinstance(value, str) else list(value or [])
|
||
result: list[str] = []
|
||
for item in items:
|
||
provider = str(item or "").strip().lower()
|
||
if provider in SUPPORTED_PROVIDERS and provider not in result:
|
||
result.append(provider)
|
||
return result
|
||
|
||
|
||
def social_appkey(config: dict[str, Any]) -> str:
|
||
encrypted = str((config or {}).get("social_login_appkey") or "").strip()
|
||
if not encrypted:
|
||
return ""
|
||
return decrypt_password(encrypted)
|
||
|
||
|
||
def encrypt_social_appkey(value: str) -> str:
|
||
raw = str(value or "").strip()
|
||
return encrypt_password(raw) if raw else ""
|
||
|
||
|
||
def mask_secret(value: str) -> str:
|
||
if not value:
|
||
return ""
|
||
if len(value) <= 8:
|
||
return f"{value[:2]}***"
|
||
return f"{value[:4]}***{value[-4:]}"
|
||
|
||
|
||
def admin_social_config_out(config: dict[str, Any]) -> dict[str, Any]:
|
||
key = social_appkey(config)
|
||
return {
|
||
"social_login_enabled": int(config.get("social_login_enabled") or 0),
|
||
"social_login_endpoint": config.get("social_login_endpoint") or DEFAULT_SPACE_ENDPOINT,
|
||
"social_login_appid": config.get("social_login_appid") or "",
|
||
"social_login_appkey": "",
|
||
"social_login_appkey_configured": bool(key),
|
||
"social_login_appkey_masked": mask_secret(key),
|
||
"social_login_providers": parse_providers(config.get("social_login_providers")) or ["qq", "wx", "alipay"],
|
||
}
|
||
|
||
|
||
def public_social_config(config: dict[str, Any]) -> dict[str, Any]:
|
||
providers = parse_providers(config.get("social_login_providers"))
|
||
configured = bool(config.get("social_login_appid") and social_appkey(config))
|
||
enabled = bool(int(config.get("social_login_enabled") or 0) == 1 and configured and providers)
|
||
return {"enabled": enabled, "providers": providers if enabled else []}
|
||
|
||
|
||
def validate_social_ready(config: dict[str, Any], provider: str) -> None:
|
||
provider = str(provider or "").strip().lower()
|
||
if int(config.get("social_login_enabled") or 0) != 1:
|
||
raise SocialLoginError("聚合登录未启用", 409)
|
||
if provider not in parse_providers(config.get("social_login_providers")):
|
||
raise SocialLoginError("该登录方式未启用", 400)
|
||
if not config.get("social_login_appid") or not social_appkey(config):
|
||
raise SocialLoginError("聚合登录配置不完整", 409)
|
||
|
||
|
||
def validate_redirect_uri(redirect_uri: str, *, allowed_hosts: set[str] | None = None) -> str:
|
||
uri = str(redirect_uri or "").strip()
|
||
parsed = urlparse(uri)
|
||
if parsed.scheme not in {"http", "https"} or not parsed.netloc:
|
||
raise SocialLoginError("redirect_uri 必须是完整的 http/https 地址", 400)
|
||
if allowed_hosts:
|
||
host = parsed.netloc.lower()
|
||
hostname = (parsed.hostname or "").lower()
|
||
if host not in allowed_hosts and hostname not in allowed_hosts:
|
||
raise SocialLoginError("redirect_uri 域名不在允许范围内", 400)
|
||
return uri
|
||
|
||
|
||
def _space_base_url(endpoint: str) -> str:
|
||
parsed = urlparse(endpoint)
|
||
return f"{parsed.scheme}://{parsed.netloc}"
|
||
|
||
|
||
def _parse_space_json(payload: dict[str, Any], error_message: str) -> dict[str, Any]:
|
||
try:
|
||
code_value = int(payload.get("code"))
|
||
except (TypeError, ValueError):
|
||
code_value = -1
|
||
if code_value != 0:
|
||
raise SocialLoginError(str(payload.get("msg") or error_message), 409)
|
||
return payload
|
||
|
||
|
||
def parse_space_scan_page(html: str) -> tuple[str, str]:
|
||
qrcode_match = re.search(r'var\s+qrcode_url\s*=\s*["\']([^"\']+)["\']', html or "")
|
||
state_match = re.search(r'var\s+state\s*=\s*["\']([^"\']+)["\']', html or "")
|
||
qrcode_url = qrcode_match.group(1).strip() if qrcode_match else ""
|
||
state = state_match.group(1).strip() if state_match else ""
|
||
if not qrcode_url or not state:
|
||
raise SocialLoginError("聚合登录扫码页面解析失败", 409)
|
||
return qrcode_url, state
|
||
|
||
|
||
def fetch_social_login_url(
|
||
config: dict[str, Any],
|
||
*,
|
||
provider: str,
|
||
mode: str,
|
||
redirect_uri: str,
|
||
allowed_hosts: set[str] | None = None,
|
||
) -> dict[str, Any]:
|
||
provider = str(provider or "").strip().lower()
|
||
mode = "bind" if str(mode or "").strip().lower() == "bind" else "login"
|
||
validate_social_ready(config, provider)
|
||
redirect_uri = validate_redirect_uri(redirect_uri, allowed_hosts=allowed_hosts)
|
||
endpoint = normalize_social_endpoint(str(config.get("social_login_endpoint") or ""))
|
||
params = {
|
||
"act": "login",
|
||
"appid": str(config.get("social_login_appid") or "").strip(),
|
||
"appkey": social_appkey(config),
|
||
"type": provider,
|
||
"redirect_uri": redirect_uri,
|
||
}
|
||
try:
|
||
response = requests.get(endpoint, params=params, timeout=15)
|
||
payload = response.json()
|
||
except ValueError as exc:
|
||
raise SocialLoginError("聚合接口未返回 JSON,请检查 endpoint 是否为 connect.php", 409) from exc
|
||
except requests.RequestException as exc:
|
||
raise SocialLoginError("聚合接口连接失败", 409) from exc
|
||
|
||
data = _parse_space_json(payload, "获取聚合登录地址失败")
|
||
url = str(data.get("url") or "")
|
||
if not url:
|
||
raise SocialLoginError("聚合接口未返回授权地址", 409)
|
||
|
||
result = {
|
||
"provider": provider,
|
||
"mode": mode,
|
||
"url": url,
|
||
"qrcode": str(data.get("qrcode") or ""),
|
||
"scan_url": "",
|
||
"scan_state": "",
|
||
"scan_poll_interval": 2,
|
||
}
|
||
if provider == "wx":
|
||
scan_page_url = result["qrcode"] or result["url"]
|
||
try:
|
||
scan_response = requests.get(scan_page_url, headers=SPACE_BROWSER_HEADERS, timeout=15)
|
||
except requests.RequestException as exc:
|
||
raise SocialLoginError("微信扫码页获取失败", 409) from exc
|
||
scan_url, state = parse_space_scan_page(scan_response.text)
|
||
result["scan_url"] = scan_url
|
||
result["scan_state"] = state
|
||
return result
|
||
|
||
|
||
def poll_social_scan(config: dict[str, Any], *, provider: str, state: str) -> dict[str, Any]:
|
||
provider = str(provider or "").strip().lower()
|
||
validate_social_ready(config, provider)
|
||
if provider != "wx":
|
||
raise SocialLoginError("该登录方式不需要轮询", 400)
|
||
state_value = str(state or "").strip()
|
||
if not state_value:
|
||
raise SocialLoginError("缺少扫码状态", 400)
|
||
|
||
endpoint = normalize_social_endpoint(str(config.get("social_login_endpoint") or ""))
|
||
base_url = _space_base_url(endpoint)
|
||
ajax_url = urljoin(base_url + "/", "ajax.php")
|
||
headers = {
|
||
**SPACE_BROWSER_HEADERS,
|
||
"Referer": f"{base_url}/jump.php?state={state_value}&client=1",
|
||
"X-Requested-With": "XMLHttpRequest",
|
||
"Accept": "application/json, text/javascript, */*; q=0.01",
|
||
}
|
||
try:
|
||
response = requests.get(ajax_url, params={"act": "login", "state": state_value}, headers=headers, timeout=10)
|
||
payload = response.json()
|
||
except ValueError as exc:
|
||
raise SocialLoginError("扫码轮询未返回 JSON", 409) from exc
|
||
except requests.RequestException as exc:
|
||
raise SocialLoginError("扫码轮询连接失败", 409) from exc
|
||
|
||
try:
|
||
code_value = int(payload.get("code"))
|
||
except (TypeError, ValueError):
|
||
code_value = -1
|
||
if code_value == 1:
|
||
return {"status": "pending"}
|
||
if code_value == 0:
|
||
url = str(payload.get("url") or "")
|
||
if not url:
|
||
raise SocialLoginError("扫码成功但未返回回调地址", 409)
|
||
return {"status": "authorized", "url": url}
|
||
raise SocialLoginError(str(payload.get("msg") or "扫码登录失败"), 409)
|
||
|
||
|
||
def _profile_uid(payload: dict[str, Any]) -> str:
|
||
for key in ("social_uid", "uid", "openid", "unionid", "id"):
|
||
value = str(payload.get(key) or "").strip()
|
||
if value:
|
||
return value
|
||
raise SocialLoginError("聚合回调未返回用户唯一标识", 409)
|
||
|
||
|
||
def fetch_space_profile(config: dict[str, Any], *, provider: str, code: str) -> SpaceProfile:
|
||
provider = str(provider or "").strip().lower()
|
||
code_value = str(code or "").strip()
|
||
validate_social_ready(config, provider)
|
||
if not code_value:
|
||
raise SocialLoginError("缺少授权 code", 400)
|
||
endpoint = normalize_social_endpoint(str(config.get("social_login_endpoint") or ""))
|
||
params = {
|
||
"act": "callback",
|
||
"appid": str(config.get("social_login_appid") or "").strip(),
|
||
"appkey": social_appkey(config),
|
||
"type": provider,
|
||
"code": code_value,
|
||
}
|
||
try:
|
||
response = requests.get(endpoint, params=params, timeout=15)
|
||
payload = response.json()
|
||
except ValueError as exc:
|
||
raise SocialLoginError("聚合回调接口未返回 JSON", 409) from exc
|
||
except requests.RequestException as exc:
|
||
raise SocialLoginError("聚合回调接口连接失败", 409) from exc
|
||
|
||
data = _parse_space_json(payload, "聚合登录回调失败")
|
||
return SpaceProfile(
|
||
provider=provider,
|
||
social_uid=_profile_uid(data),
|
||
nickname=str(data.get("nickname") or data.get("nick") or data.get("name") or ""),
|
||
avatar_url=str(data.get("faceimg") or data.get("avatar") or data.get("avatar_url") or ""),
|
||
)
|
||
|
||
|
||
def callback_mode_from_redirect_query(redirect_url: str, fallback: str = "login") -> str:
|
||
parsed = urlparse(redirect_url)
|
||
mode = parse_qs(parsed.query).get("mode", [fallback])[0]
|
||
return "bind" if mode == "bind" else "login"
|
||
|
||
|
||
def append_query(url: str, values: dict[str, str]) -> str:
|
||
parsed = urlparse(url)
|
||
query = parse_qs(parsed.query)
|
||
for key, value in values.items():
|
||
query[key] = [value]
|
||
return parsed._replace(query=urlencode(query, doseq=True)).geturl()
|
||
|
||
|
||
def new_bind_token() -> str:
|
||
return secrets.token_urlsafe(32)
|