Files
zsglpt/db/admin.py

436 lines
14 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
from __future__ import annotations
import sqlite3
from datetime import datetime, timedelta
import pytz
import db_pool
from db.utils import get_cst_now_str
from password_utils import (
hash_password_bcrypt,
is_sha256_hash,
verify_password_bcrypt,
verify_password_sha256,
)
def ensure_default_admin() -> bool:
"""确保存在默认管理员账号(行为保持不变)。"""
import secrets
import string
with db_pool.get_db() as conn:
cursor = conn.cursor()
cursor.execute("SELECT COUNT(*) as count FROM admins")
result = cursor.fetchone()
if result["count"] == 0:
alphabet = string.ascii_letters + string.digits
random_password = "".join(secrets.choice(alphabet) for _ in range(12))
default_password_hash = hash_password_bcrypt(random_password)
cursor.execute(
"INSERT INTO admins (username, password_hash, created_at) VALUES (?, ?, ?)",
("admin", default_password_hash, get_cst_now_str()),
)
conn.commit()
print("=" * 60)
print("安全提醒:已创建默认管理员账号")
print("用户名: admin")
print(f"密码: {random_password}")
print("请立即登录后修改密码!")
print("=" * 60)
return True
return False
def verify_admin(username: str, password: str):
"""验证管理员登录 - 自动从SHA256升级到bcrypt行为保持不变"""
with db_pool.get_db() as conn:
cursor = conn.cursor()
cursor.execute("SELECT * FROM admins WHERE username = ?", (username,))
admin = cursor.fetchone()
if not admin:
return None
admin_dict = dict(admin)
password_hash = admin_dict["password_hash"]
if is_sha256_hash(password_hash):
if verify_password_sha256(password, password_hash):
new_hash = hash_password_bcrypt(password)
cursor.execute("UPDATE admins SET password_hash = ? WHERE username = ?", (new_hash, username))
conn.commit()
print(f"管理员 {username} 密码已自动升级到bcrypt")
return admin_dict
return None
if verify_password_bcrypt(password, password_hash):
return admin_dict
return None
def update_admin_password(username: str, new_password: str) -> bool:
"""更新管理员密码"""
with db_pool.get_db() as conn:
cursor = conn.cursor()
password_hash = hash_password_bcrypt(new_password)
cursor.execute("UPDATE admins SET password_hash = ? WHERE username = ?", (password_hash, username))
conn.commit()
return cursor.rowcount > 0
def update_admin_username(old_username: str, new_username: str) -> bool:
"""更新管理员用户名"""
with db_pool.get_db() as conn:
cursor = conn.cursor()
try:
cursor.execute("UPDATE admins SET username = ? WHERE username = ?", (new_username, old_username))
conn.commit()
return True
except sqlite3.IntegrityError:
return False
def get_system_stats() -> dict:
"""获取系统统计信息"""
with db_pool.get_db() as conn:
cursor = conn.cursor()
cursor.execute("SELECT COUNT(*) as count FROM users")
total_users = cursor.fetchone()["count"]
cursor.execute("SELECT COUNT(*) as count FROM users WHERE status = 'approved'")
approved_users = cursor.fetchone()["count"]
cursor.execute(
"""
SELECT COUNT(*) as count
FROM users
WHERE date(created_at) = date('now', 'localtime')
"""
)
new_users_today = cursor.fetchone()["count"]
cursor.execute(
"""
SELECT COUNT(*) as count
FROM users
WHERE datetime(created_at) >= datetime('now', 'localtime', '-7 days')
"""
)
new_users_7d = cursor.fetchone()["count"]
cursor.execute("SELECT COUNT(*) as count FROM accounts")
total_accounts = cursor.fetchone()["count"]
cursor.execute(
"""
SELECT COUNT(*) as count FROM users
WHERE vip_expire_time IS NOT NULL
AND datetime(vip_expire_time) > datetime('now', 'localtime')
"""
)
vip_users = cursor.fetchone()["count"]
return {
"total_users": total_users,
"approved_users": approved_users,
"new_users_today": new_users_today,
"new_users_7d": new_users_7d,
"total_accounts": total_accounts,
"vip_users": vip_users,
}
def get_system_config_raw() -> dict:
"""获取系统配置(无缓存,供 facade 做缓存/失效)。"""
with db_pool.get_db() as conn:
cursor = conn.cursor()
cursor.execute("SELECT * FROM system_config WHERE id = 1")
row = cursor.fetchone()
if row:
return dict(row)
return {
"max_concurrent_global": 2,
"max_concurrent_per_account": 1,
"max_screenshot_concurrent": 3,
"schedule_enabled": 0,
"schedule_time": "02:00",
"schedule_browse_type": "应读",
"schedule_weekdays": "1,2,3,4,5,6,7",
"proxy_enabled": 0,
"proxy_api_url": "",
"proxy_expire_minutes": 3,
"enable_screenshot": 1,
"auto_approve_enabled": 0,
"auto_approve_hourly_limit": 10,
"auto_approve_vip_days": 7,
}
def update_system_config(
*,
max_concurrent=None,
schedule_enabled=None,
schedule_time=None,
schedule_browse_type=None,
schedule_weekdays=None,
max_concurrent_per_account=None,
max_screenshot_concurrent=None,
proxy_enabled=None,
proxy_api_url=None,
proxy_expire_minutes=None,
auto_approve_enabled=None,
auto_approve_hourly_limit=None,
auto_approve_vip_days=None,
) -> bool:
"""更新系统配置仅更新DB不做缓存处理"""
allowed_fields = {
"max_concurrent_global",
"schedule_enabled",
"schedule_time",
"schedule_browse_type",
"schedule_weekdays",
"max_concurrent_per_account",
"max_screenshot_concurrent",
"proxy_enabled",
"proxy_api_url",
"proxy_expire_minutes",
"auto_approve_enabled",
"auto_approve_hourly_limit",
"auto_approve_vip_days",
"updated_at",
}
with db_pool.get_db() as conn:
cursor = conn.cursor()
updates = []
params = []
if max_concurrent is not None:
updates.append("max_concurrent_global = ?")
params.append(max_concurrent)
if schedule_enabled is not None:
updates.append("schedule_enabled = ?")
params.append(schedule_enabled)
if schedule_time is not None:
updates.append("schedule_time = ?")
params.append(schedule_time)
if schedule_browse_type is not None:
updates.append("schedule_browse_type = ?")
params.append(schedule_browse_type)
if max_concurrent_per_account is not None:
updates.append("max_concurrent_per_account = ?")
params.append(max_concurrent_per_account)
if max_screenshot_concurrent is not None:
updates.append("max_screenshot_concurrent = ?")
params.append(max_screenshot_concurrent)
if schedule_weekdays is not None:
updates.append("schedule_weekdays = ?")
params.append(schedule_weekdays)
if proxy_enabled is not None:
updates.append("proxy_enabled = ?")
params.append(proxy_enabled)
if proxy_api_url is not None:
updates.append("proxy_api_url = ?")
params.append(proxy_api_url)
if proxy_expire_minutes is not None:
updates.append("proxy_expire_minutes = ?")
params.append(proxy_expire_minutes)
if auto_approve_enabled is not None:
updates.append("auto_approve_enabled = ?")
params.append(auto_approve_enabled)
if auto_approve_hourly_limit is not None:
updates.append("auto_approve_hourly_limit = ?")
params.append(auto_approve_hourly_limit)
if auto_approve_vip_days is not None:
updates.append("auto_approve_vip_days = ?")
params.append(auto_approve_vip_days)
if not updates:
return False
updates.append("updated_at = ?")
params.append(get_cst_now_str())
for update_clause in updates:
field_name = update_clause.split("=")[0].strip()
if field_name not in allowed_fields:
raise ValueError(f"非法字段名: {field_name}")
sql = f"UPDATE system_config SET {', '.join(updates)} WHERE id = 1"
cursor.execute(sql, params)
conn.commit()
return True
def get_hourly_registration_count() -> int:
"""获取最近一小时内的注册用户数"""
with db_pool.get_db() as conn:
cursor = conn.cursor()
cursor.execute(
"""
SELECT COUNT(*) FROM users
WHERE created_at >= datetime('now', 'localtime', '-1 hour')
"""
)
return cursor.fetchone()[0]
# ==================== 密码重置(管理员) ====================
def create_password_reset_request(user_id: int, new_password: str):
"""创建密码重置申请(存储哈希)"""
with db_pool.get_db() as conn:
cursor = conn.cursor()
password_hash = hash_password_bcrypt(new_password)
cst_time = get_cst_now_str()
try:
cursor.execute(
"""
INSERT INTO password_reset_requests (user_id, new_password_hash, status, created_at)
VALUES (?, ?, 'pending', ?)
""",
(user_id, password_hash, cst_time),
)
conn.commit()
return cursor.lastrowid
except Exception as e:
print(f"创建密码重置申请失败: {e}")
return None
def get_pending_password_resets():
"""获取待审核的密码重置申请列表"""
with db_pool.get_db() as conn:
cursor = conn.cursor()
cursor.execute(
"""
SELECT r.id, r.user_id, r.created_at, r.status,
u.username, u.email
FROM password_reset_requests r
JOIN users u ON r.user_id = u.id
WHERE r.status = 'pending'
ORDER BY r.created_at DESC
"""
)
return [dict(row) for row in cursor.fetchall()]
def approve_password_reset(request_id: int) -> bool:
"""批准密码重置申请"""
with db_pool.get_db() as conn:
cursor = conn.cursor()
cst_time = get_cst_now_str()
try:
cursor.execute(
"""
SELECT user_id, new_password_hash
FROM password_reset_requests
WHERE id = ? AND status = 'pending'
""",
(request_id,),
)
result = cursor.fetchone()
if not result:
return False
user_id = result["user_id"]
new_password_hash = result["new_password_hash"]
cursor.execute("UPDATE users SET password_hash = ? WHERE id = ?", (new_password_hash, user_id))
cursor.execute(
"""
UPDATE password_reset_requests
SET status = 'approved', processed_at = ?
WHERE id = ?
""",
(cst_time, request_id),
)
conn.commit()
return True
except Exception as e:
print(f"批准密码重置失败: {e}")
return False
def reject_password_reset(request_id: int) -> bool:
"""拒绝密码重置申请"""
with db_pool.get_db() as conn:
cursor = conn.cursor()
cst_time = get_cst_now_str()
try:
cursor.execute(
"""
UPDATE password_reset_requests
SET status = 'rejected', processed_at = ?
WHERE id = ? AND status = 'pending'
""",
(cst_time, request_id),
)
conn.commit()
return cursor.rowcount > 0
except Exception as e:
print(f"拒绝密码重置失败: {e}")
return False
def admin_reset_user_password(user_id: int, new_password: str) -> bool:
"""管理员直接重置用户密码"""
with db_pool.get_db() as conn:
cursor = conn.cursor()
password_hash = hash_password_bcrypt(new_password)
try:
cursor.execute("UPDATE users SET password_hash = ? WHERE id = ?", (password_hash, user_id))
conn.commit()
return cursor.rowcount > 0
except Exception as e:
print(f"管理员重置密码失败: {e}")
return False
def clean_old_operation_logs(days: int = 30) -> int:
"""清理指定天数前的操作日志如果存在operation_logs表"""
with db_pool.get_db() as conn:
cursor = conn.cursor()
cursor.execute(
"""
SELECT name FROM sqlite_master
WHERE type='table' AND name='operation_logs'
"""
)
if not cursor.fetchone():
return 0
try:
cursor.execute(
"""
DELETE FROM operation_logs
WHERE created_at < datetime('now', 'localtime', '-' || ? || ' days')
""",
(days,),
)
deleted_count = cursor.rowcount
conn.commit()
print(f"已清理 {deleted_count} 条旧操作日志 (>{days}天)")
return deleted_count
except Exception as e:
print(f"清理旧操作日志失败: {e}")
return 0