安全修复: 收敛认证与日志风险并补充基础测试
This commit is contained in:
@@ -15,6 +15,12 @@ FLASK_DEBUG=false
|
|||||||
SESSION_LIFETIME_HOURS=24
|
SESSION_LIFETIME_HOURS=24
|
||||||
SESSION_COOKIE_SECURE=true # 生产环境HTTPS必须为true,本地HTTP调试可临时设为false
|
SESSION_COOKIE_SECURE=true # 生产环境HTTPS必须为true,本地HTTP调试可临时设为false
|
||||||
HTTPS_ENABLED=true
|
HTTPS_ENABLED=true
|
||||||
|
# 是否信任 X-Forwarded-* 代理头(默认关闭,建议仅在可信反代后开启)
|
||||||
|
TRUST_PROXY_HEADERS=false
|
||||||
|
# TRUST_PROXY_HEADERS=true 时生效,按需配置你的反向代理网段
|
||||||
|
TRUSTED_PROXY_CIDRS=127.0.0.1/32,::1/128
|
||||||
|
# 可选:首次启动时指定默认管理员密码(避免控制台输出明文密码)
|
||||||
|
# DEFAULT_ADMIN_PASSWORD=your-strong-admin-password
|
||||||
|
|
||||||
# ==================== 数据库配置 ====================
|
# ==================== 数据库配置 ====================
|
||||||
DB_FILE=data/app_data.db
|
DB_FILE=data/app_data.db
|
||||||
@@ -36,6 +42,7 @@ DB_PRAGMA_OPTIMIZE_INTERVAL_SECONDS=21600
|
|||||||
DB_ANALYZE_INTERVAL_SECONDS=86400
|
DB_ANALYZE_INTERVAL_SECONDS=86400
|
||||||
DB_WAL_CHECKPOINT_INTERVAL_SECONDS=43200
|
DB_WAL_CHECKPOINT_INTERVAL_SECONDS=43200
|
||||||
DB_WAL_CHECKPOINT_MODE=PASSIVE
|
DB_WAL_CHECKPOINT_MODE=PASSIVE
|
||||||
|
SYSTEM_CONFIG_CACHE_TTL_SECONDS=30
|
||||||
|
|
||||||
# ==================== 并发控制配置 ====================
|
# ==================== 并发控制配置 ====================
|
||||||
MAX_CONCURRENT_GLOBAL=2
|
MAX_CONCURRENT_GLOBAL=2
|
||||||
|
|||||||
4
.gitignore
vendored
4
.gitignore
vendored
@@ -146,3 +146,7 @@ Thumbs.db
|
|||||||
# Temporary files
|
# Temporary files
|
||||||
*.tmp
|
*.tmp
|
||||||
*.temp
|
*.temp
|
||||||
|
|
||||||
|
# Allow committed test cases
|
||||||
|
!tests/
|
||||||
|
!tests/**/*.py
|
||||||
|
|||||||
@@ -790,7 +790,8 @@ docker logs -f knowledge-automation-multiuser
|
|||||||
# 6. 访问系统
|
# 6. 访问系统
|
||||||
# 浏览器打开: http://your-ip:51232
|
# 浏览器打开: http://your-ip:51232
|
||||||
# 后台管理: http://your-ip:51232/yuyx
|
# 后台管理: http://your-ip:51232/yuyx
|
||||||
# 默认管理员账号见容器启动日志(首次运行会生成随机密码)
|
# 首次管理员密码会写入 data/default_admin_credentials.txt(权限600)
|
||||||
|
# 登录后请立即修改密码并删除该文件
|
||||||
```
|
```
|
||||||
|
|
||||||
完成!🎉
|
完成!🎉
|
||||||
|
|||||||
49
app.py
49
app.py
@@ -49,12 +49,13 @@ from services.tasks import get_task_scheduler
|
|||||||
|
|
||||||
# 设置时区为中国标准时间(CST, UTC+8)
|
# 设置时区为中国标准时间(CST, UTC+8)
|
||||||
os.environ["TZ"] = "Asia/Shanghai"
|
os.environ["TZ"] = "Asia/Shanghai"
|
||||||
|
_TZSET_ERROR = None
|
||||||
try:
|
try:
|
||||||
import time as _time
|
import time as _time
|
||||||
|
|
||||||
_time.tzset()
|
_time.tzset()
|
||||||
except Exception:
|
except Exception as e:
|
||||||
pass
|
_TZSET_ERROR = e
|
||||||
|
|
||||||
|
|
||||||
def _sigchld_handler(signum, frame):
|
def _sigchld_handler(signum, frame):
|
||||||
@@ -116,6 +117,8 @@ except Exception as socketio_error:
|
|||||||
|
|
||||||
init_logging(log_level=config.LOG_LEVEL, log_file=config.LOG_FILE)
|
init_logging(log_level=config.LOG_LEVEL, log_file=config.LOG_FILE)
|
||||||
logger = get_logger("app")
|
logger = get_logger("app")
|
||||||
|
if _TZSET_ERROR is not None:
|
||||||
|
logger.warning(f"设置时区失败,将继续使用系统默认时区: {_TZSET_ERROR}")
|
||||||
if _socketio_fallback_reason:
|
if _socketio_fallback_reason:
|
||||||
logger.warning(f"[SocketIO] 初始化失败,已回退 threading 模式: {_socketio_fallback_reason}")
|
logger.warning(f"[SocketIO] 初始化失败,已回退 threading 模式: {_socketio_fallback_reason}")
|
||||||
logger.info(f"[SocketIO] 当前 async_mode: {socketio.async_mode}")
|
logger.info(f"[SocketIO] 当前 async_mode: {socketio.async_mode}")
|
||||||
@@ -139,15 +142,15 @@ def _request_uses_https() -> bool:
|
|||||||
try:
|
try:
|
||||||
if bool(request.is_secure):
|
if bool(request.is_secure):
|
||||||
return True
|
return True
|
||||||
except Exception:
|
except Exception as e:
|
||||||
pass
|
logger.debug(f"检查 request.is_secure 失败: {e}")
|
||||||
|
|
||||||
try:
|
try:
|
||||||
forwarded_proto = str(request.headers.get("X-Forwarded-Proto", "") or "").split(",", 1)[0].strip().lower()
|
forwarded_proto = str(request.headers.get("X-Forwarded-Proto", "") or "").split(",", 1)[0].strip().lower()
|
||||||
if forwarded_proto == "https":
|
if forwarded_proto == "https":
|
||||||
return True
|
return True
|
||||||
except Exception:
|
except Exception as e:
|
||||||
pass
|
logger.debug(f"检查 X-Forwarded-Proto 失败: {e}")
|
||||||
|
|
||||||
return False
|
return False
|
||||||
|
|
||||||
@@ -255,8 +258,8 @@ def _record_request_metric_after_response(response) -> None:
|
|||||||
logger.warning(
|
logger.warning(
|
||||||
f"[API-DIAG] {method} {path} -> {status_code} ({duration_ms:.1f}ms)"
|
f"[API-DIAG] {method} {path} -> {status_code} ({duration_ms:.1f}ms)"
|
||||||
)
|
)
|
||||||
except Exception:
|
except Exception as e:
|
||||||
pass
|
logger.debug(f"记录请求指标失败: {e}")
|
||||||
|
|
||||||
|
|
||||||
@app.after_request
|
@app.after_request
|
||||||
@@ -312,12 +315,12 @@ def serve_static(filename):
|
|||||||
# 协商缓存:确保存在 ETag,并基于 If-None-Match/If-Modified-Since 返回 304
|
# 协商缓存:确保存在 ETag,并基于 If-None-Match/If-Modified-Since 返回 304
|
||||||
try:
|
try:
|
||||||
response.add_etag(overwrite=False)
|
response.add_etag(overwrite=False)
|
||||||
except Exception:
|
except Exception as e:
|
||||||
pass
|
logger.debug(f"静态资源 ETag 设置失败({filename}): {e}")
|
||||||
try:
|
try:
|
||||||
response.make_conditional(request)
|
response.make_conditional(request)
|
||||||
except Exception:
|
except Exception as e:
|
||||||
pass
|
logger.debug(f"静态资源协商缓存处理失败({filename}): {e}")
|
||||||
|
|
||||||
response.headers.setdefault("Vary", "Accept-Encoding")
|
response.headers.setdefault("Vary", "Accept-Encoding")
|
||||||
if is_hashed_asset:
|
if is_hashed_asset:
|
||||||
@@ -341,33 +344,33 @@ def cleanup_on_exit():
|
|||||||
for acc in accounts.values():
|
for acc in accounts.values():
|
||||||
if getattr(acc, "is_running", False):
|
if getattr(acc, "is_running", False):
|
||||||
acc.should_stop = True
|
acc.should_stop = True
|
||||||
except Exception:
|
except Exception as e:
|
||||||
pass
|
logger.warning(f"停止运行中任务失败: {e}")
|
||||||
|
|
||||||
logger.info("- 停止任务调度器...")
|
logger.info("- 停止任务调度器...")
|
||||||
try:
|
try:
|
||||||
scheduler = get_task_scheduler()
|
scheduler = get_task_scheduler()
|
||||||
scheduler.shutdown(timeout=5)
|
scheduler.shutdown(timeout=5)
|
||||||
except Exception:
|
except Exception as e:
|
||||||
pass
|
logger.warning(f"停止任务调度器失败: {e}")
|
||||||
|
|
||||||
logger.info("- 关闭截图线程池...")
|
logger.info("- 关闭截图线程池...")
|
||||||
try:
|
try:
|
||||||
shutdown_browser_worker_pool()
|
shutdown_browser_worker_pool()
|
||||||
except Exception:
|
except Exception as e:
|
||||||
pass
|
logger.warning(f"关闭截图线程池失败: {e}")
|
||||||
|
|
||||||
logger.info("- 关闭邮件队列...")
|
logger.info("- 关闭邮件队列...")
|
||||||
try:
|
try:
|
||||||
email_service.shutdown_email_queue()
|
email_service.shutdown_email_queue()
|
||||||
except Exception:
|
except Exception as e:
|
||||||
pass
|
logger.warning(f"关闭邮件队列失败: {e}")
|
||||||
|
|
||||||
logger.info("- 关闭数据库连接池...")
|
logger.info("- 关闭数据库连接池...")
|
||||||
try:
|
try:
|
||||||
db_pool._pool.close_all() if db_pool._pool else None
|
db_pool._pool.close_all() if db_pool._pool else None
|
||||||
except Exception:
|
except Exception as e:
|
||||||
pass
|
logger.warning(f"关闭数据库连接池失败: {e}")
|
||||||
|
|
||||||
logger.info("[OK] 资源清理完成")
|
logger.info("[OK] 资源清理完成")
|
||||||
|
|
||||||
@@ -464,7 +467,7 @@ def _log_startup_urls() -> None:
|
|||||||
logger.info("服务器启动中...")
|
logger.info("服务器启动中...")
|
||||||
logger.info(f"用户访问地址: http://{config.SERVER_HOST}:{config.SERVER_PORT}")
|
logger.info(f"用户访问地址: http://{config.SERVER_HOST}:{config.SERVER_PORT}")
|
||||||
logger.info(f"后台管理地址: http://{config.SERVER_HOST}:{config.SERVER_PORT}/yuyx")
|
logger.info(f"后台管理地址: http://{config.SERVER_HOST}:{config.SERVER_PORT}/yuyx")
|
||||||
logger.info("默认管理员: admin (首次运行随机密码见日志)")
|
logger.info("默认管理员: admin (首次运行密码写入 data/default_admin_credentials.txt)")
|
||||||
logger.info("=" * 60)
|
logger.info("=" * 60)
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -197,6 +197,9 @@ class IPRateLimiter:
|
|||||||
# 全局IP限流器实例
|
# 全局IP限流器实例
|
||||||
ip_rate_limiter = IPRateLimiter()
|
ip_rate_limiter = IPRateLimiter()
|
||||||
|
|
||||||
|
_TRUTHY_VALUES = {"1", "true", "yes", "on"}
|
||||||
|
_TRUST_PROXY_HEADERS = str(os.environ.get("TRUST_PROXY_HEADERS", "false") or "").strip().lower() in _TRUTHY_VALUES
|
||||||
|
|
||||||
|
|
||||||
def require_ip_not_locked(f):
|
def require_ip_not_locked(f):
|
||||||
"""装饰器:检查IP是否被锁定"""
|
"""装饰器:检查IP是否被锁定"""
|
||||||
@@ -443,7 +446,7 @@ def get_client_ip(trust_proxy=False):
|
|||||||
"""
|
"""
|
||||||
# 安全说明:X-Forwarded-For 可被伪造
|
# 安全说明:X-Forwarded-For 可被伪造
|
||||||
# 仅在确认请求来自可信代理时才使用代理头
|
# 仅在确认请求来自可信代理时才使用代理头
|
||||||
if trust_proxy:
|
if trust_proxy and _TRUST_PROXY_HEADERS:
|
||||||
if request.headers.get('X-Forwarded-For'):
|
if request.headers.get('X-Forwarded-For'):
|
||||||
return request.headers.get('X-Forwarded-For').split(',')[0].strip()
|
return request.headers.get('X-Forwarded-For').split(',')[0].strip()
|
||||||
elif request.headers.get('X-Real-IP'):
|
elif request.headers.get('X-Real-IP'):
|
||||||
@@ -455,7 +458,7 @@ def get_client_ip(trust_proxy=False):
|
|||||||
|
|
||||||
def _load_trusted_proxy_networks():
|
def _load_trusted_proxy_networks():
|
||||||
"""加载可信代理 CIDR 列表。"""
|
"""加载可信代理 CIDR 列表。"""
|
||||||
default_cidrs = "127.0.0.1/32,::1/128,10.0.0.0/8,172.16.0.0/12,192.168.0.0/16,fc00::/7"
|
default_cidrs = "127.0.0.1/32,::1/128"
|
||||||
raw = str(os.environ.get("TRUSTED_PROXY_CIDRS", default_cidrs) or "").strip()
|
raw = str(os.environ.get("TRUSTED_PROXY_CIDRS", default_cidrs) or "").strip()
|
||||||
if not raw:
|
if not raw:
|
||||||
return []
|
return []
|
||||||
@@ -525,6 +528,9 @@ def _extract_real_ip_from_forwarded_chain() -> str | None:
|
|||||||
def get_rate_limit_ip() -> str:
|
def get_rate_limit_ip() -> str:
|
||||||
"""在可信代理场景下取真实IP,用于限流/风控。"""
|
"""在可信代理场景下取真实IP,用于限流/风控。"""
|
||||||
remote_addr = request.remote_addr or ""
|
remote_addr = request.remote_addr or ""
|
||||||
|
if not _TRUST_PROXY_HEADERS:
|
||||||
|
return remote_addr
|
||||||
|
|
||||||
remote_ip = _parse_ip_address(remote_addr)
|
remote_ip = _parse_ip_address(remote_addr)
|
||||||
if remote_ip is None:
|
if remote_ip is None:
|
||||||
return remote_addr
|
return remote_addr
|
||||||
|
|||||||
@@ -119,15 +119,11 @@ def get_encryption_key():
|
|||||||
"2. 或在 docker-compose.yml 中设置 ENCRYPTION_KEY_RAW 环境变量\n"
|
"2. 或在 docker-compose.yml 中设置 ENCRYPTION_KEY_RAW 环境变量\n"
|
||||||
"3. 如果密钥确实丢失,需要重新录入所有账号密码\n"
|
"3. 如果密钥确实丢失,需要重新录入所有账号密码\n"
|
||||||
"\n"
|
"\n"
|
||||||
"设置 ALLOW_NEW_KEY=true 环境变量可强制生成新密钥(不推荐)\n"
|
|
||||||
+ "=" * 60
|
+ "=" * 60
|
||||||
)
|
)
|
||||||
logger.error(error_msg)
|
logger.error(error_msg)
|
||||||
|
print(error_msg, file=sys.stderr)
|
||||||
# 检查是否强制允许生成新密钥
|
raise RuntimeError("加密密钥丢失且存在已加密数据,请恢复密钥后再启动")
|
||||||
if os.environ.get('ALLOW_NEW_KEY', '').lower() != 'true':
|
|
||||||
print(error_msg, file=sys.stderr)
|
|
||||||
raise RuntimeError("加密密钥丢失且存在已加密数据,请检查配置")
|
|
||||||
|
|
||||||
# 生成新的密钥
|
# 生成新的密钥
|
||||||
key = Fernet.generate_key()
|
key = Fernet.generate_key()
|
||||||
|
|||||||
14
database.py
14
database.py
@@ -19,6 +19,7 @@ from typing import Optional
|
|||||||
|
|
||||||
import db_pool
|
import db_pool
|
||||||
from app_config import get_config
|
from app_config import get_config
|
||||||
|
from app_logger import get_logger
|
||||||
|
|
||||||
from db.schema import ensure_schema
|
from db.schema import ensure_schema
|
||||||
from db.migrations import migrate_database as _migrate_database
|
from db.migrations import migrate_database as _migrate_database
|
||||||
@@ -126,6 +127,7 @@ from db.users import (
|
|||||||
from db.security import record_login_context
|
from db.security import record_login_context
|
||||||
|
|
||||||
config = get_config()
|
config = get_config()
|
||||||
|
logger = get_logger(__name__)
|
||||||
|
|
||||||
# 数据库文件路径
|
# 数据库文件路径
|
||||||
DB_FILE = config.DB_FILE
|
DB_FILE = config.DB_FILE
|
||||||
@@ -140,9 +142,9 @@ _system_config_cache_lock = threading.Lock()
|
|||||||
_system_config_cache_value: Optional[dict] = None
|
_system_config_cache_value: Optional[dict] = None
|
||||||
_system_config_cache_loaded_at = 0.0
|
_system_config_cache_loaded_at = 0.0
|
||||||
try:
|
try:
|
||||||
_SYSTEM_CONFIG_CACHE_TTL_SECONDS = float(os.environ.get("SYSTEM_CONFIG_CACHE_TTL_SECONDS", "3"))
|
_SYSTEM_CONFIG_CACHE_TTL_SECONDS = float(os.environ.get("SYSTEM_CONFIG_CACHE_TTL_SECONDS", "30"))
|
||||||
except Exception:
|
except Exception:
|
||||||
_SYSTEM_CONFIG_CACHE_TTL_SECONDS = 3.0
|
_SYSTEM_CONFIG_CACHE_TTL_SECONDS = 30.0
|
||||||
_SYSTEM_CONFIG_CACHE_TTL_SECONDS = max(0.0, _SYSTEM_CONFIG_CACHE_TTL_SECONDS)
|
_SYSTEM_CONFIG_CACHE_TTL_SECONDS = max(0.0, _SYSTEM_CONFIG_CACHE_TTL_SECONDS)
|
||||||
|
|
||||||
|
|
||||||
@@ -197,8 +199,8 @@ def init_database():
|
|||||||
try:
|
try:
|
||||||
config_value = get_system_config()
|
config_value = get_system_config()
|
||||||
db_pool.configure_slow_query_runtime(threshold_ms=config_value.get("db_slow_query_ms"))
|
db_pool.configure_slow_query_runtime(threshold_ms=config_value.get("db_slow_query_ms"))
|
||||||
except Exception:
|
except Exception as e:
|
||||||
pass
|
logger.warning(f"初始化慢查询阈值失败,使用默认值: {e}")
|
||||||
|
|
||||||
|
|
||||||
def migrate_database():
|
def migrate_database():
|
||||||
@@ -293,6 +295,6 @@ def update_system_config(
|
|||||||
try:
|
try:
|
||||||
latest_config = get_system_config()
|
latest_config = get_system_config()
|
||||||
db_pool.configure_slow_query_runtime(threshold_ms=latest_config.get("db_slow_query_ms"))
|
db_pool.configure_slow_query_runtime(threshold_ms=latest_config.get("db_slow_query_ms"))
|
||||||
except Exception:
|
except Exception as e:
|
||||||
pass
|
logger.warning(f"更新慢查询阈值失败,保留当前配置: {e}")
|
||||||
return ok
|
return ok
|
||||||
|
|||||||
36
db/admin.py
36
db/admin.py
@@ -2,7 +2,9 @@
|
|||||||
# -*- coding: utf-8 -*-
|
# -*- coding: utf-8 -*-
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import os
|
||||||
import sqlite3
|
import sqlite3
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
import db_pool
|
import db_pool
|
||||||
from db.utils import get_cst_now_str
|
from db.utils import get_cst_now_str
|
||||||
@@ -109,6 +111,28 @@ def _normalize_days(days, default: int = 30) -> int:
|
|||||||
return value
|
return value
|
||||||
|
|
||||||
|
|
||||||
|
def _store_default_admin_credentials(username: str, password: str) -> str | None:
|
||||||
|
"""将首次管理员账号密码写入受限权限文件,避免打印到日志。"""
|
||||||
|
raw_path = str(
|
||||||
|
os.environ.get("DEFAULT_ADMIN_CREDENTIALS_FILE", "data/default_admin_credentials.txt") or ""
|
||||||
|
).strip()
|
||||||
|
if not raw_path:
|
||||||
|
return None
|
||||||
|
|
||||||
|
cred_path = Path(raw_path)
|
||||||
|
try:
|
||||||
|
cred_path.parent.mkdir(parents=True, exist_ok=True)
|
||||||
|
with open(cred_path, "w", encoding="utf-8") as f:
|
||||||
|
f.write("安全提醒:首次管理员账号已创建\n")
|
||||||
|
f.write(f"用户名: {username}\n")
|
||||||
|
f.write(f"密码: {password}\n")
|
||||||
|
f.write("请登录后立即修改密码,并删除该文件。\n")
|
||||||
|
os.chmod(cred_path, 0o600)
|
||||||
|
return str(cred_path)
|
||||||
|
except Exception:
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
def ensure_default_admin() -> bool:
|
def ensure_default_admin() -> bool:
|
||||||
"""确保存在默认管理员账号(行为保持不变)。"""
|
"""确保存在默认管理员账号(行为保持不变)。"""
|
||||||
import secrets
|
import secrets
|
||||||
@@ -120,7 +144,8 @@ def ensure_default_admin() -> bool:
|
|||||||
|
|
||||||
if count == 0:
|
if count == 0:
|
||||||
alphabet = string.ascii_letters + string.digits
|
alphabet = string.ascii_letters + string.digits
|
||||||
random_password = "".join(secrets.choice(alphabet) for _ in range(12))
|
bootstrap_password = str(os.environ.get("DEFAULT_ADMIN_PASSWORD", "") or "").strip()
|
||||||
|
random_password = bootstrap_password or "".join(secrets.choice(alphabet) for _ in range(12))
|
||||||
|
|
||||||
default_password_hash = hash_password_bcrypt(random_password)
|
default_password_hash = hash_password_bcrypt(random_password)
|
||||||
cursor.execute(
|
cursor.execute(
|
||||||
@@ -128,11 +153,16 @@ def ensure_default_admin() -> bool:
|
|||||||
("admin", default_password_hash, get_cst_now_str()),
|
("admin", default_password_hash, get_cst_now_str()),
|
||||||
)
|
)
|
||||||
conn.commit()
|
conn.commit()
|
||||||
|
credential_file = _store_default_admin_credentials("admin", random_password)
|
||||||
print("=" * 60)
|
print("=" * 60)
|
||||||
print("安全提醒:已创建默认管理员账号")
|
print("安全提醒:已创建默认管理员账号")
|
||||||
print("用户名: admin")
|
print("用户名: admin")
|
||||||
print(f"密码: {random_password}")
|
if credential_file:
|
||||||
print("请立即登录后修改密码!")
|
print(f"初始密码已写入: {credential_file}(权限600)")
|
||||||
|
print("请立即登录后修改密码,并删除该文件。")
|
||||||
|
else:
|
||||||
|
print("未能写入初始密码文件。")
|
||||||
|
print("建议设置 DEFAULT_ADMIN_PASSWORD 后重建管理员账号。")
|
||||||
print("=" * 60)
|
print("=" * 60)
|
||||||
return True
|
return True
|
||||||
return False
|
return False
|
||||||
|
|||||||
@@ -21,6 +21,10 @@ logger = get_logger(__name__)
|
|||||||
|
|
||||||
_CST_TZ = pytz.timezone("Asia/Shanghai")
|
_CST_TZ = pytz.timezone("Asia/Shanghai")
|
||||||
_PERMANENT_VIP_EXPIRE = "2099-12-31 23:59:59"
|
_PERMANENT_VIP_EXPIRE = "2099-12-31 23:59:59"
|
||||||
|
_USER_LOOKUP_SQL = {
|
||||||
|
"id": "SELECT * FROM users WHERE id = ?",
|
||||||
|
"username": "SELECT * FROM users WHERE username = ?",
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
def _row_to_dict(row):
|
def _row_to_dict(row):
|
||||||
@@ -28,9 +32,12 @@ def _row_to_dict(row):
|
|||||||
|
|
||||||
|
|
||||||
def _get_user_by_field(field_name: str, field_value):
|
def _get_user_by_field(field_name: str, field_value):
|
||||||
|
query_sql = _USER_LOOKUP_SQL.get(str(field_name or ""))
|
||||||
|
if not query_sql:
|
||||||
|
raise ValueError(f"unsupported user lookup field: {field_name}")
|
||||||
with db_pool.get_db() as conn:
|
with db_pool.get_db() as conn:
|
||||||
cursor = conn.cursor()
|
cursor = conn.cursor()
|
||||||
cursor.execute(f"SELECT * FROM users WHERE {field_name} = ?", (field_value,))
|
cursor.execute(query_sql, (field_value,))
|
||||||
return _row_to_dict(cursor.fetchone())
|
return _row_to_dict(cursor.fetchone())
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -201,8 +201,8 @@ def _send_login_security_alert_if_needed(user: dict, username: str, client_ip: s
|
|||||||
new_device=context.get("new_device", False),
|
new_device=context.get("new_device", False),
|
||||||
user_id=user["id"],
|
user_id=user["id"],
|
||||||
)
|
)
|
||||||
except Exception:
|
except Exception as e:
|
||||||
pass
|
logger.warning(f"发送登录安全提醒失败: user_id={user.get('id')}, error={e}")
|
||||||
|
|
||||||
|
|
||||||
def _parse_credential_payload(data: dict) -> dict | None:
|
def _parse_credential_payload(data: dict) -> dict | None:
|
||||||
@@ -308,10 +308,9 @@ def verify_email(token):
|
|||||||
if result:
|
if result:
|
||||||
token_id = result["token_id"]
|
token_id = result["token_id"]
|
||||||
user_id = result["user_id"]
|
user_id = result["user_id"]
|
||||||
email = result["email"]
|
|
||||||
|
|
||||||
if not database.approve_user(user_id):
|
if not database.approve_user(user_id):
|
||||||
logger.error(f"用户邮箱验证失败: 用户审核更新失败 user_id={user_id}, email={email}")
|
logger.error(f"用户邮箱验证失败: 用户审核更新失败 user_id={user_id}")
|
||||||
error_message = "验证处理失败,请稍后重试"
|
error_message = "验证处理失败,请稍后重试"
|
||||||
spa_initial_state = {
|
spa_initial_state = {
|
||||||
"page": "verify_result",
|
"page": "verify_result",
|
||||||
@@ -333,9 +332,9 @@ def verify_email(token):
|
|||||||
database.set_user_vip(user_id, auto_approve_vip_days)
|
database.set_user_vip(user_id, auto_approve_vip_days)
|
||||||
|
|
||||||
if not email_service.consume_email_token(token_id):
|
if not email_service.consume_email_token(token_id):
|
||||||
logger.warning(f"用户邮箱验证后Token消费失败: token_id={token_id}, user_id={user_id}")
|
logger.warning(f"用户邮箱验证后Token消费失败: user_id={user_id}")
|
||||||
|
|
||||||
logger.info(f"用户邮箱验证成功: user_id={user_id}, email={email}")
|
logger.info(f"用户邮箱验证成功: user_id={user_id}")
|
||||||
spa_initial_state = {
|
spa_initial_state = {
|
||||||
"page": "verify_result",
|
"page": "verify_result",
|
||||||
"success": True,
|
"success": True,
|
||||||
@@ -348,7 +347,7 @@ def verify_email(token):
|
|||||||
}
|
}
|
||||||
return render_app_spa_or_legacy("verify_success.html", spa_initial_state=spa_initial_state)
|
return render_app_spa_or_legacy("verify_success.html", spa_initial_state=spa_initial_state)
|
||||||
|
|
||||||
logger.warning(f"邮箱验证失败: token={token[:20]}...")
|
logger.warning("邮箱验证失败: token无效或已过期")
|
||||||
error_message = "验证链接无效或已过期,请重新注册或申请重发验证邮件"
|
error_message = "验证链接无效或已过期,请重新注册或申请重发验证邮件"
|
||||||
spa_initial_state = {
|
spa_initial_state = {
|
||||||
"page": "verify_result",
|
"page": "verify_result",
|
||||||
|
|||||||
@@ -365,7 +365,7 @@ def verify_bind_email(token):
|
|||||||
|
|
||||||
if database.update_user_email(user_id, email, verified=True):
|
if database.update_user_email(user_id, email, verified=True):
|
||||||
if not email_service.consume_email_token(token_id):
|
if not email_service.consume_email_token(token_id):
|
||||||
logger.warning(f"邮箱绑定成功但Token消费失败: token_id={token_id}, user_id={user_id}")
|
logger.warning(f"邮箱绑定成功但Token消费失败: user_id={user_id}")
|
||||||
return _render_verify_bind_success(email)
|
return _render_verify_bind_success(email)
|
||||||
|
|
||||||
return _render_verify_bind_failed(title="绑定失败", error_message="邮箱绑定失败,请重试")
|
return _render_verify_bind_failed(title="绑定失败", error_message="邮箱绑定失败,请重试")
|
||||||
|
|||||||
58
tests/test_security_hardening.py
Normal file
58
tests/test_security_hardening.py
Normal file
@@ -0,0 +1,58 @@
|
|||||||
|
import importlib
|
||||||
|
import sys
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
from flask import Flask
|
||||||
|
|
||||||
|
PROJECT_ROOT = Path(__file__).resolve().parents[1]
|
||||||
|
if str(PROJECT_ROOT) not in sys.path:
|
||||||
|
sys.path.insert(0, str(PROJECT_ROOT))
|
||||||
|
|
||||||
|
import app_security
|
||||||
|
import crypto_utils
|
||||||
|
from db import users as db_users
|
||||||
|
|
||||||
|
|
||||||
|
def test_user_lookup_rejects_dynamic_field_name():
|
||||||
|
with pytest.raises(ValueError):
|
||||||
|
db_users._get_user_by_field("username OR 1=1 --", "demo")
|
||||||
|
|
||||||
|
|
||||||
|
def test_rate_limit_ip_does_not_trust_proxy_headers_by_default(monkeypatch):
|
||||||
|
monkeypatch.delenv("TRUST_PROXY_HEADERS", raising=False)
|
||||||
|
monkeypatch.delenv("TRUSTED_PROXY_CIDRS", raising=False)
|
||||||
|
security_module = importlib.reload(app_security)
|
||||||
|
|
||||||
|
app = Flask(__name__)
|
||||||
|
with app.test_request_context(
|
||||||
|
"/",
|
||||||
|
environ_base={"REMOTE_ADDR": "10.0.0.9"},
|
||||||
|
headers={"X-Forwarded-For": "198.51.100.10"},
|
||||||
|
):
|
||||||
|
assert security_module.get_rate_limit_ip() == "10.0.0.9"
|
||||||
|
|
||||||
|
|
||||||
|
def test_rate_limit_ip_can_use_forwarded_chain_when_explicitly_enabled(monkeypatch):
|
||||||
|
monkeypatch.setenv("TRUST_PROXY_HEADERS", "true")
|
||||||
|
monkeypatch.setenv("TRUSTED_PROXY_CIDRS", "10.0.0.0/8")
|
||||||
|
security_module = importlib.reload(app_security)
|
||||||
|
|
||||||
|
app = Flask(__name__)
|
||||||
|
with app.test_request_context(
|
||||||
|
"/",
|
||||||
|
environ_base={"REMOTE_ADDR": "10.2.3.4"},
|
||||||
|
headers={"X-Forwarded-For": "198.51.100.10, 10.2.3.4"},
|
||||||
|
):
|
||||||
|
assert security_module.get_rate_limit_ip() == "198.51.100.10"
|
||||||
|
|
||||||
|
|
||||||
|
def test_get_encryption_key_refuses_regeneration_when_encrypted_data_exists(monkeypatch, tmp_path):
|
||||||
|
monkeypatch.setenv("ALLOW_NEW_KEY", "true")
|
||||||
|
monkeypatch.delenv("ENCRYPTION_KEY_RAW", raising=False)
|
||||||
|
monkeypatch.delenv("ENCRYPTION_KEY", raising=False)
|
||||||
|
monkeypatch.setattr(crypto_utils, "ENCRYPTION_KEY_FILE", str(tmp_path / "missing_key.bin"))
|
||||||
|
monkeypatch.setattr(crypto_utils, "_check_existing_encrypted_data", lambda: True)
|
||||||
|
|
||||||
|
with pytest.raises(RuntimeError):
|
||||||
|
crypto_utils.get_encryption_key()
|
||||||
Reference in New Issue
Block a user