refactor: optimize structure, stability and runtime performance

This commit is contained in:
2026-02-07 00:35:11 +08:00
parent fae21329d7
commit bf29ac1924
44 changed files with 6894 additions and 4792 deletions

View File

@@ -6,19 +6,51 @@ import db_pool
from crypto_utils import decrypt_password, encrypt_password
from db.utils import get_cst_now_str
_ACCOUNT_STATUS_QUERY_SQL = """
SELECT status, login_fail_count, last_login_error
FROM accounts
WHERE id = ?
"""
def _decode_account_password(account_dict: dict) -> dict:
account_dict["password"] = decrypt_password(account_dict.get("password", ""))
return account_dict
def _normalize_account_ids(account_ids) -> list[str]:
normalized = []
seen = set()
for account_id in account_ids or []:
if not account_id:
continue
account_key = str(account_id)
if account_key in seen:
continue
seen.add(account_key)
normalized.append(account_key)
return normalized
def create_account(user_id, account_id, username, password, remember=True, remark=""):
"""创建账号(密码加密存储)"""
with db_pool.get_db() as conn:
cursor = conn.cursor()
cst_time = get_cst_now_str()
encrypted_password = encrypt_password(password)
cursor.execute(
"""
INSERT INTO accounts (id, user_id, username, password, remember, remark, created_at)
VALUES (?, ?, ?, ?, ?, ?, ?)
""",
(account_id, user_id, username, encrypted_password, 1 if remember else 0, remark, cst_time),
(
account_id,
user_id,
username,
encrypted_password,
1 if remember else 0,
remark,
get_cst_now_str(),
),
)
conn.commit()
return cursor.lastrowid
@@ -29,12 +61,7 @@ def get_user_accounts(user_id):
with db_pool.get_db() as conn:
cursor = conn.cursor()
cursor.execute("SELECT * FROM accounts WHERE user_id = ? ORDER BY created_at DESC", (user_id,))
accounts = []
for row in cursor.fetchall():
account = dict(row)
account["password"] = decrypt_password(account.get("password", ""))
accounts.append(account)
return accounts
return [_decode_account_password(dict(row)) for row in cursor.fetchall()]
def get_account(account_id):
@@ -43,11 +70,9 @@ def get_account(account_id):
cursor = conn.cursor()
cursor.execute("SELECT * FROM accounts WHERE id = ?", (account_id,))
row = cursor.fetchone()
if row:
account = dict(row)
account["password"] = decrypt_password(account.get("password", ""))
return account
return None
if not row:
return None
return _decode_account_password(dict(row))
def update_account_remark(account_id, remark):
@@ -78,33 +103,21 @@ def increment_account_login_fail(account_id, error_message):
if not row:
return False
fail_count = (row["login_fail_count"] or 0) + 1
if fail_count >= 3:
cursor.execute(
"""
UPDATE accounts
SET login_fail_count = ?,
last_login_error = ?,
status = 'suspended'
WHERE id = ?
""",
(fail_count, error_message, account_id),
)
conn.commit()
return True
fail_count = int(row["login_fail_count"] or 0) + 1
is_suspended = fail_count >= 3
cursor.execute(
"""
UPDATE accounts
SET login_fail_count = ?,
last_login_error = ?
last_login_error = ?,
status = CASE WHEN ? = 1 THEN 'suspended' ELSE status END
WHERE id = ?
""",
(fail_count, error_message, account_id),
(fail_count, error_message, 1 if is_suspended else 0, account_id),
)
conn.commit()
return False
return is_suspended
def reset_account_login_status(account_id):
@@ -129,29 +142,22 @@ def get_account_status(account_id):
"""获取账号状态信息"""
with db_pool.get_db() as conn:
cursor = conn.cursor()
cursor.execute(
"""
SELECT status, login_fail_count, last_login_error
FROM accounts
WHERE id = ?
""",
(account_id,),
)
cursor.execute(_ACCOUNT_STATUS_QUERY_SQL, (account_id,))
return cursor.fetchone()
def get_account_status_batch(account_ids):
"""批量获取账号状态信息"""
account_ids = [str(account_id) for account_id in (account_ids or []) if account_id]
if not account_ids:
normalized_ids = _normalize_account_ids(account_ids)
if not normalized_ids:
return {}
results = {}
chunk_size = 900 # 避免触发 SQLite 绑定参数上限
with db_pool.get_db() as conn:
cursor = conn.cursor()
for idx in range(0, len(account_ids), chunk_size):
chunk = account_ids[idx : idx + chunk_size]
for idx in range(0, len(normalized_ids), chunk_size):
chunk = normalized_ids[idx : idx + chunk_size]
placeholders = ",".join("?" for _ in chunk)
cursor.execute(
f"""

View File

@@ -3,9 +3,6 @@
from __future__ import annotations
import sqlite3
from datetime import datetime, timedelta
import pytz
import db_pool
from db.utils import get_cst_now_str
@@ -16,6 +13,99 @@ from password_utils import (
verify_password_sha256,
)
_DEFAULT_SYSTEM_CONFIG = {
"max_concurrent_global": 2,
"max_concurrent_per_account": 1,
"max_screenshot_concurrent": 3,
"schedule_enabled": 0,
"schedule_time": "02:00",
"schedule_browse_type": "应读",
"schedule_weekdays": "1,2,3,4,5,6,7",
"proxy_enabled": 0,
"proxy_api_url": "",
"proxy_expire_minutes": 3,
"enable_screenshot": 1,
"auto_approve_enabled": 0,
"auto_approve_hourly_limit": 10,
"auto_approve_vip_days": 7,
"kdocs_enabled": 0,
"kdocs_doc_url": "",
"kdocs_default_unit": "",
"kdocs_sheet_name": "",
"kdocs_sheet_index": 0,
"kdocs_unit_column": "A",
"kdocs_image_column": "D",
"kdocs_admin_notify_enabled": 0,
"kdocs_admin_notify_email": "",
"kdocs_row_start": 0,
"kdocs_row_end": 0,
}
_SYSTEM_CONFIG_UPDATERS = (
("max_concurrent_global", "max_concurrent"),
("schedule_enabled", "schedule_enabled"),
("schedule_time", "schedule_time"),
("schedule_browse_type", "schedule_browse_type"),
("schedule_weekdays", "schedule_weekdays"),
("max_concurrent_per_account", "max_concurrent_per_account"),
("max_screenshot_concurrent", "max_screenshot_concurrent"),
("enable_screenshot", "enable_screenshot"),
("proxy_enabled", "proxy_enabled"),
("proxy_api_url", "proxy_api_url"),
("proxy_expire_minutes", "proxy_expire_minutes"),
("auto_approve_enabled", "auto_approve_enabled"),
("auto_approve_hourly_limit", "auto_approve_hourly_limit"),
("auto_approve_vip_days", "auto_approve_vip_days"),
("kdocs_enabled", "kdocs_enabled"),
("kdocs_doc_url", "kdocs_doc_url"),
("kdocs_default_unit", "kdocs_default_unit"),
("kdocs_sheet_name", "kdocs_sheet_name"),
("kdocs_sheet_index", "kdocs_sheet_index"),
("kdocs_unit_column", "kdocs_unit_column"),
("kdocs_image_column", "kdocs_image_column"),
("kdocs_admin_notify_enabled", "kdocs_admin_notify_enabled"),
("kdocs_admin_notify_email", "kdocs_admin_notify_email"),
("kdocs_row_start", "kdocs_row_start"),
("kdocs_row_end", "kdocs_row_end"),
)
def _count_scalar(cursor, sql: str, params=()) -> int:
cursor.execute(sql, params)
row = cursor.fetchone()
if not row:
return 0
try:
if "count" in row.keys():
return int(row["count"] or 0)
except Exception:
pass
try:
return int(row[0] or 0)
except Exception:
return 0
def _table_exists(cursor, table_name: str) -> bool:
cursor.execute(
"""
SELECT name FROM sqlite_master
WHERE type='table' AND name=?
""",
(table_name,),
)
return bool(cursor.fetchone())
def _normalize_days(days, default: int = 30) -> int:
try:
value = int(days)
except Exception:
value = default
if value < 0:
return 0
return value
def ensure_default_admin() -> bool:
"""确保存在默认管理员账号(行为保持不变)。"""
@@ -24,10 +114,9 @@ def ensure_default_admin() -> bool:
with db_pool.get_db() as conn:
cursor = conn.cursor()
cursor.execute("SELECT COUNT(*) as count FROM admins")
result = cursor.fetchone()
count = _count_scalar(cursor, "SELECT COUNT(*) as count FROM admins")
if result["count"] == 0:
if count == 0:
alphabet = string.ascii_letters + string.digits
random_password = "".join(secrets.choice(alphabet) for _ in range(12))
@@ -101,41 +190,33 @@ def get_system_stats() -> dict:
with db_pool.get_db() as conn:
cursor = conn.cursor()
cursor.execute("SELECT COUNT(*) as count FROM users")
total_users = cursor.fetchone()["count"]
cursor.execute("SELECT COUNT(*) as count FROM users WHERE status = 'approved'")
approved_users = cursor.fetchone()["count"]
cursor.execute(
total_users = _count_scalar(cursor, "SELECT COUNT(*) as count FROM users")
approved_users = _count_scalar(cursor, "SELECT COUNT(*) as count FROM users WHERE status = 'approved'")
new_users_today = _count_scalar(
cursor,
"""
SELECT COUNT(*) as count
FROM users
WHERE date(created_at) = date('now', 'localtime')
"""
""",
)
new_users_today = cursor.fetchone()["count"]
cursor.execute(
new_users_7d = _count_scalar(
cursor,
"""
SELECT COUNT(*) as count
FROM users
WHERE datetime(created_at) >= datetime('now', 'localtime', '-7 days')
"""
""",
)
new_users_7d = cursor.fetchone()["count"]
cursor.execute("SELECT COUNT(*) as count FROM accounts")
total_accounts = cursor.fetchone()["count"]
cursor.execute(
total_accounts = _count_scalar(cursor, "SELECT COUNT(*) as count FROM accounts")
vip_users = _count_scalar(
cursor,
"""
SELECT COUNT(*) as count FROM users
WHERE vip_expire_time IS NOT NULL
AND datetime(vip_expire_time) > datetime('now', 'localtime')
"""
""",
)
vip_users = cursor.fetchone()["count"]
return {
"total_users": total_users,
@@ -153,37 +234,9 @@ def get_system_config_raw() -> dict:
cursor = conn.cursor()
cursor.execute("SELECT * FROM system_config WHERE id = 1")
row = cursor.fetchone()
if row:
return dict(row)
return {
"max_concurrent_global": 2,
"max_concurrent_per_account": 1,
"max_screenshot_concurrent": 3,
"schedule_enabled": 0,
"schedule_time": "02:00",
"schedule_browse_type": "应读",
"schedule_weekdays": "1,2,3,4,5,6,7",
"proxy_enabled": 0,
"proxy_api_url": "",
"proxy_expire_minutes": 3,
"enable_screenshot": 1,
"auto_approve_enabled": 0,
"auto_approve_hourly_limit": 10,
"auto_approve_vip_days": 7,
"kdocs_enabled": 0,
"kdocs_doc_url": "",
"kdocs_default_unit": "",
"kdocs_sheet_name": "",
"kdocs_sheet_index": 0,
"kdocs_unit_column": "A",
"kdocs_image_column": "D",
"kdocs_admin_notify_enabled": 0,
"kdocs_admin_notify_email": "",
"kdocs_row_start": 0,
"kdocs_row_end": 0,
}
return dict(_DEFAULT_SYSTEM_CONFIG)
def update_system_config(
@@ -215,127 +268,51 @@ def update_system_config(
kdocs_row_end=None,
) -> bool:
"""更新系统配置仅更新DB不做缓存处理"""
allowed_fields = {
"max_concurrent_global",
"schedule_enabled",
"schedule_time",
"schedule_browse_type",
"schedule_weekdays",
"max_concurrent_per_account",
"max_screenshot_concurrent",
"enable_screenshot",
"proxy_enabled",
"proxy_api_url",
"proxy_expire_minutes",
"auto_approve_enabled",
"auto_approve_hourly_limit",
"auto_approve_vip_days",
"kdocs_enabled",
"kdocs_doc_url",
"kdocs_default_unit",
"kdocs_sheet_name",
"kdocs_sheet_index",
"kdocs_unit_column",
"kdocs_image_column",
"kdocs_admin_notify_enabled",
"kdocs_admin_notify_email",
"kdocs_row_start",
"kdocs_row_end",
"updated_at",
arg_values = {
"max_concurrent": max_concurrent,
"schedule_enabled": schedule_enabled,
"schedule_time": schedule_time,
"schedule_browse_type": schedule_browse_type,
"schedule_weekdays": schedule_weekdays,
"max_concurrent_per_account": max_concurrent_per_account,
"max_screenshot_concurrent": max_screenshot_concurrent,
"enable_screenshot": enable_screenshot,
"proxy_enabled": proxy_enabled,
"proxy_api_url": proxy_api_url,
"proxy_expire_minutes": proxy_expire_minutes,
"auto_approve_enabled": auto_approve_enabled,
"auto_approve_hourly_limit": auto_approve_hourly_limit,
"auto_approve_vip_days": auto_approve_vip_days,
"kdocs_enabled": kdocs_enabled,
"kdocs_doc_url": kdocs_doc_url,
"kdocs_default_unit": kdocs_default_unit,
"kdocs_sheet_name": kdocs_sheet_name,
"kdocs_sheet_index": kdocs_sheet_index,
"kdocs_unit_column": kdocs_unit_column,
"kdocs_image_column": kdocs_image_column,
"kdocs_admin_notify_enabled": kdocs_admin_notify_enabled,
"kdocs_admin_notify_email": kdocs_admin_notify_email,
"kdocs_row_start": kdocs_row_start,
"kdocs_row_end": kdocs_row_end,
}
updates = []
params = []
for db_field, arg_name in _SYSTEM_CONFIG_UPDATERS:
value = arg_values.get(arg_name)
if value is None:
continue
updates.append(f"{db_field} = ?")
params.append(value)
if not updates:
return False
updates.append("updated_at = ?")
params.append(get_cst_now_str())
with db_pool.get_db() as conn:
cursor = conn.cursor()
updates = []
params = []
if max_concurrent is not None:
updates.append("max_concurrent_global = ?")
params.append(max_concurrent)
if schedule_enabled is not None:
updates.append("schedule_enabled = ?")
params.append(schedule_enabled)
if schedule_time is not None:
updates.append("schedule_time = ?")
params.append(schedule_time)
if schedule_browse_type is not None:
updates.append("schedule_browse_type = ?")
params.append(schedule_browse_type)
if max_concurrent_per_account is not None:
updates.append("max_concurrent_per_account = ?")
params.append(max_concurrent_per_account)
if max_screenshot_concurrent is not None:
updates.append("max_screenshot_concurrent = ?")
params.append(max_screenshot_concurrent)
if enable_screenshot is not None:
updates.append("enable_screenshot = ?")
params.append(enable_screenshot)
if schedule_weekdays is not None:
updates.append("schedule_weekdays = ?")
params.append(schedule_weekdays)
if proxy_enabled is not None:
updates.append("proxy_enabled = ?")
params.append(proxy_enabled)
if proxy_api_url is not None:
updates.append("proxy_api_url = ?")
params.append(proxy_api_url)
if proxy_expire_minutes is not None:
updates.append("proxy_expire_minutes = ?")
params.append(proxy_expire_minutes)
if auto_approve_enabled is not None:
updates.append("auto_approve_enabled = ?")
params.append(auto_approve_enabled)
if auto_approve_hourly_limit is not None:
updates.append("auto_approve_hourly_limit = ?")
params.append(auto_approve_hourly_limit)
if auto_approve_vip_days is not None:
updates.append("auto_approve_vip_days = ?")
params.append(auto_approve_vip_days)
if kdocs_enabled is not None:
updates.append("kdocs_enabled = ?")
params.append(kdocs_enabled)
if kdocs_doc_url is not None:
updates.append("kdocs_doc_url = ?")
params.append(kdocs_doc_url)
if kdocs_default_unit is not None:
updates.append("kdocs_default_unit = ?")
params.append(kdocs_default_unit)
if kdocs_sheet_name is not None:
updates.append("kdocs_sheet_name = ?")
params.append(kdocs_sheet_name)
if kdocs_sheet_index is not None:
updates.append("kdocs_sheet_index = ?")
params.append(kdocs_sheet_index)
if kdocs_unit_column is not None:
updates.append("kdocs_unit_column = ?")
params.append(kdocs_unit_column)
if kdocs_image_column is not None:
updates.append("kdocs_image_column = ?")
params.append(kdocs_image_column)
if kdocs_admin_notify_enabled is not None:
updates.append("kdocs_admin_notify_enabled = ?")
params.append(kdocs_admin_notify_enabled)
if kdocs_admin_notify_email is not None:
updates.append("kdocs_admin_notify_email = ?")
params.append(kdocs_admin_notify_email)
if kdocs_row_start is not None:
updates.append("kdocs_row_start = ?")
params.append(kdocs_row_start)
if kdocs_row_end is not None:
updates.append("kdocs_row_end = ?")
params.append(kdocs_row_end)
if not updates:
return False
updates.append("updated_at = ?")
params.append(get_cst_now_str())
for update_clause in updates:
field_name = update_clause.split("=")[0].strip()
if field_name not in allowed_fields:
raise ValueError(f"非法字段名: {field_name}")
sql = f"UPDATE system_config SET {', '.join(updates)} WHERE id = 1"
cursor.execute(sql, params)
conn.commit()
@@ -346,13 +323,13 @@ def get_hourly_registration_count() -> int:
"""获取最近一小时内的注册用户数"""
with db_pool.get_db() as conn:
cursor = conn.cursor()
cursor.execute(
return _count_scalar(
cursor,
"""
SELECT COUNT(*) FROM users
SELECT COUNT(*) as count FROM users
WHERE created_at >= datetime('now', 'localtime', '-1 hour')
"""
""",
)
return cursor.fetchone()[0]
# ==================== 密码重置(管理员) ====================
@@ -374,17 +351,12 @@ def admin_reset_user_password(user_id: int, new_password: str) -> bool:
def clean_old_operation_logs(days: int = 30) -> int:
"""清理指定天数前的操作日志如果存在operation_logs表"""
safe_days = _normalize_days(days, default=30)
with db_pool.get_db() as conn:
cursor = conn.cursor()
cursor.execute(
"""
SELECT name FROM sqlite_master
WHERE type='table' AND name='operation_logs'
"""
)
if not cursor.fetchone():
if not _table_exists(cursor, "operation_logs"):
return 0
try:
@@ -393,11 +365,11 @@ def clean_old_operation_logs(days: int = 30) -> int:
DELETE FROM operation_logs
WHERE created_at < datetime('now', 'localtime', '-' || ? || ' days')
""",
(days,),
(safe_days,),
)
deleted_count = cursor.rowcount
conn.commit()
print(f"已清理 {deleted_count} 条旧操作日志 (>{days}天)")
print(f"已清理 {deleted_count} 条旧操作日志 (>{safe_days}天)")
return deleted_count
except Exception as e:
print(f"清理旧操作日志失败: {e}")

View File

@@ -6,12 +6,38 @@ import db_pool
from db.utils import get_cst_now_str
def _normalize_limit(value, default: int, *, minimum: int = 1, maximum: int = 500) -> int:
try:
parsed = int(value)
except Exception:
parsed = default
parsed = max(minimum, parsed)
parsed = min(maximum, parsed)
return parsed
def _normalize_offset(value, default: int = 0) -> int:
try:
parsed = int(value)
except Exception:
parsed = default
return max(0, parsed)
def _normalize_announcement_payload(title, content, image_url):
normalized_title = str(title or "").strip()
normalized_content = str(content or "").strip()
normalized_image = str(image_url or "").strip() or None
return normalized_title, normalized_content, normalized_image
def _deactivate_all_active_announcements(cursor, cst_time: str) -> None:
cursor.execute("UPDATE announcements SET is_active = 0, updated_at = ? WHERE is_active = 1", (cst_time,))
def create_announcement(title, content, image_url=None, is_active=True):
"""创建公告(默认启用;启用时会自动停用其他公告)"""
title = (title or "").strip()
content = (content or "").strip()
image_url = (image_url or "").strip()
image_url = image_url or None
title, content, image_url = _normalize_announcement_payload(title, content, image_url)
if not title or not content:
return None
@@ -20,7 +46,7 @@ def create_announcement(title, content, image_url=None, is_active=True):
cst_time = get_cst_now_str()
if is_active:
cursor.execute("UPDATE announcements SET is_active = 0, updated_at = ? WHERE is_active = 1", (cst_time,))
_deactivate_all_active_announcements(cursor, cst_time)
cursor.execute(
"""
@@ -44,6 +70,9 @@ def get_announcement_by_id(announcement_id):
def get_announcements(limit=50, offset=0):
"""获取公告列表(管理员用)"""
safe_limit = _normalize_limit(limit, 50)
safe_offset = _normalize_offset(offset, 0)
with db_pool.get_db() as conn:
cursor = conn.cursor()
cursor.execute(
@@ -52,7 +81,7 @@ def get_announcements(limit=50, offset=0):
ORDER BY created_at DESC, id DESC
LIMIT ? OFFSET ?
""",
(limit, offset),
(safe_limit, safe_offset),
)
return [dict(row) for row in cursor.fetchall()]
@@ -64,7 +93,7 @@ def set_announcement_active(announcement_id, is_active):
cst_time = get_cst_now_str()
if is_active:
cursor.execute("UPDATE announcements SET is_active = 0, updated_at = ? WHERE is_active = 1", (cst_time,))
_deactivate_all_active_announcements(cursor, cst_time)
cursor.execute(
"""
UPDATE announcements
@@ -121,13 +150,12 @@ def dismiss_announcement_for_user(user_id, announcement_id):
"""用户永久关闭某条公告(幂等)"""
with db_pool.get_db() as conn:
cursor = conn.cursor()
cst_time = get_cst_now_str()
cursor.execute(
"""
INSERT OR IGNORE INTO announcement_dismissals (user_id, announcement_id, dismissed_at)
VALUES (?, ?, ?)
""",
(user_id, announcement_id, cst_time),
(user_id, announcement_id, get_cst_now_str()),
)
conn.commit()
return cursor.rowcount >= 0

View File

@@ -5,6 +5,27 @@ from __future__ import annotations
import db_pool
def _to_bool_with_default(value, default: bool = True) -> bool:
if value is None:
return default
try:
return bool(int(value))
except Exception:
try:
return bool(value)
except Exception:
return default
def _normalize_notify_enabled(enabled) -> int:
if isinstance(enabled, bool):
return 1 if enabled else 0
try:
return 1 if int(enabled) else 0
except Exception:
return 1
def get_user_by_email(email):
"""根据邮箱获取用户"""
with db_pool.get_db() as conn:
@@ -25,7 +46,7 @@ def update_user_email(user_id, email, verified=False):
SET email = ?, email_verified = ?
WHERE id = ?
""",
(email, int(verified), user_id),
(email, 1 if verified else 0, user_id),
)
conn.commit()
return cursor.rowcount > 0
@@ -42,7 +63,7 @@ def update_user_email_notify(user_id, enabled):
SET email_notify_enabled = ?
WHERE id = ?
""",
(int(enabled), user_id),
(_normalize_notify_enabled(enabled), user_id),
)
conn.commit()
return cursor.rowcount > 0
@@ -57,6 +78,6 @@ def get_user_email_notify(user_id):
row = cursor.fetchone()
if row is None:
return True
return bool(row[0]) if row[0] is not None else True
return _to_bool_with_default(row[0], default=True)
except Exception:
return True

View File

@@ -2,32 +2,73 @@
# -*- coding: utf-8 -*-
from __future__ import annotations
from datetime import datetime
import pytz
import db_pool
from db.utils import escape_html
from db.utils import escape_html, get_cst_now_str
def _normalize_limit(value, default: int, *, minimum: int = 1, maximum: int = 500) -> int:
try:
parsed = int(value)
except Exception:
parsed = default
parsed = max(minimum, parsed)
parsed = min(maximum, parsed)
return parsed
def _normalize_offset(value, default: int = 0) -> int:
try:
parsed = int(value)
except Exception:
parsed = default
return max(0, parsed)
def _safe_text(value) -> str:
if value is None:
return ""
text = str(value)
return escape_html(text) if text else ""
def _build_feedback_filter_sql(status_filter=None) -> tuple[str, list]:
where_clauses = ["1=1"]
params = []
if status_filter:
where_clauses.append("status = ?")
params.append(status_filter)
return " AND ".join(where_clauses), params
def _normalize_feedback_stats_row(row) -> dict:
row_dict = dict(row) if row else {}
return {
"total": int(row_dict.get("total") or 0),
"pending": int(row_dict.get("pending") or 0),
"replied": int(row_dict.get("replied") or 0),
"closed": int(row_dict.get("closed") or 0),
}
def create_bug_feedback(user_id, username, title, description, contact=""):
"""创建Bug反馈带XSS防护"""
with db_pool.get_db() as conn:
cursor = conn.cursor()
cst_tz = pytz.timezone("Asia/Shanghai")
cst_time = datetime.now(cst_tz).strftime("%Y-%m-%d %H:%M:%S")
safe_title = escape_html(title) if title else ""
safe_description = escape_html(description) if description else ""
safe_contact = escape_html(contact) if contact else ""
safe_username = escape_html(username) if username else ""
cursor.execute(
"""
INSERT INTO bug_feedbacks (user_id, username, title, description, contact, created_at)
VALUES (?, ?, ?, ?, ?, ?)
""",
(user_id, safe_username, safe_title, safe_description, safe_contact, cst_time),
(
user_id,
_safe_text(username),
_safe_text(title),
_safe_text(description),
_safe_text(contact),
get_cst_now_str(),
),
)
conn.commit()
@@ -36,25 +77,25 @@ def create_bug_feedback(user_id, username, title, description, contact=""):
def get_bug_feedbacks(limit=100, offset=0, status_filter=None):
"""获取Bug反馈列表管理员用"""
safe_limit = _normalize_limit(limit, 100, minimum=1, maximum=1000)
safe_offset = _normalize_offset(offset, 0)
with db_pool.get_db() as conn:
cursor = conn.cursor()
sql = "SELECT * FROM bug_feedbacks WHERE 1=1"
params = []
if status_filter:
sql += " AND status = ?"
params.append(status_filter)
sql += " ORDER BY created_at DESC LIMIT ? OFFSET ?"
params.extend([limit, offset])
cursor.execute(sql, params)
where_sql, params = _build_feedback_filter_sql(status_filter=status_filter)
sql = f"""
SELECT * FROM bug_feedbacks
WHERE {where_sql}
ORDER BY created_at DESC
LIMIT ? OFFSET ?
"""
cursor.execute(sql, params + [safe_limit, safe_offset])
return [dict(row) for row in cursor.fetchall()]
def get_user_feedbacks(user_id, limit=50):
"""获取用户自己的反馈列表"""
safe_limit = _normalize_limit(limit, 50, minimum=1, maximum=1000)
with db_pool.get_db() as conn:
cursor = conn.cursor()
cursor.execute(
@@ -64,7 +105,7 @@ def get_user_feedbacks(user_id, limit=50):
ORDER BY created_at DESC
LIMIT ?
""",
(user_id, limit),
(user_id, safe_limit),
)
return [dict(row) for row in cursor.fetchall()]
@@ -82,18 +123,13 @@ def reply_feedback(feedback_id, admin_reply):
"""管理员回复反馈带XSS防护"""
with db_pool.get_db() as conn:
cursor = conn.cursor()
cst_tz = pytz.timezone("Asia/Shanghai")
cst_time = datetime.now(cst_tz).strftime("%Y-%m-%d %H:%M:%S")
safe_reply = escape_html(admin_reply) if admin_reply else ""
cursor.execute(
"""
UPDATE bug_feedbacks
SET admin_reply = ?, status = 'replied', replied_at = ?
WHERE id = ?
""",
(safe_reply, cst_time, feedback_id),
(_safe_text(admin_reply), get_cst_now_str(), feedback_id),
)
conn.commit()
@@ -139,6 +175,4 @@ def get_feedback_stats():
FROM bug_feedbacks
"""
)
row = cursor.fetchone()
return dict(row) if row else {"total": 0, "pending": 0, "replied": 0, "closed": 0}
return _normalize_feedback_stats_row(cursor.fetchone())

View File

@@ -28,105 +28,136 @@ def set_current_version(conn, version: int) -> None:
conn.commit()
def _table_exists(cursor, table_name: str) -> bool:
cursor.execute("SELECT name FROM sqlite_master WHERE type='table' AND name=?", (str(table_name),))
return cursor.fetchone() is not None
def _get_table_columns(cursor, table_name: str) -> set[str]:
cursor.execute(f"PRAGMA table_info({table_name})")
return {col[1] for col in cursor.fetchall()}
def _add_column_if_missing(cursor, table_name: str, columns: set[str], column_name: str, column_ddl: str, *, ok_message: str) -> bool:
if column_name in columns:
return False
cursor.execute(f"ALTER TABLE {table_name} ADD COLUMN {column_name} {column_ddl}")
columns.add(column_name)
print(ok_message)
return True
def _read_row_value(row, key: str, index: int):
if isinstance(row, sqlite3.Row):
return row[key]
return row[index]
def _get_migration_steps():
return [
(1, _migrate_to_v1),
(2, _migrate_to_v2),
(3, _migrate_to_v3),
(4, _migrate_to_v4),
(5, _migrate_to_v5),
(6, _migrate_to_v6),
(7, _migrate_to_v7),
(8, _migrate_to_v8),
(9, _migrate_to_v9),
(10, _migrate_to_v10),
(11, _migrate_to_v11),
(12, _migrate_to_v12),
(13, _migrate_to_v13),
(14, _migrate_to_v14),
(15, _migrate_to_v15),
(16, _migrate_to_v16),
(17, _migrate_to_v17),
(18, _migrate_to_v18),
]
def migrate_database(conn, target_version: int) -> None:
"""数据库迁移:按版本增量升级(向前兼容)。"""
cursor = conn.cursor()
cursor.execute("INSERT OR IGNORE INTO db_version (id, version, updated_at) VALUES (1, 0, ?)", (get_cst_now_str(),))
conn.commit()
target_version = int(target_version)
current_version = get_current_version(conn)
if current_version < 1:
_migrate_to_v1(conn)
current_version = 1
if current_version < 2:
_migrate_to_v2(conn)
current_version = 2
if current_version < 3:
_migrate_to_v3(conn)
current_version = 3
if current_version < 4:
_migrate_to_v4(conn)
current_version = 4
if current_version < 5:
_migrate_to_v5(conn)
current_version = 5
if current_version < 6:
_migrate_to_v6(conn)
current_version = 6
if current_version < 7:
_migrate_to_v7(conn)
current_version = 7
if current_version < 8:
_migrate_to_v8(conn)
current_version = 8
if current_version < 9:
_migrate_to_v9(conn)
current_version = 9
if current_version < 10:
_migrate_to_v10(conn)
current_version = 10
if current_version < 11:
_migrate_to_v11(conn)
current_version = 11
if current_version < 12:
_migrate_to_v12(conn)
current_version = 12
if current_version < 13:
_migrate_to_v13(conn)
current_version = 13
if current_version < 14:
_migrate_to_v14(conn)
current_version = 14
if current_version < 15:
_migrate_to_v15(conn)
current_version = 15
if current_version < 16:
_migrate_to_v16(conn)
current_version = 16
if current_version < 17:
_migrate_to_v17(conn)
current_version = 17
if current_version < 18:
_migrate_to_v18(conn)
current_version = 18
for version, migrate_fn in _get_migration_steps():
if version > target_version or current_version >= version:
continue
migrate_fn(conn)
current_version = version
if current_version != int(target_version):
set_current_version(conn, int(target_version))
if current_version != target_version:
set_current_version(conn, target_version)
def _migrate_to_v1(conn):
"""迁移到版本1 - 添加缺失字段"""
cursor = conn.cursor()
cursor.execute("PRAGMA table_info(system_config)")
columns = [col[1] for col in cursor.fetchall()]
system_columns = _get_table_columns(cursor, "system_config")
_add_column_if_missing(
cursor,
"system_config",
system_columns,
"schedule_weekdays",
'TEXT DEFAULT "1,2,3,4,5,6,7"',
ok_message=" [OK] 添加 schedule_weekdays 字段",
)
_add_column_if_missing(
cursor,
"system_config",
system_columns,
"max_screenshot_concurrent",
"INTEGER DEFAULT 3",
ok_message=" [OK] 添加 max_screenshot_concurrent 字段",
)
_add_column_if_missing(
cursor,
"system_config",
system_columns,
"max_concurrent_per_account",
"INTEGER DEFAULT 1",
ok_message=" [OK] 添加 max_concurrent_per_account 字段",
)
_add_column_if_missing(
cursor,
"system_config",
system_columns,
"auto_approve_enabled",
"INTEGER DEFAULT 0",
ok_message=" [OK] 添加 auto_approve_enabled 字段",
)
_add_column_if_missing(
cursor,
"system_config",
system_columns,
"auto_approve_hourly_limit",
"INTEGER DEFAULT 10",
ok_message=" [OK] 添加 auto_approve_hourly_limit 字段",
)
_add_column_if_missing(
cursor,
"system_config",
system_columns,
"auto_approve_vip_days",
"INTEGER DEFAULT 7",
ok_message=" [OK] 添加 auto_approve_vip_days 字段",
)
if "schedule_weekdays" not in columns:
cursor.execute('ALTER TABLE system_config ADD COLUMN schedule_weekdays TEXT DEFAULT "1,2,3,4,5,6,7"')
print(" [OK] 添加 schedule_weekdays 字段")
if "max_screenshot_concurrent" not in columns:
cursor.execute("ALTER TABLE system_config ADD COLUMN max_screenshot_concurrent INTEGER DEFAULT 3")
print(" [OK] 添加 max_screenshot_concurrent 字段")
if "max_concurrent_per_account" not in columns:
cursor.execute("ALTER TABLE system_config ADD COLUMN max_concurrent_per_account INTEGER DEFAULT 1")
print(" [OK] 添加 max_concurrent_per_account 字段")
if "auto_approve_enabled" not in columns:
cursor.execute("ALTER TABLE system_config ADD COLUMN auto_approve_enabled INTEGER DEFAULT 0")
print(" [OK] 添加 auto_approve_enabled 字段")
if "auto_approve_hourly_limit" not in columns:
cursor.execute("ALTER TABLE system_config ADD COLUMN auto_approve_hourly_limit INTEGER DEFAULT 10")
print(" [OK] 添加 auto_approve_hourly_limit 字段")
if "auto_approve_vip_days" not in columns:
cursor.execute("ALTER TABLE system_config ADD COLUMN auto_approve_vip_days INTEGER DEFAULT 7")
print(" [OK] 添加 auto_approve_vip_days 字段")
cursor.execute("PRAGMA table_info(task_logs)")
columns = [col[1] for col in cursor.fetchall()]
if "duration" not in columns:
cursor.execute("ALTER TABLE task_logs ADD COLUMN duration INTEGER")
print(" [OK] 添加 duration 字段到 task_logs")
task_log_columns = _get_table_columns(cursor, "task_logs")
_add_column_if_missing(
cursor,
"task_logs",
task_log_columns,
"duration",
"INTEGER",
ok_message=" [OK] 添加 duration 字段到 task_logs",
)
conn.commit()
@@ -135,24 +166,39 @@ def _migrate_to_v2(conn):
"""迁移到版本2 - 添加代理配置字段"""
cursor = conn.cursor()
cursor.execute("PRAGMA table_info(system_config)")
columns = [col[1] for col in cursor.fetchall()]
if "proxy_enabled" not in columns:
cursor.execute("ALTER TABLE system_config ADD COLUMN proxy_enabled INTEGER DEFAULT 0")
print(" [OK] 添加 proxy_enabled 字段")
if "proxy_api_url" not in columns:
cursor.execute('ALTER TABLE system_config ADD COLUMN proxy_api_url TEXT DEFAULT ""')
print(" [OK] 添加 proxy_api_url 字段")
if "proxy_expire_minutes" not in columns:
cursor.execute("ALTER TABLE system_config ADD COLUMN proxy_expire_minutes INTEGER DEFAULT 3")
print(" [OK] 添加 proxy_expire_minutes 字段")
if "enable_screenshot" not in columns:
cursor.execute("ALTER TABLE system_config ADD COLUMN enable_screenshot INTEGER DEFAULT 1")
print(" [OK] 添加 enable_screenshot 字段")
columns = _get_table_columns(cursor, "system_config")
_add_column_if_missing(
cursor,
"system_config",
columns,
"proxy_enabled",
"INTEGER DEFAULT 0",
ok_message=" [OK] 添加 proxy_enabled 字段",
)
_add_column_if_missing(
cursor,
"system_config",
columns,
"proxy_api_url",
'TEXT DEFAULT ""',
ok_message=" [OK] 添加 proxy_api_url 字段",
)
_add_column_if_missing(
cursor,
"system_config",
columns,
"proxy_expire_minutes",
"INTEGER DEFAULT 3",
ok_message=" [OK] 添加 proxy_expire_minutes 字段",
)
_add_column_if_missing(
cursor,
"system_config",
columns,
"enable_screenshot",
"INTEGER DEFAULT 1",
ok_message=" [OK] 添加 enable_screenshot 字段",
)
conn.commit()
@@ -161,20 +207,31 @@ def _migrate_to_v3(conn):
"""迁移到版本3 - 添加账号状态和登录失败计数字段"""
cursor = conn.cursor()
cursor.execute("PRAGMA table_info(accounts)")
columns = [col[1] for col in cursor.fetchall()]
if "status" not in columns:
cursor.execute('ALTER TABLE accounts ADD COLUMN status TEXT DEFAULT "active"')
print(" [OK] 添加 accounts.status 字段 (账号状态)")
if "login_fail_count" not in columns:
cursor.execute("ALTER TABLE accounts ADD COLUMN login_fail_count INTEGER DEFAULT 0")
print(" [OK] 添加 accounts.login_fail_count 字段 (登录失败计数)")
if "last_login_error" not in columns:
cursor.execute("ALTER TABLE accounts ADD COLUMN last_login_error TEXT")
print(" [OK] 添加 accounts.last_login_error 字段 (最后登录错误)")
columns = _get_table_columns(cursor, "accounts")
_add_column_if_missing(
cursor,
"accounts",
columns,
"status",
'TEXT DEFAULT "active"',
ok_message=" [OK] 添加 accounts.status 字段 (账号状态)",
)
_add_column_if_missing(
cursor,
"accounts",
columns,
"login_fail_count",
"INTEGER DEFAULT 0",
ok_message=" [OK] 添加 accounts.login_fail_count 字段 (登录失败计数)",
)
_add_column_if_missing(
cursor,
"accounts",
columns,
"last_login_error",
"TEXT",
ok_message=" [OK] 添加 accounts.last_login_error 字段 (最后登录错误)",
)
conn.commit()
@@ -183,12 +240,15 @@ def _migrate_to_v4(conn):
"""迁移到版本4 - 添加任务来源字段"""
cursor = conn.cursor()
cursor.execute("PRAGMA table_info(task_logs)")
columns = [col[1] for col in cursor.fetchall()]
if "source" not in columns:
cursor.execute('ALTER TABLE task_logs ADD COLUMN source TEXT DEFAULT "manual"')
print(" [OK] 添加 task_logs.source 字段 (任务来源: manual/scheduled/immediate)")
columns = _get_table_columns(cursor, "task_logs")
_add_column_if_missing(
cursor,
"task_logs",
columns,
"source",
'TEXT DEFAULT "manual"',
ok_message=" [OK] 添加 task_logs.source 字段 (任务来源: manual/scheduled/immediate)",
)
conn.commit()
@@ -300,20 +360,17 @@ def _migrate_to_v6(conn):
def _migrate_to_v7(conn):
"""迁移到版本7 - 统一存储北京时间将历史UTC时间字段整体+8小时"""
cursor = conn.cursor()
def table_exists(table_name: str) -> bool:
cursor.execute("SELECT name FROM sqlite_master WHERE type='table' AND name=?", (table_name,))
return cursor.fetchone() is not None
def column_exists(table_name: str, column_name: str) -> bool:
cursor.execute(f"PRAGMA table_info({table_name})")
return any(row[1] == column_name for row in cursor.fetchall())
columns_cache: dict[str, set[str]] = {}
def shift_utc_to_cst(table_name: str, column_name: str) -> None:
if not table_exists(table_name):
if not _table_exists(cursor, table_name):
return
if not column_exists(table_name, column_name):
if table_name not in columns_cache:
columns_cache[table_name] = _get_table_columns(cursor, table_name)
if column_name not in columns_cache[table_name]:
return
cursor.execute(
f"""
UPDATE {table_name}
@@ -329,10 +386,6 @@ def _migrate_to_v7(conn):
("accounts", "created_at"),
("password_reset_requests", "created_at"),
("password_reset_requests", "processed_at"),
]:
shift_utc_to_cst(table, col)
for table, col in [
("smtp_configs", "created_at"),
("smtp_configs", "updated_at"),
("smtp_configs", "last_success_at"),
@@ -340,10 +393,6 @@ def _migrate_to_v7(conn):
("email_tokens", "created_at"),
("email_logs", "created_at"),
("email_stats", "last_updated"),
]:
shift_utc_to_cst(table, col)
for table, col in [
("task_checkpoints", "created_at"),
("task_checkpoints", "updated_at"),
("task_checkpoints", "completed_at"),
@@ -359,15 +408,23 @@ def _migrate_to_v8(conn):
cursor = conn.cursor()
# 1) 增量字段random_delay旧库可能不存在
cursor.execute("PRAGMA table_info(user_schedules)")
columns = [col[1] for col in cursor.fetchall()]
if "random_delay" not in columns:
cursor.execute("ALTER TABLE user_schedules ADD COLUMN random_delay INTEGER DEFAULT 0")
print(" [OK] 添加 user_schedules.random_delay 字段")
if "next_run_at" not in columns:
cursor.execute("ALTER TABLE user_schedules ADD COLUMN next_run_at TIMESTAMP")
print(" [OK] 添加 user_schedules.next_run_at 字段")
columns = _get_table_columns(cursor, "user_schedules")
_add_column_if_missing(
cursor,
"user_schedules",
columns,
"random_delay",
"INTEGER DEFAULT 0",
ok_message=" [OK] 添加 user_schedules.random_delay 字段",
)
_add_column_if_missing(
cursor,
"user_schedules",
columns,
"next_run_at",
"TIMESTAMP",
ok_message=" [OK] 添加 user_schedules.next_run_at 字段",
)
cursor.execute("CREATE INDEX IF NOT EXISTS idx_user_schedules_next_run ON user_schedules(next_run_at)")
conn.commit()
@@ -392,12 +449,12 @@ def _migrate_to_v8(conn):
fixed = 0
for row in rows:
try:
schedule_id = row["id"] if isinstance(row, sqlite3.Row) else row[0]
schedule_time = row["schedule_time"] if isinstance(row, sqlite3.Row) else row[1]
weekdays = row["weekdays"] if isinstance(row, sqlite3.Row) else row[2]
random_delay = row["random_delay"] if isinstance(row, sqlite3.Row) else row[3]
last_run_at = row["last_run_at"] if isinstance(row, sqlite3.Row) else row[4]
next_run_at = row["next_run_at"] if isinstance(row, sqlite3.Row) else row[5]
schedule_id = _read_row_value(row, "id", 0)
schedule_time = _read_row_value(row, "schedule_time", 1)
weekdays = _read_row_value(row, "weekdays", 2)
random_delay = _read_row_value(row, "random_delay", 3)
last_run_at = _read_row_value(row, "last_run_at", 4)
next_run_at = _read_row_value(row, "next_run_at", 5)
except Exception:
continue
@@ -430,27 +487,46 @@ def _migrate_to_v9(conn):
"""迁移到版本9 - 邮件设置字段迁移(清理 email_service scattered ALTER TABLE"""
cursor = conn.cursor()
cursor.execute("SELECT name FROM sqlite_master WHERE type='table' AND name='email_settings'")
if not cursor.fetchone():
if not _table_exists(cursor, "email_settings"):
# 邮件表由 email_service.init_email_tables 创建;此处仅做增量字段迁移
return
cursor.execute("PRAGMA table_info(email_settings)")
columns = [col[1] for col in cursor.fetchall()]
columns = _get_table_columns(cursor, "email_settings")
changed = False
if "register_verify_enabled" not in columns:
cursor.execute("ALTER TABLE email_settings ADD COLUMN register_verify_enabled INTEGER DEFAULT 0")
print(" [OK] 添加 email_settings.register_verify_enabled 字段")
changed = True
if "base_url" not in columns:
cursor.execute("ALTER TABLE email_settings ADD COLUMN base_url TEXT DEFAULT ''")
print(" [OK] 添加 email_settings.base_url 字段")
changed = True
if "task_notify_enabled" not in columns:
cursor.execute("ALTER TABLE email_settings ADD COLUMN task_notify_enabled INTEGER DEFAULT 0")
print(" [OK] 添加 email_settings.task_notify_enabled 字段")
changed = True
changed = (
_add_column_if_missing(
cursor,
"email_settings",
columns,
"register_verify_enabled",
"INTEGER DEFAULT 0",
ok_message=" [OK] 添加 email_settings.register_verify_enabled 字段",
)
or changed
)
changed = (
_add_column_if_missing(
cursor,
"email_settings",
columns,
"base_url",
"TEXT DEFAULT ''",
ok_message=" [OK] 添加 email_settings.base_url 字段",
)
or changed
)
changed = (
_add_column_if_missing(
cursor,
"email_settings",
columns,
"task_notify_enabled",
"INTEGER DEFAULT 0",
ok_message=" [OK] 添加 email_settings.task_notify_enabled 字段",
)
or changed
)
if changed:
conn.commit()
@@ -459,18 +535,31 @@ def _migrate_to_v9(conn):
def _migrate_to_v10(conn):
"""迁移到版本10 - users 邮箱字段迁移(避免运行时 ALTER TABLE"""
cursor = conn.cursor()
cursor.execute("PRAGMA table_info(users)")
columns = [col[1] for col in cursor.fetchall()]
columns = _get_table_columns(cursor, "users")
changed = False
if "email_verified" not in columns:
cursor.execute("ALTER TABLE users ADD COLUMN email_verified INTEGER DEFAULT 0")
print(" [OK] 添加 users.email_verified 字段")
changed = True
if "email_notify_enabled" not in columns:
cursor.execute("ALTER TABLE users ADD COLUMN email_notify_enabled INTEGER DEFAULT 1")
print(" [OK] 添加 users.email_notify_enabled 字段")
changed = True
changed = (
_add_column_if_missing(
cursor,
"users",
columns,
"email_verified",
"INTEGER DEFAULT 0",
ok_message=" [OK] 添加 users.email_verified 字段",
)
or changed
)
changed = (
_add_column_if_missing(
cursor,
"users",
columns,
"email_notify_enabled",
"INTEGER DEFAULT 1",
ok_message=" [OK] 添加 users.email_notify_enabled 字段",
)
or changed
)
if changed:
conn.commit()
@@ -657,19 +746,24 @@ def _migrate_to_v15(conn):
"""迁移到版本15 - 邮件设置:新设备登录提醒全局开关"""
cursor = conn.cursor()
cursor.execute("SELECT name FROM sqlite_master WHERE type='table' AND name='email_settings'")
if not cursor.fetchone():
if not _table_exists(cursor, "email_settings"):
# 邮件表由 email_service.init_email_tables 创建;此处仅做增量字段迁移
return
cursor.execute("PRAGMA table_info(email_settings)")
columns = [col[1] for col in cursor.fetchall()]
columns = _get_table_columns(cursor, "email_settings")
changed = False
if "login_alert_enabled" not in columns:
cursor.execute("ALTER TABLE email_settings ADD COLUMN login_alert_enabled INTEGER DEFAULT 1")
print(" [OK] 添加 email_settings.login_alert_enabled 字段")
changed = True
changed = (
_add_column_if_missing(
cursor,
"email_settings",
columns,
"login_alert_enabled",
"INTEGER DEFAULT 1",
ok_message=" [OK] 添加 email_settings.login_alert_enabled 字段",
)
or changed
)
try:
cursor.execute("UPDATE email_settings SET login_alert_enabled = 1 WHERE login_alert_enabled IS NULL")
@@ -686,22 +780,24 @@ def _migrate_to_v15(conn):
def _migrate_to_v16(conn):
"""迁移到版本16 - 公告支持图片字段"""
cursor = conn.cursor()
cursor.execute("PRAGMA table_info(announcements)")
columns = [col[1] for col in cursor.fetchall()]
columns = _get_table_columns(cursor, "announcements")
if "image_url" not in columns:
cursor.execute("ALTER TABLE announcements ADD COLUMN image_url TEXT")
if _add_column_if_missing(
cursor,
"announcements",
columns,
"image_url",
"TEXT",
ok_message=" [OK] 添加 announcements.image_url 字段",
):
conn.commit()
print(" [OK] 添加 announcements.image_url 字段")
def _migrate_to_v17(conn):
"""迁移到版本17 - 金山文档上传配置与用户开关"""
cursor = conn.cursor()
cursor.execute("PRAGMA table_info(system_config)")
columns = [col[1] for col in cursor.fetchall()]
system_columns = _get_table_columns(cursor, "system_config")
system_fields = [
("kdocs_enabled", "INTEGER DEFAULT 0"),
("kdocs_doc_url", "TEXT DEFAULT ''"),
@@ -714,21 +810,29 @@ def _migrate_to_v17(conn):
("kdocs_admin_notify_email", "TEXT DEFAULT ''"),
]
for field, ddl in system_fields:
if field not in columns:
cursor.execute(f"ALTER TABLE system_config ADD COLUMN {field} {ddl}")
print(f" [OK] 添加 system_config.{field} 字段")
cursor.execute("PRAGMA table_info(users)")
columns = [col[1] for col in cursor.fetchall()]
_add_column_if_missing(
cursor,
"system_config",
system_columns,
field,
ddl,
ok_message=f" [OK] 添加 system_config.{field} 字段",
)
user_columns = _get_table_columns(cursor, "users")
user_fields = [
("kdocs_unit", "TEXT DEFAULT ''"),
("kdocs_auto_upload", "INTEGER DEFAULT 0"),
]
for field, ddl in user_fields:
if field not in columns:
cursor.execute(f"ALTER TABLE users ADD COLUMN {field} {ddl}")
print(f" [OK] 添加 users.{field} 字段")
_add_column_if_missing(
cursor,
"users",
user_columns,
field,
ddl,
ok_message=f" [OK] 添加 users.{field} 字段",
)
conn.commit()
@@ -737,15 +841,22 @@ def _migrate_to_v18(conn):
"""迁移到版本18 - 金山文档上传:有效行范围配置"""
cursor = conn.cursor()
cursor.execute("PRAGMA table_info(system_config)")
columns = [col[1] for col in cursor.fetchall()]
if "kdocs_row_start" not in columns:
cursor.execute("ALTER TABLE system_config ADD COLUMN kdocs_row_start INTEGER DEFAULT 0")
print(" [OK] 添加 system_config.kdocs_row_start 字段")
if "kdocs_row_end" not in columns:
cursor.execute("ALTER TABLE system_config ADD COLUMN kdocs_row_end INTEGER DEFAULT 0")
print(" [OK] 添加 system_config.kdocs_row_end 字段")
columns = _get_table_columns(cursor, "system_config")
_add_column_if_missing(
cursor,
"system_config",
columns,
"kdocs_row_start",
"INTEGER DEFAULT 0",
ok_message=" [OK] 添加 system_config.kdocs_row_start 字段",
)
_add_column_if_missing(
cursor,
"system_config",
columns,
"kdocs_row_end",
"INTEGER DEFAULT 0",
ok_message=" [OK] 添加 system_config.kdocs_row_end 字段",
)
conn.commit()

View File

@@ -2,12 +2,93 @@
# -*- coding: utf-8 -*-
from __future__ import annotations
from datetime import datetime
import json
from datetime import datetime, timedelta
import db_pool
from services.schedule_utils import compute_next_run_at, format_cst
from services.time_utils import get_beijing_now
_SCHEDULE_DEFAULT_TIME = "08:00"
_SCHEDULE_DEFAULT_WEEKDAYS = "1,2,3,4,5"
_ALLOWED_SCHEDULE_UPDATE_FIELDS = (
"name",
"enabled",
"schedule_time",
"weekdays",
"browse_type",
"enable_screenshot",
"random_delay",
"account_ids",
)
_ALLOWED_EXEC_LOG_UPDATE_FIELDS = (
"total_accounts",
"success_accounts",
"failed_accounts",
"total_items",
"total_attachments",
"total_screenshots",
"duration_seconds",
"status",
"error_message",
)
def _normalize_limit(limit, default: int, *, minimum: int = 1) -> int:
try:
parsed = int(limit)
except Exception:
parsed = default
if parsed < minimum:
return minimum
return parsed
def _to_int(value, default: int = 0) -> int:
try:
return int(value)
except Exception:
return default
def _format_optional_datetime(dt: datetime | None) -> str | None:
if dt is None:
return None
return format_cst(dt)
def _serialize_account_ids(account_ids) -> str:
return json.dumps(account_ids) if account_ids else "[]"
def _compute_schedule_next_run_str(
*,
now_dt,
schedule_time,
weekdays,
random_delay,
last_run_at,
) -> str:
next_dt = compute_next_run_at(
now=now_dt,
schedule_time=str(schedule_time or _SCHEDULE_DEFAULT_TIME),
weekdays=str(weekdays or _SCHEDULE_DEFAULT_WEEKDAYS),
random_delay=_to_int(random_delay, 0),
last_run_at=str(last_run_at or "") if last_run_at else None,
)
return format_cst(next_dt)
def _map_schedule_log_row(row) -> dict:
log = dict(row)
log["created_at"] = log.get("execute_time")
log["success_count"] = log.get("success_accounts", 0)
log["failed_count"] = log.get("failed_accounts", 0)
log["duration"] = log.get("duration_seconds", 0)
return log
def get_user_schedules(user_id):
"""获取用户的所有定时任务"""
@@ -44,14 +125,10 @@ def create_user_schedule(
account_ids=None,
):
"""创建用户定时任务"""
import json
with db_pool.get_db() as conn:
cursor = conn.cursor()
cst_time = format_cst(get_beijing_now())
account_ids_str = json.dumps(account_ids) if account_ids else "[]"
cursor.execute(
"""
INSERT INTO user_schedules (
@@ -66,8 +143,8 @@ def create_user_schedule(
weekdays,
browse_type,
enable_screenshot,
int(random_delay or 0),
account_ids_str,
_to_int(random_delay, 0),
_serialize_account_ids(account_ids),
cst_time,
cst_time,
),
@@ -79,28 +156,11 @@ def create_user_schedule(
def update_user_schedule(schedule_id, **kwargs):
"""更新用户定时任务"""
import json
with db_pool.get_db() as conn:
cursor = conn.cursor()
now_dt = get_beijing_now()
now_str = format_cst(now_dt)
updates = []
params = []
allowed_fields = [
"name",
"enabled",
"schedule_time",
"weekdays",
"browse_type",
"enable_screenshot",
"random_delay",
"account_ids",
]
# 读取旧值,用于决定是否需要重算 next_run_at
cursor.execute(
"""
SELECT enabled, schedule_time, weekdays, random_delay, last_run_at
@@ -112,10 +172,11 @@ def update_user_schedule(schedule_id, **kwargs):
current = cursor.fetchone()
if not current:
return False
current_enabled = int(current[0] or 0)
current_enabled = _to_int(current[0], 0)
current_time = current[1]
current_weekdays = current[2]
current_random_delay = int(current[3] or 0)
current_random_delay = _to_int(current[3], 0)
current_last_run_at = current[4]
will_enabled = current_enabled
@@ -123,21 +184,28 @@ def update_user_schedule(schedule_id, **kwargs):
next_weekdays = current_weekdays
next_random_delay = current_random_delay
for field in allowed_fields:
if field in kwargs:
value = kwargs[field]
if field == "account_ids" and isinstance(value, list):
value = json.dumps(value)
if field == "enabled":
will_enabled = 1 if value else 0
if field == "schedule_time":
next_time = value
if field == "weekdays":
next_weekdays = value
if field == "random_delay":
next_random_delay = int(value or 0)
updates.append(f"{field} = ?")
params.append(value)
updates = []
params = []
for field in _ALLOWED_SCHEDULE_UPDATE_FIELDS:
if field not in kwargs:
continue
value = kwargs[field]
if field == "account_ids" and isinstance(value, list):
value = json.dumps(value)
if field == "enabled":
will_enabled = 1 if value else 0
if field == "schedule_time":
next_time = value
if field == "weekdays":
next_weekdays = value
if field == "random_delay":
next_random_delay = int(value or 0)
updates.append(f"{field} = ?")
params.append(value)
if not updates:
return False
@@ -145,30 +213,26 @@ def update_user_schedule(schedule_id, **kwargs):
updates.append("updated_at = ?")
params.append(now_str)
# 关键字段变更后重算 next_run_at确保索引驱动不会跑偏
#
# 需求:当用户修改“执行时间/执行日期/随机±15分钟”后即使今天已经执行过也允许按新配置在今天再次触发。
# 做法:这些关键字段发生变更时,重算 next_run_at 时忽略 last_run_at 的“同日仅一次”限制。
config_changed = any(key in kwargs for key in ["schedule_time", "weekdays", "random_delay"])
config_changed = any(key in kwargs for key in ("schedule_time", "weekdays", "random_delay"))
enabled_toggled = "enabled" in kwargs
should_recompute_next = config_changed or (enabled_toggled and will_enabled == 1)
if should_recompute_next:
next_dt = compute_next_run_at(
now=now_dt,
schedule_time=str(next_time or "08:00"),
weekdays=str(next_weekdays or "1,2,3,4,5"),
random_delay=int(next_random_delay or 0),
last_run_at=None if config_changed else (str(current_last_run_at or "") if current_last_run_at else None),
next_run_at = _compute_schedule_next_run_str(
now_dt=now_dt,
schedule_time=next_time,
weekdays=next_weekdays,
random_delay=next_random_delay,
last_run_at=None if config_changed else current_last_run_at,
)
updates.append("next_run_at = ?")
params.append(format_cst(next_dt))
params.append(next_run_at)
# 若本次显式禁用任务,则 next_run_at 清空(与 toggle 行为保持一致)
if enabled_toggled and will_enabled == 0:
updates.append("next_run_at = ?")
params.append(None)
params.append(schedule_id)
params.append(schedule_id)
sql = f"UPDATE user_schedules SET {', '.join(updates)} WHERE id = ?"
cursor.execute(sql, params)
conn.commit()
@@ -203,28 +267,19 @@ def toggle_user_schedule(schedule_id, enabled):
)
row = cursor.fetchone()
if row:
schedule_time, weekdays, random_delay, last_run_at, existing_next_run_at = (
row[0],
row[1],
row[2],
row[3],
row[4],
)
schedule_time, weekdays, random_delay, last_run_at, existing_next_run_at = row
existing_next_run_at = str(existing_next_run_at or "").strip() or None
# 若 next_run_at 已经被“修改配置”逻辑预先计算好且仍在未来,则优先沿用,
# 避免 last_run_at 的“同日仅一次”限制阻塞用户把任务调整到今天再次触发。
if existing_next_run_at and existing_next_run_at > now_str:
next_run_at = existing_next_run_at
else:
next_dt = compute_next_run_at(
now=now_dt,
schedule_time=str(schedule_time or "08:00"),
weekdays=str(weekdays or "1,2,3,4,5"),
random_delay=int(random_delay or 0),
last_run_at=str(last_run_at or "") if last_run_at else None,
next_run_at = _compute_schedule_next_run_str(
now_dt=now_dt,
schedule_time=schedule_time,
weekdays=weekdays,
random_delay=random_delay,
last_run_at=last_run_at,
)
next_run_at = format_cst(next_dt)
cursor.execute(
"""
@@ -272,16 +327,15 @@ def update_schedule_last_run(schedule_id):
row = cursor.fetchone()
if not row:
return False
schedule_time, weekdays, random_delay = row[0], row[1], row[2]
next_dt = compute_next_run_at(
now=now_dt,
schedule_time=str(schedule_time or "08:00"),
weekdays=str(weekdays or "1,2,3,4,5"),
random_delay=int(random_delay or 0),
schedule_time, weekdays, random_delay = row
next_run_at = _compute_schedule_next_run_str(
now_dt=now_dt,
schedule_time=schedule_time,
weekdays=weekdays,
random_delay=random_delay,
last_run_at=now_str,
)
next_run_at = format_cst(next_dt)
cursor.execute(
"""
@@ -305,7 +359,11 @@ def update_schedule_next_run(schedule_id: int, next_run_at: str) -> bool:
SET next_run_at = ?, updated_at = ?
WHERE id = ?
""",
(str(next_run_at or "").strip() or None, format_cst(get_beijing_now()), int(schedule_id)),
(
str(next_run_at or "").strip() or None,
format_cst(get_beijing_now()),
int(schedule_id),
),
)
conn.commit()
return cursor.rowcount > 0
@@ -328,15 +386,15 @@ def recompute_schedule_next_run(schedule_id: int, *, now_dt=None) -> bool:
if not row:
return False
schedule_time, weekdays, random_delay, last_run_at = row[0], row[1], row[2], row[3]
next_dt = compute_next_run_at(
now=now_dt,
schedule_time=str(schedule_time or "08:00"),
weekdays=str(weekdays or "1,2,3,4,5"),
random_delay=int(random_delay or 0),
last_run_at=str(last_run_at or "") if last_run_at else None,
schedule_time, weekdays, random_delay, last_run_at = row
next_run_at = _compute_schedule_next_run_str(
now_dt=now_dt,
schedule_time=schedule_time,
weekdays=weekdays,
random_delay=random_delay,
last_run_at=last_run_at,
)
return update_schedule_next_run(int(schedule_id), format_cst(next_dt))
return update_schedule_next_run(int(schedule_id), next_run_at)
def get_due_user_schedules(now_cst: str, limit: int = 50):
@@ -345,6 +403,8 @@ def get_due_user_schedules(now_cst: str, limit: int = 50):
if not now_cst:
now_cst = format_cst(get_beijing_now())
safe_limit = _normalize_limit(limit, 50, minimum=1)
with db_pool.get_db() as conn:
cursor = conn.cursor()
cursor.execute(
@@ -358,7 +418,7 @@ def get_due_user_schedules(now_cst: str, limit: int = 50):
ORDER BY us.next_run_at ASC
LIMIT ?
""",
(now_cst, int(limit)),
(now_cst, safe_limit),
)
return [dict(row) for row in cursor.fetchall()]
@@ -370,15 +430,13 @@ def create_schedule_execution_log(schedule_id, user_id, schedule_name):
"""创建定时任务执行日志"""
with db_pool.get_db() as conn:
cursor = conn.cursor()
execute_time = format_cst(get_beijing_now())
cursor.execute(
"""
INSERT INTO schedule_execution_logs (
schedule_id, user_id, schedule_name, execute_time, status
) VALUES (?, ?, ?, ?, 'running')
""",
(schedule_id, user_id, schedule_name, execute_time),
(schedule_id, user_id, schedule_name, format_cst(get_beijing_now())),
)
conn.commit()
@@ -393,22 +451,11 @@ def update_schedule_execution_log(log_id, **kwargs):
updates = []
params = []
allowed_fields = [
"total_accounts",
"success_accounts",
"failed_accounts",
"total_items",
"total_attachments",
"total_screenshots",
"duration_seconds",
"status",
"error_message",
]
for field in allowed_fields:
if field in kwargs:
updates.append(f"{field} = ?")
params.append(kwargs[field])
for field in _ALLOWED_EXEC_LOG_UPDATE_FIELDS:
if field not in kwargs:
continue
updates.append(f"{field} = ?")
params.append(kwargs[field])
if not updates:
return False
@@ -424,6 +471,7 @@ def update_schedule_execution_log(log_id, **kwargs):
def get_schedule_execution_logs(schedule_id, limit=10):
"""获取定时任务执行日志"""
try:
safe_limit = _normalize_limit(limit, 10, minimum=1)
with db_pool.get_db() as conn:
cursor = conn.cursor()
cursor.execute(
@@ -433,24 +481,16 @@ def get_schedule_execution_logs(schedule_id, limit=10):
ORDER BY execute_time DESC
LIMIT ?
""",
(schedule_id, limit),
(schedule_id, safe_limit),
)
logs = []
rows = cursor.fetchall()
for row in rows:
for row in cursor.fetchall():
try:
log = dict(row)
log["created_at"] = log.get("execute_time")
log["success_count"] = log.get("success_accounts", 0)
log["failed_count"] = log.get("failed_accounts", 0)
log["duration"] = log.get("duration_seconds", 0)
logs.append(log)
logs.append(_map_schedule_log_row(row))
except Exception as e:
print(f"[数据库] 处理日志行时出错: {e}")
continue
return logs
except Exception as e:
print(f"[数据库] 查询定时任务日志时出错: {e}")
@@ -462,6 +502,7 @@ def get_schedule_execution_logs(schedule_id, limit=10):
def get_user_all_schedule_logs(user_id, limit=50):
"""获取用户所有定时任务的执行日志"""
safe_limit = _normalize_limit(limit, 50, minimum=1)
with db_pool.get_db() as conn:
cursor = conn.cursor()
cursor.execute(
@@ -471,7 +512,7 @@ def get_user_all_schedule_logs(user_id, limit=50):
ORDER BY execute_time DESC
LIMIT ?
""",
(user_id, limit),
(user_id, safe_limit),
)
return [dict(row) for row in cursor.fetchall()]
@@ -493,14 +534,21 @@ def delete_schedule_logs(schedule_id, user_id):
def clean_old_schedule_logs(days=30):
"""清理指定天数前的定时任务执行日志"""
safe_days = _to_int(days, 30)
if safe_days < 0:
safe_days = 0
cutoff_dt = get_beijing_now() - timedelta(days=safe_days)
cutoff_str = format_cst(cutoff_dt)
with db_pool.get_db() as conn:
cursor = conn.cursor()
cursor.execute(
"""
DELETE FROM schedule_execution_logs
WHERE execute_time < datetime('now', 'localtime', '-' || ? || ' days')
WHERE execute_time < ?
""",
(days,),
(cutoff_str,),
)
conn.commit()
return cursor.rowcount

View File

@@ -362,6 +362,8 @@ def ensure_schema(conn) -> None:
cursor.execute("CREATE INDEX IF NOT EXISTS idx_users_username ON users(username)")
cursor.execute("CREATE INDEX IF NOT EXISTS idx_users_status ON users(status)")
cursor.execute("CREATE INDEX IF NOT EXISTS idx_users_vip_expire ON users(vip_expire_time)")
cursor.execute("CREATE INDEX IF NOT EXISTS idx_users_email ON users(email)")
cursor.execute("CREATE INDEX IF NOT EXISTS idx_users_created_at ON users(created_at)")
cursor.execute("CREATE INDEX IF NOT EXISTS idx_login_fingerprints_user ON login_fingerprints(user_id)")
cursor.execute("CREATE INDEX IF NOT EXISTS idx_login_ips_user ON login_ips(user_id)")
@@ -391,6 +393,8 @@ def ensure_schema(conn) -> None:
cursor.execute("CREATE INDEX IF NOT EXISTS idx_task_logs_user_id ON task_logs(user_id)")
cursor.execute("CREATE INDEX IF NOT EXISTS idx_task_logs_status ON task_logs(status)")
cursor.execute("CREATE INDEX IF NOT EXISTS idx_task_logs_created_at ON task_logs(created_at)")
cursor.execute("CREATE INDEX IF NOT EXISTS idx_task_logs_source ON task_logs(source)")
cursor.execute("CREATE INDEX IF NOT EXISTS idx_task_logs_source_created_at ON task_logs(source, created_at)")
cursor.execute("CREATE INDEX IF NOT EXISTS idx_task_logs_user_date ON task_logs(user_id, created_at)")
cursor.execute("CREATE INDEX IF NOT EXISTS idx_bug_feedbacks_user_id ON bug_feedbacks(user_id)")
@@ -409,6 +413,9 @@ def ensure_schema(conn) -> None:
cursor.execute("CREATE INDEX IF NOT EXISTS idx_schedule_execution_logs_schedule_id ON schedule_execution_logs(schedule_id)")
cursor.execute("CREATE INDEX IF NOT EXISTS idx_schedule_execution_logs_user_id ON schedule_execution_logs(user_id)")
cursor.execute("CREATE INDEX IF NOT EXISTS idx_schedule_execution_logs_status ON schedule_execution_logs(status)")
cursor.execute("CREATE INDEX IF NOT EXISTS idx_schedule_execution_logs_execute_time ON schedule_execution_logs(execute_time)")
cursor.execute("CREATE INDEX IF NOT EXISTS idx_schedule_execution_logs_schedule_time ON schedule_execution_logs(schedule_id, execute_time)")
cursor.execute("CREATE INDEX IF NOT EXISTS idx_schedule_execution_logs_user_time ON schedule_execution_logs(user_id, execute_time)")
# 初始化VIP配置幂等
try:

View File

@@ -3,13 +3,82 @@
from __future__ import annotations
from datetime import timedelta
from typing import Any, Optional
from typing import Dict
from typing import Any, Dict, Optional
import db_pool
from db.utils import get_cst_now, get_cst_now_str
_THREAT_EVENT_SELECT_COLUMNS = """
id,
threat_type,
score,
rule,
field_name,
matched,
value_preview,
ip,
user_id,
request_method,
request_path,
user_agent,
created_at
"""
def _normalize_page(page: int) -> int:
try:
page_i = int(page)
except Exception:
page_i = 1
return max(1, page_i)
def _normalize_per_page(per_page: int, default: int = 20) -> int:
try:
value = int(per_page)
except Exception:
value = default
return max(1, min(200, value))
def _normalize_limit(limit: int, default: int = 50) -> int:
try:
value = int(limit)
except Exception:
value = default
return max(1, min(200, value))
def _row_value(row, key: str, index: int = 0, default=None):
if row is None:
return default
try:
return row[key]
except Exception:
try:
return row[index]
except Exception:
return default
def _fetch_threat_events_history(where_clause: str, params: tuple[Any, ...], limit_i: int) -> list[dict]:
with db_pool.get_db() as conn:
cursor = conn.cursor()
cursor.execute(
f"""
SELECT
{_THREAT_EVENT_SELECT_COLUMNS}
FROM threat_events
WHERE {where_clause}
ORDER BY created_at DESC, id DESC
LIMIT ?
""",
tuple(params) + (limit_i,),
)
return [dict(r) for r in cursor.fetchall()]
def record_login_context(user_id: int, ip_address: str, user_agent: str) -> Dict[str, bool]:
"""记录登录环境信息,返回是否新设备/新IP。"""
user_id = int(user_id)
@@ -36,7 +105,7 @@ def record_login_context(user_id: int, ip_address: str, user_agent: str) -> Dict
SET last_seen = ?, last_ip = ?
WHERE id = ?
""",
(now_str, ip_text, row["id"] if isinstance(row, dict) else row[0]),
(now_str, ip_text, _row_value(row, "id", 0)),
)
else:
cursor.execute(
@@ -61,7 +130,7 @@ def record_login_context(user_id: int, ip_address: str, user_agent: str) -> Dict
SET last_seen = ?
WHERE id = ?
""",
(now_str, row["id"] if isinstance(row, dict) else row[0]),
(now_str, _row_value(row, "id", 0)),
)
else:
cursor.execute(
@@ -166,15 +235,8 @@ def _build_threat_events_where_clause(filters: Optional[dict]) -> tuple[str, lis
def get_threat_events_list(page: int, per_page: int, filters: Optional[dict] = None) -> dict:
"""分页获取威胁事件。"""
try:
page_i = max(1, int(page))
except Exception:
page_i = 1
try:
per_page_i = int(per_page)
except Exception:
per_page_i = 20
per_page_i = max(1, min(200, per_page_i))
page_i = _normalize_page(page)
per_page_i = _normalize_per_page(per_page, default=20)
where_sql, params = _build_threat_events_where_clause(filters)
offset = (page_i - 1) * per_page_i
@@ -188,19 +250,7 @@ def get_threat_events_list(page: int, per_page: int, filters: Optional[dict] = N
cursor.execute(
f"""
SELECT
id,
threat_type,
score,
rule,
field_name,
matched,
value_preview,
ip,
user_id,
request_method,
request_path,
user_agent,
created_at
{_THREAT_EVENT_SELECT_COLUMNS}
FROM threat_events
{where_sql}
ORDER BY created_at DESC, id DESC
@@ -218,75 +268,20 @@ def get_ip_threat_history(ip: str, limit: int = 50) -> list[dict]:
ip_text = str(ip or "").strip()[:64]
if not ip_text:
return []
try:
limit_i = max(1, min(200, int(limit)))
except Exception:
limit_i = 50
with db_pool.get_db() as conn:
cursor = conn.cursor()
cursor.execute(
"""
SELECT
id,
threat_type,
score,
rule,
field_name,
matched,
value_preview,
ip,
user_id,
request_method,
request_path,
user_agent,
created_at
FROM threat_events
WHERE ip = ?
ORDER BY created_at DESC, id DESC
LIMIT ?
""",
(ip_text, limit_i),
)
return [dict(r) for r in cursor.fetchall()]
limit_i = _normalize_limit(limit, default=50)
return _fetch_threat_events_history("ip = ?", (ip_text,), limit_i)
def get_user_threat_history(user_id: int, limit: int = 50) -> list[dict]:
"""获取用户的威胁历史最近limit条"""
if user_id is None:
return []
try:
user_id_int = int(user_id)
except Exception:
return []
try:
limit_i = max(1, min(200, int(limit)))
except Exception:
limit_i = 50
with db_pool.get_db() as conn:
cursor = conn.cursor()
cursor.execute(
"""
SELECT
id,
threat_type,
score,
rule,
field_name,
matched,
value_preview,
ip,
user_id,
request_method,
request_path,
user_agent,
created_at
FROM threat_events
WHERE user_id = ?
ORDER BY created_at DESC, id DESC
LIMIT ?
""",
(user_id_int, limit_i),
)
return [dict(r) for r in cursor.fetchall()]
limit_i = _normalize_limit(limit, default=50)
return _fetch_threat_events_history("user_id = ?", (user_id_int,), limit_i)

View File

@@ -2,12 +2,135 @@
# -*- coding: utf-8 -*-
from __future__ import annotations
from datetime import datetime
import pytz
from datetime import datetime, timedelta
import db_pool
from db.utils import sanitize_sql_like_pattern
from db.utils import get_cst_now, get_cst_now_str, sanitize_sql_like_pattern
_TASK_STATS_SELECT_SQL = """
SELECT
COUNT(*) as total_tasks,
SUM(CASE WHEN status = 'success' THEN 1 ELSE 0 END) as success_tasks,
SUM(CASE WHEN status = 'failed' THEN 1 ELSE 0 END) as failed_tasks,
SUM(total_items) as total_items,
SUM(total_attachments) as total_attachments
FROM task_logs
"""
_USER_RUN_STATS_SELECT_SQL = """
SELECT
SUM(CASE WHEN status = 'success' THEN 1 ELSE 0 END) as completed,
SUM(CASE WHEN status = 'failed' THEN 1 ELSE 0 END) as failed,
SUM(total_items) as total_items,
SUM(total_attachments) as total_attachments
FROM task_logs
"""
def _build_day_bounds(date_filter: str) -> tuple[str | None, str | None]:
"""将 YYYY-MM-DD 转换为 [day_start, day_end) 区间。"""
try:
day_start = datetime.strptime(str(date_filter), "%Y-%m-%d")
except Exception:
return None, None
day_end = day_start + timedelta(days=1)
return day_start.strftime("%Y-%m-%d %H:%M:%S"), day_end.strftime("%Y-%m-%d %H:%M:%S")
def _normalize_int(value, default: int, *, minimum: int | None = None) -> int:
try:
parsed = int(value)
except Exception:
parsed = default
if minimum is not None and parsed < minimum:
return minimum
return parsed
def _stat_value(row, key: str) -> int:
try:
value = row[key] if row else 0
except Exception:
value = 0
return int(value or 0)
def _build_task_logs_where_sql(
*,
date_filter=None,
status_filter=None,
source_filter=None,
user_id_filter=None,
account_filter=None,
) -> tuple[str, list]:
where_clauses = ["1=1"]
params = []
if date_filter:
day_start, day_end = _build_day_bounds(date_filter)
if day_start and day_end:
where_clauses.append("tl.created_at >= ? AND tl.created_at < ?")
params.extend([day_start, day_end])
else:
where_clauses.append("date(tl.created_at) = ?")
params.append(date_filter)
if status_filter:
where_clauses.append("tl.status = ?")
params.append(status_filter)
if source_filter:
source_filter = str(source_filter or "").strip()
if source_filter == "user_scheduled":
where_clauses.append("tl.source LIKE ? ESCAPE '\\\\'")
params.append("user_scheduled:%")
elif source_filter.endswith("*"):
prefix = source_filter[:-1]
safe_prefix = sanitize_sql_like_pattern(prefix)
where_clauses.append("tl.source LIKE ? ESCAPE '\\\\'")
params.append(f"{safe_prefix}%")
else:
where_clauses.append("tl.source = ?")
params.append(source_filter)
if user_id_filter:
where_clauses.append("tl.user_id = ?")
params.append(user_id_filter)
if account_filter:
safe_filter = sanitize_sql_like_pattern(account_filter)
where_clauses.append("tl.username LIKE ? ESCAPE '\\\\'")
params.append(f"%{safe_filter}%")
return " AND ".join(where_clauses), params
def _fetch_task_stats_row(cursor, *, where_clause: str = "", params: tuple | list = ()) -> dict:
sql = _TASK_STATS_SELECT_SQL
if where_clause:
sql = f"{sql}\nWHERE {where_clause}"
cursor.execute(sql, params)
row = cursor.fetchone()
return {
"total_tasks": _stat_value(row, "total_tasks"),
"success_tasks": _stat_value(row, "success_tasks"),
"failed_tasks": _stat_value(row, "failed_tasks"),
"total_items": _stat_value(row, "total_items"),
"total_attachments": _stat_value(row, "total_attachments"),
}
def _fetch_user_run_stats_row(cursor, *, where_clause: str, params: tuple | list) -> dict:
sql = f"{_USER_RUN_STATS_SELECT_SQL}\nWHERE {where_clause}"
cursor.execute(sql, params)
row = cursor.fetchone()
return {
"completed": _stat_value(row, "completed"),
"failed": _stat_value(row, "failed"),
"total_items": _stat_value(row, "total_items"),
"total_attachments": _stat_value(row, "total_attachments"),
}
def create_task_log(
@@ -25,8 +148,6 @@ def create_task_log(
"""创建任务日志记录"""
with db_pool.get_db() as conn:
cursor = conn.cursor()
cst_tz = pytz.timezone("Asia/Shanghai")
cst_time = datetime.now(cst_tz).strftime("%Y-%m-%d %H:%M:%S")
cursor.execute(
"""
@@ -45,7 +166,7 @@ def create_task_log(
total_attachments,
error_message,
duration,
cst_time,
get_cst_now_str(),
source,
),
)
@@ -64,54 +185,27 @@ def get_task_logs(
account_filter=None,
):
"""获取任务日志列表(支持分页和多种筛选)"""
limit = _normalize_int(limit, 100, minimum=1)
offset = _normalize_int(offset, 0, minimum=0)
with db_pool.get_db() as conn:
cursor = conn.cursor()
where_clauses = ["1=1"]
params = []
if date_filter:
where_clauses.append("date(tl.created_at) = ?")
params.append(date_filter)
if status_filter:
where_clauses.append("tl.status = ?")
params.append(status_filter)
if source_filter:
source_filter = str(source_filter or "").strip()
# 兼容“虚拟来源”:用于筛选 user_scheduled:batch_xxx 这类动态值
if source_filter == "user_scheduled":
where_clauses.append("tl.source LIKE ? ESCAPE '\\\\'")
params.append("user_scheduled:%")
elif source_filter.endswith("*"):
prefix = source_filter[:-1]
safe_prefix = sanitize_sql_like_pattern(prefix)
where_clauses.append("tl.source LIKE ? ESCAPE '\\\\'")
params.append(f"{safe_prefix}%")
else:
where_clauses.append("tl.source = ?")
params.append(source_filter)
if user_id_filter:
where_clauses.append("tl.user_id = ?")
params.append(user_id_filter)
if account_filter:
safe_filter = sanitize_sql_like_pattern(account_filter)
where_clauses.append("tl.username LIKE ? ESCAPE '\\\\'")
params.append(f"%{safe_filter}%")
where_sql = " AND ".join(where_clauses)
where_sql, params = _build_task_logs_where_sql(
date_filter=date_filter,
status_filter=status_filter,
source_filter=source_filter,
user_id_filter=user_id_filter,
account_filter=account_filter,
)
count_sql = f"""
SELECT COUNT(*) as total
FROM task_logs tl
LEFT JOIN users u ON tl.user_id = u.id
WHERE {where_sql}
"""
cursor.execute(count_sql, params)
total = cursor.fetchone()["total"]
total = _stat_value(cursor.fetchone(), "total")
data_sql = f"""
SELECT
@@ -123,9 +217,10 @@ def get_task_logs(
ORDER BY tl.created_at DESC
LIMIT ? OFFSET ?
"""
params.extend([limit, offset])
data_params = list(params)
data_params.extend([limit, offset])
cursor.execute(data_sql, params)
cursor.execute(data_sql, data_params)
logs = [dict(row) for row in cursor.fetchall()]
return {"logs": logs, "total": total}
@@ -133,61 +228,39 @@ def get_task_logs(
def get_task_stats(date_filter=None):
"""获取任务统计信息"""
if date_filter is None:
date_filter = get_cst_now().strftime("%Y-%m-%d")
day_start, day_end = _build_day_bounds(date_filter)
with db_pool.get_db() as conn:
cursor = conn.cursor()
cst_tz = pytz.timezone("Asia/Shanghai")
if date_filter is None:
date_filter = datetime.now(cst_tz).strftime("%Y-%m-%d")
if day_start and day_end:
today_stats = _fetch_task_stats_row(
cursor,
where_clause="created_at >= ? AND created_at < ?",
params=(day_start, day_end),
)
else:
today_stats = _fetch_task_stats_row(
cursor,
where_clause="date(created_at) = ?",
params=(date_filter,),
)
cursor.execute(
"""
SELECT
COUNT(*) as total_tasks,
SUM(CASE WHEN status = 'success' THEN 1 ELSE 0 END) as success_tasks,
SUM(CASE WHEN status = 'failed' THEN 1 ELSE 0 END) as failed_tasks,
SUM(total_items) as total_items,
SUM(total_attachments) as total_attachments
FROM task_logs
WHERE date(created_at) = ?
""",
(date_filter,),
)
today_stats = cursor.fetchone()
total_stats = _fetch_task_stats_row(cursor)
cursor.execute(
"""
SELECT
COUNT(*) as total_tasks,
SUM(CASE WHEN status = 'success' THEN 1 ELSE 0 END) as success_tasks,
SUM(CASE WHEN status = 'failed' THEN 1 ELSE 0 END) as failed_tasks,
SUM(total_items) as total_items,
SUM(total_attachments) as total_attachments
FROM task_logs
"""
)
total_stats = cursor.fetchone()
return {
"today": {
"total_tasks": today_stats["total_tasks"] or 0,
"success_tasks": today_stats["success_tasks"] or 0,
"failed_tasks": today_stats["failed_tasks"] or 0,
"total_items": today_stats["total_items"] or 0,
"total_attachments": today_stats["total_attachments"] or 0,
},
"total": {
"total_tasks": total_stats["total_tasks"] or 0,
"success_tasks": total_stats["success_tasks"] or 0,
"failed_tasks": total_stats["failed_tasks"] or 0,
"total_items": total_stats["total_items"] or 0,
"total_attachments": total_stats["total_attachments"] or 0,
},
}
return {"today": today_stats, "total": total_stats}
def delete_old_task_logs(days=30, batch_size=1000):
"""删除N天前的任务日志分批删除避免长时间锁表"""
days = _normalize_int(days, 30, minimum=0)
batch_size = _normalize_int(batch_size, 1000, minimum=1)
cutoff = (get_cst_now() - timedelta(days=days)).strftime("%Y-%m-%d %H:%M:%S")
total_deleted = 0
while True:
with db_pool.get_db() as conn:
@@ -197,16 +270,16 @@ def delete_old_task_logs(days=30, batch_size=1000):
DELETE FROM task_logs
WHERE rowid IN (
SELECT rowid FROM task_logs
WHERE created_at < datetime('now', 'localtime', '-' || ? || ' days')
WHERE created_at < ?
LIMIT ?
)
""",
(days, batch_size),
(cutoff, batch_size),
)
deleted = cursor.rowcount
conn.commit()
if deleted == 0:
if deleted <= 0:
break
total_deleted += deleted
@@ -215,31 +288,23 @@ def delete_old_task_logs(days=30, batch_size=1000):
def get_user_run_stats(user_id, date_filter=None):
"""获取用户的运行统计信息"""
if date_filter is None:
date_filter = get_cst_now().strftime("%Y-%m-%d")
day_start, day_end = _build_day_bounds(date_filter)
with db_pool.get_db() as conn:
cst_tz = pytz.timezone("Asia/Shanghai")
cursor = conn.cursor()
if date_filter is None:
date_filter = datetime.now(cst_tz).strftime("%Y-%m-%d")
if day_start and day_end:
return _fetch_user_run_stats_row(
cursor,
where_clause="user_id = ? AND created_at >= ? AND created_at < ?",
params=(user_id, day_start, day_end),
)
cursor.execute(
"""
SELECT
SUM(CASE WHEN status = 'success' THEN 1 ELSE 0 END) as completed,
SUM(CASE WHEN status = 'failed' THEN 1 ELSE 0 END) as failed,
SUM(total_items) as total_items,
SUM(total_attachments) as total_attachments
FROM task_logs
WHERE user_id = ? AND date(created_at) = ?
""",
(user_id, date_filter),
return _fetch_user_run_stats_row(
cursor,
where_clause="user_id = ? AND date(created_at) = ?",
params=(user_id, date_filter),
)
stats = cursor.fetchone()
return {
"completed": stats["completed"] or 0,
"failed": stats["failed"] or 0,
"total_items": stats["total_items"] or 0,
"total_attachments": stats["total_attachments"] or 0,
}

View File

@@ -16,8 +16,41 @@ from password_utils import (
verify_password_bcrypt,
verify_password_sha256,
)
logger = get_logger(__name__)
_CST_TZ = pytz.timezone("Asia/Shanghai")
_PERMANENT_VIP_EXPIRE = "2099-12-31 23:59:59"
def _row_to_dict(row):
return dict(row) if row else None
def _get_user_by_field(field_name: str, field_value):
with db_pool.get_db() as conn:
cursor = conn.cursor()
cursor.execute(f"SELECT * FROM users WHERE {field_name} = ?", (field_value,))
return _row_to_dict(cursor.fetchone())
def _parse_cst_datetime(datetime_str: str | None):
if not datetime_str:
return None
try:
naive_dt = datetime.strptime(str(datetime_str), "%Y-%m-%d %H:%M:%S")
return _CST_TZ.localize(naive_dt)
except Exception:
return None
def _format_vip_expire(days: int, *, base_dt: datetime | None = None) -> str:
if int(days) == 999999:
return _PERMANENT_VIP_EXPIRE
if base_dt is None:
base_dt = datetime.now(_CST_TZ)
return (base_dt + timedelta(days=int(days))).strftime("%Y-%m-%d %H:%M:%S")
def get_vip_config():
"""获取VIP配置"""
@@ -32,13 +65,12 @@ def set_default_vip_days(days):
"""设置默认VIP天数"""
with db_pool.get_db() as conn:
cursor = conn.cursor()
cst_time = get_cst_now_str()
cursor.execute(
"""
INSERT OR REPLACE INTO vip_config (id, default_vip_days, updated_at)
VALUES (1, ?, ?)
""",
(days, cst_time),
(days, get_cst_now_str()),
)
conn.commit()
return True
@@ -47,14 +79,8 @@ def set_default_vip_days(days):
def set_user_vip(user_id, days):
"""设置用户VIP - days: 7=一周, 30=一个月, 365=一年, 999999=永久"""
with db_pool.get_db() as conn:
cst_tz = pytz.timezone("Asia/Shanghai")
cursor = conn.cursor()
if days == 999999:
expire_time = "2099-12-31 23:59:59"
else:
expire_time = (datetime.now(cst_tz) + timedelta(days=days)).strftime("%Y-%m-%d %H:%M:%S")
expire_time = _format_vip_expire(days)
cursor.execute("UPDATE users SET vip_expire_time = ? WHERE id = ?", (expire_time, user_id))
conn.commit()
return cursor.rowcount > 0
@@ -63,29 +89,26 @@ def set_user_vip(user_id, days):
def extend_user_vip(user_id, days):
"""延长用户VIP时间"""
user = get_user_by_id(user_id)
cst_tz = pytz.timezone("Asia/Shanghai")
if not user:
return False
current_expire = user.get("vip_expire_time")
now_dt = datetime.now(_CST_TZ)
if current_expire and current_expire != _PERMANENT_VIP_EXPIRE:
expire_time = _parse_cst_datetime(current_expire)
if expire_time is not None:
if expire_time < now_dt:
expire_time = now_dt
new_expire = _format_vip_expire(days, base_dt=expire_time)
else:
logger.warning("解析VIP过期时间失败使用当前时间")
new_expire = _format_vip_expire(days, base_dt=now_dt)
else:
new_expire = _format_vip_expire(days, base_dt=now_dt)
with db_pool.get_db() as conn:
cursor = conn.cursor()
current_expire = user.get("vip_expire_time")
if current_expire and current_expire != "2099-12-31 23:59:59":
try:
expire_time_naive = datetime.strptime(current_expire, "%Y-%m-%d %H:%M:%S")
expire_time = cst_tz.localize(expire_time_naive)
now = datetime.now(cst_tz)
if expire_time < now:
expire_time = now
new_expire = (expire_time + timedelta(days=days)).strftime("%Y-%m-%d %H:%M:%S")
except (ValueError, AttributeError) as e:
logger.warning(f"解析VIP过期时间失败: {e}, 使用当前时间")
new_expire = (datetime.now(cst_tz) + timedelta(days=days)).strftime("%Y-%m-%d %H:%M:%S")
else:
new_expire = (datetime.now(cst_tz) + timedelta(days=days)).strftime("%Y-%m-%d %H:%M:%S")
cursor.execute("UPDATE users SET vip_expire_time = ? WHERE id = ?", (new_expire, user_id))
conn.commit()
return cursor.rowcount > 0
@@ -105,45 +128,49 @@ def is_user_vip(user_id):
注意数据库中存储的时间统一使用CSTAsia/Shanghai时区
"""
cst_tz = pytz.timezone("Asia/Shanghai")
user = get_user_by_id(user_id)
if not user or not user.get("vip_expire_time"):
if not user:
return False
try:
expire_time_naive = datetime.strptime(user["vip_expire_time"], "%Y-%m-%d %H:%M:%S")
expire_time = cst_tz.localize(expire_time_naive)
now = datetime.now(cst_tz)
return now < expire_time
except (ValueError, AttributeError) as e:
logger.warning(f"检查VIP状态失败 (user_id={user_id}): {e}")
vip_expire_time = user.get("vip_expire_time")
if not vip_expire_time:
return False
expire_time = _parse_cst_datetime(vip_expire_time)
if expire_time is None:
logger.warning(f"检查VIP状态失败 (user_id={user_id}): 无法解析时间")
return False
return datetime.now(_CST_TZ) < expire_time
def get_user_vip_info(user_id):
"""获取用户VIP信息"""
cst_tz = pytz.timezone("Asia/Shanghai")
user = get_user_by_id(user_id)
if not user:
return {"is_vip": False, "expire_time": None, "days_left": 0, "username": ""}
vip_expire_time = user.get("vip_expire_time")
username = user.get("username", "")
if not vip_expire_time:
return {"is_vip": False, "expire_time": None, "days_left": 0, "username": user.get("username", "")}
return {"is_vip": False, "expire_time": None, "days_left": 0, "username": username}
try:
expire_time_naive = datetime.strptime(vip_expire_time, "%Y-%m-%d %H:%M:%S")
expire_time = cst_tz.localize(expire_time_naive)
now = datetime.now(cst_tz)
is_vip = now < expire_time
days_left = (expire_time - now).days if is_vip else 0
expire_time = _parse_cst_datetime(vip_expire_time)
if expire_time is None:
logger.warning("VIP信息获取错误: 无法解析过期时间")
return {"is_vip": False, "expire_time": None, "days_left": 0, "username": username}
return {"username": user.get("username", ""), "is_vip": is_vip, "expire_time": vip_expire_time, "days_left": max(0, days_left)}
except Exception as e:
logger.warning(f"VIP信息获取错误: {e}")
return {"is_vip": False, "expire_time": None, "days_left": 0, "username": user.get("username", "")}
now_dt = datetime.now(_CST_TZ)
is_vip = now_dt < expire_time
days_left = (expire_time - now_dt).days if is_vip else 0
return {
"username": username,
"is_vip": is_vip,
"expire_time": vip_expire_time,
"days_left": max(0, days_left),
}
# ==================== 用户相关 ====================
@@ -151,8 +178,6 @@ def get_user_vip_info(user_id):
def create_user(username, password, email=""):
"""创建新用户(默认直接通过,赠送默认VIP)"""
cst_tz = pytz.timezone("Asia/Shanghai")
with db_pool.get_db() as conn:
cursor = conn.cursor()
password_hash = hash_password_bcrypt(password)
@@ -160,12 +185,8 @@ def create_user(username, password, email=""):
default_vip_days = get_vip_config()["default_vip_days"]
vip_expire_time = None
if default_vip_days > 0:
if default_vip_days == 999999:
vip_expire_time = "2099-12-31 23:59:59"
else:
vip_expire_time = (datetime.now(cst_tz) + timedelta(days=default_vip_days)).strftime("%Y-%m-%d %H:%M:%S")
if int(default_vip_days or 0) > 0:
vip_expire_time = _format_vip_expire(int(default_vip_days))
try:
cursor.execute(
@@ -210,28 +231,28 @@ def verify_user(username, password):
def get_user_by_id(user_id):
"""根据ID获取用户"""
with db_pool.get_db() as conn:
cursor = conn.cursor()
cursor.execute("SELECT * FROM users WHERE id = ?", (user_id,))
user = cursor.fetchone()
return dict(user) if user else None
return _get_user_by_field("id", user_id)
def get_user_kdocs_settings(user_id):
"""获取用户的金山文档配置"""
user = get_user_by_id(user_id)
if not user:
return None
return {
"kdocs_unit": user.get("kdocs_unit") or "",
"kdocs_auto_upload": 1 if user.get("kdocs_auto_upload") else 0,
}
with db_pool.get_db() as conn:
cursor = conn.cursor()
cursor.execute("SELECT kdocs_unit, kdocs_auto_upload FROM users WHERE id = ?", (user_id,))
row = cursor.fetchone()
if not row:
return None
return {
"kdocs_unit": (row["kdocs_unit"] or "") if isinstance(row, dict) else (row[0] or ""),
"kdocs_auto_upload": 1 if ((row["kdocs_auto_upload"] if isinstance(row, dict) else row[1]) or 0) else 0,
}
def update_user_kdocs_settings(user_id, *, kdocs_unit=None, kdocs_auto_upload=None) -> bool:
"""更新用户的金山文档配置"""
updates = []
params = []
if kdocs_unit is not None:
updates.append("kdocs_unit = ?")
params.append(kdocs_unit)
@@ -252,11 +273,7 @@ def update_user_kdocs_settings(user_id, *, kdocs_unit=None, kdocs_auto_upload=No
def get_user_by_username(username):
"""根据用户名获取用户"""
with db_pool.get_db() as conn:
cursor = conn.cursor()
cursor.execute("SELECT * FROM users WHERE username = ?", (username,))
user = cursor.fetchone()
return dict(user) if user else None
return _get_user_by_field("username", username)
def get_all_users():
@@ -279,14 +296,13 @@ def approve_user(user_id):
"""审核通过用户"""
with db_pool.get_db() as conn:
cursor = conn.cursor()
cst_time = get_cst_now_str()
cursor.execute(
"""
UPDATE users
SET status = 'approved', approved_at = ?
WHERE id = ?
""",
(cst_time, user_id),
(get_cst_now_str(), user_id),
)
conn.commit()
return cursor.rowcount > 0
@@ -315,5 +331,5 @@ def get_user_stats(user_id):
with db_pool.get_db() as conn:
cursor = conn.cursor()
cursor.execute("SELECT COUNT(*) as count FROM accounts WHERE user_id = ?", (user_id,))
account_count = cursor.fetchone()["count"]
return {"account_count": account_count}
row = cursor.fetchone()
return {"account_count": int((row["count"] if row else 0) or 0)}