refactor: optimize structure, stability and runtime performance
This commit is contained in:
@@ -6,19 +6,51 @@ import db_pool
|
||||
from crypto_utils import decrypt_password, encrypt_password
|
||||
from db.utils import get_cst_now_str
|
||||
|
||||
_ACCOUNT_STATUS_QUERY_SQL = """
|
||||
SELECT status, login_fail_count, last_login_error
|
||||
FROM accounts
|
||||
WHERE id = ?
|
||||
"""
|
||||
|
||||
|
||||
def _decode_account_password(account_dict: dict) -> dict:
|
||||
account_dict["password"] = decrypt_password(account_dict.get("password", ""))
|
||||
return account_dict
|
||||
|
||||
|
||||
def _normalize_account_ids(account_ids) -> list[str]:
|
||||
normalized = []
|
||||
seen = set()
|
||||
for account_id in account_ids or []:
|
||||
if not account_id:
|
||||
continue
|
||||
account_key = str(account_id)
|
||||
if account_key in seen:
|
||||
continue
|
||||
seen.add(account_key)
|
||||
normalized.append(account_key)
|
||||
return normalized
|
||||
|
||||
|
||||
def create_account(user_id, account_id, username, password, remember=True, remark=""):
|
||||
"""创建账号(密码加密存储)"""
|
||||
with db_pool.get_db() as conn:
|
||||
cursor = conn.cursor()
|
||||
cst_time = get_cst_now_str()
|
||||
encrypted_password = encrypt_password(password)
|
||||
cursor.execute(
|
||||
"""
|
||||
INSERT INTO accounts (id, user_id, username, password, remember, remark, created_at)
|
||||
VALUES (?, ?, ?, ?, ?, ?, ?)
|
||||
""",
|
||||
(account_id, user_id, username, encrypted_password, 1 if remember else 0, remark, cst_time),
|
||||
(
|
||||
account_id,
|
||||
user_id,
|
||||
username,
|
||||
encrypted_password,
|
||||
1 if remember else 0,
|
||||
remark,
|
||||
get_cst_now_str(),
|
||||
),
|
||||
)
|
||||
conn.commit()
|
||||
return cursor.lastrowid
|
||||
@@ -29,12 +61,7 @@ def get_user_accounts(user_id):
|
||||
with db_pool.get_db() as conn:
|
||||
cursor = conn.cursor()
|
||||
cursor.execute("SELECT * FROM accounts WHERE user_id = ? ORDER BY created_at DESC", (user_id,))
|
||||
accounts = []
|
||||
for row in cursor.fetchall():
|
||||
account = dict(row)
|
||||
account["password"] = decrypt_password(account.get("password", ""))
|
||||
accounts.append(account)
|
||||
return accounts
|
||||
return [_decode_account_password(dict(row)) for row in cursor.fetchall()]
|
||||
|
||||
|
||||
def get_account(account_id):
|
||||
@@ -43,11 +70,9 @@ def get_account(account_id):
|
||||
cursor = conn.cursor()
|
||||
cursor.execute("SELECT * FROM accounts WHERE id = ?", (account_id,))
|
||||
row = cursor.fetchone()
|
||||
if row:
|
||||
account = dict(row)
|
||||
account["password"] = decrypt_password(account.get("password", ""))
|
||||
return account
|
||||
return None
|
||||
if not row:
|
||||
return None
|
||||
return _decode_account_password(dict(row))
|
||||
|
||||
|
||||
def update_account_remark(account_id, remark):
|
||||
@@ -78,33 +103,21 @@ def increment_account_login_fail(account_id, error_message):
|
||||
if not row:
|
||||
return False
|
||||
|
||||
fail_count = (row["login_fail_count"] or 0) + 1
|
||||
|
||||
if fail_count >= 3:
|
||||
cursor.execute(
|
||||
"""
|
||||
UPDATE accounts
|
||||
SET login_fail_count = ?,
|
||||
last_login_error = ?,
|
||||
status = 'suspended'
|
||||
WHERE id = ?
|
||||
""",
|
||||
(fail_count, error_message, account_id),
|
||||
)
|
||||
conn.commit()
|
||||
return True
|
||||
fail_count = int(row["login_fail_count"] or 0) + 1
|
||||
is_suspended = fail_count >= 3
|
||||
|
||||
cursor.execute(
|
||||
"""
|
||||
UPDATE accounts
|
||||
SET login_fail_count = ?,
|
||||
last_login_error = ?
|
||||
last_login_error = ?,
|
||||
status = CASE WHEN ? = 1 THEN 'suspended' ELSE status END
|
||||
WHERE id = ?
|
||||
""",
|
||||
(fail_count, error_message, account_id),
|
||||
(fail_count, error_message, 1 if is_suspended else 0, account_id),
|
||||
)
|
||||
conn.commit()
|
||||
return False
|
||||
return is_suspended
|
||||
|
||||
|
||||
def reset_account_login_status(account_id):
|
||||
@@ -129,29 +142,22 @@ def get_account_status(account_id):
|
||||
"""获取账号状态信息"""
|
||||
with db_pool.get_db() as conn:
|
||||
cursor = conn.cursor()
|
||||
cursor.execute(
|
||||
"""
|
||||
SELECT status, login_fail_count, last_login_error
|
||||
FROM accounts
|
||||
WHERE id = ?
|
||||
""",
|
||||
(account_id,),
|
||||
)
|
||||
cursor.execute(_ACCOUNT_STATUS_QUERY_SQL, (account_id,))
|
||||
return cursor.fetchone()
|
||||
|
||||
|
||||
def get_account_status_batch(account_ids):
|
||||
"""批量获取账号状态信息"""
|
||||
account_ids = [str(account_id) for account_id in (account_ids or []) if account_id]
|
||||
if not account_ids:
|
||||
normalized_ids = _normalize_account_ids(account_ids)
|
||||
if not normalized_ids:
|
||||
return {}
|
||||
|
||||
results = {}
|
||||
chunk_size = 900 # 避免触发 SQLite 绑定参数上限
|
||||
with db_pool.get_db() as conn:
|
||||
cursor = conn.cursor()
|
||||
for idx in range(0, len(account_ids), chunk_size):
|
||||
chunk = account_ids[idx : idx + chunk_size]
|
||||
for idx in range(0, len(normalized_ids), chunk_size):
|
||||
chunk = normalized_ids[idx : idx + chunk_size]
|
||||
placeholders = ",".join("?" for _ in chunk)
|
||||
cursor.execute(
|
||||
f"""
|
||||
|
||||
344
db/admin.py
344
db/admin.py
@@ -3,9 +3,6 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import sqlite3
|
||||
from datetime import datetime, timedelta
|
||||
|
||||
import pytz
|
||||
|
||||
import db_pool
|
||||
from db.utils import get_cst_now_str
|
||||
@@ -16,6 +13,99 @@ from password_utils import (
|
||||
verify_password_sha256,
|
||||
)
|
||||
|
||||
_DEFAULT_SYSTEM_CONFIG = {
|
||||
"max_concurrent_global": 2,
|
||||
"max_concurrent_per_account": 1,
|
||||
"max_screenshot_concurrent": 3,
|
||||
"schedule_enabled": 0,
|
||||
"schedule_time": "02:00",
|
||||
"schedule_browse_type": "应读",
|
||||
"schedule_weekdays": "1,2,3,4,5,6,7",
|
||||
"proxy_enabled": 0,
|
||||
"proxy_api_url": "",
|
||||
"proxy_expire_minutes": 3,
|
||||
"enable_screenshot": 1,
|
||||
"auto_approve_enabled": 0,
|
||||
"auto_approve_hourly_limit": 10,
|
||||
"auto_approve_vip_days": 7,
|
||||
"kdocs_enabled": 0,
|
||||
"kdocs_doc_url": "",
|
||||
"kdocs_default_unit": "",
|
||||
"kdocs_sheet_name": "",
|
||||
"kdocs_sheet_index": 0,
|
||||
"kdocs_unit_column": "A",
|
||||
"kdocs_image_column": "D",
|
||||
"kdocs_admin_notify_enabled": 0,
|
||||
"kdocs_admin_notify_email": "",
|
||||
"kdocs_row_start": 0,
|
||||
"kdocs_row_end": 0,
|
||||
}
|
||||
|
||||
_SYSTEM_CONFIG_UPDATERS = (
|
||||
("max_concurrent_global", "max_concurrent"),
|
||||
("schedule_enabled", "schedule_enabled"),
|
||||
("schedule_time", "schedule_time"),
|
||||
("schedule_browse_type", "schedule_browse_type"),
|
||||
("schedule_weekdays", "schedule_weekdays"),
|
||||
("max_concurrent_per_account", "max_concurrent_per_account"),
|
||||
("max_screenshot_concurrent", "max_screenshot_concurrent"),
|
||||
("enable_screenshot", "enable_screenshot"),
|
||||
("proxy_enabled", "proxy_enabled"),
|
||||
("proxy_api_url", "proxy_api_url"),
|
||||
("proxy_expire_minutes", "proxy_expire_minutes"),
|
||||
("auto_approve_enabled", "auto_approve_enabled"),
|
||||
("auto_approve_hourly_limit", "auto_approve_hourly_limit"),
|
||||
("auto_approve_vip_days", "auto_approve_vip_days"),
|
||||
("kdocs_enabled", "kdocs_enabled"),
|
||||
("kdocs_doc_url", "kdocs_doc_url"),
|
||||
("kdocs_default_unit", "kdocs_default_unit"),
|
||||
("kdocs_sheet_name", "kdocs_sheet_name"),
|
||||
("kdocs_sheet_index", "kdocs_sheet_index"),
|
||||
("kdocs_unit_column", "kdocs_unit_column"),
|
||||
("kdocs_image_column", "kdocs_image_column"),
|
||||
("kdocs_admin_notify_enabled", "kdocs_admin_notify_enabled"),
|
||||
("kdocs_admin_notify_email", "kdocs_admin_notify_email"),
|
||||
("kdocs_row_start", "kdocs_row_start"),
|
||||
("kdocs_row_end", "kdocs_row_end"),
|
||||
)
|
||||
|
||||
|
||||
def _count_scalar(cursor, sql: str, params=()) -> int:
|
||||
cursor.execute(sql, params)
|
||||
row = cursor.fetchone()
|
||||
if not row:
|
||||
return 0
|
||||
try:
|
||||
if "count" in row.keys():
|
||||
return int(row["count"] or 0)
|
||||
except Exception:
|
||||
pass
|
||||
try:
|
||||
return int(row[0] or 0)
|
||||
except Exception:
|
||||
return 0
|
||||
|
||||
|
||||
def _table_exists(cursor, table_name: str) -> bool:
|
||||
cursor.execute(
|
||||
"""
|
||||
SELECT name FROM sqlite_master
|
||||
WHERE type='table' AND name=?
|
||||
""",
|
||||
(table_name,),
|
||||
)
|
||||
return bool(cursor.fetchone())
|
||||
|
||||
|
||||
def _normalize_days(days, default: int = 30) -> int:
|
||||
try:
|
||||
value = int(days)
|
||||
except Exception:
|
||||
value = default
|
||||
if value < 0:
|
||||
return 0
|
||||
return value
|
||||
|
||||
|
||||
def ensure_default_admin() -> bool:
|
||||
"""确保存在默认管理员账号(行为保持不变)。"""
|
||||
@@ -24,10 +114,9 @@ def ensure_default_admin() -> bool:
|
||||
|
||||
with db_pool.get_db() as conn:
|
||||
cursor = conn.cursor()
|
||||
cursor.execute("SELECT COUNT(*) as count FROM admins")
|
||||
result = cursor.fetchone()
|
||||
count = _count_scalar(cursor, "SELECT COUNT(*) as count FROM admins")
|
||||
|
||||
if result["count"] == 0:
|
||||
if count == 0:
|
||||
alphabet = string.ascii_letters + string.digits
|
||||
random_password = "".join(secrets.choice(alphabet) for _ in range(12))
|
||||
|
||||
@@ -101,41 +190,33 @@ def get_system_stats() -> dict:
|
||||
with db_pool.get_db() as conn:
|
||||
cursor = conn.cursor()
|
||||
|
||||
cursor.execute("SELECT COUNT(*) as count FROM users")
|
||||
total_users = cursor.fetchone()["count"]
|
||||
|
||||
cursor.execute("SELECT COUNT(*) as count FROM users WHERE status = 'approved'")
|
||||
approved_users = cursor.fetchone()["count"]
|
||||
|
||||
cursor.execute(
|
||||
total_users = _count_scalar(cursor, "SELECT COUNT(*) as count FROM users")
|
||||
approved_users = _count_scalar(cursor, "SELECT COUNT(*) as count FROM users WHERE status = 'approved'")
|
||||
new_users_today = _count_scalar(
|
||||
cursor,
|
||||
"""
|
||||
SELECT COUNT(*) as count
|
||||
FROM users
|
||||
WHERE date(created_at) = date('now', 'localtime')
|
||||
"""
|
||||
""",
|
||||
)
|
||||
new_users_today = cursor.fetchone()["count"]
|
||||
|
||||
cursor.execute(
|
||||
new_users_7d = _count_scalar(
|
||||
cursor,
|
||||
"""
|
||||
SELECT COUNT(*) as count
|
||||
FROM users
|
||||
WHERE datetime(created_at) >= datetime('now', 'localtime', '-7 days')
|
||||
"""
|
||||
""",
|
||||
)
|
||||
new_users_7d = cursor.fetchone()["count"]
|
||||
|
||||
cursor.execute("SELECT COUNT(*) as count FROM accounts")
|
||||
total_accounts = cursor.fetchone()["count"]
|
||||
|
||||
cursor.execute(
|
||||
total_accounts = _count_scalar(cursor, "SELECT COUNT(*) as count FROM accounts")
|
||||
vip_users = _count_scalar(
|
||||
cursor,
|
||||
"""
|
||||
SELECT COUNT(*) as count FROM users
|
||||
WHERE vip_expire_time IS NOT NULL
|
||||
AND datetime(vip_expire_time) > datetime('now', 'localtime')
|
||||
"""
|
||||
""",
|
||||
)
|
||||
vip_users = cursor.fetchone()["count"]
|
||||
|
||||
return {
|
||||
"total_users": total_users,
|
||||
@@ -153,37 +234,9 @@ def get_system_config_raw() -> dict:
|
||||
cursor = conn.cursor()
|
||||
cursor.execute("SELECT * FROM system_config WHERE id = 1")
|
||||
row = cursor.fetchone()
|
||||
|
||||
if row:
|
||||
return dict(row)
|
||||
|
||||
return {
|
||||
"max_concurrent_global": 2,
|
||||
"max_concurrent_per_account": 1,
|
||||
"max_screenshot_concurrent": 3,
|
||||
"schedule_enabled": 0,
|
||||
"schedule_time": "02:00",
|
||||
"schedule_browse_type": "应读",
|
||||
"schedule_weekdays": "1,2,3,4,5,6,7",
|
||||
"proxy_enabled": 0,
|
||||
"proxy_api_url": "",
|
||||
"proxy_expire_minutes": 3,
|
||||
"enable_screenshot": 1,
|
||||
"auto_approve_enabled": 0,
|
||||
"auto_approve_hourly_limit": 10,
|
||||
"auto_approve_vip_days": 7,
|
||||
"kdocs_enabled": 0,
|
||||
"kdocs_doc_url": "",
|
||||
"kdocs_default_unit": "",
|
||||
"kdocs_sheet_name": "",
|
||||
"kdocs_sheet_index": 0,
|
||||
"kdocs_unit_column": "A",
|
||||
"kdocs_image_column": "D",
|
||||
"kdocs_admin_notify_enabled": 0,
|
||||
"kdocs_admin_notify_email": "",
|
||||
"kdocs_row_start": 0,
|
||||
"kdocs_row_end": 0,
|
||||
}
|
||||
return dict(_DEFAULT_SYSTEM_CONFIG)
|
||||
|
||||
|
||||
def update_system_config(
|
||||
@@ -215,127 +268,51 @@ def update_system_config(
|
||||
kdocs_row_end=None,
|
||||
) -> bool:
|
||||
"""更新系统配置(仅更新DB,不做缓存处理)。"""
|
||||
allowed_fields = {
|
||||
"max_concurrent_global",
|
||||
"schedule_enabled",
|
||||
"schedule_time",
|
||||
"schedule_browse_type",
|
||||
"schedule_weekdays",
|
||||
"max_concurrent_per_account",
|
||||
"max_screenshot_concurrent",
|
||||
"enable_screenshot",
|
||||
"proxy_enabled",
|
||||
"proxy_api_url",
|
||||
"proxy_expire_minutes",
|
||||
"auto_approve_enabled",
|
||||
"auto_approve_hourly_limit",
|
||||
"auto_approve_vip_days",
|
||||
"kdocs_enabled",
|
||||
"kdocs_doc_url",
|
||||
"kdocs_default_unit",
|
||||
"kdocs_sheet_name",
|
||||
"kdocs_sheet_index",
|
||||
"kdocs_unit_column",
|
||||
"kdocs_image_column",
|
||||
"kdocs_admin_notify_enabled",
|
||||
"kdocs_admin_notify_email",
|
||||
"kdocs_row_start",
|
||||
"kdocs_row_end",
|
||||
"updated_at",
|
||||
arg_values = {
|
||||
"max_concurrent": max_concurrent,
|
||||
"schedule_enabled": schedule_enabled,
|
||||
"schedule_time": schedule_time,
|
||||
"schedule_browse_type": schedule_browse_type,
|
||||
"schedule_weekdays": schedule_weekdays,
|
||||
"max_concurrent_per_account": max_concurrent_per_account,
|
||||
"max_screenshot_concurrent": max_screenshot_concurrent,
|
||||
"enable_screenshot": enable_screenshot,
|
||||
"proxy_enabled": proxy_enabled,
|
||||
"proxy_api_url": proxy_api_url,
|
||||
"proxy_expire_minutes": proxy_expire_minutes,
|
||||
"auto_approve_enabled": auto_approve_enabled,
|
||||
"auto_approve_hourly_limit": auto_approve_hourly_limit,
|
||||
"auto_approve_vip_days": auto_approve_vip_days,
|
||||
"kdocs_enabled": kdocs_enabled,
|
||||
"kdocs_doc_url": kdocs_doc_url,
|
||||
"kdocs_default_unit": kdocs_default_unit,
|
||||
"kdocs_sheet_name": kdocs_sheet_name,
|
||||
"kdocs_sheet_index": kdocs_sheet_index,
|
||||
"kdocs_unit_column": kdocs_unit_column,
|
||||
"kdocs_image_column": kdocs_image_column,
|
||||
"kdocs_admin_notify_enabled": kdocs_admin_notify_enabled,
|
||||
"kdocs_admin_notify_email": kdocs_admin_notify_email,
|
||||
"kdocs_row_start": kdocs_row_start,
|
||||
"kdocs_row_end": kdocs_row_end,
|
||||
}
|
||||
|
||||
updates = []
|
||||
params = []
|
||||
for db_field, arg_name in _SYSTEM_CONFIG_UPDATERS:
|
||||
value = arg_values.get(arg_name)
|
||||
if value is None:
|
||||
continue
|
||||
updates.append(f"{db_field} = ?")
|
||||
params.append(value)
|
||||
|
||||
if not updates:
|
||||
return False
|
||||
|
||||
updates.append("updated_at = ?")
|
||||
params.append(get_cst_now_str())
|
||||
|
||||
with db_pool.get_db() as conn:
|
||||
cursor = conn.cursor()
|
||||
updates = []
|
||||
params = []
|
||||
|
||||
if max_concurrent is not None:
|
||||
updates.append("max_concurrent_global = ?")
|
||||
params.append(max_concurrent)
|
||||
if schedule_enabled is not None:
|
||||
updates.append("schedule_enabled = ?")
|
||||
params.append(schedule_enabled)
|
||||
if schedule_time is not None:
|
||||
updates.append("schedule_time = ?")
|
||||
params.append(schedule_time)
|
||||
if schedule_browse_type is not None:
|
||||
updates.append("schedule_browse_type = ?")
|
||||
params.append(schedule_browse_type)
|
||||
if max_concurrent_per_account is not None:
|
||||
updates.append("max_concurrent_per_account = ?")
|
||||
params.append(max_concurrent_per_account)
|
||||
if max_screenshot_concurrent is not None:
|
||||
updates.append("max_screenshot_concurrent = ?")
|
||||
params.append(max_screenshot_concurrent)
|
||||
if enable_screenshot is not None:
|
||||
updates.append("enable_screenshot = ?")
|
||||
params.append(enable_screenshot)
|
||||
if schedule_weekdays is not None:
|
||||
updates.append("schedule_weekdays = ?")
|
||||
params.append(schedule_weekdays)
|
||||
if proxy_enabled is not None:
|
||||
updates.append("proxy_enabled = ?")
|
||||
params.append(proxy_enabled)
|
||||
if proxy_api_url is not None:
|
||||
updates.append("proxy_api_url = ?")
|
||||
params.append(proxy_api_url)
|
||||
if proxy_expire_minutes is not None:
|
||||
updates.append("proxy_expire_minutes = ?")
|
||||
params.append(proxy_expire_minutes)
|
||||
if auto_approve_enabled is not None:
|
||||
updates.append("auto_approve_enabled = ?")
|
||||
params.append(auto_approve_enabled)
|
||||
if auto_approve_hourly_limit is not None:
|
||||
updates.append("auto_approve_hourly_limit = ?")
|
||||
params.append(auto_approve_hourly_limit)
|
||||
if auto_approve_vip_days is not None:
|
||||
updates.append("auto_approve_vip_days = ?")
|
||||
params.append(auto_approve_vip_days)
|
||||
if kdocs_enabled is not None:
|
||||
updates.append("kdocs_enabled = ?")
|
||||
params.append(kdocs_enabled)
|
||||
if kdocs_doc_url is not None:
|
||||
updates.append("kdocs_doc_url = ?")
|
||||
params.append(kdocs_doc_url)
|
||||
if kdocs_default_unit is not None:
|
||||
updates.append("kdocs_default_unit = ?")
|
||||
params.append(kdocs_default_unit)
|
||||
if kdocs_sheet_name is not None:
|
||||
updates.append("kdocs_sheet_name = ?")
|
||||
params.append(kdocs_sheet_name)
|
||||
if kdocs_sheet_index is not None:
|
||||
updates.append("kdocs_sheet_index = ?")
|
||||
params.append(kdocs_sheet_index)
|
||||
if kdocs_unit_column is not None:
|
||||
updates.append("kdocs_unit_column = ?")
|
||||
params.append(kdocs_unit_column)
|
||||
if kdocs_image_column is not None:
|
||||
updates.append("kdocs_image_column = ?")
|
||||
params.append(kdocs_image_column)
|
||||
if kdocs_admin_notify_enabled is not None:
|
||||
updates.append("kdocs_admin_notify_enabled = ?")
|
||||
params.append(kdocs_admin_notify_enabled)
|
||||
if kdocs_admin_notify_email is not None:
|
||||
updates.append("kdocs_admin_notify_email = ?")
|
||||
params.append(kdocs_admin_notify_email)
|
||||
if kdocs_row_start is not None:
|
||||
updates.append("kdocs_row_start = ?")
|
||||
params.append(kdocs_row_start)
|
||||
if kdocs_row_end is not None:
|
||||
updates.append("kdocs_row_end = ?")
|
||||
params.append(kdocs_row_end)
|
||||
|
||||
if not updates:
|
||||
return False
|
||||
|
||||
updates.append("updated_at = ?")
|
||||
params.append(get_cst_now_str())
|
||||
|
||||
for update_clause in updates:
|
||||
field_name = update_clause.split("=")[0].strip()
|
||||
if field_name not in allowed_fields:
|
||||
raise ValueError(f"非法字段名: {field_name}")
|
||||
|
||||
sql = f"UPDATE system_config SET {', '.join(updates)} WHERE id = 1"
|
||||
cursor.execute(sql, params)
|
||||
conn.commit()
|
||||
@@ -346,13 +323,13 @@ def get_hourly_registration_count() -> int:
|
||||
"""获取最近一小时内的注册用户数"""
|
||||
with db_pool.get_db() as conn:
|
||||
cursor = conn.cursor()
|
||||
cursor.execute(
|
||||
return _count_scalar(
|
||||
cursor,
|
||||
"""
|
||||
SELECT COUNT(*) FROM users
|
||||
SELECT COUNT(*) as count FROM users
|
||||
WHERE created_at >= datetime('now', 'localtime', '-1 hour')
|
||||
"""
|
||||
""",
|
||||
)
|
||||
return cursor.fetchone()[0]
|
||||
|
||||
|
||||
# ==================== 密码重置(管理员) ====================
|
||||
@@ -374,17 +351,12 @@ def admin_reset_user_password(user_id: int, new_password: str) -> bool:
|
||||
|
||||
def clean_old_operation_logs(days: int = 30) -> int:
|
||||
"""清理指定天数前的操作日志(如果存在operation_logs表)"""
|
||||
safe_days = _normalize_days(days, default=30)
|
||||
|
||||
with db_pool.get_db() as conn:
|
||||
cursor = conn.cursor()
|
||||
|
||||
cursor.execute(
|
||||
"""
|
||||
SELECT name FROM sqlite_master
|
||||
WHERE type='table' AND name='operation_logs'
|
||||
"""
|
||||
)
|
||||
|
||||
if not cursor.fetchone():
|
||||
if not _table_exists(cursor, "operation_logs"):
|
||||
return 0
|
||||
|
||||
try:
|
||||
@@ -393,11 +365,11 @@ def clean_old_operation_logs(days: int = 30) -> int:
|
||||
DELETE FROM operation_logs
|
||||
WHERE created_at < datetime('now', 'localtime', '-' || ? || ' days')
|
||||
""",
|
||||
(days,),
|
||||
(safe_days,),
|
||||
)
|
||||
deleted_count = cursor.rowcount
|
||||
conn.commit()
|
||||
print(f"已清理 {deleted_count} 条旧操作日志 (>{days}天)")
|
||||
print(f"已清理 {deleted_count} 条旧操作日志 (>{safe_days}天)")
|
||||
return deleted_count
|
||||
except Exception as e:
|
||||
print(f"清理旧操作日志失败: {e}")
|
||||
|
||||
@@ -6,12 +6,38 @@ import db_pool
|
||||
from db.utils import get_cst_now_str
|
||||
|
||||
|
||||
def _normalize_limit(value, default: int, *, minimum: int = 1, maximum: int = 500) -> int:
|
||||
try:
|
||||
parsed = int(value)
|
||||
except Exception:
|
||||
parsed = default
|
||||
parsed = max(minimum, parsed)
|
||||
parsed = min(maximum, parsed)
|
||||
return parsed
|
||||
|
||||
|
||||
def _normalize_offset(value, default: int = 0) -> int:
|
||||
try:
|
||||
parsed = int(value)
|
||||
except Exception:
|
||||
parsed = default
|
||||
return max(0, parsed)
|
||||
|
||||
|
||||
def _normalize_announcement_payload(title, content, image_url):
|
||||
normalized_title = str(title or "").strip()
|
||||
normalized_content = str(content or "").strip()
|
||||
normalized_image = str(image_url or "").strip() or None
|
||||
return normalized_title, normalized_content, normalized_image
|
||||
|
||||
|
||||
def _deactivate_all_active_announcements(cursor, cst_time: str) -> None:
|
||||
cursor.execute("UPDATE announcements SET is_active = 0, updated_at = ? WHERE is_active = 1", (cst_time,))
|
||||
|
||||
|
||||
def create_announcement(title, content, image_url=None, is_active=True):
|
||||
"""创建公告(默认启用;启用时会自动停用其他公告)"""
|
||||
title = (title or "").strip()
|
||||
content = (content or "").strip()
|
||||
image_url = (image_url or "").strip()
|
||||
image_url = image_url or None
|
||||
title, content, image_url = _normalize_announcement_payload(title, content, image_url)
|
||||
if not title or not content:
|
||||
return None
|
||||
|
||||
@@ -20,7 +46,7 @@ def create_announcement(title, content, image_url=None, is_active=True):
|
||||
cst_time = get_cst_now_str()
|
||||
|
||||
if is_active:
|
||||
cursor.execute("UPDATE announcements SET is_active = 0, updated_at = ? WHERE is_active = 1", (cst_time,))
|
||||
_deactivate_all_active_announcements(cursor, cst_time)
|
||||
|
||||
cursor.execute(
|
||||
"""
|
||||
@@ -44,6 +70,9 @@ def get_announcement_by_id(announcement_id):
|
||||
|
||||
def get_announcements(limit=50, offset=0):
|
||||
"""获取公告列表(管理员用)"""
|
||||
safe_limit = _normalize_limit(limit, 50)
|
||||
safe_offset = _normalize_offset(offset, 0)
|
||||
|
||||
with db_pool.get_db() as conn:
|
||||
cursor = conn.cursor()
|
||||
cursor.execute(
|
||||
@@ -52,7 +81,7 @@ def get_announcements(limit=50, offset=0):
|
||||
ORDER BY created_at DESC, id DESC
|
||||
LIMIT ? OFFSET ?
|
||||
""",
|
||||
(limit, offset),
|
||||
(safe_limit, safe_offset),
|
||||
)
|
||||
return [dict(row) for row in cursor.fetchall()]
|
||||
|
||||
@@ -64,7 +93,7 @@ def set_announcement_active(announcement_id, is_active):
|
||||
cst_time = get_cst_now_str()
|
||||
|
||||
if is_active:
|
||||
cursor.execute("UPDATE announcements SET is_active = 0, updated_at = ? WHERE is_active = 1", (cst_time,))
|
||||
_deactivate_all_active_announcements(cursor, cst_time)
|
||||
cursor.execute(
|
||||
"""
|
||||
UPDATE announcements
|
||||
@@ -121,13 +150,12 @@ def dismiss_announcement_for_user(user_id, announcement_id):
|
||||
"""用户永久关闭某条公告(幂等)"""
|
||||
with db_pool.get_db() as conn:
|
||||
cursor = conn.cursor()
|
||||
cst_time = get_cst_now_str()
|
||||
cursor.execute(
|
||||
"""
|
||||
INSERT OR IGNORE INTO announcement_dismissals (user_id, announcement_id, dismissed_at)
|
||||
VALUES (?, ?, ?)
|
||||
""",
|
||||
(user_id, announcement_id, cst_time),
|
||||
(user_id, announcement_id, get_cst_now_str()),
|
||||
)
|
||||
conn.commit()
|
||||
return cursor.rowcount >= 0
|
||||
|
||||
27
db/email.py
27
db/email.py
@@ -5,6 +5,27 @@ from __future__ import annotations
|
||||
import db_pool
|
||||
|
||||
|
||||
def _to_bool_with_default(value, default: bool = True) -> bool:
|
||||
if value is None:
|
||||
return default
|
||||
try:
|
||||
return bool(int(value))
|
||||
except Exception:
|
||||
try:
|
||||
return bool(value)
|
||||
except Exception:
|
||||
return default
|
||||
|
||||
|
||||
def _normalize_notify_enabled(enabled) -> int:
|
||||
if isinstance(enabled, bool):
|
||||
return 1 if enabled else 0
|
||||
try:
|
||||
return 1 if int(enabled) else 0
|
||||
except Exception:
|
||||
return 1
|
||||
|
||||
|
||||
def get_user_by_email(email):
|
||||
"""根据邮箱获取用户"""
|
||||
with db_pool.get_db() as conn:
|
||||
@@ -25,7 +46,7 @@ def update_user_email(user_id, email, verified=False):
|
||||
SET email = ?, email_verified = ?
|
||||
WHERE id = ?
|
||||
""",
|
||||
(email, int(verified), user_id),
|
||||
(email, 1 if verified else 0, user_id),
|
||||
)
|
||||
conn.commit()
|
||||
return cursor.rowcount > 0
|
||||
@@ -42,7 +63,7 @@ def update_user_email_notify(user_id, enabled):
|
||||
SET email_notify_enabled = ?
|
||||
WHERE id = ?
|
||||
""",
|
||||
(int(enabled), user_id),
|
||||
(_normalize_notify_enabled(enabled), user_id),
|
||||
)
|
||||
conn.commit()
|
||||
return cursor.rowcount > 0
|
||||
@@ -57,6 +78,6 @@ def get_user_email_notify(user_id):
|
||||
row = cursor.fetchone()
|
||||
if row is None:
|
||||
return True
|
||||
return bool(row[0]) if row[0] is not None else True
|
||||
return _to_bool_with_default(row[0], default=True)
|
||||
except Exception:
|
||||
return True
|
||||
|
||||
106
db/feedbacks.py
106
db/feedbacks.py
@@ -2,32 +2,73 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
from __future__ import annotations
|
||||
|
||||
from datetime import datetime
|
||||
|
||||
import pytz
|
||||
|
||||
import db_pool
|
||||
from db.utils import escape_html
|
||||
from db.utils import escape_html, get_cst_now_str
|
||||
|
||||
|
||||
def _normalize_limit(value, default: int, *, minimum: int = 1, maximum: int = 500) -> int:
|
||||
try:
|
||||
parsed = int(value)
|
||||
except Exception:
|
||||
parsed = default
|
||||
parsed = max(minimum, parsed)
|
||||
parsed = min(maximum, parsed)
|
||||
return parsed
|
||||
|
||||
|
||||
def _normalize_offset(value, default: int = 0) -> int:
|
||||
try:
|
||||
parsed = int(value)
|
||||
except Exception:
|
||||
parsed = default
|
||||
return max(0, parsed)
|
||||
|
||||
|
||||
def _safe_text(value) -> str:
|
||||
if value is None:
|
||||
return ""
|
||||
text = str(value)
|
||||
return escape_html(text) if text else ""
|
||||
|
||||
|
||||
def _build_feedback_filter_sql(status_filter=None) -> tuple[str, list]:
|
||||
where_clauses = ["1=1"]
|
||||
params = []
|
||||
|
||||
if status_filter:
|
||||
where_clauses.append("status = ?")
|
||||
params.append(status_filter)
|
||||
|
||||
return " AND ".join(where_clauses), params
|
||||
|
||||
|
||||
def _normalize_feedback_stats_row(row) -> dict:
|
||||
row_dict = dict(row) if row else {}
|
||||
return {
|
||||
"total": int(row_dict.get("total") or 0),
|
||||
"pending": int(row_dict.get("pending") or 0),
|
||||
"replied": int(row_dict.get("replied") or 0),
|
||||
"closed": int(row_dict.get("closed") or 0),
|
||||
}
|
||||
|
||||
|
||||
def create_bug_feedback(user_id, username, title, description, contact=""):
|
||||
"""创建Bug反馈(带XSS防护)"""
|
||||
with db_pool.get_db() as conn:
|
||||
cursor = conn.cursor()
|
||||
cst_tz = pytz.timezone("Asia/Shanghai")
|
||||
cst_time = datetime.now(cst_tz).strftime("%Y-%m-%d %H:%M:%S")
|
||||
|
||||
safe_title = escape_html(title) if title else ""
|
||||
safe_description = escape_html(description) if description else ""
|
||||
safe_contact = escape_html(contact) if contact else ""
|
||||
safe_username = escape_html(username) if username else ""
|
||||
|
||||
cursor.execute(
|
||||
"""
|
||||
INSERT INTO bug_feedbacks (user_id, username, title, description, contact, created_at)
|
||||
VALUES (?, ?, ?, ?, ?, ?)
|
||||
""",
|
||||
(user_id, safe_username, safe_title, safe_description, safe_contact, cst_time),
|
||||
(
|
||||
user_id,
|
||||
_safe_text(username),
|
||||
_safe_text(title),
|
||||
_safe_text(description),
|
||||
_safe_text(contact),
|
||||
get_cst_now_str(),
|
||||
),
|
||||
)
|
||||
|
||||
conn.commit()
|
||||
@@ -36,25 +77,25 @@ def create_bug_feedback(user_id, username, title, description, contact=""):
|
||||
|
||||
def get_bug_feedbacks(limit=100, offset=0, status_filter=None):
|
||||
"""获取Bug反馈列表(管理员用)"""
|
||||
safe_limit = _normalize_limit(limit, 100, minimum=1, maximum=1000)
|
||||
safe_offset = _normalize_offset(offset, 0)
|
||||
|
||||
with db_pool.get_db() as conn:
|
||||
cursor = conn.cursor()
|
||||
|
||||
sql = "SELECT * FROM bug_feedbacks WHERE 1=1"
|
||||
params = []
|
||||
|
||||
if status_filter:
|
||||
sql += " AND status = ?"
|
||||
params.append(status_filter)
|
||||
|
||||
sql += " ORDER BY created_at DESC LIMIT ? OFFSET ?"
|
||||
params.extend([limit, offset])
|
||||
|
||||
cursor.execute(sql, params)
|
||||
where_sql, params = _build_feedback_filter_sql(status_filter=status_filter)
|
||||
sql = f"""
|
||||
SELECT * FROM bug_feedbacks
|
||||
WHERE {where_sql}
|
||||
ORDER BY created_at DESC
|
||||
LIMIT ? OFFSET ?
|
||||
"""
|
||||
cursor.execute(sql, params + [safe_limit, safe_offset])
|
||||
return [dict(row) for row in cursor.fetchall()]
|
||||
|
||||
|
||||
def get_user_feedbacks(user_id, limit=50):
|
||||
"""获取用户自己的反馈列表"""
|
||||
safe_limit = _normalize_limit(limit, 50, minimum=1, maximum=1000)
|
||||
with db_pool.get_db() as conn:
|
||||
cursor = conn.cursor()
|
||||
cursor.execute(
|
||||
@@ -64,7 +105,7 @@ def get_user_feedbacks(user_id, limit=50):
|
||||
ORDER BY created_at DESC
|
||||
LIMIT ?
|
||||
""",
|
||||
(user_id, limit),
|
||||
(user_id, safe_limit),
|
||||
)
|
||||
return [dict(row) for row in cursor.fetchall()]
|
||||
|
||||
@@ -82,18 +123,13 @@ def reply_feedback(feedback_id, admin_reply):
|
||||
"""管理员回复反馈(带XSS防护)"""
|
||||
with db_pool.get_db() as conn:
|
||||
cursor = conn.cursor()
|
||||
cst_tz = pytz.timezone("Asia/Shanghai")
|
||||
cst_time = datetime.now(cst_tz).strftime("%Y-%m-%d %H:%M:%S")
|
||||
|
||||
safe_reply = escape_html(admin_reply) if admin_reply else ""
|
||||
|
||||
cursor.execute(
|
||||
"""
|
||||
UPDATE bug_feedbacks
|
||||
SET admin_reply = ?, status = 'replied', replied_at = ?
|
||||
WHERE id = ?
|
||||
""",
|
||||
(safe_reply, cst_time, feedback_id),
|
||||
(_safe_text(admin_reply), get_cst_now_str(), feedback_id),
|
||||
)
|
||||
|
||||
conn.commit()
|
||||
@@ -139,6 +175,4 @@ def get_feedback_stats():
|
||||
FROM bug_feedbacks
|
||||
"""
|
||||
)
|
||||
row = cursor.fetchone()
|
||||
return dict(row) if row else {"total": 0, "pending": 0, "replied": 0, "closed": 0}
|
||||
|
||||
return _normalize_feedback_stats_row(cursor.fetchone())
|
||||
|
||||
541
db/migrations.py
541
db/migrations.py
@@ -28,105 +28,136 @@ def set_current_version(conn, version: int) -> None:
|
||||
conn.commit()
|
||||
|
||||
|
||||
def _table_exists(cursor, table_name: str) -> bool:
|
||||
cursor.execute("SELECT name FROM sqlite_master WHERE type='table' AND name=?", (str(table_name),))
|
||||
return cursor.fetchone() is not None
|
||||
|
||||
|
||||
def _get_table_columns(cursor, table_name: str) -> set[str]:
|
||||
cursor.execute(f"PRAGMA table_info({table_name})")
|
||||
return {col[1] for col in cursor.fetchall()}
|
||||
|
||||
|
||||
def _add_column_if_missing(cursor, table_name: str, columns: set[str], column_name: str, column_ddl: str, *, ok_message: str) -> bool:
|
||||
if column_name in columns:
|
||||
return False
|
||||
cursor.execute(f"ALTER TABLE {table_name} ADD COLUMN {column_name} {column_ddl}")
|
||||
columns.add(column_name)
|
||||
print(ok_message)
|
||||
return True
|
||||
|
||||
|
||||
def _read_row_value(row, key: str, index: int):
|
||||
if isinstance(row, sqlite3.Row):
|
||||
return row[key]
|
||||
return row[index]
|
||||
|
||||
|
||||
def _get_migration_steps():
|
||||
return [
|
||||
(1, _migrate_to_v1),
|
||||
(2, _migrate_to_v2),
|
||||
(3, _migrate_to_v3),
|
||||
(4, _migrate_to_v4),
|
||||
(5, _migrate_to_v5),
|
||||
(6, _migrate_to_v6),
|
||||
(7, _migrate_to_v7),
|
||||
(8, _migrate_to_v8),
|
||||
(9, _migrate_to_v9),
|
||||
(10, _migrate_to_v10),
|
||||
(11, _migrate_to_v11),
|
||||
(12, _migrate_to_v12),
|
||||
(13, _migrate_to_v13),
|
||||
(14, _migrate_to_v14),
|
||||
(15, _migrate_to_v15),
|
||||
(16, _migrate_to_v16),
|
||||
(17, _migrate_to_v17),
|
||||
(18, _migrate_to_v18),
|
||||
]
|
||||
|
||||
|
||||
def migrate_database(conn, target_version: int) -> None:
|
||||
"""数据库迁移:按版本增量升级(向前兼容)。"""
|
||||
cursor = conn.cursor()
|
||||
cursor.execute("INSERT OR IGNORE INTO db_version (id, version, updated_at) VALUES (1, 0, ?)", (get_cst_now_str(),))
|
||||
conn.commit()
|
||||
|
||||
target_version = int(target_version)
|
||||
current_version = get_current_version(conn)
|
||||
|
||||
if current_version < 1:
|
||||
_migrate_to_v1(conn)
|
||||
current_version = 1
|
||||
if current_version < 2:
|
||||
_migrate_to_v2(conn)
|
||||
current_version = 2
|
||||
if current_version < 3:
|
||||
_migrate_to_v3(conn)
|
||||
current_version = 3
|
||||
if current_version < 4:
|
||||
_migrate_to_v4(conn)
|
||||
current_version = 4
|
||||
if current_version < 5:
|
||||
_migrate_to_v5(conn)
|
||||
current_version = 5
|
||||
if current_version < 6:
|
||||
_migrate_to_v6(conn)
|
||||
current_version = 6
|
||||
if current_version < 7:
|
||||
_migrate_to_v7(conn)
|
||||
current_version = 7
|
||||
if current_version < 8:
|
||||
_migrate_to_v8(conn)
|
||||
current_version = 8
|
||||
if current_version < 9:
|
||||
_migrate_to_v9(conn)
|
||||
current_version = 9
|
||||
if current_version < 10:
|
||||
_migrate_to_v10(conn)
|
||||
current_version = 10
|
||||
if current_version < 11:
|
||||
_migrate_to_v11(conn)
|
||||
current_version = 11
|
||||
if current_version < 12:
|
||||
_migrate_to_v12(conn)
|
||||
current_version = 12
|
||||
if current_version < 13:
|
||||
_migrate_to_v13(conn)
|
||||
current_version = 13
|
||||
if current_version < 14:
|
||||
_migrate_to_v14(conn)
|
||||
current_version = 14
|
||||
if current_version < 15:
|
||||
_migrate_to_v15(conn)
|
||||
current_version = 15
|
||||
if current_version < 16:
|
||||
_migrate_to_v16(conn)
|
||||
current_version = 16
|
||||
if current_version < 17:
|
||||
_migrate_to_v17(conn)
|
||||
current_version = 17
|
||||
if current_version < 18:
|
||||
_migrate_to_v18(conn)
|
||||
current_version = 18
|
||||
for version, migrate_fn in _get_migration_steps():
|
||||
if version > target_version or current_version >= version:
|
||||
continue
|
||||
migrate_fn(conn)
|
||||
current_version = version
|
||||
|
||||
if current_version != int(target_version):
|
||||
set_current_version(conn, int(target_version))
|
||||
if current_version != target_version:
|
||||
set_current_version(conn, target_version)
|
||||
|
||||
|
||||
def _migrate_to_v1(conn):
|
||||
"""迁移到版本1 - 添加缺失字段"""
|
||||
cursor = conn.cursor()
|
||||
|
||||
cursor.execute("PRAGMA table_info(system_config)")
|
||||
columns = [col[1] for col in cursor.fetchall()]
|
||||
system_columns = _get_table_columns(cursor, "system_config")
|
||||
_add_column_if_missing(
|
||||
cursor,
|
||||
"system_config",
|
||||
system_columns,
|
||||
"schedule_weekdays",
|
||||
'TEXT DEFAULT "1,2,3,4,5,6,7"',
|
||||
ok_message=" [OK] 添加 schedule_weekdays 字段",
|
||||
)
|
||||
_add_column_if_missing(
|
||||
cursor,
|
||||
"system_config",
|
||||
system_columns,
|
||||
"max_screenshot_concurrent",
|
||||
"INTEGER DEFAULT 3",
|
||||
ok_message=" [OK] 添加 max_screenshot_concurrent 字段",
|
||||
)
|
||||
_add_column_if_missing(
|
||||
cursor,
|
||||
"system_config",
|
||||
system_columns,
|
||||
"max_concurrent_per_account",
|
||||
"INTEGER DEFAULT 1",
|
||||
ok_message=" [OK] 添加 max_concurrent_per_account 字段",
|
||||
)
|
||||
_add_column_if_missing(
|
||||
cursor,
|
||||
"system_config",
|
||||
system_columns,
|
||||
"auto_approve_enabled",
|
||||
"INTEGER DEFAULT 0",
|
||||
ok_message=" [OK] 添加 auto_approve_enabled 字段",
|
||||
)
|
||||
_add_column_if_missing(
|
||||
cursor,
|
||||
"system_config",
|
||||
system_columns,
|
||||
"auto_approve_hourly_limit",
|
||||
"INTEGER DEFAULT 10",
|
||||
ok_message=" [OK] 添加 auto_approve_hourly_limit 字段",
|
||||
)
|
||||
_add_column_if_missing(
|
||||
cursor,
|
||||
"system_config",
|
||||
system_columns,
|
||||
"auto_approve_vip_days",
|
||||
"INTEGER DEFAULT 7",
|
||||
ok_message=" [OK] 添加 auto_approve_vip_days 字段",
|
||||
)
|
||||
|
||||
if "schedule_weekdays" not in columns:
|
||||
cursor.execute('ALTER TABLE system_config ADD COLUMN schedule_weekdays TEXT DEFAULT "1,2,3,4,5,6,7"')
|
||||
print(" [OK] 添加 schedule_weekdays 字段")
|
||||
|
||||
if "max_screenshot_concurrent" not in columns:
|
||||
cursor.execute("ALTER TABLE system_config ADD COLUMN max_screenshot_concurrent INTEGER DEFAULT 3")
|
||||
print(" [OK] 添加 max_screenshot_concurrent 字段")
|
||||
if "max_concurrent_per_account" not in columns:
|
||||
cursor.execute("ALTER TABLE system_config ADD COLUMN max_concurrent_per_account INTEGER DEFAULT 1")
|
||||
print(" [OK] 添加 max_concurrent_per_account 字段")
|
||||
if "auto_approve_enabled" not in columns:
|
||||
cursor.execute("ALTER TABLE system_config ADD COLUMN auto_approve_enabled INTEGER DEFAULT 0")
|
||||
print(" [OK] 添加 auto_approve_enabled 字段")
|
||||
if "auto_approve_hourly_limit" not in columns:
|
||||
cursor.execute("ALTER TABLE system_config ADD COLUMN auto_approve_hourly_limit INTEGER DEFAULT 10")
|
||||
print(" [OK] 添加 auto_approve_hourly_limit 字段")
|
||||
if "auto_approve_vip_days" not in columns:
|
||||
cursor.execute("ALTER TABLE system_config ADD COLUMN auto_approve_vip_days INTEGER DEFAULT 7")
|
||||
print(" [OK] 添加 auto_approve_vip_days 字段")
|
||||
|
||||
cursor.execute("PRAGMA table_info(task_logs)")
|
||||
columns = [col[1] for col in cursor.fetchall()]
|
||||
if "duration" not in columns:
|
||||
cursor.execute("ALTER TABLE task_logs ADD COLUMN duration INTEGER")
|
||||
print(" [OK] 添加 duration 字段到 task_logs")
|
||||
task_log_columns = _get_table_columns(cursor, "task_logs")
|
||||
_add_column_if_missing(
|
||||
cursor,
|
||||
"task_logs",
|
||||
task_log_columns,
|
||||
"duration",
|
||||
"INTEGER",
|
||||
ok_message=" [OK] 添加 duration 字段到 task_logs",
|
||||
)
|
||||
|
||||
conn.commit()
|
||||
|
||||
@@ -135,24 +166,39 @@ def _migrate_to_v2(conn):
|
||||
"""迁移到版本2 - 添加代理配置字段"""
|
||||
cursor = conn.cursor()
|
||||
|
||||
cursor.execute("PRAGMA table_info(system_config)")
|
||||
columns = [col[1] for col in cursor.fetchall()]
|
||||
|
||||
if "proxy_enabled" not in columns:
|
||||
cursor.execute("ALTER TABLE system_config ADD COLUMN proxy_enabled INTEGER DEFAULT 0")
|
||||
print(" [OK] 添加 proxy_enabled 字段")
|
||||
|
||||
if "proxy_api_url" not in columns:
|
||||
cursor.execute('ALTER TABLE system_config ADD COLUMN proxy_api_url TEXT DEFAULT ""')
|
||||
print(" [OK] 添加 proxy_api_url 字段")
|
||||
|
||||
if "proxy_expire_minutes" not in columns:
|
||||
cursor.execute("ALTER TABLE system_config ADD COLUMN proxy_expire_minutes INTEGER DEFAULT 3")
|
||||
print(" [OK] 添加 proxy_expire_minutes 字段")
|
||||
|
||||
if "enable_screenshot" not in columns:
|
||||
cursor.execute("ALTER TABLE system_config ADD COLUMN enable_screenshot INTEGER DEFAULT 1")
|
||||
print(" [OK] 添加 enable_screenshot 字段")
|
||||
columns = _get_table_columns(cursor, "system_config")
|
||||
_add_column_if_missing(
|
||||
cursor,
|
||||
"system_config",
|
||||
columns,
|
||||
"proxy_enabled",
|
||||
"INTEGER DEFAULT 0",
|
||||
ok_message=" [OK] 添加 proxy_enabled 字段",
|
||||
)
|
||||
_add_column_if_missing(
|
||||
cursor,
|
||||
"system_config",
|
||||
columns,
|
||||
"proxy_api_url",
|
||||
'TEXT DEFAULT ""',
|
||||
ok_message=" [OK] 添加 proxy_api_url 字段",
|
||||
)
|
||||
_add_column_if_missing(
|
||||
cursor,
|
||||
"system_config",
|
||||
columns,
|
||||
"proxy_expire_minutes",
|
||||
"INTEGER DEFAULT 3",
|
||||
ok_message=" [OK] 添加 proxy_expire_minutes 字段",
|
||||
)
|
||||
_add_column_if_missing(
|
||||
cursor,
|
||||
"system_config",
|
||||
columns,
|
||||
"enable_screenshot",
|
||||
"INTEGER DEFAULT 1",
|
||||
ok_message=" [OK] 添加 enable_screenshot 字段",
|
||||
)
|
||||
|
||||
conn.commit()
|
||||
|
||||
@@ -161,20 +207,31 @@ def _migrate_to_v3(conn):
|
||||
"""迁移到版本3 - 添加账号状态和登录失败计数字段"""
|
||||
cursor = conn.cursor()
|
||||
|
||||
cursor.execute("PRAGMA table_info(accounts)")
|
||||
columns = [col[1] for col in cursor.fetchall()]
|
||||
|
||||
if "status" not in columns:
|
||||
cursor.execute('ALTER TABLE accounts ADD COLUMN status TEXT DEFAULT "active"')
|
||||
print(" [OK] 添加 accounts.status 字段 (账号状态)")
|
||||
|
||||
if "login_fail_count" not in columns:
|
||||
cursor.execute("ALTER TABLE accounts ADD COLUMN login_fail_count INTEGER DEFAULT 0")
|
||||
print(" [OK] 添加 accounts.login_fail_count 字段 (登录失败计数)")
|
||||
|
||||
if "last_login_error" not in columns:
|
||||
cursor.execute("ALTER TABLE accounts ADD COLUMN last_login_error TEXT")
|
||||
print(" [OK] 添加 accounts.last_login_error 字段 (最后登录错误)")
|
||||
columns = _get_table_columns(cursor, "accounts")
|
||||
_add_column_if_missing(
|
||||
cursor,
|
||||
"accounts",
|
||||
columns,
|
||||
"status",
|
||||
'TEXT DEFAULT "active"',
|
||||
ok_message=" [OK] 添加 accounts.status 字段 (账号状态)",
|
||||
)
|
||||
_add_column_if_missing(
|
||||
cursor,
|
||||
"accounts",
|
||||
columns,
|
||||
"login_fail_count",
|
||||
"INTEGER DEFAULT 0",
|
||||
ok_message=" [OK] 添加 accounts.login_fail_count 字段 (登录失败计数)",
|
||||
)
|
||||
_add_column_if_missing(
|
||||
cursor,
|
||||
"accounts",
|
||||
columns,
|
||||
"last_login_error",
|
||||
"TEXT",
|
||||
ok_message=" [OK] 添加 accounts.last_login_error 字段 (最后登录错误)",
|
||||
)
|
||||
|
||||
conn.commit()
|
||||
|
||||
@@ -183,12 +240,15 @@ def _migrate_to_v4(conn):
|
||||
"""迁移到版本4 - 添加任务来源字段"""
|
||||
cursor = conn.cursor()
|
||||
|
||||
cursor.execute("PRAGMA table_info(task_logs)")
|
||||
columns = [col[1] for col in cursor.fetchall()]
|
||||
|
||||
if "source" not in columns:
|
||||
cursor.execute('ALTER TABLE task_logs ADD COLUMN source TEXT DEFAULT "manual"')
|
||||
print(" [OK] 添加 task_logs.source 字段 (任务来源: manual/scheduled/immediate)")
|
||||
columns = _get_table_columns(cursor, "task_logs")
|
||||
_add_column_if_missing(
|
||||
cursor,
|
||||
"task_logs",
|
||||
columns,
|
||||
"source",
|
||||
'TEXT DEFAULT "manual"',
|
||||
ok_message=" [OK] 添加 task_logs.source 字段 (任务来源: manual/scheduled/immediate)",
|
||||
)
|
||||
|
||||
conn.commit()
|
||||
|
||||
@@ -300,20 +360,17 @@ def _migrate_to_v6(conn):
|
||||
def _migrate_to_v7(conn):
|
||||
"""迁移到版本7 - 统一存储北京时间(将历史UTC时间字段整体+8小时)"""
|
||||
cursor = conn.cursor()
|
||||
|
||||
def table_exists(table_name: str) -> bool:
|
||||
cursor.execute("SELECT name FROM sqlite_master WHERE type='table' AND name=?", (table_name,))
|
||||
return cursor.fetchone() is not None
|
||||
|
||||
def column_exists(table_name: str, column_name: str) -> bool:
|
||||
cursor.execute(f"PRAGMA table_info({table_name})")
|
||||
return any(row[1] == column_name for row in cursor.fetchall())
|
||||
columns_cache: dict[str, set[str]] = {}
|
||||
|
||||
def shift_utc_to_cst(table_name: str, column_name: str) -> None:
|
||||
if not table_exists(table_name):
|
||||
if not _table_exists(cursor, table_name):
|
||||
return
|
||||
if not column_exists(table_name, column_name):
|
||||
|
||||
if table_name not in columns_cache:
|
||||
columns_cache[table_name] = _get_table_columns(cursor, table_name)
|
||||
if column_name not in columns_cache[table_name]:
|
||||
return
|
||||
|
||||
cursor.execute(
|
||||
f"""
|
||||
UPDATE {table_name}
|
||||
@@ -329,10 +386,6 @@ def _migrate_to_v7(conn):
|
||||
("accounts", "created_at"),
|
||||
("password_reset_requests", "created_at"),
|
||||
("password_reset_requests", "processed_at"),
|
||||
]:
|
||||
shift_utc_to_cst(table, col)
|
||||
|
||||
for table, col in [
|
||||
("smtp_configs", "created_at"),
|
||||
("smtp_configs", "updated_at"),
|
||||
("smtp_configs", "last_success_at"),
|
||||
@@ -340,10 +393,6 @@ def _migrate_to_v7(conn):
|
||||
("email_tokens", "created_at"),
|
||||
("email_logs", "created_at"),
|
||||
("email_stats", "last_updated"),
|
||||
]:
|
||||
shift_utc_to_cst(table, col)
|
||||
|
||||
for table, col in [
|
||||
("task_checkpoints", "created_at"),
|
||||
("task_checkpoints", "updated_at"),
|
||||
("task_checkpoints", "completed_at"),
|
||||
@@ -359,15 +408,23 @@ def _migrate_to_v8(conn):
|
||||
cursor = conn.cursor()
|
||||
|
||||
# 1) 增量字段:random_delay(旧库可能不存在)
|
||||
cursor.execute("PRAGMA table_info(user_schedules)")
|
||||
columns = [col[1] for col in cursor.fetchall()]
|
||||
if "random_delay" not in columns:
|
||||
cursor.execute("ALTER TABLE user_schedules ADD COLUMN random_delay INTEGER DEFAULT 0")
|
||||
print(" [OK] 添加 user_schedules.random_delay 字段")
|
||||
|
||||
if "next_run_at" not in columns:
|
||||
cursor.execute("ALTER TABLE user_schedules ADD COLUMN next_run_at TIMESTAMP")
|
||||
print(" [OK] 添加 user_schedules.next_run_at 字段")
|
||||
columns = _get_table_columns(cursor, "user_schedules")
|
||||
_add_column_if_missing(
|
||||
cursor,
|
||||
"user_schedules",
|
||||
columns,
|
||||
"random_delay",
|
||||
"INTEGER DEFAULT 0",
|
||||
ok_message=" [OK] 添加 user_schedules.random_delay 字段",
|
||||
)
|
||||
_add_column_if_missing(
|
||||
cursor,
|
||||
"user_schedules",
|
||||
columns,
|
||||
"next_run_at",
|
||||
"TIMESTAMP",
|
||||
ok_message=" [OK] 添加 user_schedules.next_run_at 字段",
|
||||
)
|
||||
|
||||
cursor.execute("CREATE INDEX IF NOT EXISTS idx_user_schedules_next_run ON user_schedules(next_run_at)")
|
||||
conn.commit()
|
||||
@@ -392,12 +449,12 @@ def _migrate_to_v8(conn):
|
||||
fixed = 0
|
||||
for row in rows:
|
||||
try:
|
||||
schedule_id = row["id"] if isinstance(row, sqlite3.Row) else row[0]
|
||||
schedule_time = row["schedule_time"] if isinstance(row, sqlite3.Row) else row[1]
|
||||
weekdays = row["weekdays"] if isinstance(row, sqlite3.Row) else row[2]
|
||||
random_delay = row["random_delay"] if isinstance(row, sqlite3.Row) else row[3]
|
||||
last_run_at = row["last_run_at"] if isinstance(row, sqlite3.Row) else row[4]
|
||||
next_run_at = row["next_run_at"] if isinstance(row, sqlite3.Row) else row[5]
|
||||
schedule_id = _read_row_value(row, "id", 0)
|
||||
schedule_time = _read_row_value(row, "schedule_time", 1)
|
||||
weekdays = _read_row_value(row, "weekdays", 2)
|
||||
random_delay = _read_row_value(row, "random_delay", 3)
|
||||
last_run_at = _read_row_value(row, "last_run_at", 4)
|
||||
next_run_at = _read_row_value(row, "next_run_at", 5)
|
||||
except Exception:
|
||||
continue
|
||||
|
||||
@@ -430,27 +487,46 @@ def _migrate_to_v9(conn):
|
||||
"""迁移到版本9 - 邮件设置字段迁移(清理 email_service scattered ALTER TABLE)"""
|
||||
cursor = conn.cursor()
|
||||
|
||||
cursor.execute("SELECT name FROM sqlite_master WHERE type='table' AND name='email_settings'")
|
||||
if not cursor.fetchone():
|
||||
if not _table_exists(cursor, "email_settings"):
|
||||
# 邮件表由 email_service.init_email_tables 创建;此处仅做增量字段迁移
|
||||
return
|
||||
|
||||
cursor.execute("PRAGMA table_info(email_settings)")
|
||||
columns = [col[1] for col in cursor.fetchall()]
|
||||
columns = _get_table_columns(cursor, "email_settings")
|
||||
|
||||
changed = False
|
||||
if "register_verify_enabled" not in columns:
|
||||
cursor.execute("ALTER TABLE email_settings ADD COLUMN register_verify_enabled INTEGER DEFAULT 0")
|
||||
print(" [OK] 添加 email_settings.register_verify_enabled 字段")
|
||||
changed = True
|
||||
if "base_url" not in columns:
|
||||
cursor.execute("ALTER TABLE email_settings ADD COLUMN base_url TEXT DEFAULT ''")
|
||||
print(" [OK] 添加 email_settings.base_url 字段")
|
||||
changed = True
|
||||
if "task_notify_enabled" not in columns:
|
||||
cursor.execute("ALTER TABLE email_settings ADD COLUMN task_notify_enabled INTEGER DEFAULT 0")
|
||||
print(" [OK] 添加 email_settings.task_notify_enabled 字段")
|
||||
changed = True
|
||||
changed = (
|
||||
_add_column_if_missing(
|
||||
cursor,
|
||||
"email_settings",
|
||||
columns,
|
||||
"register_verify_enabled",
|
||||
"INTEGER DEFAULT 0",
|
||||
ok_message=" [OK] 添加 email_settings.register_verify_enabled 字段",
|
||||
)
|
||||
or changed
|
||||
)
|
||||
changed = (
|
||||
_add_column_if_missing(
|
||||
cursor,
|
||||
"email_settings",
|
||||
columns,
|
||||
"base_url",
|
||||
"TEXT DEFAULT ''",
|
||||
ok_message=" [OK] 添加 email_settings.base_url 字段",
|
||||
)
|
||||
or changed
|
||||
)
|
||||
changed = (
|
||||
_add_column_if_missing(
|
||||
cursor,
|
||||
"email_settings",
|
||||
columns,
|
||||
"task_notify_enabled",
|
||||
"INTEGER DEFAULT 0",
|
||||
ok_message=" [OK] 添加 email_settings.task_notify_enabled 字段",
|
||||
)
|
||||
or changed
|
||||
)
|
||||
|
||||
if changed:
|
||||
conn.commit()
|
||||
@@ -459,18 +535,31 @@ def _migrate_to_v9(conn):
|
||||
def _migrate_to_v10(conn):
|
||||
"""迁移到版本10 - users 邮箱字段迁移(避免运行时 ALTER TABLE)"""
|
||||
cursor = conn.cursor()
|
||||
cursor.execute("PRAGMA table_info(users)")
|
||||
columns = [col[1] for col in cursor.fetchall()]
|
||||
columns = _get_table_columns(cursor, "users")
|
||||
|
||||
changed = False
|
||||
if "email_verified" not in columns:
|
||||
cursor.execute("ALTER TABLE users ADD COLUMN email_verified INTEGER DEFAULT 0")
|
||||
print(" [OK] 添加 users.email_verified 字段")
|
||||
changed = True
|
||||
if "email_notify_enabled" not in columns:
|
||||
cursor.execute("ALTER TABLE users ADD COLUMN email_notify_enabled INTEGER DEFAULT 1")
|
||||
print(" [OK] 添加 users.email_notify_enabled 字段")
|
||||
changed = True
|
||||
changed = (
|
||||
_add_column_if_missing(
|
||||
cursor,
|
||||
"users",
|
||||
columns,
|
||||
"email_verified",
|
||||
"INTEGER DEFAULT 0",
|
||||
ok_message=" [OK] 添加 users.email_verified 字段",
|
||||
)
|
||||
or changed
|
||||
)
|
||||
changed = (
|
||||
_add_column_if_missing(
|
||||
cursor,
|
||||
"users",
|
||||
columns,
|
||||
"email_notify_enabled",
|
||||
"INTEGER DEFAULT 1",
|
||||
ok_message=" [OK] 添加 users.email_notify_enabled 字段",
|
||||
)
|
||||
or changed
|
||||
)
|
||||
|
||||
if changed:
|
||||
conn.commit()
|
||||
@@ -657,19 +746,24 @@ def _migrate_to_v15(conn):
|
||||
"""迁移到版本15 - 邮件设置:新设备登录提醒全局开关"""
|
||||
cursor = conn.cursor()
|
||||
|
||||
cursor.execute("SELECT name FROM sqlite_master WHERE type='table' AND name='email_settings'")
|
||||
if not cursor.fetchone():
|
||||
if not _table_exists(cursor, "email_settings"):
|
||||
# 邮件表由 email_service.init_email_tables 创建;此处仅做增量字段迁移
|
||||
return
|
||||
|
||||
cursor.execute("PRAGMA table_info(email_settings)")
|
||||
columns = [col[1] for col in cursor.fetchall()]
|
||||
columns = _get_table_columns(cursor, "email_settings")
|
||||
|
||||
changed = False
|
||||
if "login_alert_enabled" not in columns:
|
||||
cursor.execute("ALTER TABLE email_settings ADD COLUMN login_alert_enabled INTEGER DEFAULT 1")
|
||||
print(" [OK] 添加 email_settings.login_alert_enabled 字段")
|
||||
changed = True
|
||||
changed = (
|
||||
_add_column_if_missing(
|
||||
cursor,
|
||||
"email_settings",
|
||||
columns,
|
||||
"login_alert_enabled",
|
||||
"INTEGER DEFAULT 1",
|
||||
ok_message=" [OK] 添加 email_settings.login_alert_enabled 字段",
|
||||
)
|
||||
or changed
|
||||
)
|
||||
|
||||
try:
|
||||
cursor.execute("UPDATE email_settings SET login_alert_enabled = 1 WHERE login_alert_enabled IS NULL")
|
||||
@@ -686,22 +780,24 @@ def _migrate_to_v15(conn):
|
||||
def _migrate_to_v16(conn):
|
||||
"""迁移到版本16 - 公告支持图片字段"""
|
||||
cursor = conn.cursor()
|
||||
cursor.execute("PRAGMA table_info(announcements)")
|
||||
columns = [col[1] for col in cursor.fetchall()]
|
||||
columns = _get_table_columns(cursor, "announcements")
|
||||
|
||||
if "image_url" not in columns:
|
||||
cursor.execute("ALTER TABLE announcements ADD COLUMN image_url TEXT")
|
||||
if _add_column_if_missing(
|
||||
cursor,
|
||||
"announcements",
|
||||
columns,
|
||||
"image_url",
|
||||
"TEXT",
|
||||
ok_message=" [OK] 添加 announcements.image_url 字段",
|
||||
):
|
||||
conn.commit()
|
||||
print(" [OK] 添加 announcements.image_url 字段")
|
||||
|
||||
|
||||
def _migrate_to_v17(conn):
|
||||
"""迁移到版本17 - 金山文档上传配置与用户开关"""
|
||||
cursor = conn.cursor()
|
||||
|
||||
cursor.execute("PRAGMA table_info(system_config)")
|
||||
columns = [col[1] for col in cursor.fetchall()]
|
||||
|
||||
system_columns = _get_table_columns(cursor, "system_config")
|
||||
system_fields = [
|
||||
("kdocs_enabled", "INTEGER DEFAULT 0"),
|
||||
("kdocs_doc_url", "TEXT DEFAULT ''"),
|
||||
@@ -714,21 +810,29 @@ def _migrate_to_v17(conn):
|
||||
("kdocs_admin_notify_email", "TEXT DEFAULT ''"),
|
||||
]
|
||||
for field, ddl in system_fields:
|
||||
if field not in columns:
|
||||
cursor.execute(f"ALTER TABLE system_config ADD COLUMN {field} {ddl}")
|
||||
print(f" [OK] 添加 system_config.{field} 字段")
|
||||
|
||||
cursor.execute("PRAGMA table_info(users)")
|
||||
columns = [col[1] for col in cursor.fetchall()]
|
||||
_add_column_if_missing(
|
||||
cursor,
|
||||
"system_config",
|
||||
system_columns,
|
||||
field,
|
||||
ddl,
|
||||
ok_message=f" [OK] 添加 system_config.{field} 字段",
|
||||
)
|
||||
|
||||
user_columns = _get_table_columns(cursor, "users")
|
||||
user_fields = [
|
||||
("kdocs_unit", "TEXT DEFAULT ''"),
|
||||
("kdocs_auto_upload", "INTEGER DEFAULT 0"),
|
||||
]
|
||||
for field, ddl in user_fields:
|
||||
if field not in columns:
|
||||
cursor.execute(f"ALTER TABLE users ADD COLUMN {field} {ddl}")
|
||||
print(f" [OK] 添加 users.{field} 字段")
|
||||
_add_column_if_missing(
|
||||
cursor,
|
||||
"users",
|
||||
user_columns,
|
||||
field,
|
||||
ddl,
|
||||
ok_message=f" [OK] 添加 users.{field} 字段",
|
||||
)
|
||||
|
||||
conn.commit()
|
||||
|
||||
@@ -737,15 +841,22 @@ def _migrate_to_v18(conn):
|
||||
"""迁移到版本18 - 金山文档上传:有效行范围配置"""
|
||||
cursor = conn.cursor()
|
||||
|
||||
cursor.execute("PRAGMA table_info(system_config)")
|
||||
columns = [col[1] for col in cursor.fetchall()]
|
||||
|
||||
if "kdocs_row_start" not in columns:
|
||||
cursor.execute("ALTER TABLE system_config ADD COLUMN kdocs_row_start INTEGER DEFAULT 0")
|
||||
print(" [OK] 添加 system_config.kdocs_row_start 字段")
|
||||
|
||||
if "kdocs_row_end" not in columns:
|
||||
cursor.execute("ALTER TABLE system_config ADD COLUMN kdocs_row_end INTEGER DEFAULT 0")
|
||||
print(" [OK] 添加 system_config.kdocs_row_end 字段")
|
||||
columns = _get_table_columns(cursor, "system_config")
|
||||
_add_column_if_missing(
|
||||
cursor,
|
||||
"system_config",
|
||||
columns,
|
||||
"kdocs_row_start",
|
||||
"INTEGER DEFAULT 0",
|
||||
ok_message=" [OK] 添加 system_config.kdocs_row_start 字段",
|
||||
)
|
||||
_add_column_if_missing(
|
||||
cursor,
|
||||
"system_config",
|
||||
columns,
|
||||
"kdocs_row_end",
|
||||
"INTEGER DEFAULT 0",
|
||||
ok_message=" [OK] 添加 system_config.kdocs_row_end 字段",
|
||||
)
|
||||
|
||||
conn.commit()
|
||||
|
||||
292
db/schedules.py
292
db/schedules.py
@@ -2,12 +2,93 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
from __future__ import annotations
|
||||
|
||||
from datetime import datetime
|
||||
import json
|
||||
from datetime import datetime, timedelta
|
||||
|
||||
import db_pool
|
||||
from services.schedule_utils import compute_next_run_at, format_cst
|
||||
from services.time_utils import get_beijing_now
|
||||
|
||||
_SCHEDULE_DEFAULT_TIME = "08:00"
|
||||
_SCHEDULE_DEFAULT_WEEKDAYS = "1,2,3,4,5"
|
||||
|
||||
_ALLOWED_SCHEDULE_UPDATE_FIELDS = (
|
||||
"name",
|
||||
"enabled",
|
||||
"schedule_time",
|
||||
"weekdays",
|
||||
"browse_type",
|
||||
"enable_screenshot",
|
||||
"random_delay",
|
||||
"account_ids",
|
||||
)
|
||||
|
||||
_ALLOWED_EXEC_LOG_UPDATE_FIELDS = (
|
||||
"total_accounts",
|
||||
"success_accounts",
|
||||
"failed_accounts",
|
||||
"total_items",
|
||||
"total_attachments",
|
||||
"total_screenshots",
|
||||
"duration_seconds",
|
||||
"status",
|
||||
"error_message",
|
||||
)
|
||||
|
||||
|
||||
def _normalize_limit(limit, default: int, *, minimum: int = 1) -> int:
|
||||
try:
|
||||
parsed = int(limit)
|
||||
except Exception:
|
||||
parsed = default
|
||||
if parsed < minimum:
|
||||
return minimum
|
||||
return parsed
|
||||
|
||||
|
||||
def _to_int(value, default: int = 0) -> int:
|
||||
try:
|
||||
return int(value)
|
||||
except Exception:
|
||||
return default
|
||||
|
||||
|
||||
def _format_optional_datetime(dt: datetime | None) -> str | None:
|
||||
if dt is None:
|
||||
return None
|
||||
return format_cst(dt)
|
||||
|
||||
|
||||
def _serialize_account_ids(account_ids) -> str:
|
||||
return json.dumps(account_ids) if account_ids else "[]"
|
||||
|
||||
|
||||
def _compute_schedule_next_run_str(
|
||||
*,
|
||||
now_dt,
|
||||
schedule_time,
|
||||
weekdays,
|
||||
random_delay,
|
||||
last_run_at,
|
||||
) -> str:
|
||||
next_dt = compute_next_run_at(
|
||||
now=now_dt,
|
||||
schedule_time=str(schedule_time or _SCHEDULE_DEFAULT_TIME),
|
||||
weekdays=str(weekdays or _SCHEDULE_DEFAULT_WEEKDAYS),
|
||||
random_delay=_to_int(random_delay, 0),
|
||||
last_run_at=str(last_run_at or "") if last_run_at else None,
|
||||
)
|
||||
return format_cst(next_dt)
|
||||
|
||||
|
||||
def _map_schedule_log_row(row) -> dict:
|
||||
log = dict(row)
|
||||
log["created_at"] = log.get("execute_time")
|
||||
log["success_count"] = log.get("success_accounts", 0)
|
||||
log["failed_count"] = log.get("failed_accounts", 0)
|
||||
log["duration"] = log.get("duration_seconds", 0)
|
||||
return log
|
||||
|
||||
|
||||
def get_user_schedules(user_id):
|
||||
"""获取用户的所有定时任务"""
|
||||
@@ -44,14 +125,10 @@ def create_user_schedule(
|
||||
account_ids=None,
|
||||
):
|
||||
"""创建用户定时任务"""
|
||||
import json
|
||||
|
||||
with db_pool.get_db() as conn:
|
||||
cursor = conn.cursor()
|
||||
cst_time = format_cst(get_beijing_now())
|
||||
|
||||
account_ids_str = json.dumps(account_ids) if account_ids else "[]"
|
||||
|
||||
cursor.execute(
|
||||
"""
|
||||
INSERT INTO user_schedules (
|
||||
@@ -66,8 +143,8 @@ def create_user_schedule(
|
||||
weekdays,
|
||||
browse_type,
|
||||
enable_screenshot,
|
||||
int(random_delay or 0),
|
||||
account_ids_str,
|
||||
_to_int(random_delay, 0),
|
||||
_serialize_account_ids(account_ids),
|
||||
cst_time,
|
||||
cst_time,
|
||||
),
|
||||
@@ -79,28 +156,11 @@ def create_user_schedule(
|
||||
|
||||
def update_user_schedule(schedule_id, **kwargs):
|
||||
"""更新用户定时任务"""
|
||||
import json
|
||||
|
||||
with db_pool.get_db() as conn:
|
||||
cursor = conn.cursor()
|
||||
now_dt = get_beijing_now()
|
||||
now_str = format_cst(now_dt)
|
||||
|
||||
updates = []
|
||||
params = []
|
||||
|
||||
allowed_fields = [
|
||||
"name",
|
||||
"enabled",
|
||||
"schedule_time",
|
||||
"weekdays",
|
||||
"browse_type",
|
||||
"enable_screenshot",
|
||||
"random_delay",
|
||||
"account_ids",
|
||||
]
|
||||
|
||||
# 读取旧值,用于决定是否需要重算 next_run_at
|
||||
cursor.execute(
|
||||
"""
|
||||
SELECT enabled, schedule_time, weekdays, random_delay, last_run_at
|
||||
@@ -112,10 +172,11 @@ def update_user_schedule(schedule_id, **kwargs):
|
||||
current = cursor.fetchone()
|
||||
if not current:
|
||||
return False
|
||||
current_enabled = int(current[0] or 0)
|
||||
|
||||
current_enabled = _to_int(current[0], 0)
|
||||
current_time = current[1]
|
||||
current_weekdays = current[2]
|
||||
current_random_delay = int(current[3] or 0)
|
||||
current_random_delay = _to_int(current[3], 0)
|
||||
current_last_run_at = current[4]
|
||||
|
||||
will_enabled = current_enabled
|
||||
@@ -123,21 +184,28 @@ def update_user_schedule(schedule_id, **kwargs):
|
||||
next_weekdays = current_weekdays
|
||||
next_random_delay = current_random_delay
|
||||
|
||||
for field in allowed_fields:
|
||||
if field in kwargs:
|
||||
value = kwargs[field]
|
||||
if field == "account_ids" and isinstance(value, list):
|
||||
value = json.dumps(value)
|
||||
if field == "enabled":
|
||||
will_enabled = 1 if value else 0
|
||||
if field == "schedule_time":
|
||||
next_time = value
|
||||
if field == "weekdays":
|
||||
next_weekdays = value
|
||||
if field == "random_delay":
|
||||
next_random_delay = int(value or 0)
|
||||
updates.append(f"{field} = ?")
|
||||
params.append(value)
|
||||
updates = []
|
||||
params = []
|
||||
|
||||
for field in _ALLOWED_SCHEDULE_UPDATE_FIELDS:
|
||||
if field not in kwargs:
|
||||
continue
|
||||
|
||||
value = kwargs[field]
|
||||
if field == "account_ids" and isinstance(value, list):
|
||||
value = json.dumps(value)
|
||||
|
||||
if field == "enabled":
|
||||
will_enabled = 1 if value else 0
|
||||
if field == "schedule_time":
|
||||
next_time = value
|
||||
if field == "weekdays":
|
||||
next_weekdays = value
|
||||
if field == "random_delay":
|
||||
next_random_delay = int(value or 0)
|
||||
|
||||
updates.append(f"{field} = ?")
|
||||
params.append(value)
|
||||
|
||||
if not updates:
|
||||
return False
|
||||
@@ -145,30 +213,26 @@ def update_user_schedule(schedule_id, **kwargs):
|
||||
updates.append("updated_at = ?")
|
||||
params.append(now_str)
|
||||
|
||||
# 关键字段变更后重算 next_run_at,确保索引驱动不会跑偏
|
||||
#
|
||||
# 需求:当用户修改“执行时间/执行日期/随机±15分钟”后,即使今天已经执行过,也允许按新配置在今天再次触发。
|
||||
# 做法:这些关键字段发生变更时,重算 next_run_at 时忽略 last_run_at 的“同日仅一次”限制。
|
||||
config_changed = any(key in kwargs for key in ["schedule_time", "weekdays", "random_delay"])
|
||||
config_changed = any(key in kwargs for key in ("schedule_time", "weekdays", "random_delay"))
|
||||
enabled_toggled = "enabled" in kwargs
|
||||
should_recompute_next = config_changed or (enabled_toggled and will_enabled == 1)
|
||||
|
||||
if should_recompute_next:
|
||||
next_dt = compute_next_run_at(
|
||||
now=now_dt,
|
||||
schedule_time=str(next_time or "08:00"),
|
||||
weekdays=str(next_weekdays or "1,2,3,4,5"),
|
||||
random_delay=int(next_random_delay or 0),
|
||||
last_run_at=None if config_changed else (str(current_last_run_at or "") if current_last_run_at else None),
|
||||
next_run_at = _compute_schedule_next_run_str(
|
||||
now_dt=now_dt,
|
||||
schedule_time=next_time,
|
||||
weekdays=next_weekdays,
|
||||
random_delay=next_random_delay,
|
||||
last_run_at=None if config_changed else current_last_run_at,
|
||||
)
|
||||
updates.append("next_run_at = ?")
|
||||
params.append(format_cst(next_dt))
|
||||
params.append(next_run_at)
|
||||
|
||||
# 若本次显式禁用任务,则 next_run_at 清空(与 toggle 行为保持一致)
|
||||
if enabled_toggled and will_enabled == 0:
|
||||
updates.append("next_run_at = ?")
|
||||
params.append(None)
|
||||
params.append(schedule_id)
|
||||
|
||||
params.append(schedule_id)
|
||||
sql = f"UPDATE user_schedules SET {', '.join(updates)} WHERE id = ?"
|
||||
cursor.execute(sql, params)
|
||||
conn.commit()
|
||||
@@ -203,28 +267,19 @@ def toggle_user_schedule(schedule_id, enabled):
|
||||
)
|
||||
row = cursor.fetchone()
|
||||
if row:
|
||||
schedule_time, weekdays, random_delay, last_run_at, existing_next_run_at = (
|
||||
row[0],
|
||||
row[1],
|
||||
row[2],
|
||||
row[3],
|
||||
row[4],
|
||||
)
|
||||
|
||||
schedule_time, weekdays, random_delay, last_run_at, existing_next_run_at = row
|
||||
existing_next_run_at = str(existing_next_run_at or "").strip() or None
|
||||
# 若 next_run_at 已经被“修改配置”逻辑预先计算好且仍在未来,则优先沿用,
|
||||
# 避免 last_run_at 的“同日仅一次”限制阻塞用户把任务调整到今天再次触发。
|
||||
|
||||
if existing_next_run_at and existing_next_run_at > now_str:
|
||||
next_run_at = existing_next_run_at
|
||||
else:
|
||||
next_dt = compute_next_run_at(
|
||||
now=now_dt,
|
||||
schedule_time=str(schedule_time or "08:00"),
|
||||
weekdays=str(weekdays or "1,2,3,4,5"),
|
||||
random_delay=int(random_delay or 0),
|
||||
last_run_at=str(last_run_at or "") if last_run_at else None,
|
||||
next_run_at = _compute_schedule_next_run_str(
|
||||
now_dt=now_dt,
|
||||
schedule_time=schedule_time,
|
||||
weekdays=weekdays,
|
||||
random_delay=random_delay,
|
||||
last_run_at=last_run_at,
|
||||
)
|
||||
next_run_at = format_cst(next_dt)
|
||||
|
||||
cursor.execute(
|
||||
"""
|
||||
@@ -272,16 +327,15 @@ def update_schedule_last_run(schedule_id):
|
||||
row = cursor.fetchone()
|
||||
if not row:
|
||||
return False
|
||||
schedule_time, weekdays, random_delay = row[0], row[1], row[2]
|
||||
|
||||
next_dt = compute_next_run_at(
|
||||
now=now_dt,
|
||||
schedule_time=str(schedule_time or "08:00"),
|
||||
weekdays=str(weekdays or "1,2,3,4,5"),
|
||||
random_delay=int(random_delay or 0),
|
||||
schedule_time, weekdays, random_delay = row
|
||||
next_run_at = _compute_schedule_next_run_str(
|
||||
now_dt=now_dt,
|
||||
schedule_time=schedule_time,
|
||||
weekdays=weekdays,
|
||||
random_delay=random_delay,
|
||||
last_run_at=now_str,
|
||||
)
|
||||
next_run_at = format_cst(next_dt)
|
||||
|
||||
cursor.execute(
|
||||
"""
|
||||
@@ -305,7 +359,11 @@ def update_schedule_next_run(schedule_id: int, next_run_at: str) -> bool:
|
||||
SET next_run_at = ?, updated_at = ?
|
||||
WHERE id = ?
|
||||
""",
|
||||
(str(next_run_at or "").strip() or None, format_cst(get_beijing_now()), int(schedule_id)),
|
||||
(
|
||||
str(next_run_at or "").strip() or None,
|
||||
format_cst(get_beijing_now()),
|
||||
int(schedule_id),
|
||||
),
|
||||
)
|
||||
conn.commit()
|
||||
return cursor.rowcount > 0
|
||||
@@ -328,15 +386,15 @@ def recompute_schedule_next_run(schedule_id: int, *, now_dt=None) -> bool:
|
||||
if not row:
|
||||
return False
|
||||
|
||||
schedule_time, weekdays, random_delay, last_run_at = row[0], row[1], row[2], row[3]
|
||||
next_dt = compute_next_run_at(
|
||||
now=now_dt,
|
||||
schedule_time=str(schedule_time or "08:00"),
|
||||
weekdays=str(weekdays or "1,2,3,4,5"),
|
||||
random_delay=int(random_delay or 0),
|
||||
last_run_at=str(last_run_at or "") if last_run_at else None,
|
||||
schedule_time, weekdays, random_delay, last_run_at = row
|
||||
next_run_at = _compute_schedule_next_run_str(
|
||||
now_dt=now_dt,
|
||||
schedule_time=schedule_time,
|
||||
weekdays=weekdays,
|
||||
random_delay=random_delay,
|
||||
last_run_at=last_run_at,
|
||||
)
|
||||
return update_schedule_next_run(int(schedule_id), format_cst(next_dt))
|
||||
return update_schedule_next_run(int(schedule_id), next_run_at)
|
||||
|
||||
|
||||
def get_due_user_schedules(now_cst: str, limit: int = 50):
|
||||
@@ -345,6 +403,8 @@ def get_due_user_schedules(now_cst: str, limit: int = 50):
|
||||
if not now_cst:
|
||||
now_cst = format_cst(get_beijing_now())
|
||||
|
||||
safe_limit = _normalize_limit(limit, 50, minimum=1)
|
||||
|
||||
with db_pool.get_db() as conn:
|
||||
cursor = conn.cursor()
|
||||
cursor.execute(
|
||||
@@ -358,7 +418,7 @@ def get_due_user_schedules(now_cst: str, limit: int = 50):
|
||||
ORDER BY us.next_run_at ASC
|
||||
LIMIT ?
|
||||
""",
|
||||
(now_cst, int(limit)),
|
||||
(now_cst, safe_limit),
|
||||
)
|
||||
return [dict(row) for row in cursor.fetchall()]
|
||||
|
||||
@@ -370,15 +430,13 @@ def create_schedule_execution_log(schedule_id, user_id, schedule_name):
|
||||
"""创建定时任务执行日志"""
|
||||
with db_pool.get_db() as conn:
|
||||
cursor = conn.cursor()
|
||||
execute_time = format_cst(get_beijing_now())
|
||||
|
||||
cursor.execute(
|
||||
"""
|
||||
INSERT INTO schedule_execution_logs (
|
||||
schedule_id, user_id, schedule_name, execute_time, status
|
||||
) VALUES (?, ?, ?, ?, 'running')
|
||||
""",
|
||||
(schedule_id, user_id, schedule_name, execute_time),
|
||||
(schedule_id, user_id, schedule_name, format_cst(get_beijing_now())),
|
||||
)
|
||||
|
||||
conn.commit()
|
||||
@@ -393,22 +451,11 @@ def update_schedule_execution_log(log_id, **kwargs):
|
||||
updates = []
|
||||
params = []
|
||||
|
||||
allowed_fields = [
|
||||
"total_accounts",
|
||||
"success_accounts",
|
||||
"failed_accounts",
|
||||
"total_items",
|
||||
"total_attachments",
|
||||
"total_screenshots",
|
||||
"duration_seconds",
|
||||
"status",
|
||||
"error_message",
|
||||
]
|
||||
|
||||
for field in allowed_fields:
|
||||
if field in kwargs:
|
||||
updates.append(f"{field} = ?")
|
||||
params.append(kwargs[field])
|
||||
for field in _ALLOWED_EXEC_LOG_UPDATE_FIELDS:
|
||||
if field not in kwargs:
|
||||
continue
|
||||
updates.append(f"{field} = ?")
|
||||
params.append(kwargs[field])
|
||||
|
||||
if not updates:
|
||||
return False
|
||||
@@ -424,6 +471,7 @@ def update_schedule_execution_log(log_id, **kwargs):
|
||||
def get_schedule_execution_logs(schedule_id, limit=10):
|
||||
"""获取定时任务执行日志"""
|
||||
try:
|
||||
safe_limit = _normalize_limit(limit, 10, minimum=1)
|
||||
with db_pool.get_db() as conn:
|
||||
cursor = conn.cursor()
|
||||
cursor.execute(
|
||||
@@ -433,24 +481,16 @@ def get_schedule_execution_logs(schedule_id, limit=10):
|
||||
ORDER BY execute_time DESC
|
||||
LIMIT ?
|
||||
""",
|
||||
(schedule_id, limit),
|
||||
(schedule_id, safe_limit),
|
||||
)
|
||||
|
||||
logs = []
|
||||
rows = cursor.fetchall()
|
||||
|
||||
for row in rows:
|
||||
for row in cursor.fetchall():
|
||||
try:
|
||||
log = dict(row)
|
||||
log["created_at"] = log.get("execute_time")
|
||||
log["success_count"] = log.get("success_accounts", 0)
|
||||
log["failed_count"] = log.get("failed_accounts", 0)
|
||||
log["duration"] = log.get("duration_seconds", 0)
|
||||
logs.append(log)
|
||||
logs.append(_map_schedule_log_row(row))
|
||||
except Exception as e:
|
||||
print(f"[数据库] 处理日志行时出错: {e}")
|
||||
continue
|
||||
|
||||
return logs
|
||||
except Exception as e:
|
||||
print(f"[数据库] 查询定时任务日志时出错: {e}")
|
||||
@@ -462,6 +502,7 @@ def get_schedule_execution_logs(schedule_id, limit=10):
|
||||
|
||||
def get_user_all_schedule_logs(user_id, limit=50):
|
||||
"""获取用户所有定时任务的执行日志"""
|
||||
safe_limit = _normalize_limit(limit, 50, minimum=1)
|
||||
with db_pool.get_db() as conn:
|
||||
cursor = conn.cursor()
|
||||
cursor.execute(
|
||||
@@ -471,7 +512,7 @@ def get_user_all_schedule_logs(user_id, limit=50):
|
||||
ORDER BY execute_time DESC
|
||||
LIMIT ?
|
||||
""",
|
||||
(user_id, limit),
|
||||
(user_id, safe_limit),
|
||||
)
|
||||
return [dict(row) for row in cursor.fetchall()]
|
||||
|
||||
@@ -493,14 +534,21 @@ def delete_schedule_logs(schedule_id, user_id):
|
||||
|
||||
def clean_old_schedule_logs(days=30):
|
||||
"""清理指定天数前的定时任务执行日志"""
|
||||
safe_days = _to_int(days, 30)
|
||||
if safe_days < 0:
|
||||
safe_days = 0
|
||||
|
||||
cutoff_dt = get_beijing_now() - timedelta(days=safe_days)
|
||||
cutoff_str = format_cst(cutoff_dt)
|
||||
|
||||
with db_pool.get_db() as conn:
|
||||
cursor = conn.cursor()
|
||||
cursor.execute(
|
||||
"""
|
||||
DELETE FROM schedule_execution_logs
|
||||
WHERE execute_time < datetime('now', 'localtime', '-' || ? || ' days')
|
||||
WHERE execute_time < ?
|
||||
""",
|
||||
(days,),
|
||||
(cutoff_str,),
|
||||
)
|
||||
conn.commit()
|
||||
return cursor.rowcount
|
||||
|
||||
@@ -362,6 +362,8 @@ def ensure_schema(conn) -> None:
|
||||
cursor.execute("CREATE INDEX IF NOT EXISTS idx_users_username ON users(username)")
|
||||
cursor.execute("CREATE INDEX IF NOT EXISTS idx_users_status ON users(status)")
|
||||
cursor.execute("CREATE INDEX IF NOT EXISTS idx_users_vip_expire ON users(vip_expire_time)")
|
||||
cursor.execute("CREATE INDEX IF NOT EXISTS idx_users_email ON users(email)")
|
||||
cursor.execute("CREATE INDEX IF NOT EXISTS idx_users_created_at ON users(created_at)")
|
||||
cursor.execute("CREATE INDEX IF NOT EXISTS idx_login_fingerprints_user ON login_fingerprints(user_id)")
|
||||
cursor.execute("CREATE INDEX IF NOT EXISTS idx_login_ips_user ON login_ips(user_id)")
|
||||
|
||||
@@ -391,6 +393,8 @@ def ensure_schema(conn) -> None:
|
||||
cursor.execute("CREATE INDEX IF NOT EXISTS idx_task_logs_user_id ON task_logs(user_id)")
|
||||
cursor.execute("CREATE INDEX IF NOT EXISTS idx_task_logs_status ON task_logs(status)")
|
||||
cursor.execute("CREATE INDEX IF NOT EXISTS idx_task_logs_created_at ON task_logs(created_at)")
|
||||
cursor.execute("CREATE INDEX IF NOT EXISTS idx_task_logs_source ON task_logs(source)")
|
||||
cursor.execute("CREATE INDEX IF NOT EXISTS idx_task_logs_source_created_at ON task_logs(source, created_at)")
|
||||
cursor.execute("CREATE INDEX IF NOT EXISTS idx_task_logs_user_date ON task_logs(user_id, created_at)")
|
||||
|
||||
cursor.execute("CREATE INDEX IF NOT EXISTS idx_bug_feedbacks_user_id ON bug_feedbacks(user_id)")
|
||||
@@ -409,6 +413,9 @@ def ensure_schema(conn) -> None:
|
||||
cursor.execute("CREATE INDEX IF NOT EXISTS idx_schedule_execution_logs_schedule_id ON schedule_execution_logs(schedule_id)")
|
||||
cursor.execute("CREATE INDEX IF NOT EXISTS idx_schedule_execution_logs_user_id ON schedule_execution_logs(user_id)")
|
||||
cursor.execute("CREATE INDEX IF NOT EXISTS idx_schedule_execution_logs_status ON schedule_execution_logs(status)")
|
||||
cursor.execute("CREATE INDEX IF NOT EXISTS idx_schedule_execution_logs_execute_time ON schedule_execution_logs(execute_time)")
|
||||
cursor.execute("CREATE INDEX IF NOT EXISTS idx_schedule_execution_logs_schedule_time ON schedule_execution_logs(schedule_id, execute_time)")
|
||||
cursor.execute("CREATE INDEX IF NOT EXISTS idx_schedule_execution_logs_user_time ON schedule_execution_logs(user_id, execute_time)")
|
||||
|
||||
# 初始化VIP配置(幂等)
|
||||
try:
|
||||
|
||||
167
db/security.py
167
db/security.py
@@ -3,13 +3,82 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from datetime import timedelta
|
||||
from typing import Any, Optional
|
||||
from typing import Dict
|
||||
from typing import Any, Dict, Optional
|
||||
|
||||
import db_pool
|
||||
from db.utils import get_cst_now, get_cst_now_str
|
||||
|
||||
|
||||
_THREAT_EVENT_SELECT_COLUMNS = """
|
||||
id,
|
||||
threat_type,
|
||||
score,
|
||||
rule,
|
||||
field_name,
|
||||
matched,
|
||||
value_preview,
|
||||
ip,
|
||||
user_id,
|
||||
request_method,
|
||||
request_path,
|
||||
user_agent,
|
||||
created_at
|
||||
"""
|
||||
|
||||
|
||||
def _normalize_page(page: int) -> int:
|
||||
try:
|
||||
page_i = int(page)
|
||||
except Exception:
|
||||
page_i = 1
|
||||
return max(1, page_i)
|
||||
|
||||
|
||||
def _normalize_per_page(per_page: int, default: int = 20) -> int:
|
||||
try:
|
||||
value = int(per_page)
|
||||
except Exception:
|
||||
value = default
|
||||
return max(1, min(200, value))
|
||||
|
||||
|
||||
def _normalize_limit(limit: int, default: int = 50) -> int:
|
||||
try:
|
||||
value = int(limit)
|
||||
except Exception:
|
||||
value = default
|
||||
return max(1, min(200, value))
|
||||
|
||||
|
||||
def _row_value(row, key: str, index: int = 0, default=None):
|
||||
if row is None:
|
||||
return default
|
||||
try:
|
||||
return row[key]
|
||||
except Exception:
|
||||
try:
|
||||
return row[index]
|
||||
except Exception:
|
||||
return default
|
||||
|
||||
|
||||
def _fetch_threat_events_history(where_clause: str, params: tuple[Any, ...], limit_i: int) -> list[dict]:
|
||||
with db_pool.get_db() as conn:
|
||||
cursor = conn.cursor()
|
||||
cursor.execute(
|
||||
f"""
|
||||
SELECT
|
||||
{_THREAT_EVENT_SELECT_COLUMNS}
|
||||
FROM threat_events
|
||||
WHERE {where_clause}
|
||||
ORDER BY created_at DESC, id DESC
|
||||
LIMIT ?
|
||||
""",
|
||||
tuple(params) + (limit_i,),
|
||||
)
|
||||
return [dict(r) for r in cursor.fetchall()]
|
||||
|
||||
|
||||
def record_login_context(user_id: int, ip_address: str, user_agent: str) -> Dict[str, bool]:
|
||||
"""记录登录环境信息,返回是否新设备/新IP。"""
|
||||
user_id = int(user_id)
|
||||
@@ -36,7 +105,7 @@ def record_login_context(user_id: int, ip_address: str, user_agent: str) -> Dict
|
||||
SET last_seen = ?, last_ip = ?
|
||||
WHERE id = ?
|
||||
""",
|
||||
(now_str, ip_text, row["id"] if isinstance(row, dict) else row[0]),
|
||||
(now_str, ip_text, _row_value(row, "id", 0)),
|
||||
)
|
||||
else:
|
||||
cursor.execute(
|
||||
@@ -61,7 +130,7 @@ def record_login_context(user_id: int, ip_address: str, user_agent: str) -> Dict
|
||||
SET last_seen = ?
|
||||
WHERE id = ?
|
||||
""",
|
||||
(now_str, row["id"] if isinstance(row, dict) else row[0]),
|
||||
(now_str, _row_value(row, "id", 0)),
|
||||
)
|
||||
else:
|
||||
cursor.execute(
|
||||
@@ -166,15 +235,8 @@ def _build_threat_events_where_clause(filters: Optional[dict]) -> tuple[str, lis
|
||||
|
||||
def get_threat_events_list(page: int, per_page: int, filters: Optional[dict] = None) -> dict:
|
||||
"""分页获取威胁事件。"""
|
||||
try:
|
||||
page_i = max(1, int(page))
|
||||
except Exception:
|
||||
page_i = 1
|
||||
try:
|
||||
per_page_i = int(per_page)
|
||||
except Exception:
|
||||
per_page_i = 20
|
||||
per_page_i = max(1, min(200, per_page_i))
|
||||
page_i = _normalize_page(page)
|
||||
per_page_i = _normalize_per_page(per_page, default=20)
|
||||
|
||||
where_sql, params = _build_threat_events_where_clause(filters)
|
||||
offset = (page_i - 1) * per_page_i
|
||||
@@ -188,19 +250,7 @@ def get_threat_events_list(page: int, per_page: int, filters: Optional[dict] = N
|
||||
cursor.execute(
|
||||
f"""
|
||||
SELECT
|
||||
id,
|
||||
threat_type,
|
||||
score,
|
||||
rule,
|
||||
field_name,
|
||||
matched,
|
||||
value_preview,
|
||||
ip,
|
||||
user_id,
|
||||
request_method,
|
||||
request_path,
|
||||
user_agent,
|
||||
created_at
|
||||
{_THREAT_EVENT_SELECT_COLUMNS}
|
||||
FROM threat_events
|
||||
{where_sql}
|
||||
ORDER BY created_at DESC, id DESC
|
||||
@@ -218,75 +268,20 @@ def get_ip_threat_history(ip: str, limit: int = 50) -> list[dict]:
|
||||
ip_text = str(ip or "").strip()[:64]
|
||||
if not ip_text:
|
||||
return []
|
||||
try:
|
||||
limit_i = max(1, min(200, int(limit)))
|
||||
except Exception:
|
||||
limit_i = 50
|
||||
|
||||
with db_pool.get_db() as conn:
|
||||
cursor = conn.cursor()
|
||||
cursor.execute(
|
||||
"""
|
||||
SELECT
|
||||
id,
|
||||
threat_type,
|
||||
score,
|
||||
rule,
|
||||
field_name,
|
||||
matched,
|
||||
value_preview,
|
||||
ip,
|
||||
user_id,
|
||||
request_method,
|
||||
request_path,
|
||||
user_agent,
|
||||
created_at
|
||||
FROM threat_events
|
||||
WHERE ip = ?
|
||||
ORDER BY created_at DESC, id DESC
|
||||
LIMIT ?
|
||||
""",
|
||||
(ip_text, limit_i),
|
||||
)
|
||||
return [dict(r) for r in cursor.fetchall()]
|
||||
limit_i = _normalize_limit(limit, default=50)
|
||||
return _fetch_threat_events_history("ip = ?", (ip_text,), limit_i)
|
||||
|
||||
|
||||
def get_user_threat_history(user_id: int, limit: int = 50) -> list[dict]:
|
||||
"""获取用户的威胁历史(最近limit条)。"""
|
||||
if user_id is None:
|
||||
return []
|
||||
|
||||
try:
|
||||
user_id_int = int(user_id)
|
||||
except Exception:
|
||||
return []
|
||||
try:
|
||||
limit_i = max(1, min(200, int(limit)))
|
||||
except Exception:
|
||||
limit_i = 50
|
||||
|
||||
with db_pool.get_db() as conn:
|
||||
cursor = conn.cursor()
|
||||
cursor.execute(
|
||||
"""
|
||||
SELECT
|
||||
id,
|
||||
threat_type,
|
||||
score,
|
||||
rule,
|
||||
field_name,
|
||||
matched,
|
||||
value_preview,
|
||||
ip,
|
||||
user_id,
|
||||
request_method,
|
||||
request_path,
|
||||
user_agent,
|
||||
created_at
|
||||
FROM threat_events
|
||||
WHERE user_id = ?
|
||||
ORDER BY created_at DESC, id DESC
|
||||
LIMIT ?
|
||||
""",
|
||||
(user_id_int, limit_i),
|
||||
)
|
||||
return [dict(r) for r in cursor.fetchall()]
|
||||
limit_i = _normalize_limit(limit, default=50)
|
||||
return _fetch_threat_events_history("user_id = ?", (user_id_int,), limit_i)
|
||||
|
||||
303
db/tasks.py
303
db/tasks.py
@@ -2,12 +2,135 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
from __future__ import annotations
|
||||
|
||||
from datetime import datetime
|
||||
|
||||
import pytz
|
||||
from datetime import datetime, timedelta
|
||||
|
||||
import db_pool
|
||||
from db.utils import sanitize_sql_like_pattern
|
||||
from db.utils import get_cst_now, get_cst_now_str, sanitize_sql_like_pattern
|
||||
|
||||
_TASK_STATS_SELECT_SQL = """
|
||||
SELECT
|
||||
COUNT(*) as total_tasks,
|
||||
SUM(CASE WHEN status = 'success' THEN 1 ELSE 0 END) as success_tasks,
|
||||
SUM(CASE WHEN status = 'failed' THEN 1 ELSE 0 END) as failed_tasks,
|
||||
SUM(total_items) as total_items,
|
||||
SUM(total_attachments) as total_attachments
|
||||
FROM task_logs
|
||||
"""
|
||||
|
||||
_USER_RUN_STATS_SELECT_SQL = """
|
||||
SELECT
|
||||
SUM(CASE WHEN status = 'success' THEN 1 ELSE 0 END) as completed,
|
||||
SUM(CASE WHEN status = 'failed' THEN 1 ELSE 0 END) as failed,
|
||||
SUM(total_items) as total_items,
|
||||
SUM(total_attachments) as total_attachments
|
||||
FROM task_logs
|
||||
"""
|
||||
|
||||
|
||||
def _build_day_bounds(date_filter: str) -> tuple[str | None, str | None]:
|
||||
"""将 YYYY-MM-DD 转换为 [day_start, day_end) 区间。"""
|
||||
try:
|
||||
day_start = datetime.strptime(str(date_filter), "%Y-%m-%d")
|
||||
except Exception:
|
||||
return None, None
|
||||
|
||||
day_end = day_start + timedelta(days=1)
|
||||
return day_start.strftime("%Y-%m-%d %H:%M:%S"), day_end.strftime("%Y-%m-%d %H:%M:%S")
|
||||
|
||||
|
||||
def _normalize_int(value, default: int, *, minimum: int | None = None) -> int:
|
||||
try:
|
||||
parsed = int(value)
|
||||
except Exception:
|
||||
parsed = default
|
||||
if minimum is not None and parsed < minimum:
|
||||
return minimum
|
||||
return parsed
|
||||
|
||||
|
||||
def _stat_value(row, key: str) -> int:
|
||||
try:
|
||||
value = row[key] if row else 0
|
||||
except Exception:
|
||||
value = 0
|
||||
return int(value or 0)
|
||||
|
||||
|
||||
def _build_task_logs_where_sql(
|
||||
*,
|
||||
date_filter=None,
|
||||
status_filter=None,
|
||||
source_filter=None,
|
||||
user_id_filter=None,
|
||||
account_filter=None,
|
||||
) -> tuple[str, list]:
|
||||
where_clauses = ["1=1"]
|
||||
params = []
|
||||
|
||||
if date_filter:
|
||||
day_start, day_end = _build_day_bounds(date_filter)
|
||||
if day_start and day_end:
|
||||
where_clauses.append("tl.created_at >= ? AND tl.created_at < ?")
|
||||
params.extend([day_start, day_end])
|
||||
else:
|
||||
where_clauses.append("date(tl.created_at) = ?")
|
||||
params.append(date_filter)
|
||||
|
||||
if status_filter:
|
||||
where_clauses.append("tl.status = ?")
|
||||
params.append(status_filter)
|
||||
|
||||
if source_filter:
|
||||
source_filter = str(source_filter or "").strip()
|
||||
if source_filter == "user_scheduled":
|
||||
where_clauses.append("tl.source LIKE ? ESCAPE '\\\\'")
|
||||
params.append("user_scheduled:%")
|
||||
elif source_filter.endswith("*"):
|
||||
prefix = source_filter[:-1]
|
||||
safe_prefix = sanitize_sql_like_pattern(prefix)
|
||||
where_clauses.append("tl.source LIKE ? ESCAPE '\\\\'")
|
||||
params.append(f"{safe_prefix}%")
|
||||
else:
|
||||
where_clauses.append("tl.source = ?")
|
||||
params.append(source_filter)
|
||||
|
||||
if user_id_filter:
|
||||
where_clauses.append("tl.user_id = ?")
|
||||
params.append(user_id_filter)
|
||||
|
||||
if account_filter:
|
||||
safe_filter = sanitize_sql_like_pattern(account_filter)
|
||||
where_clauses.append("tl.username LIKE ? ESCAPE '\\\\'")
|
||||
params.append(f"%{safe_filter}%")
|
||||
|
||||
return " AND ".join(where_clauses), params
|
||||
|
||||
|
||||
def _fetch_task_stats_row(cursor, *, where_clause: str = "", params: tuple | list = ()) -> dict:
|
||||
sql = _TASK_STATS_SELECT_SQL
|
||||
if where_clause:
|
||||
sql = f"{sql}\nWHERE {where_clause}"
|
||||
cursor.execute(sql, params)
|
||||
row = cursor.fetchone()
|
||||
return {
|
||||
"total_tasks": _stat_value(row, "total_tasks"),
|
||||
"success_tasks": _stat_value(row, "success_tasks"),
|
||||
"failed_tasks": _stat_value(row, "failed_tasks"),
|
||||
"total_items": _stat_value(row, "total_items"),
|
||||
"total_attachments": _stat_value(row, "total_attachments"),
|
||||
}
|
||||
|
||||
|
||||
def _fetch_user_run_stats_row(cursor, *, where_clause: str, params: tuple | list) -> dict:
|
||||
sql = f"{_USER_RUN_STATS_SELECT_SQL}\nWHERE {where_clause}"
|
||||
cursor.execute(sql, params)
|
||||
row = cursor.fetchone()
|
||||
return {
|
||||
"completed": _stat_value(row, "completed"),
|
||||
"failed": _stat_value(row, "failed"),
|
||||
"total_items": _stat_value(row, "total_items"),
|
||||
"total_attachments": _stat_value(row, "total_attachments"),
|
||||
}
|
||||
|
||||
|
||||
def create_task_log(
|
||||
@@ -25,8 +148,6 @@ def create_task_log(
|
||||
"""创建任务日志记录"""
|
||||
with db_pool.get_db() as conn:
|
||||
cursor = conn.cursor()
|
||||
cst_tz = pytz.timezone("Asia/Shanghai")
|
||||
cst_time = datetime.now(cst_tz).strftime("%Y-%m-%d %H:%M:%S")
|
||||
|
||||
cursor.execute(
|
||||
"""
|
||||
@@ -45,7 +166,7 @@ def create_task_log(
|
||||
total_attachments,
|
||||
error_message,
|
||||
duration,
|
||||
cst_time,
|
||||
get_cst_now_str(),
|
||||
source,
|
||||
),
|
||||
)
|
||||
@@ -64,54 +185,27 @@ def get_task_logs(
|
||||
account_filter=None,
|
||||
):
|
||||
"""获取任务日志列表(支持分页和多种筛选)"""
|
||||
limit = _normalize_int(limit, 100, minimum=1)
|
||||
offset = _normalize_int(offset, 0, minimum=0)
|
||||
|
||||
with db_pool.get_db() as conn:
|
||||
cursor = conn.cursor()
|
||||
|
||||
where_clauses = ["1=1"]
|
||||
params = []
|
||||
|
||||
if date_filter:
|
||||
where_clauses.append("date(tl.created_at) = ?")
|
||||
params.append(date_filter)
|
||||
|
||||
if status_filter:
|
||||
where_clauses.append("tl.status = ?")
|
||||
params.append(status_filter)
|
||||
|
||||
if source_filter:
|
||||
source_filter = str(source_filter or "").strip()
|
||||
# 兼容“虚拟来源”:用于筛选 user_scheduled:batch_xxx 这类动态值
|
||||
if source_filter == "user_scheduled":
|
||||
where_clauses.append("tl.source LIKE ? ESCAPE '\\\\'")
|
||||
params.append("user_scheduled:%")
|
||||
elif source_filter.endswith("*"):
|
||||
prefix = source_filter[:-1]
|
||||
safe_prefix = sanitize_sql_like_pattern(prefix)
|
||||
where_clauses.append("tl.source LIKE ? ESCAPE '\\\\'")
|
||||
params.append(f"{safe_prefix}%")
|
||||
else:
|
||||
where_clauses.append("tl.source = ?")
|
||||
params.append(source_filter)
|
||||
|
||||
if user_id_filter:
|
||||
where_clauses.append("tl.user_id = ?")
|
||||
params.append(user_id_filter)
|
||||
|
||||
if account_filter:
|
||||
safe_filter = sanitize_sql_like_pattern(account_filter)
|
||||
where_clauses.append("tl.username LIKE ? ESCAPE '\\\\'")
|
||||
params.append(f"%{safe_filter}%")
|
||||
|
||||
where_sql = " AND ".join(where_clauses)
|
||||
where_sql, params = _build_task_logs_where_sql(
|
||||
date_filter=date_filter,
|
||||
status_filter=status_filter,
|
||||
source_filter=source_filter,
|
||||
user_id_filter=user_id_filter,
|
||||
account_filter=account_filter,
|
||||
)
|
||||
|
||||
count_sql = f"""
|
||||
SELECT COUNT(*) as total
|
||||
FROM task_logs tl
|
||||
LEFT JOIN users u ON tl.user_id = u.id
|
||||
WHERE {where_sql}
|
||||
"""
|
||||
cursor.execute(count_sql, params)
|
||||
total = cursor.fetchone()["total"]
|
||||
total = _stat_value(cursor.fetchone(), "total")
|
||||
|
||||
data_sql = f"""
|
||||
SELECT
|
||||
@@ -123,9 +217,10 @@ def get_task_logs(
|
||||
ORDER BY tl.created_at DESC
|
||||
LIMIT ? OFFSET ?
|
||||
"""
|
||||
params.extend([limit, offset])
|
||||
data_params = list(params)
|
||||
data_params.extend([limit, offset])
|
||||
|
||||
cursor.execute(data_sql, params)
|
||||
cursor.execute(data_sql, data_params)
|
||||
logs = [dict(row) for row in cursor.fetchall()]
|
||||
|
||||
return {"logs": logs, "total": total}
|
||||
@@ -133,61 +228,39 @@ def get_task_logs(
|
||||
|
||||
def get_task_stats(date_filter=None):
|
||||
"""获取任务统计信息"""
|
||||
if date_filter is None:
|
||||
date_filter = get_cst_now().strftime("%Y-%m-%d")
|
||||
|
||||
day_start, day_end = _build_day_bounds(date_filter)
|
||||
|
||||
with db_pool.get_db() as conn:
|
||||
cursor = conn.cursor()
|
||||
cst_tz = pytz.timezone("Asia/Shanghai")
|
||||
|
||||
if date_filter is None:
|
||||
date_filter = datetime.now(cst_tz).strftime("%Y-%m-%d")
|
||||
if day_start and day_end:
|
||||
today_stats = _fetch_task_stats_row(
|
||||
cursor,
|
||||
where_clause="created_at >= ? AND created_at < ?",
|
||||
params=(day_start, day_end),
|
||||
)
|
||||
else:
|
||||
today_stats = _fetch_task_stats_row(
|
||||
cursor,
|
||||
where_clause="date(created_at) = ?",
|
||||
params=(date_filter,),
|
||||
)
|
||||
|
||||
cursor.execute(
|
||||
"""
|
||||
SELECT
|
||||
COUNT(*) as total_tasks,
|
||||
SUM(CASE WHEN status = 'success' THEN 1 ELSE 0 END) as success_tasks,
|
||||
SUM(CASE WHEN status = 'failed' THEN 1 ELSE 0 END) as failed_tasks,
|
||||
SUM(total_items) as total_items,
|
||||
SUM(total_attachments) as total_attachments
|
||||
FROM task_logs
|
||||
WHERE date(created_at) = ?
|
||||
""",
|
||||
(date_filter,),
|
||||
)
|
||||
today_stats = cursor.fetchone()
|
||||
total_stats = _fetch_task_stats_row(cursor)
|
||||
|
||||
cursor.execute(
|
||||
"""
|
||||
SELECT
|
||||
COUNT(*) as total_tasks,
|
||||
SUM(CASE WHEN status = 'success' THEN 1 ELSE 0 END) as success_tasks,
|
||||
SUM(CASE WHEN status = 'failed' THEN 1 ELSE 0 END) as failed_tasks,
|
||||
SUM(total_items) as total_items,
|
||||
SUM(total_attachments) as total_attachments
|
||||
FROM task_logs
|
||||
"""
|
||||
)
|
||||
total_stats = cursor.fetchone()
|
||||
|
||||
return {
|
||||
"today": {
|
||||
"total_tasks": today_stats["total_tasks"] or 0,
|
||||
"success_tasks": today_stats["success_tasks"] or 0,
|
||||
"failed_tasks": today_stats["failed_tasks"] or 0,
|
||||
"total_items": today_stats["total_items"] or 0,
|
||||
"total_attachments": today_stats["total_attachments"] or 0,
|
||||
},
|
||||
"total": {
|
||||
"total_tasks": total_stats["total_tasks"] or 0,
|
||||
"success_tasks": total_stats["success_tasks"] or 0,
|
||||
"failed_tasks": total_stats["failed_tasks"] or 0,
|
||||
"total_items": total_stats["total_items"] or 0,
|
||||
"total_attachments": total_stats["total_attachments"] or 0,
|
||||
},
|
||||
}
|
||||
return {"today": today_stats, "total": total_stats}
|
||||
|
||||
|
||||
def delete_old_task_logs(days=30, batch_size=1000):
|
||||
"""删除N天前的任务日志(分批删除,避免长时间锁表)"""
|
||||
days = _normalize_int(days, 30, minimum=0)
|
||||
batch_size = _normalize_int(batch_size, 1000, minimum=1)
|
||||
|
||||
cutoff = (get_cst_now() - timedelta(days=days)).strftime("%Y-%m-%d %H:%M:%S")
|
||||
|
||||
total_deleted = 0
|
||||
while True:
|
||||
with db_pool.get_db() as conn:
|
||||
@@ -197,16 +270,16 @@ def delete_old_task_logs(days=30, batch_size=1000):
|
||||
DELETE FROM task_logs
|
||||
WHERE rowid IN (
|
||||
SELECT rowid FROM task_logs
|
||||
WHERE created_at < datetime('now', 'localtime', '-' || ? || ' days')
|
||||
WHERE created_at < ?
|
||||
LIMIT ?
|
||||
)
|
||||
""",
|
||||
(days, batch_size),
|
||||
(cutoff, batch_size),
|
||||
)
|
||||
deleted = cursor.rowcount
|
||||
conn.commit()
|
||||
|
||||
if deleted == 0:
|
||||
if deleted <= 0:
|
||||
break
|
||||
total_deleted += deleted
|
||||
|
||||
@@ -215,31 +288,23 @@ def delete_old_task_logs(days=30, batch_size=1000):
|
||||
|
||||
def get_user_run_stats(user_id, date_filter=None):
|
||||
"""获取用户的运行统计信息"""
|
||||
if date_filter is None:
|
||||
date_filter = get_cst_now().strftime("%Y-%m-%d")
|
||||
|
||||
day_start, day_end = _build_day_bounds(date_filter)
|
||||
|
||||
with db_pool.get_db() as conn:
|
||||
cst_tz = pytz.timezone("Asia/Shanghai")
|
||||
cursor = conn.cursor()
|
||||
|
||||
if date_filter is None:
|
||||
date_filter = datetime.now(cst_tz).strftime("%Y-%m-%d")
|
||||
if day_start and day_end:
|
||||
return _fetch_user_run_stats_row(
|
||||
cursor,
|
||||
where_clause="user_id = ? AND created_at >= ? AND created_at < ?",
|
||||
params=(user_id, day_start, day_end),
|
||||
)
|
||||
|
||||
cursor.execute(
|
||||
"""
|
||||
SELECT
|
||||
SUM(CASE WHEN status = 'success' THEN 1 ELSE 0 END) as completed,
|
||||
SUM(CASE WHEN status = 'failed' THEN 1 ELSE 0 END) as failed,
|
||||
SUM(total_items) as total_items,
|
||||
SUM(total_attachments) as total_attachments
|
||||
FROM task_logs
|
||||
WHERE user_id = ? AND date(created_at) = ?
|
||||
""",
|
||||
(user_id, date_filter),
|
||||
return _fetch_user_run_stats_row(
|
||||
cursor,
|
||||
where_clause="user_id = ? AND date(created_at) = ?",
|
||||
params=(user_id, date_filter),
|
||||
)
|
||||
|
||||
stats = cursor.fetchone()
|
||||
|
||||
return {
|
||||
"completed": stats["completed"] or 0,
|
||||
"failed": stats["failed"] or 0,
|
||||
"total_items": stats["total_items"] or 0,
|
||||
"total_attachments": stats["total_attachments"] or 0,
|
||||
}
|
||||
|
||||
174
db/users.py
174
db/users.py
@@ -16,8 +16,41 @@ from password_utils import (
|
||||
verify_password_bcrypt,
|
||||
verify_password_sha256,
|
||||
)
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
_CST_TZ = pytz.timezone("Asia/Shanghai")
|
||||
_PERMANENT_VIP_EXPIRE = "2099-12-31 23:59:59"
|
||||
|
||||
|
||||
def _row_to_dict(row):
|
||||
return dict(row) if row else None
|
||||
|
||||
|
||||
def _get_user_by_field(field_name: str, field_value):
|
||||
with db_pool.get_db() as conn:
|
||||
cursor = conn.cursor()
|
||||
cursor.execute(f"SELECT * FROM users WHERE {field_name} = ?", (field_value,))
|
||||
return _row_to_dict(cursor.fetchone())
|
||||
|
||||
|
||||
def _parse_cst_datetime(datetime_str: str | None):
|
||||
if not datetime_str:
|
||||
return None
|
||||
try:
|
||||
naive_dt = datetime.strptime(str(datetime_str), "%Y-%m-%d %H:%M:%S")
|
||||
return _CST_TZ.localize(naive_dt)
|
||||
except Exception:
|
||||
return None
|
||||
|
||||
|
||||
def _format_vip_expire(days: int, *, base_dt: datetime | None = None) -> str:
|
||||
if int(days) == 999999:
|
||||
return _PERMANENT_VIP_EXPIRE
|
||||
if base_dt is None:
|
||||
base_dt = datetime.now(_CST_TZ)
|
||||
return (base_dt + timedelta(days=int(days))).strftime("%Y-%m-%d %H:%M:%S")
|
||||
|
||||
|
||||
def get_vip_config():
|
||||
"""获取VIP配置"""
|
||||
@@ -32,13 +65,12 @@ def set_default_vip_days(days):
|
||||
"""设置默认VIP天数"""
|
||||
with db_pool.get_db() as conn:
|
||||
cursor = conn.cursor()
|
||||
cst_time = get_cst_now_str()
|
||||
cursor.execute(
|
||||
"""
|
||||
INSERT OR REPLACE INTO vip_config (id, default_vip_days, updated_at)
|
||||
VALUES (1, ?, ?)
|
||||
""",
|
||||
(days, cst_time),
|
||||
(days, get_cst_now_str()),
|
||||
)
|
||||
conn.commit()
|
||||
return True
|
||||
@@ -47,14 +79,8 @@ def set_default_vip_days(days):
|
||||
def set_user_vip(user_id, days):
|
||||
"""设置用户VIP - days: 7=一周, 30=一个月, 365=一年, 999999=永久"""
|
||||
with db_pool.get_db() as conn:
|
||||
cst_tz = pytz.timezone("Asia/Shanghai")
|
||||
cursor = conn.cursor()
|
||||
|
||||
if days == 999999:
|
||||
expire_time = "2099-12-31 23:59:59"
|
||||
else:
|
||||
expire_time = (datetime.now(cst_tz) + timedelta(days=days)).strftime("%Y-%m-%d %H:%M:%S")
|
||||
|
||||
expire_time = _format_vip_expire(days)
|
||||
cursor.execute("UPDATE users SET vip_expire_time = ? WHERE id = ?", (expire_time, user_id))
|
||||
conn.commit()
|
||||
return cursor.rowcount > 0
|
||||
@@ -63,29 +89,26 @@ def set_user_vip(user_id, days):
|
||||
def extend_user_vip(user_id, days):
|
||||
"""延长用户VIP时间"""
|
||||
user = get_user_by_id(user_id)
|
||||
cst_tz = pytz.timezone("Asia/Shanghai")
|
||||
|
||||
if not user:
|
||||
return False
|
||||
|
||||
current_expire = user.get("vip_expire_time")
|
||||
now_dt = datetime.now(_CST_TZ)
|
||||
|
||||
if current_expire and current_expire != _PERMANENT_VIP_EXPIRE:
|
||||
expire_time = _parse_cst_datetime(current_expire)
|
||||
if expire_time is not None:
|
||||
if expire_time < now_dt:
|
||||
expire_time = now_dt
|
||||
new_expire = _format_vip_expire(days, base_dt=expire_time)
|
||||
else:
|
||||
logger.warning("解析VIP过期时间失败,使用当前时间")
|
||||
new_expire = _format_vip_expire(days, base_dt=now_dt)
|
||||
else:
|
||||
new_expire = _format_vip_expire(days, base_dt=now_dt)
|
||||
|
||||
with db_pool.get_db() as conn:
|
||||
cursor = conn.cursor()
|
||||
current_expire = user.get("vip_expire_time")
|
||||
|
||||
if current_expire and current_expire != "2099-12-31 23:59:59":
|
||||
try:
|
||||
expire_time_naive = datetime.strptime(current_expire, "%Y-%m-%d %H:%M:%S")
|
||||
expire_time = cst_tz.localize(expire_time_naive)
|
||||
now = datetime.now(cst_tz)
|
||||
if expire_time < now:
|
||||
expire_time = now
|
||||
new_expire = (expire_time + timedelta(days=days)).strftime("%Y-%m-%d %H:%M:%S")
|
||||
except (ValueError, AttributeError) as e:
|
||||
logger.warning(f"解析VIP过期时间失败: {e}, 使用当前时间")
|
||||
new_expire = (datetime.now(cst_tz) + timedelta(days=days)).strftime("%Y-%m-%d %H:%M:%S")
|
||||
else:
|
||||
new_expire = (datetime.now(cst_tz) + timedelta(days=days)).strftime("%Y-%m-%d %H:%M:%S")
|
||||
|
||||
cursor.execute("UPDATE users SET vip_expire_time = ? WHERE id = ?", (new_expire, user_id))
|
||||
conn.commit()
|
||||
return cursor.rowcount > 0
|
||||
@@ -105,45 +128,49 @@ def is_user_vip(user_id):
|
||||
|
||||
注意:数据库中存储的时间统一使用CST(Asia/Shanghai)时区
|
||||
"""
|
||||
cst_tz = pytz.timezone("Asia/Shanghai")
|
||||
user = get_user_by_id(user_id)
|
||||
|
||||
if not user or not user.get("vip_expire_time"):
|
||||
if not user:
|
||||
return False
|
||||
|
||||
try:
|
||||
expire_time_naive = datetime.strptime(user["vip_expire_time"], "%Y-%m-%d %H:%M:%S")
|
||||
expire_time = cst_tz.localize(expire_time_naive)
|
||||
now = datetime.now(cst_tz)
|
||||
return now < expire_time
|
||||
except (ValueError, AttributeError) as e:
|
||||
logger.warning(f"检查VIP状态失败 (user_id={user_id}): {e}")
|
||||
vip_expire_time = user.get("vip_expire_time")
|
||||
if not vip_expire_time:
|
||||
return False
|
||||
|
||||
expire_time = _parse_cst_datetime(vip_expire_time)
|
||||
if expire_time is None:
|
||||
logger.warning(f"检查VIP状态失败 (user_id={user_id}): 无法解析时间")
|
||||
return False
|
||||
|
||||
return datetime.now(_CST_TZ) < expire_time
|
||||
|
||||
|
||||
def get_user_vip_info(user_id):
|
||||
"""获取用户VIP信息"""
|
||||
cst_tz = pytz.timezone("Asia/Shanghai")
|
||||
user = get_user_by_id(user_id)
|
||||
|
||||
if not user:
|
||||
return {"is_vip": False, "expire_time": None, "days_left": 0, "username": ""}
|
||||
|
||||
vip_expire_time = user.get("vip_expire_time")
|
||||
username = user.get("username", "")
|
||||
|
||||
if not vip_expire_time:
|
||||
return {"is_vip": False, "expire_time": None, "days_left": 0, "username": user.get("username", "")}
|
||||
return {"is_vip": False, "expire_time": None, "days_left": 0, "username": username}
|
||||
|
||||
try:
|
||||
expire_time_naive = datetime.strptime(vip_expire_time, "%Y-%m-%d %H:%M:%S")
|
||||
expire_time = cst_tz.localize(expire_time_naive)
|
||||
now = datetime.now(cst_tz)
|
||||
is_vip = now < expire_time
|
||||
days_left = (expire_time - now).days if is_vip else 0
|
||||
expire_time = _parse_cst_datetime(vip_expire_time)
|
||||
if expire_time is None:
|
||||
logger.warning("VIP信息获取错误: 无法解析过期时间")
|
||||
return {"is_vip": False, "expire_time": None, "days_left": 0, "username": username}
|
||||
|
||||
return {"username": user.get("username", ""), "is_vip": is_vip, "expire_time": vip_expire_time, "days_left": max(0, days_left)}
|
||||
except Exception as e:
|
||||
logger.warning(f"VIP信息获取错误: {e}")
|
||||
return {"is_vip": False, "expire_time": None, "days_left": 0, "username": user.get("username", "")}
|
||||
now_dt = datetime.now(_CST_TZ)
|
||||
is_vip = now_dt < expire_time
|
||||
days_left = (expire_time - now_dt).days if is_vip else 0
|
||||
|
||||
return {
|
||||
"username": username,
|
||||
"is_vip": is_vip,
|
||||
"expire_time": vip_expire_time,
|
||||
"days_left": max(0, days_left),
|
||||
}
|
||||
|
||||
|
||||
# ==================== 用户相关 ====================
|
||||
@@ -151,8 +178,6 @@ def get_user_vip_info(user_id):
|
||||
|
||||
def create_user(username, password, email=""):
|
||||
"""创建新用户(默认直接通过,赠送默认VIP)"""
|
||||
cst_tz = pytz.timezone("Asia/Shanghai")
|
||||
|
||||
with db_pool.get_db() as conn:
|
||||
cursor = conn.cursor()
|
||||
password_hash = hash_password_bcrypt(password)
|
||||
@@ -160,12 +185,8 @@ def create_user(username, password, email=""):
|
||||
|
||||
default_vip_days = get_vip_config()["default_vip_days"]
|
||||
vip_expire_time = None
|
||||
|
||||
if default_vip_days > 0:
|
||||
if default_vip_days == 999999:
|
||||
vip_expire_time = "2099-12-31 23:59:59"
|
||||
else:
|
||||
vip_expire_time = (datetime.now(cst_tz) + timedelta(days=default_vip_days)).strftime("%Y-%m-%d %H:%M:%S")
|
||||
if int(default_vip_days or 0) > 0:
|
||||
vip_expire_time = _format_vip_expire(int(default_vip_days))
|
||||
|
||||
try:
|
||||
cursor.execute(
|
||||
@@ -210,28 +231,28 @@ def verify_user(username, password):
|
||||
|
||||
def get_user_by_id(user_id):
|
||||
"""根据ID获取用户"""
|
||||
with db_pool.get_db() as conn:
|
||||
cursor = conn.cursor()
|
||||
cursor.execute("SELECT * FROM users WHERE id = ?", (user_id,))
|
||||
user = cursor.fetchone()
|
||||
return dict(user) if user else None
|
||||
return _get_user_by_field("id", user_id)
|
||||
|
||||
|
||||
def get_user_kdocs_settings(user_id):
|
||||
"""获取用户的金山文档配置"""
|
||||
user = get_user_by_id(user_id)
|
||||
if not user:
|
||||
return None
|
||||
return {
|
||||
"kdocs_unit": user.get("kdocs_unit") or "",
|
||||
"kdocs_auto_upload": 1 if user.get("kdocs_auto_upload") else 0,
|
||||
}
|
||||
with db_pool.get_db() as conn:
|
||||
cursor = conn.cursor()
|
||||
cursor.execute("SELECT kdocs_unit, kdocs_auto_upload FROM users WHERE id = ?", (user_id,))
|
||||
row = cursor.fetchone()
|
||||
if not row:
|
||||
return None
|
||||
return {
|
||||
"kdocs_unit": (row["kdocs_unit"] or "") if isinstance(row, dict) else (row[0] or ""),
|
||||
"kdocs_auto_upload": 1 if ((row["kdocs_auto_upload"] if isinstance(row, dict) else row[1]) or 0) else 0,
|
||||
}
|
||||
|
||||
|
||||
def update_user_kdocs_settings(user_id, *, kdocs_unit=None, kdocs_auto_upload=None) -> bool:
|
||||
"""更新用户的金山文档配置"""
|
||||
updates = []
|
||||
params = []
|
||||
|
||||
if kdocs_unit is not None:
|
||||
updates.append("kdocs_unit = ?")
|
||||
params.append(kdocs_unit)
|
||||
@@ -252,11 +273,7 @@ def update_user_kdocs_settings(user_id, *, kdocs_unit=None, kdocs_auto_upload=No
|
||||
|
||||
def get_user_by_username(username):
|
||||
"""根据用户名获取用户"""
|
||||
with db_pool.get_db() as conn:
|
||||
cursor = conn.cursor()
|
||||
cursor.execute("SELECT * FROM users WHERE username = ?", (username,))
|
||||
user = cursor.fetchone()
|
||||
return dict(user) if user else None
|
||||
return _get_user_by_field("username", username)
|
||||
|
||||
|
||||
def get_all_users():
|
||||
@@ -279,14 +296,13 @@ def approve_user(user_id):
|
||||
"""审核通过用户"""
|
||||
with db_pool.get_db() as conn:
|
||||
cursor = conn.cursor()
|
||||
cst_time = get_cst_now_str()
|
||||
cursor.execute(
|
||||
"""
|
||||
UPDATE users
|
||||
SET status = 'approved', approved_at = ?
|
||||
WHERE id = ?
|
||||
""",
|
||||
(cst_time, user_id),
|
||||
(get_cst_now_str(), user_id),
|
||||
)
|
||||
conn.commit()
|
||||
return cursor.rowcount > 0
|
||||
@@ -315,5 +331,5 @@ def get_user_stats(user_id):
|
||||
with db_pool.get_db() as conn:
|
||||
cursor = conn.cursor()
|
||||
cursor.execute("SELECT COUNT(*) as count FROM accounts WHERE user_id = ?", (user_id,))
|
||||
account_count = cursor.fetchone()["count"]
|
||||
return {"account_count": account_count}
|
||||
row = cursor.fetchone()
|
||||
return {"account_count": int((row["count"] if row else 0) or 0)}
|
||||
|
||||
Reference in New Issue
Block a user