refactor: optimize structure, stability and runtime performance

This commit is contained in:
2026-02-07 00:35:11 +08:00
parent fae21329d7
commit bf29ac1924
44 changed files with 6894 additions and 4792 deletions

View File

@@ -16,8 +16,41 @@ from password_utils import (
verify_password_bcrypt,
verify_password_sha256,
)
logger = get_logger(__name__)
_CST_TZ = pytz.timezone("Asia/Shanghai")
_PERMANENT_VIP_EXPIRE = "2099-12-31 23:59:59"
def _row_to_dict(row):
return dict(row) if row else None
def _get_user_by_field(field_name: str, field_value):
with db_pool.get_db() as conn:
cursor = conn.cursor()
cursor.execute(f"SELECT * FROM users WHERE {field_name} = ?", (field_value,))
return _row_to_dict(cursor.fetchone())
def _parse_cst_datetime(datetime_str: str | None):
if not datetime_str:
return None
try:
naive_dt = datetime.strptime(str(datetime_str), "%Y-%m-%d %H:%M:%S")
return _CST_TZ.localize(naive_dt)
except Exception:
return None
def _format_vip_expire(days: int, *, base_dt: datetime | None = None) -> str:
if int(days) == 999999:
return _PERMANENT_VIP_EXPIRE
if base_dt is None:
base_dt = datetime.now(_CST_TZ)
return (base_dt + timedelta(days=int(days))).strftime("%Y-%m-%d %H:%M:%S")
def get_vip_config():
"""获取VIP配置"""
@@ -32,13 +65,12 @@ def set_default_vip_days(days):
"""设置默认VIP天数"""
with db_pool.get_db() as conn:
cursor = conn.cursor()
cst_time = get_cst_now_str()
cursor.execute(
"""
INSERT OR REPLACE INTO vip_config (id, default_vip_days, updated_at)
VALUES (1, ?, ?)
""",
(days, cst_time),
(days, get_cst_now_str()),
)
conn.commit()
return True
@@ -47,14 +79,8 @@ def set_default_vip_days(days):
def set_user_vip(user_id, days):
"""设置用户VIP - days: 7=一周, 30=一个月, 365=一年, 999999=永久"""
with db_pool.get_db() as conn:
cst_tz = pytz.timezone("Asia/Shanghai")
cursor = conn.cursor()
if days == 999999:
expire_time = "2099-12-31 23:59:59"
else:
expire_time = (datetime.now(cst_tz) + timedelta(days=days)).strftime("%Y-%m-%d %H:%M:%S")
expire_time = _format_vip_expire(days)
cursor.execute("UPDATE users SET vip_expire_time = ? WHERE id = ?", (expire_time, user_id))
conn.commit()
return cursor.rowcount > 0
@@ -63,29 +89,26 @@ def set_user_vip(user_id, days):
def extend_user_vip(user_id, days):
"""延长用户VIP时间"""
user = get_user_by_id(user_id)
cst_tz = pytz.timezone("Asia/Shanghai")
if not user:
return False
current_expire = user.get("vip_expire_time")
now_dt = datetime.now(_CST_TZ)
if current_expire and current_expire != _PERMANENT_VIP_EXPIRE:
expire_time = _parse_cst_datetime(current_expire)
if expire_time is not None:
if expire_time < now_dt:
expire_time = now_dt
new_expire = _format_vip_expire(days, base_dt=expire_time)
else:
logger.warning("解析VIP过期时间失败使用当前时间")
new_expire = _format_vip_expire(days, base_dt=now_dt)
else:
new_expire = _format_vip_expire(days, base_dt=now_dt)
with db_pool.get_db() as conn:
cursor = conn.cursor()
current_expire = user.get("vip_expire_time")
if current_expire and current_expire != "2099-12-31 23:59:59":
try:
expire_time_naive = datetime.strptime(current_expire, "%Y-%m-%d %H:%M:%S")
expire_time = cst_tz.localize(expire_time_naive)
now = datetime.now(cst_tz)
if expire_time < now:
expire_time = now
new_expire = (expire_time + timedelta(days=days)).strftime("%Y-%m-%d %H:%M:%S")
except (ValueError, AttributeError) as e:
logger.warning(f"解析VIP过期时间失败: {e}, 使用当前时间")
new_expire = (datetime.now(cst_tz) + timedelta(days=days)).strftime("%Y-%m-%d %H:%M:%S")
else:
new_expire = (datetime.now(cst_tz) + timedelta(days=days)).strftime("%Y-%m-%d %H:%M:%S")
cursor.execute("UPDATE users SET vip_expire_time = ? WHERE id = ?", (new_expire, user_id))
conn.commit()
return cursor.rowcount > 0
@@ -105,45 +128,49 @@ def is_user_vip(user_id):
注意数据库中存储的时间统一使用CSTAsia/Shanghai时区
"""
cst_tz = pytz.timezone("Asia/Shanghai")
user = get_user_by_id(user_id)
if not user or not user.get("vip_expire_time"):
if not user:
return False
try:
expire_time_naive = datetime.strptime(user["vip_expire_time"], "%Y-%m-%d %H:%M:%S")
expire_time = cst_tz.localize(expire_time_naive)
now = datetime.now(cst_tz)
return now < expire_time
except (ValueError, AttributeError) as e:
logger.warning(f"检查VIP状态失败 (user_id={user_id}): {e}")
vip_expire_time = user.get("vip_expire_time")
if not vip_expire_time:
return False
expire_time = _parse_cst_datetime(vip_expire_time)
if expire_time is None:
logger.warning(f"检查VIP状态失败 (user_id={user_id}): 无法解析时间")
return False
return datetime.now(_CST_TZ) < expire_time
def get_user_vip_info(user_id):
"""获取用户VIP信息"""
cst_tz = pytz.timezone("Asia/Shanghai")
user = get_user_by_id(user_id)
if not user:
return {"is_vip": False, "expire_time": None, "days_left": 0, "username": ""}
vip_expire_time = user.get("vip_expire_time")
username = user.get("username", "")
if not vip_expire_time:
return {"is_vip": False, "expire_time": None, "days_left": 0, "username": user.get("username", "")}
return {"is_vip": False, "expire_time": None, "days_left": 0, "username": username}
try:
expire_time_naive = datetime.strptime(vip_expire_time, "%Y-%m-%d %H:%M:%S")
expire_time = cst_tz.localize(expire_time_naive)
now = datetime.now(cst_tz)
is_vip = now < expire_time
days_left = (expire_time - now).days if is_vip else 0
expire_time = _parse_cst_datetime(vip_expire_time)
if expire_time is None:
logger.warning("VIP信息获取错误: 无法解析过期时间")
return {"is_vip": False, "expire_time": None, "days_left": 0, "username": username}
return {"username": user.get("username", ""), "is_vip": is_vip, "expire_time": vip_expire_time, "days_left": max(0, days_left)}
except Exception as e:
logger.warning(f"VIP信息获取错误: {e}")
return {"is_vip": False, "expire_time": None, "days_left": 0, "username": user.get("username", "")}
now_dt = datetime.now(_CST_TZ)
is_vip = now_dt < expire_time
days_left = (expire_time - now_dt).days if is_vip else 0
return {
"username": username,
"is_vip": is_vip,
"expire_time": vip_expire_time,
"days_left": max(0, days_left),
}
# ==================== 用户相关 ====================
@@ -151,8 +178,6 @@ def get_user_vip_info(user_id):
def create_user(username, password, email=""):
"""创建新用户(默认直接通过,赠送默认VIP)"""
cst_tz = pytz.timezone("Asia/Shanghai")
with db_pool.get_db() as conn:
cursor = conn.cursor()
password_hash = hash_password_bcrypt(password)
@@ -160,12 +185,8 @@ def create_user(username, password, email=""):
default_vip_days = get_vip_config()["default_vip_days"]
vip_expire_time = None
if default_vip_days > 0:
if default_vip_days == 999999:
vip_expire_time = "2099-12-31 23:59:59"
else:
vip_expire_time = (datetime.now(cst_tz) + timedelta(days=default_vip_days)).strftime("%Y-%m-%d %H:%M:%S")
if int(default_vip_days or 0) > 0:
vip_expire_time = _format_vip_expire(int(default_vip_days))
try:
cursor.execute(
@@ -210,28 +231,28 @@ def verify_user(username, password):
def get_user_by_id(user_id):
"""根据ID获取用户"""
with db_pool.get_db() as conn:
cursor = conn.cursor()
cursor.execute("SELECT * FROM users WHERE id = ?", (user_id,))
user = cursor.fetchone()
return dict(user) if user else None
return _get_user_by_field("id", user_id)
def get_user_kdocs_settings(user_id):
"""获取用户的金山文档配置"""
user = get_user_by_id(user_id)
if not user:
return None
return {
"kdocs_unit": user.get("kdocs_unit") or "",
"kdocs_auto_upload": 1 if user.get("kdocs_auto_upload") else 0,
}
with db_pool.get_db() as conn:
cursor = conn.cursor()
cursor.execute("SELECT kdocs_unit, kdocs_auto_upload FROM users WHERE id = ?", (user_id,))
row = cursor.fetchone()
if not row:
return None
return {
"kdocs_unit": (row["kdocs_unit"] or "") if isinstance(row, dict) else (row[0] or ""),
"kdocs_auto_upload": 1 if ((row["kdocs_auto_upload"] if isinstance(row, dict) else row[1]) or 0) else 0,
}
def update_user_kdocs_settings(user_id, *, kdocs_unit=None, kdocs_auto_upload=None) -> bool:
"""更新用户的金山文档配置"""
updates = []
params = []
if kdocs_unit is not None:
updates.append("kdocs_unit = ?")
params.append(kdocs_unit)
@@ -252,11 +273,7 @@ def update_user_kdocs_settings(user_id, *, kdocs_unit=None, kdocs_auto_upload=No
def get_user_by_username(username):
"""根据用户名获取用户"""
with db_pool.get_db() as conn:
cursor = conn.cursor()
cursor.execute("SELECT * FROM users WHERE username = ?", (username,))
user = cursor.fetchone()
return dict(user) if user else None
return _get_user_by_field("username", username)
def get_all_users():
@@ -279,14 +296,13 @@ def approve_user(user_id):
"""审核通过用户"""
with db_pool.get_db() as conn:
cursor = conn.cursor()
cst_time = get_cst_now_str()
cursor.execute(
"""
UPDATE users
SET status = 'approved', approved_at = ?
WHERE id = ?
""",
(cst_time, user_id),
(get_cst_now_str(), user_id),
)
conn.commit()
return cursor.rowcount > 0
@@ -315,5 +331,5 @@ def get_user_stats(user_id):
with db_pool.get_db() as conn:
cursor = conn.cursor()
cursor.execute("SELECT COUNT(*) as count FROM accounts WHERE user_id = ?", (user_id,))
account_count = cursor.fetchone()["count"]
return {"account_count": account_count}
row = cursor.fetchone()
return {"account_count": int((row["count"] if row else 0) or 0)}