Files
zsglpt/db/users.py
2026-01-07 12:32:41 +08:00

320 lines
10 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
from __future__ import annotations
import sqlite3
from datetime import datetime, timedelta
import pytz
from app_logger import get_logger
import db_pool
from db.utils import get_cst_now_str
from password_utils import (
hash_password_bcrypt,
is_sha256_hash,
verify_password_bcrypt,
verify_password_sha256,
)
logger = get_logger(__name__)
def get_vip_config():
"""获取VIP配置"""
with db_pool.get_db() as conn:
cursor = conn.cursor()
cursor.execute("SELECT * FROM vip_config WHERE id = 1")
config = cursor.fetchone()
return dict(config) if config else {"default_vip_days": 0}
def set_default_vip_days(days):
"""设置默认VIP天数"""
with db_pool.get_db() as conn:
cursor = conn.cursor()
cst_time = get_cst_now_str()
cursor.execute(
"""
INSERT OR REPLACE INTO vip_config (id, default_vip_days, updated_at)
VALUES (1, ?, ?)
""",
(days, cst_time),
)
conn.commit()
return True
def set_user_vip(user_id, days):
"""设置用户VIP - days: 7=一周, 30=一个月, 365=一年, 999999=永久"""
with db_pool.get_db() as conn:
cst_tz = pytz.timezone("Asia/Shanghai")
cursor = conn.cursor()
if days == 999999:
expire_time = "2099-12-31 23:59:59"
else:
expire_time = (datetime.now(cst_tz) + timedelta(days=days)).strftime("%Y-%m-%d %H:%M:%S")
cursor.execute("UPDATE users SET vip_expire_time = ? WHERE id = ?", (expire_time, user_id))
conn.commit()
return cursor.rowcount > 0
def extend_user_vip(user_id, days):
"""延长用户VIP时间"""
user = get_user_by_id(user_id)
cst_tz = pytz.timezone("Asia/Shanghai")
if not user:
return False
with db_pool.get_db() as conn:
cursor = conn.cursor()
current_expire = user.get("vip_expire_time")
if current_expire and current_expire != "2099-12-31 23:59:59":
try:
expire_time_naive = datetime.strptime(current_expire, "%Y-%m-%d %H:%M:%S")
expire_time = cst_tz.localize(expire_time_naive)
now = datetime.now(cst_tz)
if expire_time < now:
expire_time = now
new_expire = (expire_time + timedelta(days=days)).strftime("%Y-%m-%d %H:%M:%S")
except (ValueError, AttributeError) as e:
logger.warning(f"解析VIP过期时间失败: {e}, 使用当前时间")
new_expire = (datetime.now(cst_tz) + timedelta(days=days)).strftime("%Y-%m-%d %H:%M:%S")
else:
new_expire = (datetime.now(cst_tz) + timedelta(days=days)).strftime("%Y-%m-%d %H:%M:%S")
cursor.execute("UPDATE users SET vip_expire_time = ? WHERE id = ?", (new_expire, user_id))
conn.commit()
return cursor.rowcount > 0
def remove_user_vip(user_id):
"""移除用户VIP"""
with db_pool.get_db() as conn:
cursor = conn.cursor()
cursor.execute("UPDATE users SET vip_expire_time = NULL WHERE id = ?", (user_id,))
conn.commit()
return cursor.rowcount > 0
def is_user_vip(user_id):
"""检查用户是否是VIP
注意数据库中存储的时间统一使用CSTAsia/Shanghai时区
"""
cst_tz = pytz.timezone("Asia/Shanghai")
user = get_user_by_id(user_id)
if not user or not user.get("vip_expire_time"):
return False
try:
expire_time_naive = datetime.strptime(user["vip_expire_time"], "%Y-%m-%d %H:%M:%S")
expire_time = cst_tz.localize(expire_time_naive)
now = datetime.now(cst_tz)
return now < expire_time
except (ValueError, AttributeError) as e:
logger.warning(f"检查VIP状态失败 (user_id={user_id}): {e}")
return False
def get_user_vip_info(user_id):
"""获取用户VIP信息"""
cst_tz = pytz.timezone("Asia/Shanghai")
user = get_user_by_id(user_id)
if not user:
return {"is_vip": False, "expire_time": None, "days_left": 0, "username": ""}
vip_expire_time = user.get("vip_expire_time")
if not vip_expire_time:
return {"is_vip": False, "expire_time": None, "days_left": 0, "username": user.get("username", "")}
try:
expire_time_naive = datetime.strptime(vip_expire_time, "%Y-%m-%d %H:%M:%S")
expire_time = cst_tz.localize(expire_time_naive)
now = datetime.now(cst_tz)
is_vip = now < expire_time
days_left = (expire_time - now).days if is_vip else 0
return {"username": user.get("username", ""), "is_vip": is_vip, "expire_time": vip_expire_time, "days_left": max(0, days_left)}
except Exception as e:
logger.warning(f"VIP信息获取错误: {e}")
return {"is_vip": False, "expire_time": None, "days_left": 0, "username": user.get("username", "")}
# ==================== 用户相关 ====================
def create_user(username, password, email=""):
"""创建新用户(默认直接通过,赠送默认VIP)"""
cst_tz = pytz.timezone("Asia/Shanghai")
with db_pool.get_db() as conn:
cursor = conn.cursor()
password_hash = hash_password_bcrypt(password)
cst_time = get_cst_now_str()
default_vip_days = get_vip_config()["default_vip_days"]
vip_expire_time = None
if default_vip_days > 0:
if default_vip_days == 999999:
vip_expire_time = "2099-12-31 23:59:59"
else:
vip_expire_time = (datetime.now(cst_tz) + timedelta(days=default_vip_days)).strftime("%Y-%m-%d %H:%M:%S")
try:
cursor.execute(
"""
INSERT INTO users (username, password_hash, email, status, vip_expire_time, created_at, approved_at)
VALUES (?, ?, ?, 'approved', ?, ?, ?)
""",
(username, password_hash, email, vip_expire_time, cst_time, cst_time),
)
conn.commit()
return cursor.lastrowid
except sqlite3.IntegrityError:
return None
def verify_user(username, password):
"""验证用户登录 - 自动从SHA256升级到bcrypt"""
with db_pool.get_db() as conn:
cursor = conn.cursor()
cursor.execute("SELECT * FROM users WHERE username = ? AND status = 'approved'", (username,))
user = cursor.fetchone()
if not user:
return None
user_dict = dict(user)
password_hash = user_dict["password_hash"]
if is_sha256_hash(password_hash):
if verify_password_sha256(password, password_hash):
new_hash = hash_password_bcrypt(password)
cursor.execute("UPDATE users SET password_hash = ? WHERE id = ?", (new_hash, user_dict["id"]))
conn.commit()
logger.info(f"用户密码已自动升级到bcrypt (user_id={user_dict['id']})")
return user_dict
return None
if verify_password_bcrypt(password, password_hash):
return user_dict
return None
def get_user_by_id(user_id):
"""根据ID获取用户"""
with db_pool.get_db() as conn:
cursor = conn.cursor()
cursor.execute("SELECT * FROM users WHERE id = ?", (user_id,))
user = cursor.fetchone()
return dict(user) if user else None
def get_user_kdocs_settings(user_id):
"""获取用户的金山文档配置"""
user = get_user_by_id(user_id)
if not user:
return None
return {
"kdocs_unit": user.get("kdocs_unit") or "",
"kdocs_auto_upload": 1 if user.get("kdocs_auto_upload") else 0,
}
def update_user_kdocs_settings(user_id, *, kdocs_unit=None, kdocs_auto_upload=None) -> bool:
"""更新用户的金山文档配置"""
updates = []
params = []
if kdocs_unit is not None:
updates.append("kdocs_unit = ?")
params.append(kdocs_unit)
if kdocs_auto_upload is not None:
updates.append("kdocs_auto_upload = ?")
params.append(kdocs_auto_upload)
if not updates:
return False
params.append(user_id)
with db_pool.get_db() as conn:
cursor = conn.cursor()
cursor.execute(f"UPDATE users SET {', '.join(updates)} WHERE id = ?", params)
conn.commit()
return cursor.rowcount > 0
def get_user_by_username(username):
"""根据用户名获取用户"""
with db_pool.get_db() as conn:
cursor = conn.cursor()
cursor.execute("SELECT * FROM users WHERE username = ?", (username,))
user = cursor.fetchone()
return dict(user) if user else None
def get_all_users():
"""获取所有用户"""
with db_pool.get_db() as conn:
cursor = conn.cursor()
cursor.execute("SELECT * FROM users ORDER BY created_at DESC")
return [dict(row) for row in cursor.fetchall()]
def get_pending_users():
"""获取待审核用户"""
with db_pool.get_db() as conn:
cursor = conn.cursor()
cursor.execute("SELECT * FROM users WHERE status = 'pending' ORDER BY created_at DESC")
return [dict(row) for row in cursor.fetchall()]
def approve_user(user_id):
"""审核通过用户"""
with db_pool.get_db() as conn:
cursor = conn.cursor()
cst_time = get_cst_now_str()
cursor.execute(
"""
UPDATE users
SET status = 'approved', approved_at = ?
WHERE id = ?
""",
(cst_time, user_id),
)
conn.commit()
return cursor.rowcount > 0
def reject_user(user_id):
"""拒绝用户"""
with db_pool.get_db() as conn:
cursor = conn.cursor()
cursor.execute("UPDATE users SET status = 'rejected' WHERE id = ?", (user_id,))
conn.commit()
return cursor.rowcount > 0
def delete_user(user_id):
"""删除用户(级联删除相关账号)"""
with db_pool.get_db() as conn:
cursor = conn.cursor()
cursor.execute("DELETE FROM users WHERE id = ?", (user_id,))
conn.commit()
return cursor.rowcount > 0
def get_user_stats(user_id):
"""获取用户统计信息"""
with db_pool.get_db() as conn:
cursor = conn.cursor()
cursor.execute("SELECT COUNT(*) as count FROM accounts WHERE user_id = ?", (user_id,))
account_count = cursor.fetchone()["count"]
return {"account_count": account_count}