安全修复:加固CSRF与凭证保护并修复越权风险

This commit is contained in:
2026-02-16 01:19:43 +08:00
parent 14b506e8a1
commit 1389ec7434
22 changed files with 375 additions and 83 deletions

View File

@@ -117,7 +117,11 @@ def get_cookie_jar_path(username: str) -> str:
"""获取截图用的 cookies 文件路径Netscape Cookie 格式)""" """获取截图用的 cookies 文件路径Netscape Cookie 格式)"""
import hashlib import hashlib
os.makedirs(COOKIES_DIR, exist_ok=True) os.makedirs(COOKIES_DIR, mode=0o700, exist_ok=True)
try:
os.chmod(COOKIES_DIR, 0o700)
except Exception:
pass
filename = hashlib.sha256(username.encode()).hexdigest()[:32] + ".cookies.txt" filename = hashlib.sha256(username.encode()).hexdigest()[:32] + ".cookies.txt"
return os.path.join(COOKIES_DIR, filename) return os.path.join(COOKIES_DIR, filename)
@@ -260,6 +264,10 @@ class APIBrowser:
with open(cookies_path, "w", encoding="utf-8") as f: with open(cookies_path, "w", encoding="utf-8") as f:
f.write("\n".join(lines) + "\n") f.write("\n".join(lines) + "\n")
try:
os.chmod(cookies_path, 0o600)
except Exception:
pass
self.log(f"[API] Cookies已保存供截图使用") self.log(f"[API] Cookies已保存供截图使用")
return True return True

5
app.py
View File

@@ -212,11 +212,12 @@ def enforce_csrf_protection():
return return
if request.path.startswith("/static/"): if request.path.startswith("/static/"):
return return
# 登录相关路由豁免 CSRF 检查(登录本身就是建立 session 的过程 # 登录挑战相关路由豁免 CSRF(会话尚未建立前需要可用
csrf_exempt_paths = { csrf_exempt_paths = {
"/yuyx/api/login", "/yuyx/api/login",
"/api/login", "/api/login",
"/api/auth/login", "/api/auth/login",
"/api/generate_captcha",
"/yuyx/api/passkeys/login/options", "/yuyx/api/passkeys/login/options",
"/yuyx/api/passkeys/login/verify", "/yuyx/api/passkeys/login/verify",
"/api/passkeys/login/options", "/api/passkeys/login/options",
@@ -224,8 +225,6 @@ def enforce_csrf_protection():
} }
if request.path in csrf_exempt_paths: if request.path in csrf_exempt_paths:
return return
if not (current_user.is_authenticated or "admin_id" in session):
return
token = request.headers.get("X-CSRF-Token") or request.form.get("csrf_token") token = request.headers.get("X-CSRF-Token") or request.form.get("csrf_token")
if not token or not validate_csrf_token(token): if not token or not validate_csrf_token(token):
return jsonify({"error": "CSRF token missing or invalid"}), 403 return jsonify({"error": "CSRF token missing or invalid"}), 403

View File

@@ -33,6 +33,23 @@ except ImportError:
SECRET_KEY_FILE = "data/secret_key.txt" SECRET_KEY_FILE = "data/secret_key.txt"
def _ensure_private_dir(path: str) -> None:
if not path:
return
os.makedirs(path, mode=0o700, exist_ok=True)
try:
os.chmod(path, 0o700)
except Exception:
pass
def _ensure_private_file(path: str) -> None:
try:
os.chmod(path, 0o600)
except Exception:
pass
def get_secret_key(): def get_secret_key():
"""获取SECRET_KEY优先环境变量""" """获取SECRET_KEY优先环境变量"""
# 优先从环境变量读取 # 优先从环境变量读取
@@ -42,14 +59,16 @@ def get_secret_key():
# 从文件读取 # 从文件读取
if os.path.exists(SECRET_KEY_FILE): if os.path.exists(SECRET_KEY_FILE):
_ensure_private_file(SECRET_KEY_FILE)
with open(SECRET_KEY_FILE, "r") as f: with open(SECRET_KEY_FILE, "r") as f:
return f.read().strip() return f.read().strip()
# 生成新的 # 生成新的
new_key = os.urandom(24).hex() new_key = os.urandom(24).hex()
os.makedirs("data", exist_ok=True) _ensure_private_dir("data")
with open(SECRET_KEY_FILE, "w") as f: with open(SECRET_KEY_FILE, "w") as f:
f.write(new_key) f.write(new_key)
_ensure_private_file(SECRET_KEY_FILE)
print(f"[OK] 已生成新的SECRET_KEY并保存到 {SECRET_KEY_FILE}") print(f"[OK] 已生成新的SECRET_KEY并保存到 {SECRET_KEY_FILE}")
return new_key return new_key
@@ -203,7 +222,7 @@ class Config:
SERVER_PORT = int(os.environ.get("SERVER_PORT", "51233")) SERVER_PORT = int(os.environ.get("SERVER_PORT", "51233"))
# ==================== SocketIO配置 ==================== # ==================== SocketIO配置 ====================
SOCKETIO_CORS_ALLOWED_ORIGINS = os.environ.get("SOCKETIO_CORS_ALLOWED_ORIGINS", "*") SOCKETIO_CORS_ALLOWED_ORIGINS = os.environ.get("SOCKETIO_CORS_ALLOWED_ORIGINS", "")
# ==================== 网站基础URL配置 ==================== # ==================== 网站基础URL配置 ====================
# 用于生成邮件中的验证链接等 # 用于生成邮件中的验证链接等

View File

