343 lines
10 KiB
Python
343 lines
10 KiB
Python
#!/usr/bin/env python3
|
||
# -*- coding: utf-8 -*-
|
||
from __future__ import annotations
|
||
|
||
import sqlite3
|
||
from datetime import datetime, timedelta
|
||
|
||
import pytz
|
||
from app_logger import get_logger
|
||
|
||
import db_pool
|
||
from db.utils import get_cst_now_str
|
||
from password_utils import (
|
||
hash_password_bcrypt,
|
||
is_sha256_hash,
|
||
verify_password_bcrypt,
|
||
verify_password_sha256,
|
||
)
|
||
|
||
logger = get_logger(__name__)
|
||
|
||
_CST_TZ = pytz.timezone("Asia/Shanghai")
|
||
_PERMANENT_VIP_EXPIRE = "2099-12-31 23:59:59"
|
||
_USER_LOOKUP_SQL = {
|
||
"id": "SELECT * FROM users WHERE id = ?",
|
||
"username": "SELECT * FROM users WHERE username = ?",
|
||
}
|
||
|
||
|
||
def _row_to_dict(row):
|
||
return dict(row) if row else None
|
||
|
||
|
||
def _get_user_by_field(field_name: str, field_value):
|
||
query_sql = _USER_LOOKUP_SQL.get(str(field_name or ""))
|
||
if not query_sql:
|
||
raise ValueError(f"unsupported user lookup field: {field_name}")
|
||
with db_pool.get_db() as conn:
|
||
cursor = conn.cursor()
|
||
cursor.execute(query_sql, (field_value,))
|
||
return _row_to_dict(cursor.fetchone())
|
||
|
||
|
||
def _parse_cst_datetime(datetime_str: str | None):
|
||
if not datetime_str:
|
||
return None
|
||
try:
|
||
naive_dt = datetime.strptime(str(datetime_str), "%Y-%m-%d %H:%M:%S")
|
||
return _CST_TZ.localize(naive_dt)
|
||
except Exception:
|
||
return None
|
||
|
||
|
||
def _format_vip_expire(days: int, *, base_dt: datetime | None = None) -> str:
|
||
if int(days) == 999999:
|
||
return _PERMANENT_VIP_EXPIRE
|
||
if base_dt is None:
|
||
base_dt = datetime.now(_CST_TZ)
|
||
return (base_dt + timedelta(days=int(days))).strftime("%Y-%m-%d %H:%M:%S")
|
||
|
||
|
||
def get_vip_config():
|
||
"""获取VIP配置"""
|
||
with db_pool.get_db() as conn:
|
||
cursor = conn.cursor()
|
||
cursor.execute("SELECT * FROM vip_config WHERE id = 1")
|
||
config = cursor.fetchone()
|
||
return dict(config) if config else {"default_vip_days": 0}
|
||
|
||
|
||
def set_default_vip_days(days):
|
||
"""设置默认VIP天数"""
|
||
with db_pool.get_db() as conn:
|
||
cursor = conn.cursor()
|
||
cursor.execute(
|
||
"""
|
||
INSERT OR REPLACE INTO vip_config (id, default_vip_days, updated_at)
|
||
VALUES (1, ?, ?)
|
||
""",
|
||
(days, get_cst_now_str()),
|
||
)
|
||
conn.commit()
|
||
return True
|
||
|
||
|
||
def set_user_vip(user_id, days):
|
||
"""设置用户VIP - days: 7=一周, 30=一个月, 365=一年, 999999=永久"""
|
||
with db_pool.get_db() as conn:
|
||
cursor = conn.cursor()
|
||
expire_time = _format_vip_expire(days)
|
||
cursor.execute("UPDATE users SET vip_expire_time = ? WHERE id = ?", (expire_time, user_id))
|
||
conn.commit()
|
||
return cursor.rowcount > 0
|
||
|
||
|
||
def extend_user_vip(user_id, days):
|
||
"""延长用户VIP时间"""
|
||
user = get_user_by_id(user_id)
|
||
if not user:
|
||
return False
|
||
|
||
current_expire = user.get("vip_expire_time")
|
||
now_dt = datetime.now(_CST_TZ)
|
||
|
||
if current_expire and current_expire != _PERMANENT_VIP_EXPIRE:
|
||
expire_time = _parse_cst_datetime(current_expire)
|
||
if expire_time is not None:
|
||
if expire_time < now_dt:
|
||
expire_time = now_dt
|
||
new_expire = _format_vip_expire(days, base_dt=expire_time)
|
||
else:
|
||
logger.warning("解析VIP过期时间失败,使用当前时间")
|
||
new_expire = _format_vip_expire(days, base_dt=now_dt)
|
||
else:
|
||
new_expire = _format_vip_expire(days, base_dt=now_dt)
|
||
|
||
with db_pool.get_db() as conn:
|
||
cursor = conn.cursor()
|
||
cursor.execute("UPDATE users SET vip_expire_time = ? WHERE id = ?", (new_expire, user_id))
|
||
conn.commit()
|
||
return cursor.rowcount > 0
|
||
|
||
|
||
def remove_user_vip(user_id):
|
||
"""移除用户VIP"""
|
||
with db_pool.get_db() as conn:
|
||
cursor = conn.cursor()
|
||
cursor.execute("UPDATE users SET vip_expire_time = NULL WHERE id = ?", (user_id,))
|
||
conn.commit()
|
||
return cursor.rowcount > 0
|
||
|
||
|
||
def is_user_vip(user_id):
|
||
"""检查用户是否是VIP
|
||
|
||
注意:数据库中存储的时间统一使用CST(Asia/Shanghai)时区
|
||
"""
|
||
user = get_user_by_id(user_id)
|
||
if not user:
|
||
return False
|
||
|
||
vip_expire_time = user.get("vip_expire_time")
|
||
if not vip_expire_time:
|
||
return False
|
||
|
||
expire_time = _parse_cst_datetime(vip_expire_time)
|
||
if expire_time is None:
|
||
logger.warning(f"检查VIP状态失败 (user_id={user_id}): 无法解析时间")
|
||
return False
|
||
|
||
return datetime.now(_CST_TZ) < expire_time
|
||
|
||
|
||
def get_user_vip_info(user_id):
|
||
"""获取用户VIP信息"""
|
||
user = get_user_by_id(user_id)
|
||
if not user:
|
||
return {"is_vip": False, "expire_time": None, "days_left": 0, "username": ""}
|
||
|
||
vip_expire_time = user.get("vip_expire_time")
|
||
username = user.get("username", "")
|
||
|
||
if not vip_expire_time:
|
||
return {"is_vip": False, "expire_time": None, "days_left": 0, "username": username}
|
||
|
||
expire_time = _parse_cst_datetime(vip_expire_time)
|
||
if expire_time is None:
|
||
logger.warning("VIP信息获取错误: 无法解析过期时间")
|
||
return {"is_vip": False, "expire_time": None, "days_left": 0, "username": username}
|
||
|
||
now_dt = datetime.now(_CST_TZ)
|
||
is_vip = now_dt < expire_time
|
||
days_left = (expire_time - now_dt).days if is_vip else 0
|
||
|
||
return {
|
||
"username": username,
|
||
"is_vip": is_vip,
|
||
"expire_time": vip_expire_time,
|
||
"days_left": max(0, days_left),
|
||
}
|
||
|
||
|
||
# ==================== 用户相关 ====================
|
||
|
||
|
||
def create_user(username, password, email=""):
|
||
"""创建新用户(默认直接通过,赠送默认VIP)"""
|
||
with db_pool.get_db() as conn:
|
||
cursor = conn.cursor()
|
||
password_hash = hash_password_bcrypt(password)
|
||
cst_time = get_cst_now_str()
|
||
|
||
default_vip_days = get_vip_config()["default_vip_days"]
|
||
vip_expire_time = None
|
||
if int(default_vip_days or 0) > 0:
|
||
vip_expire_time = _format_vip_expire(int(default_vip_days))
|
||
|
||
try:
|
||
cursor.execute(
|
||
"""
|
||
INSERT INTO users (username, password_hash, email, status, vip_expire_time, created_at, approved_at)
|
||
VALUES (?, ?, ?, 'approved', ?, ?, ?)
|
||
""",
|
||
(username, password_hash, email, vip_expire_time, cst_time, cst_time),
|
||
)
|
||
conn.commit()
|
||
return cursor.lastrowid
|
||
except sqlite3.IntegrityError:
|
||
return None
|
||
|
||
|
||
def verify_user(username, password):
|
||
"""验证用户登录 - 自动从SHA256升级到bcrypt"""
|
||
with db_pool.get_db() as conn:
|
||
cursor = conn.cursor()
|
||
cursor.execute("SELECT * FROM users WHERE username = ? AND status = 'approved'", (username,))
|
||
user = cursor.fetchone()
|
||
|
||
if not user:
|
||
return None
|
||
|
||
user_dict = dict(user)
|
||
password_hash = user_dict["password_hash"]
|
||
|
||
if is_sha256_hash(password_hash):
|
||
if verify_password_sha256(password, password_hash):
|
||
new_hash = hash_password_bcrypt(password)
|
||
cursor.execute("UPDATE users SET password_hash = ? WHERE id = ?", (new_hash, user_dict["id"]))
|
||
conn.commit()
|
||
logger.info(f"用户密码已自动升级到bcrypt (user_id={user_dict['id']})")
|
||
return user_dict
|
||
return None
|
||
|
||
if verify_password_bcrypt(password, password_hash):
|
||
return user_dict
|
||
return None
|
||
|
||
|
||
def get_user_by_id(user_id):
|
||
"""根据ID获取用户"""
|
||
return _get_user_by_field("id", user_id)
|
||
|
||
|
||
def get_user_kdocs_settings(user_id):
|
||
"""获取用户的金山文档配置"""
|
||
with db_pool.get_db() as conn:
|
||
cursor = conn.cursor()
|
||
cursor.execute("SELECT kdocs_unit, kdocs_auto_upload FROM users WHERE id = ?", (user_id,))
|
||
row = cursor.fetchone()
|
||
if not row:
|
||
return None
|
||
return {
|
||
"kdocs_unit": (row["kdocs_unit"] or "") if isinstance(row, dict) else (row[0] or ""),
|
||
"kdocs_auto_upload": 1 if ((row["kdocs_auto_upload"] if isinstance(row, dict) else row[1]) or 0) else 0,
|
||
}
|
||
|
||
|
||
def update_user_kdocs_settings(user_id, *, kdocs_unit=None, kdocs_auto_upload=None) -> bool:
|
||
"""更新用户的金山文档配置"""
|
||
updates = []
|
||
params = []
|
||
|
||
if kdocs_unit is not None:
|
||
updates.append("kdocs_unit = ?")
|
||
params.append(kdocs_unit)
|
||
if kdocs_auto_upload is not None:
|
||
updates.append("kdocs_auto_upload = ?")
|
||
params.append(kdocs_auto_upload)
|
||
|
||
if not updates:
|
||
return False
|
||
|
||
params.append(user_id)
|
||
with db_pool.get_db() as conn:
|
||
cursor = conn.cursor()
|
||
cursor.execute(f"UPDATE users SET {', '.join(updates)} WHERE id = ?", params)
|
||
conn.commit()
|
||
return cursor.rowcount > 0
|
||
|
||
|
||
def get_user_by_username(username):
|
||
"""根据用户名获取用户"""
|
||
return _get_user_by_field("username", username)
|
||
|
||
|
||
def get_all_users():
|
||
"""获取所有用户"""
|
||
with db_pool.get_db() as conn:
|
||
cursor = conn.cursor()
|
||
cursor.execute("SELECT * FROM users ORDER BY created_at DESC")
|
||
return [dict(row) for row in cursor.fetchall()]
|
||
|
||
|
||
def get_pending_users():
|
||
"""获取待审核用户"""
|
||
with db_pool.get_db() as conn:
|
||
cursor = conn.cursor()
|
||
cursor.execute("SELECT * FROM users WHERE status = 'pending' ORDER BY created_at DESC")
|
||
return [dict(row) for row in cursor.fetchall()]
|
||
|
||
|
||
def approve_user(user_id):
|
||
"""审核通过用户"""
|
||
with db_pool.get_db() as conn:
|
||
cursor = conn.cursor()
|
||
cursor.execute(
|
||
"""
|
||
UPDATE users
|
||
SET status = 'approved', approved_at = ?
|
||
WHERE id = ?
|
||
""",
|
||
(get_cst_now_str(), user_id),
|
||
)
|
||
conn.commit()
|
||
return cursor.rowcount > 0
|
||
|
||
|
||
def reject_user(user_id):
|
||
"""拒绝用户"""
|
||
with db_pool.get_db() as conn:
|
||
cursor = conn.cursor()
|
||
cursor.execute("UPDATE users SET status = 'rejected' WHERE id = ?", (user_id,))
|
||
conn.commit()
|
||
return cursor.rowcount > 0
|
||
|
||
|
||
def delete_user(user_id):
|
||
"""删除用户(级联删除相关账号)"""
|
||
with db_pool.get_db() as conn:
|
||
cursor = conn.cursor()
|
||
cursor.execute("DELETE FROM users WHERE id = ?", (user_id,))
|
||
conn.commit()
|
||
return cursor.rowcount > 0
|
||
|
||
|
||
def get_user_stats(user_id):
|
||
"""获取用户统计信息"""
|
||
with db_pool.get_db() as conn:
|
||
cursor = conn.cursor()
|
||
cursor.execute("SELECT COUNT(*) as count FROM accounts WHERE user_id = ?", (user_id,))
|
||
row = cursor.fetchone()
|
||
return {"account_count": int((row["count"] if row else 0) or 0)}
|