#!/usr/bin/env python3 # -*- coding: utf-8 -*- from __future__ import annotations import sqlite3 from datetime import datetime, timedelta import pytz from app_logger import get_logger import db_pool from db.utils import get_cst_now_str from password_utils import ( hash_password_bcrypt, is_sha256_hash, verify_password_bcrypt, verify_password_sha256, ) logger = get_logger(__name__) _CST_TZ = pytz.timezone("Asia/Shanghai") _PERMANENT_VIP_EXPIRE = "2099-12-31 23:59:59" def _row_to_dict(row): return dict(row) if row else None def _get_user_by_field(field_name: str, field_value): with db_pool.get_db() as conn: cursor = conn.cursor() cursor.execute(f"SELECT * FROM users WHERE {field_name} = ?", (field_value,)) return _row_to_dict(cursor.fetchone()) def _parse_cst_datetime(datetime_str: str | None): if not datetime_str: return None try: naive_dt = datetime.strptime(str(datetime_str), "%Y-%m-%d %H:%M:%S") return _CST_TZ.localize(naive_dt) except Exception: return None def _format_vip_expire(days: int, *, base_dt: datetime | None = None) -> str: if int(days) == 999999: return _PERMANENT_VIP_EXPIRE if base_dt is None: base_dt = datetime.now(_CST_TZ) return (base_dt + timedelta(days=int(days))).strftime("%Y-%m-%d %H:%M:%S") def get_vip_config(): """获取VIP配置""" with db_pool.get_db() as conn: cursor = conn.cursor() cursor.execute("SELECT * FROM vip_config WHERE id = 1") config = cursor.fetchone() return dict(config) if config else {"default_vip_days": 0} def set_default_vip_days(days): """设置默认VIP天数""" with db_pool.get_db() as conn: cursor = conn.cursor() cursor.execute( """ INSERT OR REPLACE INTO vip_config (id, default_vip_days, updated_at) VALUES (1, ?, ?) """, (days, get_cst_now_str()), ) conn.commit() return True def set_user_vip(user_id, days): """设置用户VIP - days: 7=一周, 30=一个月, 365=一年, 999999=永久""" with db_pool.get_db() as conn: cursor = conn.cursor() expire_time = _format_vip_expire(days) cursor.execute("UPDATE users SET vip_expire_time = ? WHERE id = ?", (expire_time, user_id)) conn.commit() return cursor.rowcount > 0 def extend_user_vip(user_id, days): """延长用户VIP时间""" user = get_user_by_id(user_id) if not user: return False current_expire = user.get("vip_expire_time") now_dt = datetime.now(_CST_TZ) if current_expire and current_expire != _PERMANENT_VIP_EXPIRE: expire_time = _parse_cst_datetime(current_expire) if expire_time is not None: if expire_time < now_dt: expire_time = now_dt new_expire = _format_vip_expire(days, base_dt=expire_time) else: logger.warning("解析VIP过期时间失败,使用当前时间") new_expire = _format_vip_expire(days, base_dt=now_dt) else: new_expire = _format_vip_expire(days, base_dt=now_dt) with db_pool.get_db() as conn: cursor = conn.cursor() cursor.execute("UPDATE users SET vip_expire_time = ? WHERE id = ?", (new_expire, user_id)) conn.commit() return cursor.rowcount > 0 def remove_user_vip(user_id): """移除用户VIP""" with db_pool.get_db() as conn: cursor = conn.cursor() cursor.execute("UPDATE users SET vip_expire_time = NULL WHERE id = ?", (user_id,)) conn.commit() return cursor.rowcount > 0 def is_user_vip(user_id): """检查用户是否是VIP 注意:数据库中存储的时间统一使用CST(Asia/Shanghai)时区 """ user = get_user_by_id(user_id) if not user: return False vip_expire_time = user.get("vip_expire_time") if not vip_expire_time: return False expire_time = _parse_cst_datetime(vip_expire_time) if expire_time is None: logger.warning(f"检查VIP状态失败 (user_id={user_id}): 无法解析时间") return False return datetime.now(_CST_TZ) < expire_time def get_user_vip_info(user_id): """获取用户VIP信息""" user = get_user_by_id(user_id) if not user: return {"is_vip": False, "expire_time": None, "days_left": 0, "username": ""} vip_expire_time = user.get("vip_expire_time") username = user.get("username", "") if not vip_expire_time: return {"is_vip": False, "expire_time": None, "days_left": 0, "username": username} expire_time = _parse_cst_datetime(vip_expire_time) if expire_time is None: logger.warning("VIP信息获取错误: 无法解析过期时间") return {"is_vip": False, "expire_time": None, "days_left": 0, "username": username} now_dt = datetime.now(_CST_TZ) is_vip = now_dt < expire_time days_left = (expire_time - now_dt).days if is_vip else 0 return { "username": username, "is_vip": is_vip, "expire_time": vip_expire_time, "days_left": max(0, days_left), } # ==================== 用户相关 ==================== def create_user(username, password, email=""): """创建新用户(默认直接通过,赠送默认VIP)""" with db_pool.get_db() as conn: cursor = conn.cursor() password_hash = hash_password_bcrypt(password) cst_time = get_cst_now_str() default_vip_days = get_vip_config()["default_vip_days"] vip_expire_time = None if int(default_vip_days or 0) > 0: vip_expire_time = _format_vip_expire(int(default_vip_days)) try: cursor.execute( """ INSERT INTO users (username, password_hash, email, status, vip_expire_time, created_at, approved_at) VALUES (?, ?, ?, 'approved', ?, ?, ?) """, (username, password_hash, email, vip_expire_time, cst_time, cst_time), ) conn.commit() return cursor.lastrowid except sqlite3.IntegrityError: return None def verify_user(username, password): """验证用户登录 - 自动从SHA256升级到bcrypt""" with db_pool.get_db() as conn: cursor = conn.cursor() cursor.execute("SELECT * FROM users WHERE username = ? AND status = 'approved'", (username,)) user = cursor.fetchone() if not user: return None user_dict = dict(user) password_hash = user_dict["password_hash"] if is_sha256_hash(password_hash): if verify_password_sha256(password, password_hash): new_hash = hash_password_bcrypt(password) cursor.execute("UPDATE users SET password_hash = ? WHERE id = ?", (new_hash, user_dict["id"])) conn.commit() logger.info(f"用户密码已自动升级到bcrypt (user_id={user_dict['id']})") return user_dict return None if verify_password_bcrypt(password, password_hash): return user_dict return None def get_user_by_id(user_id): """根据ID获取用户""" return _get_user_by_field("id", user_id) def get_user_kdocs_settings(user_id): """获取用户的金山文档配置""" with db_pool.get_db() as conn: cursor = conn.cursor() cursor.execute("SELECT kdocs_unit, kdocs_auto_upload FROM users WHERE id = ?", (user_id,)) row = cursor.fetchone() if not row: return None return { "kdocs_unit": (row["kdocs_unit"] or "") if isinstance(row, dict) else (row[0] or ""), "kdocs_auto_upload": 1 if ((row["kdocs_auto_upload"] if isinstance(row, dict) else row[1]) or 0) else 0, } def update_user_kdocs_settings(user_id, *, kdocs_unit=None, kdocs_auto_upload=None) -> bool: """更新用户的金山文档配置""" updates = [] params = [] if kdocs_unit is not None: updates.append("kdocs_unit = ?") params.append(kdocs_unit) if kdocs_auto_upload is not None: updates.append("kdocs_auto_upload = ?") params.append(kdocs_auto_upload) if not updates: return False params.append(user_id) with db_pool.get_db() as conn: cursor = conn.cursor() cursor.execute(f"UPDATE users SET {', '.join(updates)} WHERE id = ?", params) conn.commit() return cursor.rowcount > 0 def get_user_by_username(username): """根据用户名获取用户""" return _get_user_by_field("username", username) def get_all_users(): """获取所有用户""" with db_pool.get_db() as conn: cursor = conn.cursor() cursor.execute("SELECT * FROM users ORDER BY created_at DESC") return [dict(row) for row in cursor.fetchall()] def get_pending_users(): """获取待审核用户""" with db_pool.get_db() as conn: cursor = conn.cursor() cursor.execute("SELECT * FROM users WHERE status = 'pending' ORDER BY created_at DESC") return [dict(row) for row in cursor.fetchall()] def approve_user(user_id): """审核通过用户""" with db_pool.get_db() as conn: cursor = conn.cursor() cursor.execute( """ UPDATE users SET status = 'approved', approved_at = ? WHERE id = ? """, (get_cst_now_str(), user_id), ) conn.commit() return cursor.rowcount > 0 def reject_user(user_id): """拒绝用户""" with db_pool.get_db() as conn: cursor = conn.cursor() cursor.execute("UPDATE users SET status = 'rejected' WHERE id = ?", (user_id,)) conn.commit() return cursor.rowcount > 0 def delete_user(user_id): """删除用户(级联删除相关账号)""" with db_pool.get_db() as conn: cursor = conn.cursor() cursor.execute("DELETE FROM users WHERE id = ?", (user_id,)) conn.commit() return cursor.rowcount > 0 def get_user_stats(user_id): """获取用户统计信息""" with db_pool.get_db() as conn: cursor = conn.cursor() cursor.execute("SELECT COUNT(*) as count FROM accounts WHERE user_id = ?", (user_id,)) row = cursor.fetchone() return {"account_count": int((row["count"] if row else 0) or 0)}