@@ -9,6 +9,7 @@ import os
import re import re
import time import time
import hashlib import hashlib
import hmac
import secrets import secrets
import ipaddress import ipaddress
import socket import socket
@@ -78,7 +79,13 @@ def sanitize_filename(filename):
class IPRateLimiter: class IPRateLimiter:
"""IP访问频率限制器""" """IP访问频率限制器"""
def __init__(self, max_attempts=10, window_seconds=3600, lock_duration=3600): def __init__(
self,
max_attempts=10,
window_seconds=3600,
lock_duration=3600,
max_tracked_ips=20000,
):
""" """
初始化限流器 初始化限流器
@@ -90,6 +97,7 @@ class IPRateLimiter:
self.max_attempts = max_attempts self.max_attempts = max_attempts
self.window_seconds = window_seconds self.window_seconds = window_seconds
self.lock_duration = lock_duration self.lock_duration = lock_duration
self.max_tracked_ips = max(1000, int(max_tracked_ips or 0))
# IP访问记录: {ip: [(timestamp, success), ...]} # IP访问记录: {ip: [(timestamp, success), ...]}
self._attempts = defaultdict(list) self._attempts = defaultdict(list)
@@ -97,6 +105,47 @@ class IPRateLimiter:
self._locked = {} self._locked = {}
self._lock = threading.Lock() self._lock = threading.Lock()
def _prune_if_oversized(self, now_ts: float) -> None:
"""限制内部映射大小避免在高频随机IP攻击下持续膨胀。"""
tracked = len(self._attempts) + len(self._locked)
if tracked <= self.max_tracked_ips:
return
cutoff_time = now_ts - self.window_seconds
for ip in list(self._attempts.keys()):
self._attempts[ip] = [
(ts, succ) for ts, succ in self._attempts[ip]
if ts > cutoff_time
]
if not self._attempts[ip]:
del self._attempts[ip]
for ip in list(self._locked.keys()):
if now_ts >= self._locked[ip]:
del self._locked[ip]
tracked = len(self._attempts) + len(self._locked)
if tracked <= self.max_tracked_ips:
return
# 优先按“最近访问时间最早”淘汰 attempts 中的 IP 记录。
overflow = tracked - self.max_tracked_ips
oldest = []
for ip, attempt_items in self._attempts.items():
if attempt_items:
oldest.append((attempt_items[-1][0], ip))
else:
oldest.append((0.0, ip))
oldest.sort(key=lambda item: item[0])
removed = 0
for _, ip in oldest:
self._attempts.pop(ip, None)
self._locked.pop(ip, None)
removed += 1
if removed >= overflow:
break
def is_locked(self, ip_address): def is_locked(self, ip_address):
""" """
检查IP是否被锁定 检查IP是否被锁定
@@ -129,6 +178,7 @@ class IPRateLimiter:
""" """
with self._lock: with self._lock:
now = time.time() now = time.time()
self._prune_if_oversized(now)
# 清理过期记录 # 清理过期记录
cutoff_time = now - self.window_seconds cutoff_time = now - self.window_seconds
@@ -357,7 +407,19 @@ def generate_csrf_token():
def validate_csrf_token(token): def validate_csrf_token(token):
"""验证CSRF令牌""" """验证CSRF令牌"""
return token == session.get('csrf_token') expected = session.get("csrf_token")
if (token is None) or (expected is None):
return False
provided_text = str(token or "")
expected_text = str(expected or "")
if (not provided_text) or (not expected_text):
return False
return hmac.compare_digest(
provided_text.encode("utf-8"),
expected_text.encode("utf-8"),
)
# ==================== 内容安全 ==================== # ==================== 内容安全 ====================

View File

