diff --git a/.env.example b/.env.example index 72d9698..6c5a5cc 100644 --- a/.env.example +++ b/.env.example @@ -15,6 +15,12 @@ FLASK_DEBUG=false SESSION_LIFETIME_HOURS=24 SESSION_COOKIE_SECURE=true # 生产环境HTTPS必须为true,本地HTTP调试可临时设为false HTTPS_ENABLED=true +# 是否信任 X-Forwarded-* 代理头(默认关闭,建议仅在可信反代后开启) +TRUST_PROXY_HEADERS=false +# TRUST_PROXY_HEADERS=true 时生效,按需配置你的反向代理网段 +TRUSTED_PROXY_CIDRS=127.0.0.1/32,::1/128 +# 可选:首次启动时指定默认管理员密码(避免控制台输出明文密码) +# DEFAULT_ADMIN_PASSWORD=your-strong-admin-password # ==================== 数据库配置 ==================== DB_FILE=data/app_data.db @@ -36,6 +42,7 @@ DB_PRAGMA_OPTIMIZE_INTERVAL_SECONDS=21600 DB_ANALYZE_INTERVAL_SECONDS=86400 DB_WAL_CHECKPOINT_INTERVAL_SECONDS=43200 DB_WAL_CHECKPOINT_MODE=PASSIVE +SYSTEM_CONFIG_CACHE_TTL_SECONDS=30 # ==================== 并发控制配置 ==================== MAX_CONCURRENT_GLOBAL=2 diff --git a/.gitignore b/.gitignore index 25452af..f62f0cf 100644 --- a/.gitignore +++ b/.gitignore @@ -146,3 +146,7 @@ Thumbs.db # Temporary files *.tmp *.temp + +# Allow committed test cases +!tests/ +!tests/**/*.py diff --git a/README.md b/README.md index 320ddea..8b0360a 100644 --- a/README.md +++ b/README.md @@ -790,7 +790,8 @@ docker logs -f knowledge-automation-multiuser # 6. 访问系统 # 浏览器打开: http://your-ip:51232 # 后台管理: http://your-ip:51232/yuyx -# 默认管理员账号见容器启动日志(首次运行会生成随机密码) +# 首次管理员密码会写入 data/default_admin_credentials.txt(权限600) +# 登录后请立即修改密码并删除该文件 ``` 完成!🎉 diff --git a/app.py b/app.py index b2c9e35..831b28a 100644 --- a/app.py +++ b/app.py @@ -49,12 +49,13 @@ from services.tasks import get_task_scheduler # 设置时区为中国标准时间(CST, UTC+8) os.environ["TZ"] = "Asia/Shanghai" +_TZSET_ERROR = None try: import time as _time _time.tzset() -except Exception: - pass +except Exception as e: + _TZSET_ERROR = e def _sigchld_handler(signum, frame): @@ -116,6 +117,8 @@ except Exception as socketio_error: init_logging(log_level=config.LOG_LEVEL, log_file=config.LOG_FILE) logger = get_logger("app") +if _TZSET_ERROR is not None: + logger.warning(f"设置时区失败,将继续使用系统默认时区: {_TZSET_ERROR}") if _socketio_fallback_reason: logger.warning(f"[SocketIO] 初始化失败,已回退 threading 模式: {_socketio_fallback_reason}") logger.info(f"[SocketIO] 当前 async_mode: {socketio.async_mode}") @@ -139,15 +142,15 @@ def _request_uses_https() -> bool: try: if bool(request.is_secure): return True - except Exception: - pass + except Exception as e: + logger.debug(f"检查 request.is_secure 失败: {e}") try: forwarded_proto = str(request.headers.get("X-Forwarded-Proto", "") or "").split(",", 1)[0].strip().lower() if forwarded_proto == "https": return True - except Exception: - pass + except Exception as e: + logger.debug(f"检查 X-Forwarded-Proto 失败: {e}") return False @@ -255,8 +258,8 @@ def _record_request_metric_after_response(response) -> None: logger.warning( f"[API-DIAG] {method} {path} -> {status_code} ({duration_ms:.1f}ms)" ) - except Exception: - pass + except Exception as e: + logger.debug(f"记录请求指标失败: {e}") @app.after_request @@ -312,12 +315,12 @@ def serve_static(filename): # 协商缓存:确保存在 ETag,并基于 If-None-Match/If-Modified-Since 返回 304 try: response.add_etag(overwrite=False) - except Exception: - pass + except Exception as e: + logger.debug(f"静态资源 ETag 设置失败({filename}): {e}") try: response.make_conditional(request) - except Exception: - pass + except Exception as e: + logger.debug(f"静态资源协商缓存处理失败({filename}): {e}") response.headers.setdefault("Vary", "Accept-Encoding") if is_hashed_asset: @@ -341,33 +344,33 @@ def cleanup_on_exit(): for acc in accounts.values(): if getattr(acc, "is_running", False): acc.should_stop = True - except Exception: - pass + except Exception as e: + logger.warning(f"停止运行中任务失败: {e}") logger.info("- 停止任务调度器...") try: scheduler = get_task_scheduler() scheduler.shutdown(timeout=5) - except Exception: - pass + except Exception as e: + logger.warning(f"停止任务调度器失败: {e}") logger.info("- 关闭截图线程池...") try: shutdown_browser_worker_pool() - except Exception: - pass + except Exception as e: + logger.warning(f"关闭截图线程池失败: {e}") logger.info("- 关闭邮件队列...") try: email_service.shutdown_email_queue() - except Exception: - pass + except Exception as e: + logger.warning(f"关闭邮件队列失败: {e}") logger.info("- 关闭数据库连接池...") try: db_pool._pool.close_all() if db_pool._pool else None - except Exception: - pass + except Exception as e: + logger.warning(f"关闭数据库连接池失败: {e}") logger.info("[OK] 资源清理完成") @@ -464,7 +467,7 @@ def _log_startup_urls() -> None: logger.info("服务器启动中...") logger.info(f"用户访问地址: http://{config.SERVER_HOST}:{config.SERVER_PORT}") logger.info(f"后台管理地址: http://{config.SERVER_HOST}:{config.SERVER_PORT}/yuyx") - logger.info("默认管理员: admin (首次运行随机密码见日志)") + logger.info("默认管理员: admin (首次运行密码写入 data/default_admin_credentials.txt)") logger.info("=" * 60) diff --git a/app_security.py b/app_security.py index 4bce639..7b1b5b9 100755 --- a/app_security.py +++ b/app_security.py @@ -197,6 +197,9 @@ class IPRateLimiter: # 全局IP限流器实例 ip_rate_limiter = IPRateLimiter() +_TRUTHY_VALUES = {"1", "true", "yes", "on"} +_TRUST_PROXY_HEADERS = str(os.environ.get("TRUST_PROXY_HEADERS", "false") or "").strip().lower() in _TRUTHY_VALUES + def require_ip_not_locked(f): """装饰器:检查IP是否被锁定""" @@ -443,7 +446,7 @@ def get_client_ip(trust_proxy=False): """ # 安全说明:X-Forwarded-For 可被伪造 # 仅在确认请求来自可信代理时才使用代理头 - if trust_proxy: + if trust_proxy and _TRUST_PROXY_HEADERS: if request.headers.get('X-Forwarded-For'): return request.headers.get('X-Forwarded-For').split(',')[0].strip() elif request.headers.get('X-Real-IP'): @@ -455,7 +458,7 @@ def get_client_ip(trust_proxy=False): def _load_trusted_proxy_networks(): """加载可信代理 CIDR 列表。""" - default_cidrs = "127.0.0.1/32,::1/128,10.0.0.0/8,172.16.0.0/12,192.168.0.0/16,fc00::/7" + default_cidrs = "127.0.0.1/32,::1/128" raw = str(os.environ.get("TRUSTED_PROXY_CIDRS", default_cidrs) or "").strip() if not raw: return [] @@ -525,6 +528,9 @@ def _extract_real_ip_from_forwarded_chain() -> str | None: def get_rate_limit_ip() -> str: """在可信代理场景下取真实IP,用于限流/风控。""" remote_addr = request.remote_addr or "" + if not _TRUST_PROXY_HEADERS: + return remote_addr + remote_ip = _parse_ip_address(remote_addr) if remote_ip is None: return remote_addr diff --git a/crypto_utils.py b/crypto_utils.py index de7986d..b937aad 100644 --- a/crypto_utils.py +++ b/crypto_utils.py @@ -119,15 +119,11 @@ def get_encryption_key(): "2. 或在 docker-compose.yml 中设置 ENCRYPTION_KEY_RAW 环境变量\n" "3. 如果密钥确实丢失,需要重新录入所有账号密码\n" "\n" - "设置 ALLOW_NEW_KEY=true 环境变量可强制生成新密钥(不推荐)\n" + "=" * 60 ) logger.error(error_msg) - - # 检查是否强制允许生成新密钥 - if os.environ.get('ALLOW_NEW_KEY', '').lower() != 'true': - print(error_msg, file=sys.stderr) - raise RuntimeError("加密密钥丢失且存在已加密数据,请检查配置") + print(error_msg, file=sys.stderr) + raise RuntimeError("加密密钥丢失且存在已加密数据,请恢复密钥后再启动") # 生成新的密钥 key = Fernet.generate_key() diff --git a/database.py b/database.py index c4fb0d4..64a3406 100644 --- a/database.py +++ b/database.py @@ -19,6 +19,7 @@ from typing import Optional import db_pool from app_config import get_config +from app_logger import get_logger from db.schema import ensure_schema from db.migrations import migrate_database as _migrate_database @@ -126,6 +127,7 @@ from db.users import ( from db.security import record_login_context config = get_config() +logger = get_logger(__name__) # 数据库文件路径 DB_FILE = config.DB_FILE @@ -140,9 +142,9 @@ _system_config_cache_lock = threading.Lock() _system_config_cache_value: Optional[dict] = None _system_config_cache_loaded_at = 0.0 try: - _SYSTEM_CONFIG_CACHE_TTL_SECONDS = float(os.environ.get("SYSTEM_CONFIG_CACHE_TTL_SECONDS", "3")) + _SYSTEM_CONFIG_CACHE_TTL_SECONDS = float(os.environ.get("SYSTEM_CONFIG_CACHE_TTL_SECONDS", "30")) except Exception: - _SYSTEM_CONFIG_CACHE_TTL_SECONDS = 3.0 + _SYSTEM_CONFIG_CACHE_TTL_SECONDS = 30.0 _SYSTEM_CONFIG_CACHE_TTL_SECONDS = max(0.0, _SYSTEM_CONFIG_CACHE_TTL_SECONDS) @@ -197,8 +199,8 @@ def init_database(): try: config_value = get_system_config() db_pool.configure_slow_query_runtime(threshold_ms=config_value.get("db_slow_query_ms")) - except Exception: - pass + except Exception as e: + logger.warning(f"初始化慢查询阈值失败,使用默认值: {e}") def migrate_database(): @@ -293,6 +295,6 @@ def update_system_config( try: latest_config = get_system_config() db_pool.configure_slow_query_runtime(threshold_ms=latest_config.get("db_slow_query_ms")) - except Exception: - pass + except Exception as e: + logger.warning(f"更新慢查询阈值失败,保留当前配置: {e}") return ok diff --git a/db/admin.py b/db/admin.py index 088220f..b9b5255 100644 --- a/db/admin.py +++ b/db/admin.py @@ -2,7 +2,9 @@ # -*- coding: utf-8 -*- from __future__ import annotations +import os import sqlite3 +from pathlib import Path import db_pool from db.utils import get_cst_now_str @@ -109,6 +111,28 @@ def _normalize_days(days, default: int = 30) -> int: return value +def _store_default_admin_credentials(username: str, password: str) -> str | None: + """将首次管理员账号密码写入受限权限文件,避免打印到日志。""" + raw_path = str( + os.environ.get("DEFAULT_ADMIN_CREDENTIALS_FILE", "data/default_admin_credentials.txt") or "" + ).strip() + if not raw_path: + return None + + cred_path = Path(raw_path) + try: + cred_path.parent.mkdir(parents=True, exist_ok=True) + with open(cred_path, "w", encoding="utf-8") as f: + f.write("安全提醒:首次管理员账号已创建\n") + f.write(f"用户名: {username}\n") + f.write(f"密码: {password}\n") + f.write("请登录后立即修改密码,并删除该文件。\n") + os.chmod(cred_path, 0o600) + return str(cred_path) + except Exception: + return None + + def ensure_default_admin() -> bool: """确保存在默认管理员账号(行为保持不变)。""" import secrets @@ -120,7 +144,8 @@ def ensure_default_admin() -> bool: if count == 0: alphabet = string.ascii_letters + string.digits - random_password = "".join(secrets.choice(alphabet) for _ in range(12)) + bootstrap_password = str(os.environ.get("DEFAULT_ADMIN_PASSWORD", "") or "").strip() + random_password = bootstrap_password or "".join(secrets.choice(alphabet) for _ in range(12)) default_password_hash = hash_password_bcrypt(random_password) cursor.execute( @@ -128,11 +153,16 @@ def ensure_default_admin() -> bool: ("admin", default_password_hash, get_cst_now_str()), ) conn.commit() + credential_file = _store_default_admin_credentials("admin", random_password) print("=" * 60) print("安全提醒:已创建默认管理员账号") print("用户名: admin") - print(f"密码: {random_password}") - print("请立即登录后修改密码!") + if credential_file: + print(f"初始密码已写入: {credential_file}(权限600)") + print("请立即登录后修改密码,并删除该文件。") + else: + print("未能写入初始密码文件。") + print("建议设置 DEFAULT_ADMIN_PASSWORD 后重建管理员账号。") print("=" * 60) return True return False diff --git a/db/users.py b/db/users.py index c5cf5d1..81b343f 100644 --- a/db/users.py +++ b/db/users.py @@ -21,6 +21,10 @@ logger = get_logger(__name__) _CST_TZ = pytz.timezone("Asia/Shanghai") _PERMANENT_VIP_EXPIRE = "2099-12-31 23:59:59" +_USER_LOOKUP_SQL = { + "id": "SELECT * FROM users WHERE id = ?", + "username": "SELECT * FROM users WHERE username = ?", +} def _row_to_dict(row): @@ -28,9 +32,12 @@ def _row_to_dict(row): def _get_user_by_field(field_name: str, field_value): + query_sql = _USER_LOOKUP_SQL.get(str(field_name or "")) + if not query_sql: + raise ValueError(f"unsupported user lookup field: {field_name}") with db_pool.get_db() as conn: cursor = conn.cursor() - cursor.execute(f"SELECT * FROM users WHERE {field_name} = ?", (field_value,)) + cursor.execute(query_sql, (field_value,)) return _row_to_dict(cursor.fetchone()) diff --git a/routes/api_auth.py b/routes/api_auth.py index aba8c68..26514d2 100644 --- a/routes/api_auth.py +++ b/routes/api_auth.py @@ -201,8 +201,8 @@ def _send_login_security_alert_if_needed(user: dict, username: str, client_ip: s new_device=context.get("new_device", False), user_id=user["id"], ) - except Exception: - pass + except Exception as e: + logger.warning(f"发送登录安全提醒失败: user_id={user.get('id')}, error={e}") def _parse_credential_payload(data: dict) -> dict | None: @@ -308,10 +308,9 @@ def verify_email(token): if result: token_id = result["token_id"] user_id = result["user_id"] - email = result["email"] if not database.approve_user(user_id): - logger.error(f"用户邮箱验证失败: 用户审核更新失败 user_id={user_id}, email={email}") + logger.error(f"用户邮箱验证失败: 用户审核更新失败 user_id={user_id}") error_message = "验证处理失败,请稍后重试" spa_initial_state = { "page": "verify_result", @@ -333,9 +332,9 @@ def verify_email(token): database.set_user_vip(user_id, auto_approve_vip_days) if not email_service.consume_email_token(token_id): - logger.warning(f"用户邮箱验证后Token消费失败: token_id={token_id}, user_id={user_id}") + logger.warning(f"用户邮箱验证后Token消费失败: user_id={user_id}") - logger.info(f"用户邮箱验证成功: user_id={user_id}, email={email}") + logger.info(f"用户邮箱验证成功: user_id={user_id}") spa_initial_state = { "page": "verify_result", "success": True, @@ -348,7 +347,7 @@ def verify_email(token): } return render_app_spa_or_legacy("verify_success.html", spa_initial_state=spa_initial_state) - logger.warning(f"邮箱验证失败: token={token[:20]}...") + logger.warning("邮箱验证失败: token无效或已过期") error_message = "验证链接无效或已过期,请重新注册或申请重发验证邮件" spa_initial_state = { "page": "verify_result", diff --git a/routes/api_user.py b/routes/api_user.py index ea79c0c..800bd74 100644 --- a/routes/api_user.py +++ b/routes/api_user.py @@ -365,7 +365,7 @@ def verify_bind_email(token): if database.update_user_email(user_id, email, verified=True): if not email_service.consume_email_token(token_id): - logger.warning(f"邮箱绑定成功但Token消费失败: token_id={token_id}, user_id={user_id}") + logger.warning(f"邮箱绑定成功但Token消费失败: user_id={user_id}") return _render_verify_bind_success(email) return _render_verify_bind_failed(title="绑定失败", error_message="邮箱绑定失败,请重试") diff --git a/tests/test_security_hardening.py b/tests/test_security_hardening.py new file mode 100644 index 0000000..bf32e09 --- /dev/null +++ b/tests/test_security_hardening.py @@ -0,0 +1,58 @@ +import importlib +import sys +from pathlib import Path + +import pytest +from flask import Flask + +PROJECT_ROOT = Path(__file__).resolve().parents[1] +if str(PROJECT_ROOT) not in sys.path: + sys.path.insert(0, str(PROJECT_ROOT)) + +import app_security +import crypto_utils +from db import users as db_users + + +def test_user_lookup_rejects_dynamic_field_name(): + with pytest.raises(ValueError): + db_users._get_user_by_field("username OR 1=1 --", "demo") + + +def test_rate_limit_ip_does_not_trust_proxy_headers_by_default(monkeypatch): + monkeypatch.delenv("TRUST_PROXY_HEADERS", raising=False) + monkeypatch.delenv("TRUSTED_PROXY_CIDRS", raising=False) + security_module = importlib.reload(app_security) + + app = Flask(__name__) + with app.test_request_context( + "/", + environ_base={"REMOTE_ADDR": "10.0.0.9"}, + headers={"X-Forwarded-For": "198.51.100.10"}, + ): + assert security_module.get_rate_limit_ip() == "10.0.0.9" + + +def test_rate_limit_ip_can_use_forwarded_chain_when_explicitly_enabled(monkeypatch): + monkeypatch.setenv("TRUST_PROXY_HEADERS", "true") + monkeypatch.setenv("TRUSTED_PROXY_CIDRS", "10.0.0.0/8") + security_module = importlib.reload(app_security) + + app = Flask(__name__) + with app.test_request_context( + "/", + environ_base={"REMOTE_ADDR": "10.2.3.4"}, + headers={"X-Forwarded-For": "198.51.100.10, 10.2.3.4"}, + ): + assert security_module.get_rate_limit_ip() == "198.51.100.10" + + +def test_get_encryption_key_refuses_regeneration_when_encrypted_data_exists(monkeypatch, tmp_path): + monkeypatch.setenv("ALLOW_NEW_KEY", "true") + monkeypatch.delenv("ENCRYPTION_KEY_RAW", raising=False) + monkeypatch.delenv("ENCRYPTION_KEY", raising=False) + monkeypatch.setattr(crypto_utils, "ENCRYPTION_KEY_FILE", str(tmp_path / "missing_key.bin")) + monkeypatch.setattr(crypto_utils, "_check_existing_encrypted_data", lambda: True) + + with pytest.raises(RuntimeError): + crypto_utils.get_encryption_key()