Files
zsglpt/db/users.py

336 lines
10 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
from __future__ import annotations
import sqlite3
from datetime import datetime, timedelta
import pytz
from app_logger import get_logger
import db_pool
from db.utils import get_cst_now_str
from password_utils import (
hash_password_bcrypt,
is_sha256_hash,
verify_password_bcrypt,
verify_password_sha256,
)
logger = get_logger(__name__)
_CST_TZ = pytz.timezone("Asia/Shanghai")
_PERMANENT_VIP_EXPIRE = "2099-12-31 23:59:59"
def _row_to_dict(row):
return dict(row) if row else None
def _get_user_by_field(field_name: str, field_value):
with db_pool.get_db() as conn:
cursor = conn.cursor()
cursor.execute(f"SELECT * FROM users WHERE {field_name} = ?", (field_value,))
return _row_to_dict(cursor.fetchone())
def _parse_cst_datetime(datetime_str: str | None):
if not datetime_str:
return None
try:
naive_dt = datetime.strptime(str(datetime_str), "%Y-%m-%d %H:%M:%S")
return _CST_TZ.localize(naive_dt)
except Exception:
return None
def _format_vip_expire(days: int, *, base_dt: datetime | None = None) -> str:
if int(days) == 999999:
return _PERMANENT_VIP_EXPIRE
if base_dt is None:
base_dt = datetime.now(_CST_TZ)
return (base_dt + timedelta(days=int(days))).strftime("%Y-%m-%d %H:%M:%S")
def get_vip_config():
"""获取VIP配置"""
with db_pool.get_db() as conn:
cursor = conn.cursor()
cursor.execute("SELECT * FROM vip_config WHERE id = 1")
config = cursor.fetchone()
return dict(config) if config else {"default_vip_days": 0}
def set_default_vip_days(days):
"""设置默认VIP天数"""
with db_pool.get_db() as conn:
cursor = conn.cursor()
cursor.execute(
"""
INSERT OR REPLACE INTO vip_config (id, default_vip_days, updated_at)
VALUES (1, ?, ?)
""",
(days, get_cst_now_str()),
)
conn.commit()
return True
def set_user_vip(user_id, days):
"""设置用户VIP - days: 7=一周, 30=一个月, 365=一年, 999999=永久"""
with db_pool.get_db() as conn:
cursor = conn.cursor()
expire_time = _format_vip_expire(days)
cursor.execute("UPDATE users SET vip_expire_time = ? WHERE id = ?", (expire_time, user_id))
conn.commit()
return cursor.rowcount > 0
def extend_user_vip(user_id, days):
"""延长用户VIP时间"""
user = get_user_by_id(user_id)
if not user:
return False
current_expire = user.get("vip_expire_time")
now_dt = datetime.now(_CST_TZ)
if current_expire and current_expire != _PERMANENT_VIP_EXPIRE:
expire_time = _parse_cst_datetime(current_expire)
if expire_time is not None:
if expire_time < now_dt:
expire_time = now_dt
new_expire = _format_vip_expire(days, base_dt=expire_time)
else:
logger.warning("解析VIP过期时间失败使用当前时间")
new_expire = _format_vip_expire(days, base_dt=now_dt)
else:
new_expire = _format_vip_expire(days, base_dt=now_dt)
with db_pool.get_db() as conn:
cursor = conn.cursor()
cursor.execute("UPDATE users SET vip_expire_time = ? WHERE id = ?", (new_expire, user_id))
conn.commit()
return cursor.rowcount > 0
def remove_user_vip(user_id):
"""移除用户VIP"""
with db_pool.get_db() as conn:
cursor = conn.cursor()
cursor.execute("UPDATE users SET vip_expire_time = NULL WHERE id = ?", (user_id,))
conn.commit()
return cursor.rowcount > 0
def is_user_vip(user_id):
"""检查用户是否是VIP
注意数据库中存储的时间统一使用CSTAsia/Shanghai时区
"""
user = get_user_by_id(user_id)
if not user:
return False
vip_expire_time = user.get("vip_expire_time")
if not vip_expire_time:
return False
expire_time = _parse_cst_datetime(vip_expire_time)
if expire_time is None:
logger.warning(f"检查VIP状态失败 (user_id={user_id}): 无法解析时间")
return False
return datetime.now(_CST_TZ) < expire_time
def get_user_vip_info(user_id):
"""获取用户VIP信息"""
user = get_user_by_id(user_id)
if not user:
return {"is_vip": False, "expire_time": None, "days_left": 0, "username": ""}
vip_expire_time = user.get("vip_expire_time")
username = user.get("username", "")
if not vip_expire_time:
return {"is_vip": False, "expire_time": None, "days_left": 0, "username": username}
expire_time = _parse_cst_datetime(vip_expire_time)
if expire_time is None:
logger.warning("VIP信息获取错误: 无法解析过期时间")
return {"is_vip": False, "expire_time": None, "days_left": 0, "username": username}
now_dt = datetime.now(_CST_TZ)
is_vip = now_dt < expire_time
days_left = (expire_time - now_dt).days if is_vip else 0
return {
"username": username,
"is_vip": is_vip,
"expire_time": vip_expire_time,
"days_left": max(0, days_left),
}
# ==================== 用户相关 ====================
def create_user(username, password, email=""):
"""创建新用户(默认直接通过,赠送默认VIP)"""
with db_pool.get_db() as conn:
cursor = conn.cursor()
password_hash = hash_password_bcrypt(password)
cst_time = get_cst_now_str()
default_vip_days = get_vip_config()["default_vip_days"]
vip_expire_time = None
if int(default_vip_days or 0) > 0:
vip_expire_time = _format_vip_expire(int(default_vip_days))
try:
cursor.execute(
"""
INSERT INTO users (username, password_hash, email, status, vip_expire_time, created_at, approved_at)
VALUES (?, ?, ?, 'approved', ?, ?, ?)
""",
(username, password_hash, email, vip_expire_time, cst_time, cst_time),
)
conn.commit()
return cursor.lastrowid
except sqlite3.IntegrityError:
return None
def verify_user(username, password):
"""验证用户登录 - 自动从SHA256升级到bcrypt"""
with db_pool.get_db() as conn:
cursor = conn.cursor()
cursor.execute("SELECT * FROM users WHERE username = ? AND status = 'approved'", (username,))
user = cursor.fetchone()
if not user:
return None
user_dict = dict(user)
password_hash = user_dict["password_hash"]
if is_sha256_hash(password_hash):
if verify_password_sha256(password, password_hash):
new_hash = hash_password_bcrypt(password)
cursor.execute("UPDATE users SET password_hash = ? WHERE id = ?", (new_hash, user_dict["id"]))
conn.commit()
logger.info(f"用户密码已自动升级到bcrypt (user_id={user_dict['id']})")
return user_dict
return None
if verify_password_bcrypt(password, password_hash):
return user_dict
return None
def get_user_by_id(user_id):
"""根据ID获取用户"""
return _get_user_by_field("id", user_id)
def get_user_kdocs_settings(user_id):
"""获取用户的金山文档配置"""
with db_pool.get_db() as conn:
cursor = conn.cursor()
cursor.execute("SELECT kdocs_unit, kdocs_auto_upload FROM users WHERE id = ?", (user_id,))
row = cursor.fetchone()
if not row:
return None
return {
"kdocs_unit": (row["kdocs_unit"] or "") if isinstance(row, dict) else (row[0] or ""),
"kdocs_auto_upload": 1 if ((row["kdocs_auto_upload"] if isinstance(row, dict) else row[1]) or 0) else 0,
}
def update_user_kdocs_settings(user_id, *, kdocs_unit=None, kdocs_auto_upload=None) -> bool:
"""更新用户的金山文档配置"""
updates = []
params = []
if kdocs_unit is not None:
updates.append("kdocs_unit = ?")
params.append(kdocs_unit)
if kdocs_auto_upload is not None:
updates.append("kdocs_auto_upload = ?")
params.append(kdocs_auto_upload)
if not updates:
return False
params.append(user_id)
with db_pool.get_db() as conn:
cursor = conn.cursor()
cursor.execute(f"UPDATE users SET {', '.join(updates)} WHERE id = ?", params)
conn.commit()
return cursor.rowcount > 0
def get_user_by_username(username):
"""根据用户名获取用户"""
return _get_user_by_field("username", username)
def get_all_users():
"""获取所有用户"""
with db_pool.get_db() as conn:
cursor = conn.cursor()
cursor.execute("SELECT * FROM users ORDER BY created_at DESC")
return [dict(row) for row in cursor.fetchall()]
def get_pending_users():
"""获取待审核用户"""
with db_pool.get_db() as conn:
cursor = conn.cursor()
cursor.execute("SELECT * FROM users WHERE status = 'pending' ORDER BY created_at DESC")
return [dict(row) for row in cursor.fetchall()]
def approve_user(user_id):
"""审核通过用户"""
with db_pool.get_db() as conn:
cursor = conn.cursor()
cursor.execute(
"""
UPDATE users
SET status = 'approved', approved_at = ?
WHERE id = ?
""",
(get_cst_now_str(), user_id),
)
conn.commit()
return cursor.rowcount > 0
def reject_user(user_id):
"""拒绝用户"""
with db_pool.get_db() as conn:
cursor = conn.cursor()
cursor.execute("UPDATE users SET status = 'rejected' WHERE id = ?", (user_id,))
conn.commit()
return cursor.rowcount > 0
def delete_user(user_id):
"""删除用户(级联删除相关账号)"""
with db_pool.get_db() as conn:
cursor = conn.cursor()
cursor.execute("DELETE FROM users WHERE id = ?", (user_id,))
conn.commit()
return cursor.rowcount > 0
def get_user_stats(user_id):
"""获取用户统计信息"""
with db_pool.get_db() as conn:
cursor = conn.cursor()
cursor.execute("SELECT COUNT(*) as count FROM accounts WHERE user_id = ?", (user_id,))
row = cursor.fetchone()
return {"account_count": int((row["count"] if row else 0) or 0)}