@@ -14,6 +14,7 @@
import os import os
import sys import sys
import base64 import base64
import threading
from pathlib import Path from pathlib import Path
from cryptography.fernet import Fernet from cryptography.fernet import Fernet
from cryptography.hazmat.primitives import hashes from cryptography.hazmat.primitives import hashes
@@ -27,18 +28,37 @@ ENCRYPTION_KEY_FILE = os.environ.get('ENCRYPTION_KEY_FILE', 'data/encryption_key
ENCRYPTION_SALT_FILE = os.environ.get('ENCRYPTION_SALT_FILE', 'data/encryption_salt.bin') ENCRYPTION_SALT_FILE = os.environ.get('ENCRYPTION_SALT_FILE', 'data/encryption_salt.bin')
def _ensure_private_dir(path: Path) -> None:
if not path:
return
os.makedirs(path, mode=0o700, exist_ok=True)
try:
os.chmod(path, 0o700)
except Exception:
pass
def _ensure_private_file(path: Path) -> None:
try:
os.chmod(path, 0o600)
except Exception:
pass
def _get_or_create_salt(): def _get_or_create_salt():
"""获取或创建盐值""" """获取或创建盐值"""
salt_path = Path(ENCRYPTION_SALT_FILE) salt_path = Path(ENCRYPTION_SALT_FILE)
if salt_path.exists(): if salt_path.exists():
_ensure_private_file(salt_path)
with open(salt_path, 'rb') as f: with open(salt_path, 'rb') as f:
return f.read() return f.read()
# 生成新的盐值 # 生成新的盐值
salt = os.urandom(16) salt = os.urandom(16)
os.makedirs(salt_path.parent, exist_ok=True) _ensure_private_dir(salt_path.parent)
with open(salt_path, 'wb') as f: with open(salt_path, 'wb') as f:
f.write(salt) f.write(salt)
_ensure_private_file(salt_path)
return salt return salt
@@ -102,6 +122,7 @@ def get_encryption_key():
key_path = Path(ENCRYPTION_KEY_FILE) key_path = Path(ENCRYPTION_KEY_FILE)
if key_path.exists(): if key_path.exists():
logger.info(f"从文件 {ENCRYPTION_KEY_FILE} 读取加密密钥") logger.info(f"从文件 {ENCRYPTION_KEY_FILE} 读取加密密钥")
_ensure_private_file(key_path)
with open(key_path, 'rb') as f: with open(key_path, 'rb') as f:
return f.read() return f.read()
@@ -127,9 +148,10 @@ def get_encryption_key():
# 生成新的密钥 # 生成新的密钥
key = Fernet.generate_key() key = Fernet.generate_key()
os.makedirs(key_path.parent, exist_ok=True) _ensure_private_dir(key_path.parent)
with open(key_path, 'wb') as f: with open(key_path, 'wb') as f:
f.write(key) f.write(key)
_ensure_private_file(key_path)
logger.info(f"已生成新的加密密钥并保存到 {ENCRYPTION_KEY_FILE}") logger.info(f"已生成新的加密密钥并保存到 {ENCRYPTION_KEY_FILE}")
logger.warning("请立即备份此密钥文件,并建议设置 ENCRYPTION_KEY_RAW 环境变量!") logger.warning("请立即备份此密钥文件,并建议设置 ENCRYPTION_KEY_RAW 环境变量!")
return key return key
@@ -137,11 +159,14 @@ def get_encryption_key():
# 全局Fernet实例 # 全局Fernet实例
_fernet = None _fernet = None
_fernet_lock = threading.Lock()
def _get_fernet(): def _get_fernet():
"""获取Fernet加密器懒加载""" """获取Fernet加密器懒加载"""
global _fernet global _fernet
if _fernet is None:
with _fernet_lock:
if _fernet is None: if _fernet is None:
key = get_encryption_key() key = get_encryption_key()
_fernet = Fernet(key) _fernet = Fernet(key)
@@ -187,7 +212,7 @@ def decrypt_password(encrypted_password: str) -> str:
# 解密失败,可能是旧的明文密码或密钥不匹配 # 解密失败,可能是旧的明文密码或密钥不匹配
if is_encrypted(encrypted_password): if is_encrypted(encrypted_password):
logger.error(f"密码解密失败(密钥可能不匹配): {e}") logger.error(f"密码解密失败(密钥可能不匹配): {e}")
else: return ''
logger.warning(f"密码解密失败,可能是未加密的旧数据: {e}") logger.warning(f"密码解密失败,可能是未加密的旧数据: {e}")
return encrypted_password return encrypted_password

View File

@@ -1,4 +0,0 @@
# Netscape HTTP Cookie File
# This file was generated by zsglpt
postoa.aidunsoft.com FALSE / FALSE 0 ASP.NET_SessionId xtjioeuz4yvk4bx3xqyt0pyp
postoa.aidunsoft.com FALSE / FALSE 1800092244 UserInfo userName=13974663700&Pwd=9B8DC766B11550651353D98805B4995B

View File

@@ -1 +0,0 @@
_S5Vpk71XaK9bm5U8jHJe-x2ASm38YWNweVlmCcIauM=

File diff suppressed because one or more lines are too long

View File

@@ -1 +0,0 @@
4abccefe523ed05bdbb717d1153e202d25ade95458c4d78e

View File

@@ -109,6 +109,7 @@ from db.users import (
delete_user, delete_user,
extend_user_vip, extend_user_vip,
get_all_users, get_all_users,
get_users_count,
get_pending_users, get_pending_users,
get_user_by_id, get_user_by_id,
get_user_by_username, get_user_by_username,

View File

@@ -25,6 +25,20 @@ _USER_LOOKUP_SQL = {
"id": "SELECT * FROM users WHERE id = ?", "id": "SELECT * FROM users WHERE id = ?",
"username": "SELECT * FROM users WHERE username = ?", "username": "SELECT * FROM users WHERE username = ?",
} }
_USER_ADMIN_SAFE_COLUMNS = (
"id",
"username",
"email",
"email_verified",
"email_notify_enabled",
"kdocs_unit",
"kdocs_auto_upload",
"status",
"vip_expire_time",
"created_at",
"approved_at",
)
_USER_ADMIN_SAFE_COLUMNS_SQL = ", ".join(_USER_ADMIN_SAFE_COLUMNS)
def _row_to_dict(row): def _row_to_dict(row):
@@ -283,19 +297,63 @@ def get_user_by_username(username):
return _get_user_by_field("username", username) return _get_user_by_field("username", username)
def get_all_users(): def _normalize_limit_offset(limit, offset, *, max_limit: int = 500):
"""获取所有用户""" normalized_limit = None
if limit is not None:
try:
normalized_limit = int(limit)
except (TypeError, ValueError):
normalized_limit = 50
normalized_limit = max(1, min(normalized_limit, max_limit))
try:
normalized_offset = int(offset or 0)
except (TypeError, ValueError):
normalized_offset = 0
normalized_offset = max(0, normalized_offset)
return normalized_limit, normalized_offset
def get_users_count(*, status: str | None = None) -> int:
with db_pool.get_db() as conn: with db_pool.get_db() as conn:
cursor = conn.cursor() cursor = conn.cursor()
cursor.execute("SELECT * FROM users ORDER BY created_at DESC") if status:
cursor.execute("SELECT COUNT(*) AS count FROM users WHERE status = ?", (status,))
else:
cursor.execute("SELECT COUNT(*) AS count FROM users")
row = cursor.fetchone()
return int((row["count"] if row else 0) or 0)
def get_all_users(*, limit=None, offset=0):
"""获取所有用户"""
limit, offset = _normalize_limit_offset(limit, offset)
with db_pool.get_db() as conn:
cursor = conn.cursor()
sql = f"SELECT {_USER_ADMIN_SAFE_COLUMNS_SQL} FROM users ORDER BY created_at DESC"
params = []
if limit is not None:
sql += " LIMIT ? OFFSET ?"
params.extend([limit, offset])
cursor.execute(sql, params)
return [dict(row) for row in cursor.fetchall()] return [dict(row) for row in cursor.fetchall()]
def get_pending_users(): def get_pending_users(*, limit=None, offset=0):
"""获取待审核用户""" """获取待审核用户"""
limit, offset = _normalize_limit_offset(limit, offset)
with db_pool.get_db() as conn: with db_pool.get_db() as conn:
cursor = conn.cursor() cursor = conn.cursor()
cursor.execute("SELECT * FROM users WHERE status = 'pending' ORDER BY created_at DESC") sql = (
f"SELECT {_USER_ADMIN_SAFE_COLUMNS_SQL} "
"FROM users WHERE status = 'pending' ORDER BY created_at DESC"
)
params = []
if limit is not None:
sql += " LIMIT ? OFFSET ?"
params.extend([limit, offset])
cursor.execute(sql, params)
return [dict(row) for row in cursor.fetchall()] return [dict(row) for row in cursor.fetchall()]

View File

@@ -6,7 +6,7 @@ schedule==1.2.0
psutil==5.9.6 psutil==5.9.6
pytz==2024.1 pytz==2024.1
bcrypt==4.0.1 bcrypt==4.0.1
requests==2.31.0 requests==2.32.3
python-dotenv==1.0.0 python-dotenv==1.0.0
beautifulsoup4==4.12.2 beautifulsoup4==4.12.2
cryptography>=41.0.0 cryptography>=41.0.0

View File

@@ -355,9 +355,6 @@ def admin_logout():
session.pop("admin_id", None) session.pop("admin_id", None)
session.pop("admin_username", None) session.pop("admin_username", None)
session.pop("admin_reauth_until", None) session.pop("admin_reauth_until", None)
session.pop("_user_id", None)
session.pop("_fresh", None)
session.pop("_id", None)
return jsonify({"success": True}) return jsonify({"success": True})

View File

@@ -21,7 +21,7 @@ def get_email_settings_api():
return jsonify(settings) return jsonify(settings)
except Exception as e: except Exception as e:
logger.error(f"获取邮件设置失败: {e}") logger.error(f"获取邮件设置失败: {e}")
return jsonify({"error": str(e)}), 500 return jsonify({"error": "获取邮件设置失败"}), 500
@admin_api_bp.route("/email/settings", methods=["POST"]) @admin_api_bp.route("/email/settings", methods=["POST"])
@@ -48,7 +48,7 @@ def update_email_settings_api():
return jsonify({"success": True}) return jsonify({"success": True})
except Exception as e: except Exception as e:
logger.error(f"更新邮件设置失败: {e}") logger.error(f"更新邮件设置失败: {e}")
return jsonify({"error": str(e)}), 500 return jsonify({"error": "更新邮件设置失败"}), 500
@admin_api_bp.route("/smtp/configs", methods=["GET"]) @admin_api_bp.route("/smtp/configs", methods=["GET"])
@@ -60,7 +60,7 @@ def get_smtp_configs_api():
return jsonify(configs) return jsonify(configs)
except Exception as e: except Exception as e:
logger.error(f"获取SMTP配置失败: {e}") logger.error(f"获取SMTP配置失败: {e}")
return jsonify({"error": str(e)}), 500 return jsonify({"error": "获取SMTP配置失败"}), 500
@admin_api_bp.route("/smtp/configs", methods=["POST"]) @admin_api_bp.route("/smtp/configs", methods=["POST"])
@@ -78,7 +78,7 @@ def create_smtp_config_api():
return jsonify({"success": True, "id": config_id}) return jsonify({"success": True, "id": config_id})
except Exception as e: except Exception as e:
logger.error(f"创建SMTP配置失败: {e}") logger.error(f"创建SMTP配置失败: {e}")
return jsonify({"error": str(e)}), 500 return jsonify({"error": "创建SMTP配置失败"}), 500
@admin_api_bp.route("/smtp/configs/<int:config_id>", methods=["GET"]) @admin_api_bp.route("/smtp/configs/<int:config_id>", methods=["GET"])
@@ -92,7 +92,7 @@ def get_smtp_config_api(config_id):
return jsonify(config_data) return jsonify(config_data)
except Exception as e: except Exception as e:
logger.error(f"获取SMTP配置失败: {e}") logger.error(f"获取SMTP配置失败: {e}")
return jsonify({"error": str(e)}), 500 return jsonify({"error": "获取SMTP配置失败"}), 500
@admin_api_bp.route("/smtp/configs/<int:config_id>", methods=["PUT"]) @admin_api_bp.route("/smtp/configs/<int:config_id>", methods=["PUT"])
@@ -106,7 +106,7 @@ def update_smtp_config_api(config_id):
return jsonify({"error": "更新失败"}), 400 return jsonify({"error": "更新失败"}), 400
except Exception as e: except Exception as e:
logger.error(f"更新SMTP配置失败: {e}") logger.error(f"更新SMTP配置失败: {e}")
return jsonify({"error": str(e)}), 500 return jsonify({"error": "更新SMTP配置失败"}), 500
@admin_api_bp.route("/smtp/configs/<int:config_id>", methods=["DELETE"]) @admin_api_bp.route("/smtp/configs/<int:config_id>", methods=["DELETE"])
@@ -119,7 +119,7 @@ def delete_smtp_config_api(config_id):
return jsonify({"error": "删除失败"}), 400 return jsonify({"error": "删除失败"}), 400
except Exception as e: except Exception as e:
logger.error(f"删除SMTP配置失败: {e}") logger.error(f"删除SMTP配置失败: {e}")
return jsonify({"error": str(e)}), 500 return jsonify({"error": "删除SMTP配置失败"}), 500
@admin_api_bp.route("/smtp/configs/<int:config_id>/test", methods=["POST"]) @admin_api_bp.route("/smtp/configs/<int:config_id>/test", methods=["POST"])
@@ -140,7 +140,7 @@ def test_smtp_config_api(config_id):
return jsonify(result) return jsonify(result)
except Exception as e: except Exception as e:
logger.error(f"测试SMTP配置失败: {e}") logger.error(f"测试SMTP配置失败: {e}")
return jsonify({"success": False, "error": str(e)}), 500 return jsonify({"success": False, "error": "测试SMTP配置失败"}), 500
@admin_api_bp.route("/smtp/configs/<int:config_id>/primary", methods=["POST"]) @admin_api_bp.route("/smtp/configs/<int:config_id>/primary", methods=["POST"])
@@ -153,7 +153,7 @@ def set_primary_smtp_config_api(config_id):
return jsonify({"error": "设置失败"}), 400 return jsonify({"error": "设置失败"}), 400
except Exception as e: except Exception as e:
logger.error(f"设置主SMTP配置失败: {e}") logger.error(f"设置主SMTP配置失败: {e}")
return jsonify({"error": str(e)}), 500 return jsonify({"error": "设置主SMTP配置失败"}), 500
@admin_api_bp.route("/smtp/configs/primary/clear", methods=["POST"]) @admin_api_bp.route("/smtp/configs/primary/clear", methods=["POST"])
@@ -165,7 +165,7 @@ def clear_primary_smtp_config_api():
return jsonify({"success": True}) return jsonify({"success": True})
except Exception as e: except Exception as e:
logger.error(f"取消主SMTP配置失败: {e}") logger.error(f"取消主SMTP配置失败: {e}")
return jsonify({"error": str(e)}), 500 return jsonify({"error": "取消主SMTP配置失败"}), 500
@admin_api_bp.route("/email/stats", methods=["GET"]) @admin_api_bp.route("/email/stats", methods=["GET"])
@@ -177,7 +177,7 @@ def get_email_stats_api():
return jsonify(stats) return jsonify(stats)
except Exception as e: except Exception as e:
logger.error(f"获取邮件统计失败: {e}") logger.error(f"获取邮件统计失败: {e}")
return jsonify({"error": str(e)}), 500 return jsonify({"error": "获取邮件统计失败"}), 500
@admin_api_bp.route("/email/logs", methods=["GET"]) @admin_api_bp.route("/email/logs", methods=["GET"])
@@ -195,7 +195,7 @@ def get_email_logs_api():
return jsonify(result) return jsonify(result)
except Exception as e: except Exception as e:
logger.error(f"获取邮件日志失败: {e}") logger.error(f"获取邮件日志失败: {e}")
return jsonify({"error": str(e)}), 500 return jsonify({"error": "获取邮件日志失败"}), 500
@admin_api_bp.route("/email/logs/cleanup", methods=["POST"]) @admin_api_bp.route("/email/logs/cleanup", methods=["POST"])
@@ -211,4 +211,4 @@ def cleanup_email_logs_api():
return jsonify({"success": True, "deleted": deleted}) return jsonify({"success": True, "deleted": deleted})
except Exception as e: except Exception as e:
logger.error(f"清理邮件日志失败: {e}") logger.error(f"清理邮件日志失败: {e}")
return jsonify({"error": str(e)}), 500 return jsonify({"error": "清理邮件日志失败"}), 500

View File

@@ -80,7 +80,7 @@ def update_system_config_api():
if schedule_time is not None: if schedule_time is not None:
import re import re
if not re.match(r"^([01]\\d|2[0-3]):([0-5]\\d)$", schedule_time): if not re.match(r"^([01]\d|2[0-3]):([0-5]\d)$", schedule_time):
return jsonify({"error": "时间格式错误,应为 HH:MM"}), 400 return jsonify({"error": "时间格式错误,应为 HH:MM"}), 400
if schedule_browse_type is not None: if schedule_browse_type is not None:

View File

@@ -13,21 +13,53 @@ from services.state import safe_clear_user_logs, safe_remove_user_accounts
# ==================== 用户管理/统计(管理员) ==================== # ==================== 用户管理/统计(管理员) ====================
def _parse_optional_pagination(default_limit: int = 50, max_limit: int = 500) -> tuple[int | None, int]:
limit_raw = request.args.get("limit")
offset_raw = request.args.get("offset")
if (limit_raw is None) and (offset_raw is None):
return None, 0
try:
limit = int(limit_raw if limit_raw is not None else default_limit)
except (TypeError, ValueError):
limit = default_limit
limit = max(1, min(limit, max_limit))
try:
offset = int(offset_raw if offset_raw is not None else 0)
except (TypeError, ValueError):
offset = 0
offset = max(0, offset)
return limit, offset
@admin_api_bp.route("/users", methods=["GET"]) @admin_api_bp.route("/users", methods=["GET"])
@admin_required @admin_required
def get_all_users(): def get_all_users():
"""获取所有用户""" """获取所有用户"""
limit, offset = _parse_optional_pagination()
if limit is None:
users = database.get_all_users() users = database.get_all_users()
return jsonify(users) return jsonify(users)
users = database.get_all_users(limit=limit, offset=offset)
total = database.get_users_count()
return jsonify({"items": users, "total": total, "limit": limit, "offset": offset})
@admin_api_bp.route("/users/pending", methods=["GET"]) @admin_api_bp.route("/users/pending", methods=["GET"])
@admin_required @admin_required
def get_pending_users(): def get_pending_users():
"""获取待审核用户""" """获取待审核用户"""
limit, offset = _parse_optional_pagination(default_limit=30, max_limit=200)
if limit is None:
users = database.get_pending_users() users = database.get_pending_users()
return jsonify(users) return jsonify(users)
users = database.get_pending_users(limit=limit, offset=offset)
total = database.get_users_count(status="pending")
return jsonify({"items": users, "total": total, "limit": limit, "offset": offset})
@admin_api_bp.route("/users/<int:user_id>/approve", methods=["POST"]) @admin_api_bp.route("/users/<int:user_id>/approve", methods=["POST"])
@admin_required @admin_required

View File

@@ -164,11 +164,13 @@ def update_account(account_id):
""" """
UPDATE accounts UPDATE accounts
SET password = ?, remember = ? SET password = ?, remember = ?
WHERE id = ? WHERE id = ? AND user_id = ?
""", """,
(encrypted_password, new_remember, account_id), (encrypted_password, new_remember, account_id, user_id),
) )
conn.commit() conn.commit()
if cursor.rowcount <= 0:
return jsonify({"error": "账号不存在或无权限"}), 404
database.reset_account_login_status(account_id) database.reset_account_login_status(account_id)
logger.info(f"[账号更新] 用户 {user_id} 修改了账号 {account.username} 的密码,已重置登录状态") logger.info(f"[账号更新] 用户 {user_id} 修改了账号 {account.username} 的密码,已重置登录状态")

