安全修复: 收敛认证与日志风险并补充基础测试
This commit is contained in:
@@ -15,6 +15,12 @@ FLASK_DEBUG=false
|
||||
SESSION_LIFETIME_HOURS=24
|
||||
SESSION_COOKIE_SECURE=true # 生产环境HTTPS必须为true,本地HTTP调试可临时设为false
|
||||
HTTPS_ENABLED=true
|
||||
# 是否信任 X-Forwarded-* 代理头(默认关闭,建议仅在可信反代后开启)
|
||||
TRUST_PROXY_HEADERS=false
|
||||
# TRUST_PROXY_HEADERS=true 时生效,按需配置你的反向代理网段
|
||||
TRUSTED_PROXY_CIDRS=127.0.0.1/32,::1/128
|
||||
# 可选:首次启动时指定默认管理员密码(避免控制台输出明文密码)
|
||||
# DEFAULT_ADMIN_PASSWORD=your-strong-admin-password
|
||||
|
||||
# ==================== 数据库配置 ====================
|
||||
DB_FILE=data/app_data.db
|
||||
@@ -36,6 +42,7 @@ DB_PRAGMA_OPTIMIZE_INTERVAL_SECONDS=21600
|
||||
DB_ANALYZE_INTERVAL_SECONDS=86400
|
||||
DB_WAL_CHECKPOINT_INTERVAL_SECONDS=43200
|
||||
DB_WAL_CHECKPOINT_MODE=PASSIVE
|
||||
SYSTEM_CONFIG_CACHE_TTL_SECONDS=30
|
||||
|
||||
# ==================== 并发控制配置 ====================
|
||||
MAX_CONCURRENT_GLOBAL=2
|
||||
|
||||
4
.gitignore
vendored
4
.gitignore
vendored
@@ -146,3 +146,7 @@ Thumbs.db
|
||||
# Temporary files
|
||||
*.tmp
|
||||
*.temp
|
||||
|
||||
# Allow committed test cases
|
||||
!tests/
|
||||
!tests/**/*.py
|
||||
|
||||
@@ -790,7 +790,8 @@ docker logs -f knowledge-automation-multiuser
|
||||
# 6. 访问系统
|
||||
# 浏览器打开: http://your-ip:51232
|
||||
# 后台管理: http://your-ip:51232/yuyx
|
||||
# 默认管理员账号见容器启动日志(首次运行会生成随机密码)
|
||||
# 首次管理员密码会写入 data/default_admin_credentials.txt(权限600)
|
||||
# 登录后请立即修改密码并删除该文件
|
||||
```
|
||||
|
||||
完成!🎉
|
||||
|
||||
49
app.py
49
app.py
@@ -49,12 +49,13 @@ from services.tasks import get_task_scheduler
|
||||
|
||||
# 设置时区为中国标准时间(CST, UTC+8)
|
||||
os.environ["TZ"] = "Asia/Shanghai"
|
||||
_TZSET_ERROR = None
|
||||
try:
|
||||
import time as _time
|
||||
|
||||
_time.tzset()
|
||||
except Exception:
|
||||
pass
|
||||
except Exception as e:
|
||||
_TZSET_ERROR = e
|
||||
|
||||
|
||||
def _sigchld_handler(signum, frame):
|
||||
@@ -116,6 +117,8 @@ except Exception as socketio_error:
|
||||
|
||||
init_logging(log_level=config.LOG_LEVEL, log_file=config.LOG_FILE)
|
||||
logger = get_logger("app")
|
||||
if _TZSET_ERROR is not None:
|
||||
logger.warning(f"设置时区失败,将继续使用系统默认时区: {_TZSET_ERROR}")
|
||||
if _socketio_fallback_reason:
|
||||
logger.warning(f"[SocketIO] 初始化失败,已回退 threading 模式: {_socketio_fallback_reason}")
|
||||
logger.info(f"[SocketIO] 当前 async_mode: {socketio.async_mode}")
|
||||
@@ -139,15 +142,15 @@ def _request_uses_https() -> bool:
|
||||
try:
|
||||
if bool(request.is_secure):
|
||||
return True
|
||||
except Exception:
|
||||
pass
|
||||
except Exception as e:
|
||||
logger.debug(f"检查 request.is_secure 失败: {e}")
|
||||
|
||||
try:
|
||||
forwarded_proto = str(request.headers.get("X-Forwarded-Proto", "") or "").split(",", 1)[0].strip().lower()
|
||||
if forwarded_proto == "https":
|
||||
return True
|
||||
except Exception:
|
||||
pass
|
||||
except Exception as e:
|
||||
logger.debug(f"检查 X-Forwarded-Proto 失败: {e}")
|
||||
|
||||
return False
|
||||
|
||||
@@ -255,8 +258,8 @@ def _record_request_metric_after_response(response) -> None:
|
||||
logger.warning(
|
||||
f"[API-DIAG] {method} {path} -> {status_code} ({duration_ms:.1f}ms)"
|
||||
)
|
||||
except Exception:
|
||||
pass
|
||||
except Exception as e:
|
||||
logger.debug(f"记录请求指标失败: {e}")
|
||||
|
||||
|
||||
@app.after_request
|
||||
@@ -312,12 +315,12 @@ def serve_static(filename):
|
||||
# 协商缓存:确保存在 ETag,并基于 If-None-Match/If-Modified-Since 返回 304
|
||||
try:
|
||||
response.add_etag(overwrite=False)
|
||||
except Exception:
|
||||
pass
|
||||
except Exception as e:
|
||||
logger.debug(f"静态资源 ETag 设置失败({filename}): {e}")
|
||||
try:
|
||||
response.make_conditional(request)
|
||||
except Exception:
|
||||
pass
|
||||
except Exception as e:
|
||||
logger.debug(f"静态资源协商缓存处理失败({filename}): {e}")
|
||||
|
||||
response.headers.setdefault("Vary", "Accept-Encoding")
|
||||
if is_hashed_asset:
|
||||
@@ -341,33 +344,33 @@ def cleanup_on_exit():
|
||||
for acc in accounts.values():
|
||||
if getattr(acc, "is_running", False):
|
||||
acc.should_stop = True
|
||||
except Exception:
|
||||
pass
|
||||
except Exception as e:
|
||||
logger.warning(f"停止运行中任务失败: {e}")
|
||||
|
||||
logger.info("- 停止任务调度器...")
|
||||
try:
|
||||
scheduler = get_task_scheduler()
|
||||
scheduler.shutdown(timeout=5)
|
||||
except Exception:
|
||||
pass
|
||||
except Exception as e:
|
||||
logger.warning(f"停止任务调度器失败: {e}")
|
||||
|
||||
logger.info("- 关闭截图线程池...")
|
||||
try:
|
||||
shutdown_browser_worker_pool()
|
||||
except Exception:
|
||||
pass
|
||||
except Exception as e:
|
||||
logger.warning(f"关闭截图线程池失败: {e}")
|
||||
|
||||
logger.info("- 关闭邮件队列...")
|
||||
try:
|
||||
email_service.shutdown_email_queue()
|
||||
except Exception:
|
||||
pass
|
||||
except Exception as e:
|
||||
logger.warning(f"关闭邮件队列失败: {e}")
|
||||
|
||||
logger.info("- 关闭数据库连接池...")
|
||||
try:
|
||||
db_pool._pool.close_all() if db_pool._pool else None
|
||||
except Exception:
|
||||
pass
|
||||
except Exception as e:
|
||||
logger.warning(f"关闭数据库连接池失败: {e}")
|
||||
|
||||
logger.info("[OK] 资源清理完成")
|
||||
|
||||
@@ -464,7 +467,7 @@ def _log_startup_urls() -> None:
|
||||
logger.info("服务器启动中...")
|
||||
logger.info(f"用户访问地址: http://{config.SERVER_HOST}:{config.SERVER_PORT}")
|
||||
logger.info(f"后台管理地址: http://{config.SERVER_HOST}:{config.SERVER_PORT}/yuyx")
|
||||
logger.info("默认管理员: admin (首次运行随机密码见日志)")
|
||||
logger.info("默认管理员: admin (首次运行密码写入 data/default_admin_credentials.txt)")
|
||||
logger.info("=" * 60)
|
||||
|
||||
|
||||
|
||||
@@ -197,6 +197,9 @@ class IPRateLimiter:
|
||||
# 全局IP限流器实例
|
||||
ip_rate_limiter = IPRateLimiter()
|
||||
|
||||
_TRUTHY_VALUES = {"1", "true", "yes", "on"}
|
||||
_TRUST_PROXY_HEADERS = str(os.environ.get("TRUST_PROXY_HEADERS", "false") or "").strip().lower() in _TRUTHY_VALUES
|
||||
|
||||
|
||||
def require_ip_not_locked(f):
|
||||
"""装饰器:检查IP是否被锁定"""
|
||||
@@ -443,7 +446,7 @@ def get_client_ip(trust_proxy=False):
|
||||
"""
|
||||
# 安全说明:X-Forwarded-For 可被伪造
|
||||
# 仅在确认请求来自可信代理时才使用代理头
|
||||
if trust_proxy:
|
||||
if trust_proxy and _TRUST_PROXY_HEADERS:
|
||||
if request.headers.get('X-Forwarded-For'):
|
||||
return request.headers.get('X-Forwarded-For').split(',')[0].strip()
|
||||
elif request.headers.get('X-Real-IP'):
|
||||
@@ -455,7 +458,7 @@ def get_client_ip(trust_proxy=False):
|
||||
|
||||
def _load_trusted_proxy_networks():
|
||||
"""加载可信代理 CIDR 列表。"""
|
||||
default_cidrs = "127.0.0.1/32,::1/128,10.0.0.0/8,172.16.0.0/12,192.168.0.0/16,fc00::/7"
|
||||
default_cidrs = "127.0.0.1/32,::1/128"
|
||||
raw = str(os.environ.get("TRUSTED_PROXY_CIDRS", default_cidrs) or "").strip()
|
||||
if not raw:
|
||||
return []
|
||||
@@ -525,6 +528,9 @@ def _extract_real_ip_from_forwarded_chain() -> str | None:
|
||||
def get_rate_limit_ip() -> str:
|
||||
"""在可信代理场景下取真实IP,用于限流/风控。"""
|
||||
remote_addr = request.remote_addr or ""
|
||||
if not _TRUST_PROXY_HEADERS:
|
||||
return remote_addr
|
||||
|
||||
remote_ip = _parse_ip_address(remote_addr)
|
||||
if remote_ip is None:
|
||||
return remote_addr
|
||||
|
||||
@@ -119,15 +119,11 @@ def get_encryption_key():
|
||||
"2. 或在 docker-compose.yml 中设置 ENCRYPTION_KEY_RAW 环境变量\n"
|
||||
"3. 如果密钥确实丢失,需要重新录入所有账号密码\n"
|
||||
"\n"
|
||||
"设置 ALLOW_NEW_KEY=true 环境变量可强制生成新密钥(不推荐)\n"
|
||||
+ "=" * 60
|
||||
)
|
||||
logger.error(error_msg)
|
||||
|
||||
# 检查是否强制允许生成新密钥
|
||||
if os.environ.get('ALLOW_NEW_KEY', '').lower() != 'true':
|
||||
print(error_msg, file=sys.stderr)
|
||||
raise RuntimeError("加密密钥丢失且存在已加密数据,请检查配置")
|
||||
print(error_msg, file=sys.stderr)
|
||||
raise RuntimeError("加密密钥丢失且存在已加密数据,请恢复密钥后再启动")
|
||||
|
||||
# 生成新的密钥
|
||||
key = Fernet.generate_key()
|
||||
|
||||
14
database.py
14
database.py
@@ -19,6 +19,7 @@ from typing import Optional
|
||||
|
||||
import db_pool
|
||||
from app_config import get_config
|
||||
from app_logger import get_logger
|
||||
|
||||
from db.schema import ensure_schema
|
||||
from db.migrations import migrate_database as _migrate_database
|
||||
@@ -126,6 +127,7 @@ from db.users import (
|
||||
from db.security import record_login_context
|
||||
|
||||
config = get_config()
|
||||
logger = get_logger(__name__)
|
||||
|
||||
# 数据库文件路径
|
||||
DB_FILE = config.DB_FILE
|
||||
@@ -140,9 +142,9 @@ _system_config_cache_lock = threading.Lock()
|
||||
_system_config_cache_value: Optional[dict] = None
|
||||
_system_config_cache_loaded_at = 0.0
|
||||
try:
|
||||
_SYSTEM_CONFIG_CACHE_TTL_SECONDS = float(os.environ.get("SYSTEM_CONFIG_CACHE_TTL_SECONDS", "3"))
|
||||
_SYSTEM_CONFIG_CACHE_TTL_SECONDS = float(os.environ.get("SYSTEM_CONFIG_CACHE_TTL_SECONDS", "30"))
|
||||
except Exception:
|
||||
_SYSTEM_CONFIG_CACHE_TTL_SECONDS = 3.0
|
||||
_SYSTEM_CONFIG_CACHE_TTL_SECONDS = 30.0
|
||||
_SYSTEM_CONFIG_CACHE_TTL_SECONDS = max(0.0, _SYSTEM_CONFIG_CACHE_TTL_SECONDS)
|
||||
|
||||
|
||||
@@ -197,8 +199,8 @@ def init_database():
|
||||
try:
|
||||
config_value = get_system_config()
|
||||
db_pool.configure_slow_query_runtime(threshold_ms=config_value.get("db_slow_query_ms"))
|
||||
except Exception:
|
||||
pass
|
||||
except Exception as e:
|
||||
logger.warning(f"初始化慢查询阈值失败,使用默认值: {e}")
|
||||
|
||||
|
||||
def migrate_database():
|
||||
@@ -293,6 +295,6 @@ def update_system_config(
|
||||
try:
|
||||
latest_config = get_system_config()
|
||||
db_pool.configure_slow_query_runtime(threshold_ms=latest_config.get("db_slow_query_ms"))
|
||||
except Exception:
|
||||
pass
|
||||
except Exception as e:
|
||||
logger.warning(f"更新慢查询阈值失败,保留当前配置: {e}")
|
||||
return ok
|
||||
|
||||
36
db/admin.py
36
db/admin.py
@@ -2,7 +2,9 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
from __future__ import annotations
|
||||
|
||||
import os
|
||||
import sqlite3
|
||||
from pathlib import Path
|
||||
|
||||
import db_pool
|
||||
from db.utils import get_cst_now_str
|
||||
@@ -109,6 +111,28 @@ def _normalize_days(days, default: int = 30) -> int:
|
||||
return value
|
||||
|
||||
|
||||
def _store_default_admin_credentials(username: str, password: str) -> str | None:
|
||||
"""将首次管理员账号密码写入受限权限文件,避免打印到日志。"""
|
||||
raw_path = str(
|
||||
os.environ.get("DEFAULT_ADMIN_CREDENTIALS_FILE", "data/default_admin_credentials.txt") or ""
|
||||
).strip()
|
||||
if not raw_path:
|
||||
return None
|
||||
|
||||
cred_path = Path(raw_path)
|
||||
try:
|
||||
cred_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
with open(cred_path, "w", encoding="utf-8") as f:
|
||||
f.write("安全提醒:首次管理员账号已创建\n")
|
||||
f.write(f"用户名: {username}\n")
|
||||
f.write(f"密码: {password}\n")
|
||||
f.write("请登录后立即修改密码,并删除该文件。\n")
|
||||
os.chmod(cred_path, 0o600)
|
||||
return str(cred_path)
|
||||
except Exception:
|
||||
return None
|
||||
|
||||
|
||||
def ensure_default_admin() -> bool:
|
||||
"""确保存在默认管理员账号(行为保持不变)。"""
|
||||
import secrets
|
||||
@@ -120,7 +144,8 @@ def ensure_default_admin() -> bool:
|
||||
|
||||
if count == 0:
|
||||
alphabet = string.ascii_letters + string.digits
|
||||
random_password = "".join(secrets.choice(alphabet) for _ in range(12))
|
||||
bootstrap_password = str(os.environ.get("DEFAULT_ADMIN_PASSWORD", "") or "").strip()
|
||||
random_password = bootstrap_password or "".join(secrets.choice(alphabet) for _ in range(12))
|
||||
|
||||
default_password_hash = hash_password_bcrypt(random_password)
|
||||
cursor.execute(
|
||||
@@ -128,11 +153,16 @@ def ensure_default_admin() -> bool:
|
||||
("admin", default_password_hash, get_cst_now_str()),
|
||||
)
|
||||
conn.commit()
|
||||
credential_file = _store_default_admin_credentials("admin", random_password)
|
||||
print("=" * 60)
|
||||
print("安全提醒:已创建默认管理员账号")
|
||||
print("用户名: admin")
|
||||
print(f"密码: {random_password}")
|
||||
print("请立即登录后修改密码!")
|
||||
if credential_file:
|
||||
print(f"初始密码已写入: {credential_file}(权限600)")
|
||||
print("请立即登录后修改密码,并删除该文件。")
|
||||
else:
|
||||
print("未能写入初始密码文件。")
|
||||
print("建议设置 DEFAULT_ADMIN_PASSWORD 后重建管理员账号。")
|
||||
print("=" * 60)
|
||||
return True
|
||||
return False
|
||||
|
||||
@@ -21,6 +21,10 @@ logger = get_logger(__name__)
|
||||
|
||||
_CST_TZ = pytz.timezone("Asia/Shanghai")
|
||||
_PERMANENT_VIP_EXPIRE = "2099-12-31 23:59:59"
|
||||
_USER_LOOKUP_SQL = {
|
||||
"id": "SELECT * FROM users WHERE id = ?",
|
||||
"username": "SELECT * FROM users WHERE username = ?",
|
||||
}
|
||||
|
||||
|
||||
def _row_to_dict(row):
|
||||
@@ -28,9 +32,12 @@ def _row_to_dict(row):
|
||||
|
||||
|
||||
def _get_user_by_field(field_name: str, field_value):
|
||||
query_sql = _USER_LOOKUP_SQL.get(str(field_name or ""))
|
||||
if not query_sql:
|
||||
raise ValueError(f"unsupported user lookup field: {field_name}")
|
||||
with db_pool.get_db() as conn:
|
||||
cursor = conn.cursor()
|
||||
cursor.execute(f"SELECT * FROM users WHERE {field_name} = ?", (field_value,))
|
||||
cursor.execute(query_sql, (field_value,))
|
||||
return _row_to_dict(cursor.fetchone())
|
||||
|
||||
|
||||
|
||||
@@ -201,8 +201,8 @@ def _send_login_security_alert_if_needed(user: dict, username: str, client_ip: s
|
||||
new_device=context.get("new_device", False),
|
||||
user_id=user["id"],
|
||||
)
|
||||
except Exception:
|
||||
pass
|
||||
except Exception as e:
|
||||
logger.warning(f"发送登录安全提醒失败: user_id={user.get('id')}, error={e}")
|
||||
|
||||
|
||||
def _parse_credential_payload(data: dict) -> dict | None:
|
||||
@@ -308,10 +308,9 @@ def verify_email(token):
|
||||
if result:
|
||||
token_id = result["token_id"]
|
||||
user_id = result["user_id"]
|
||||
email = result["email"]
|
||||
|
||||
if not database.approve_user(user_id):
|
||||
logger.error(f"用户邮箱验证失败: 用户审核更新失败 user_id={user_id}, email={email}")
|
||||
logger.error(f"用户邮箱验证失败: 用户审核更新失败 user_id={user_id}")
|
||||
error_message = "验证处理失败,请稍后重试"
|
||||
spa_initial_state = {
|
||||
"page": "verify_result",
|
||||
@@ -333,9 +332,9 @@ def verify_email(token):
|
||||
database.set_user_vip(user_id, auto_approve_vip_days)
|
||||
|
||||
if not email_service.consume_email_token(token_id):
|
||||
logger.warning(f"用户邮箱验证后Token消费失败: token_id={token_id}, user_id={user_id}")
|
||||
logger.warning(f"用户邮箱验证后Token消费失败: user_id={user_id}")
|
||||
|
||||
logger.info(f"用户邮箱验证成功: user_id={user_id}, email={email}")
|
||||
logger.info(f"用户邮箱验证成功: user_id={user_id}")
|
||||
spa_initial_state = {
|
||||
"page": "verify_result",
|
||||
"success": True,
|
||||
@@ -348,7 +347,7 @@ def verify_email(token):
|
||||
}
|
||||
return render_app_spa_or_legacy("verify_success.html", spa_initial_state=spa_initial_state)
|
||||
|
||||
logger.warning(f"邮箱验证失败: token={token[:20]}...")
|
||||
logger.warning("邮箱验证失败: token无效或已过期")
|
||||
error_message = "验证链接无效或已过期,请重新注册或申请重发验证邮件"
|
||||
spa_initial_state = {
|
||||
"page": "verify_result",
|
||||
|
||||
@@ -365,7 +365,7 @@ def verify_bind_email(token):
|
||||
|
||||
if database.update_user_email(user_id, email, verified=True):
|
||||
if not email_service.consume_email_token(token_id):
|
||||
logger.warning(f"邮箱绑定成功但Token消费失败: token_id={token_id}, user_id={user_id}")
|
||||
logger.warning(f"邮箱绑定成功但Token消费失败: user_id={user_id}")
|
||||
return _render_verify_bind_success(email)
|
||||
|
||||
return _render_verify_bind_failed(title="绑定失败", error_message="邮箱绑定失败,请重试")
|
||||
|
||||
58
tests/test_security_hardening.py
Normal file
58
tests/test_security_hardening.py
Normal file
@@ -0,0 +1,58 @@
|
||||
import importlib
|
||||
import sys
|
||||
from pathlib import Path
|
||||
|
||||
import pytest
|
||||
from flask import Flask
|
||||
|
||||
PROJECT_ROOT = Path(__file__).resolve().parents[1]
|
||||
if str(PROJECT_ROOT) not in sys.path:
|
||||
sys.path.insert(0, str(PROJECT_ROOT))
|
||||
|
||||
import app_security
|
||||
import crypto_utils
|
||||
from db import users as db_users
|
||||
|
||||
|
||||
def test_user_lookup_rejects_dynamic_field_name():
|
||||
with pytest.raises(ValueError):
|
||||
db_users._get_user_by_field("username OR 1=1 --", "demo")
|
||||
|
||||
|
||||
def test_rate_limit_ip_does_not_trust_proxy_headers_by_default(monkeypatch):
|
||||
monkeypatch.delenv("TRUST_PROXY_HEADERS", raising=False)
|
||||
monkeypatch.delenv("TRUSTED_PROXY_CIDRS", raising=False)
|
||||
security_module = importlib.reload(app_security)
|
||||
|
||||
app = Flask(__name__)
|
||||
with app.test_request_context(
|
||||
"/",
|
||||
environ_base={"REMOTE_ADDR": "10.0.0.9"},
|
||||
headers={"X-Forwarded-For": "198.51.100.10"},
|
||||
):
|
||||
assert security_module.get_rate_limit_ip() == "10.0.0.9"
|
||||
|
||||
|
||||
def test_rate_limit_ip_can_use_forwarded_chain_when_explicitly_enabled(monkeypatch):
|
||||
monkeypatch.setenv("TRUST_PROXY_HEADERS", "true")
|
||||
monkeypatch.setenv("TRUSTED_PROXY_CIDRS", "10.0.0.0/8")
|
||||
security_module = importlib.reload(app_security)
|
||||
|
||||
app = Flask(__name__)
|
||||
with app.test_request_context(
|
||||
"/",
|
||||
environ_base={"REMOTE_ADDR": "10.2.3.4"},
|
||||
headers={"X-Forwarded-For": "198.51.100.10, 10.2.3.4"},
|
||||
):
|
||||
assert security_module.get_rate_limit_ip() == "198.51.100.10"
|
||||
|
||||
|
||||
def test_get_encryption_key_refuses_regeneration_when_encrypted_data_exists(monkeypatch, tmp_path):
|
||||
monkeypatch.setenv("ALLOW_NEW_KEY", "true")
|
||||
monkeypatch.delenv("ENCRYPTION_KEY_RAW", raising=False)
|
||||
monkeypatch.delenv("ENCRYPTION_KEY", raising=False)
|
||||
monkeypatch.setattr(crypto_utils, "ENCRYPTION_KEY_FILE", str(tmp_path / "missing_key.bin"))
|
||||
monkeypatch.setattr(crypto_utils, "_check_existing_encrypted_data", lambda: True)
|
||||
|
||||
with pytest.raises(RuntimeError):
|
||||
crypto_utils.get_encryption_key()
|
||||
Reference in New Issue
Block a user