安全修复: 收敛认证与日志风险并补充基础测试

This commit is contained in:
2026-02-16 00:34:52 +08:00
parent 7627885b1b
commit 7d42f96e42
12 changed files with 163 additions and 50 deletions

View File

@@ -2,7 +2,9 @@
# -*- coding: utf-8 -*-
from __future__ import annotations
import os
import sqlite3
from pathlib import Path
import db_pool
from db.utils import get_cst_now_str
@@ -109,6 +111,28 @@ def _normalize_days(days, default: int = 30) -> int:
return value
def _store_default_admin_credentials(username: str, password: str) -> str | None:
"""将首次管理员账号密码写入受限权限文件,避免打印到日志。"""
raw_path = str(
os.environ.get("DEFAULT_ADMIN_CREDENTIALS_FILE", "data/default_admin_credentials.txt") or ""
).strip()
if not raw_path:
return None
cred_path = Path(raw_path)
try:
cred_path.parent.mkdir(parents=True, exist_ok=True)
with open(cred_path, "w", encoding="utf-8") as f:
f.write("安全提醒:首次管理员账号已创建\n")
f.write(f"用户名: {username}\n")
f.write(f"密码: {password}\n")
f.write("请登录后立即修改密码,并删除该文件。\n")
os.chmod(cred_path, 0o600)
return str(cred_path)
except Exception:
return None
def ensure_default_admin() -> bool:
"""确保存在默认管理员账号(行为保持不变)。"""
import secrets
@@ -120,7 +144,8 @@ def ensure_default_admin() -> bool:
if count == 0:
alphabet = string.ascii_letters + string.digits
random_password = "".join(secrets.choice(alphabet) for _ in range(12))
bootstrap_password = str(os.environ.get("DEFAULT_ADMIN_PASSWORD", "") or "").strip()
random_password = bootstrap_password or "".join(secrets.choice(alphabet) for _ in range(12))
default_password_hash = hash_password_bcrypt(random_password)
cursor.execute(
@@ -128,11 +153,16 @@ def ensure_default_admin() -> bool:
("admin", default_password_hash, get_cst_now_str()),
)
conn.commit()
credential_file = _store_default_admin_credentials("admin", random_password)
print("=" * 60)
print("安全提醒:已创建默认管理员账号")
print("用户名: admin")
print(f"密码: {random_password}")
print("请立即登录后修改密码!")
if credential_file:
print(f"初始密码已写入: {credential_file}权限600")
print("请立即登录后修改密码,并删除该文件。")
else:
print("未能写入初始密码文件。")
print("建议设置 DEFAULT_ADMIN_PASSWORD 后重建管理员账号。")
print("=" * 60)
return True
return False

View File

@@ -21,6 +21,10 @@ logger = get_logger(__name__)
_CST_TZ = pytz.timezone("Asia/Shanghai")
_PERMANENT_VIP_EXPIRE = "2099-12-31 23:59:59"
_USER_LOOKUP_SQL = {
"id": "SELECT * FROM users WHERE id = ?",
"username": "SELECT * FROM users WHERE username = ?",
}
def _row_to_dict(row):
@@ -28,9 +32,12 @@ def _row_to_dict(row):
def _get_user_by_field(field_name: str, field_value):
query_sql = _USER_LOOKUP_SQL.get(str(field_name or ""))
if not query_sql:
raise ValueError(f"unsupported user lookup field: {field_name}")
with db_pool.get_db() as conn:
cursor = conn.cursor()
cursor.execute(f"SELECT * FROM users WHERE {field_name} = ?", (field_value,))
cursor.execute(query_sql, (field_value,))
return _row_to_dict(cursor.fetchone())