View File

@@ -9,6 +9,7 @@ import time as time_mod
import uuid import uuid
import database import database
from app_logger import get_logger
from flask import Blueprint, jsonify, request from flask import Blueprint, jsonify, request
from flask_login import current_user, login_required from flask_login import current_user, login_required
from services.accounts_service import load_user_accounts from services.accounts_service import load_user_accounts
@@ -17,6 +18,7 @@ from services.state import safe_get_account, safe_get_user_accounts_snapshot
from services.tasks import submit_account_task from services.tasks import submit_account_task
api_schedules_bp = Blueprint("api_schedules", __name__) api_schedules_bp = Blueprint("api_schedules", __name__)
logger = get_logger("app")
_HHMM_RE = re.compile(r"^(\d{1,2}):(\d{2})$") _HHMM_RE = re.compile(r"^(\d{1,2}):(\d{2})$")
@@ -391,4 +393,5 @@ def delete_schedule_logs_api(schedule_id):
deleted = database.delete_schedule_logs(schedule_id, current_user.id) deleted = database.delete_schedule_logs(schedule_id, current_user.id)
return jsonify({"success": True, "deleted": deleted}) return jsonify({"success": True, "deleted": deleted})
except Exception as e: except Exception as e:
return jsonify({"error": str(e)}), 500 logger.warning(f"[schedules] 清空定时任务日志失败(schedule_id={schedule_id}): {e}")
return jsonify({"error": "清空日志失败,请稍后重试"}), 500

View File

@@ -8,6 +8,7 @@ from typing import Iterator
import database import database
from app_config import get_config from app_config import get_config
from app_logger import get_logger
from app_security import is_safe_path from app_security import is_safe_path
from flask import Blueprint, jsonify, request, send_from_directory from flask import Blueprint, jsonify, request, send_from_directory
from flask_login import current_user, login_required from flask_login import current_user, login_required
@@ -28,33 +29,89 @@ except AttributeError: # Pillow<9 fallback
_RESAMPLE_FILTER = Image.LANCZOS _RESAMPLE_FILTER = Image.LANCZOS
api_screenshots_bp = Blueprint("api_screenshots", __name__) api_screenshots_bp = Blueprint("api_screenshots", __name__)
logger = get_logger("app")
def _get_user_prefix(user_id: int) -> str: def _get_user_prefix(user_id: int) -> str:
return f"u{int(user_id)}"
def _get_username(user_id: int) -> str:
user_info = database.get_user_by_id(user_id) user_info = database.get_user_by_id(user_id)
return user_info["username"] if user_info else f"user{user_id}" return str(user_info.get("username") or "") if user_info else ""
def _is_user_screenshot(filename: str, username_prefix: str) -> bool: def _list_all_usernames() -> list[str]:
return filename.startswith(username_prefix + "_") and filename.lower().endswith(_IMAGE_EXTENSIONS) users = database.get_all_users()
result = []
for row in users:
username = str(row.get("username") or "").strip()
if username:
result.append(username)
return result
def _iter_user_screenshot_entries(username_prefix: str) -> Iterator[os.DirEntry]: def _resolve_user_owned_prefix(
filename: str,
*,
user_id: int,
username: str,
all_usernames: list[str] | None = None,
) -> str | None:
lower_name = filename.lower()
if not lower_name.endswith(_IMAGE_EXTENSIONS):
return None
# 新版命名u{user_id}_...
id_prefix = _get_user_prefix(user_id)
if filename.startswith(id_prefix + "_"):
return id_prefix
# 兼容旧版命名:{username}_...
username = str(username or "").strip()
if not username:
return None
if all_usernames is None:
all_usernames = _list_all_usernames()
matched_usernames = [item for item in all_usernames if filename.startswith(item + "_")]
if not matched_usernames:
return None
# 取“最长匹配用户名”,避免 foo 越权读取 foo_bar 的截图。
max_len = max(len(item) for item in matched_usernames)
winners = [item for item in matched_usernames if len(item) == max_len]
if len(winners) != 1:
return None
if winners[0] != username:
return None
return winners[0]
def _iter_user_screenshot_entries(user_id: int, username: str, all_usernames: list[str]) -> Iterator[tuple[os.DirEntry, str]]:
if not os.path.exists(SCREENSHOTS_DIR): if not os.path.exists(SCREENSHOTS_DIR):
return return
with os.scandir(SCREENSHOTS_DIR) as entries: with os.scandir(SCREENSHOTS_DIR) as entries:
for entry in entries: for entry in entries:
if (not entry.is_file()) or (not _is_user_screenshot(entry.name, username_prefix)): if not entry.is_file():
continue continue
yield entry matched_prefix = _resolve_user_owned_prefix(
entry.name,
user_id=user_id,
username=username,
all_usernames=all_usernames,
)
if not matched_prefix:
continue
yield entry, matched_prefix
def _build_display_name(filename: str) -> str: def _build_display_name(filename: str, owner_prefix: str) -> str:
base_name, ext = filename.rsplit(".", 1) prefix = f"{owner_prefix}_"
parts = base_name.split("_", 1) if filename.startswith(prefix):
if len(parts) > 1: return filename[len(prefix) :]
return f"{parts[1]}.{ext}"
return filename return filename
@@ -126,11 +183,12 @@ def _parse_optional_pagination(default_limit: int = 24, *, max_limit: int = 100)
def get_screenshots(): def get_screenshots():
"""获取当前用户的截图列表""" """获取当前用户的截图列表"""
user_id = current_user.id user_id = current_user.id
username_prefix = _get_user_prefix(user_id) username = _get_username(user_id)
try: try:
screenshots = [] screenshots = []
for entry in _iter_user_screenshot_entries(username_prefix): all_usernames = _list_all_usernames()
for entry, matched_prefix in _iter_user_screenshot_entries(user_id, username, all_usernames):
filename = entry.name filename = entry.name
stat = entry.stat() stat = entry.stat()
created_time = datetime.fromtimestamp(stat.st_mtime, tz=BEIJING_TZ) created_time = datetime.fromtimestamp(stat.st_mtime, tz=BEIJING_TZ)
@@ -138,7 +196,7 @@ def get_screenshots():
screenshots.append( screenshots.append(
{ {
"filename": filename, "filename": filename,
"display_name": _build_display_name(filename), "display_name": _build_display_name(filename, matched_prefix),
"size": stat.st_size, "size": stat.st_size,
"created": created_time.strftime("%Y-%m-%d %H:%M:%S"), "created": created_time.strftime("%Y-%m-%d %H:%M:%S"),
"_created_ts": stat.st_mtime, "_created_ts": stat.st_mtime,
@@ -157,7 +215,8 @@ def get_screenshots():
return jsonify(screenshots) return jsonify(screenshots)
except Exception as e: except Exception as e:
return jsonify({"error": str(e)}), 500 logger.warning(f"[screenshots] 获取截图列表失败(user_id={user_id}): {e}")
return jsonify({"error": "获取截图列表失败"}), 500
@api_screenshots_bp.route("/screenshots/<filename>") @api_screenshots_bp.route("/screenshots/<filename>")
@@ -165,9 +224,8 @@ def get_screenshots():
def serve_screenshot(filename): def serve_screenshot(filename):
"""提供原图文件访问""" """提供原图文件访问"""
user_id = current_user.id user_id = current_user.id
username_prefix = _get_user_prefix(user_id) username = _get_username(user_id)
if not _resolve_user_owned_prefix(filename, user_id=user_id, username=username):
if not _is_user_screenshot(filename, username_prefix):
return jsonify({"error": "无权访问"}), 403 return jsonify({"error": "无权访问"}), 403
if not is_safe_path(SCREENSHOTS_DIR, filename): if not is_safe_path(SCREENSHOTS_DIR, filename):
@@ -181,9 +239,8 @@ def serve_screenshot(filename):
def serve_screenshot_thumbnail(filename): def serve_screenshot_thumbnail(filename):
"""提供缩略图访问(失败时自动回退原图)""" """提供缩略图访问(失败时自动回退原图)"""
user_id = current_user.id user_id = current_user.id
username_prefix = _get_user_prefix(user_id) username = _get_username(user_id)
if not _resolve_user_owned_prefix(filename, user_id=user_id, username=username):
if not _is_user_screenshot(filename, username_prefix):
return jsonify({"error": "无权访问"}), 403 return jsonify({"error": "无权访问"}), 403
if not is_safe_path(SCREENSHOTS_DIR, filename): if not is_safe_path(SCREENSHOTS_DIR, filename):
@@ -209,9 +266,8 @@ def serve_screenshot_thumbnail(filename):
def delete_screenshot(filename): def delete_screenshot(filename):
"""删除指定截图""" """删除指定截图"""
user_id = current_user.id user_id = current_user.id
username_prefix = _get_user_prefix(user_id) username = _get_username(user_id)
if not _resolve_user_owned_prefix(filename, user_id=user_id, username=username):
if not _is_user_screenshot(filename, username_prefix):
return jsonify({"error": "无权删除"}), 403 return jsonify({"error": "无权删除"}), 403
if not is_safe_path(SCREENSHOTS_DIR, filename): if not is_safe_path(SCREENSHOTS_DIR, filename):
@@ -226,7 +282,8 @@ def delete_screenshot(filename):
return jsonify({"success": True}) return jsonify({"success": True})
return jsonify({"error": "文件不存在"}), 404 return jsonify({"error": "文件不存在"}), 404
except Exception as e: except Exception as e:
return jsonify({"error": str(e)}), 500 logger.warning(f"[screenshots] 删除截图失败(user_id={user_id}, filename={filename}): {e}")
return jsonify({"error": "删除截图失败"}), 500
@api_screenshots_bp.route("/api/screenshots/clear", methods=["POST"]) @api_screenshots_bp.route("/api/screenshots/clear", methods=["POST"])
@@ -234,11 +291,12 @@ def delete_screenshot(filename):
def clear_all_screenshots(): def clear_all_screenshots():
"""清空当前用户的所有截图""" """清空当前用户的所有截图"""
user_id = current_user.id user_id = current_user.id
username_prefix = _get_user_prefix(user_id) username = _get_username(user_id)
try: try:
deleted_count = 0 deleted_count = 0
for entry in _iter_user_screenshot_entries(username_prefix): all_usernames = _list_all_usernames()
for entry, _ in _iter_user_screenshot_entries(user_id, username, all_usernames):
os.remove(entry.path) os.remove(entry.path)
_remove_thumbnail(entry.name) _remove_thumbnail(entry.name)
deleted_count += 1 deleted_count += 1
@@ -246,4 +304,5 @@ def clear_all_screenshots():
log_to_client(f"清理了 {deleted_count} 个截图文件", user_id) log_to_client(f"清理了 {deleted_count} 个截图文件", user_id)
return jsonify({"success": True, "deleted": deleted_count}) return jsonify({"success": True, "deleted": deleted_count})
except Exception as e: except Exception as e:
return jsonify({"error": str(e)}), 500 logger.warning(f"[screenshots] 清空截图失败(user_id={user_id}): {e}")
return jsonify({"error": "清空截图失败"}), 500

View File

@@ -325,6 +325,10 @@ class KDocsUploader:
if self._context is None: if self._context is None:
storage_state = getattr(config, "KDOCS_LOGIN_STATE_FILE", "data/kdocs_login_state.json") storage_state = getattr(config, "KDOCS_LOGIN_STATE_FILE", "data/kdocs_login_state.json")
if use_storage_state and os.path.exists(storage_state): if use_storage_state and os.path.exists(storage_state):
try:
os.chmod(storage_state, 0o600)
except Exception:
pass
self._context = self._browser.new_context(storage_state=storage_state) self._context = self._browser.new_context(storage_state=storage_state)
else: else:
self._context = self._browser.new_context() self._context = self._browser.new_context()
@@ -837,8 +841,18 @@ class KDocsUploader:
def _save_login_state(self) -> None: def _save_login_state(self) -> None:
try: try:
storage_state = getattr(config, "KDOCS_LOGIN_STATE_FILE", "data/kdocs_login_state.json") storage_state = getattr(config, "KDOCS_LOGIN_STATE_FILE", "data/kdocs_login_state.json")
os.makedirs(os.path.dirname(storage_state), exist_ok=True) state_dir = os.path.dirname(storage_state)
if state_dir:
os.makedirs(state_dir, mode=0o700, exist_ok=True)
try:
os.chmod(state_dir, 0o700)
except Exception:
pass
self._context.storage_state(path=storage_state) self._context.storage_state(path=storage_state)
try:
os.chmod(storage_state, 0o600)
except Exception:
pass
except Exception as e: except Exception as e:
logger.warning(f"[KDocs] 保存登录态失败: {e}") logger.warning(f"[KDocs] 保存登录态失败: {e}")

View File

@@ -538,9 +538,8 @@ def take_screenshot_for_account(
# 标记账号正在截图(防止重复提交截图任务) # 标记账号正在截图(防止重复提交截图任务)
account.is_running = True account.is_running = True
user_info = database.get_user_by_id(user_id) user_info = database.get_user_by_id(user_id)
username_prefix = user_info["username"] if user_info else f"user{user_id}" username_prefix = f"u{int(user_id)}"
def screenshot_task( def screenshot_task(
browser_instance, user_id, account_id, account, browse_type, source, task_start_time, browse_result browser_instance, user_id, account_id, account, browse_type, source, task_start_time, browse_result

View File

@@ -3,7 +3,7 @@ import sys
from pathlib import Path from pathlib import Path
import pytest import pytest
from flask import Flask from flask import Flask, session
PROJECT_ROOT = Path(__file__).resolve().parents[1] PROJECT_ROOT = Path(__file__).resolve().parents[1]
if str(PROJECT_ROOT) not in sys.path: if str(PROJECT_ROOT) not in sys.path:
@@ -56,3 +56,24 @@ def test_get_encryption_key_refuses_regeneration_when_encrypted_data_exists(monk
with pytest.raises(RuntimeError): with pytest.raises(RuntimeError):
crypto_utils.get_encryption_key() crypto_utils.get_encryption_key()
def test_validate_csrf_token_requires_matching_session_token():
app = Flask(__name__)
app.secret_key = "test-secret-key"
with app.test_request_context("/", method="POST"):
session["csrf_token"] = "fixed-token"
assert app_security.validate_csrf_token("fixed-token") is True
assert app_security.validate_csrf_token("wrong-token") is False
assert app_security.validate_csrf_token("") is False
def test_decrypt_password_returns_empty_for_unreadable_encrypted_payload(monkeypatch):
class BrokenFernet:
def decrypt(self, *_args, **_kwargs):
raise ValueError("bad token")
monkeypatch.setattr(crypto_utils, "_get_fernet", lambda: BrokenFernet())
encrypted_like_value = "gAAAAABrokenPayload"
assert crypto_utils.decrypt_password(encrypted_like_value) == ""