refactor: optimize structure, stability and runtime performance

This commit is contained in:
2026-02-07 00:35:11 +08:00
parent fae21329d7
commit bf29ac1924
44 changed files with 6894 additions and 4792 deletions

191
app.py
View File

@@ -173,10 +173,28 @@ def serve_static(filename):
if not is_safe_path("static", filename): if not is_safe_path("static", filename):
return jsonify({"error": "非法路径"}), 403 return jsonify({"error": "非法路径"}), 403
response = send_from_directory("static", filename) cache_ttl = 3600
response.headers["Cache-Control"] = "no-store, no-cache, must-revalidate, max-age=0" lowered = filename.lower()
response.headers["Pragma"] = "no-cache" if "/assets/" in lowered or lowered.endswith((".js", ".css", ".woff", ".woff2", ".ttf", ".svg")):
response.headers["Expires"] = "0" cache_ttl = 604800 # 7天
if request.args.get("v"):
cache_ttl = max(cache_ttl, 604800)
response = send_from_directory("static", filename, max_age=cache_ttl, conditional=True)
# 协商缓存:确保存在 ETag并基于 If-None-Match/If-Modified-Since 返回 304
try:
response.add_etag(overwrite=False)
except Exception:
pass
try:
response.make_conditional(request)
except Exception:
pass
response.headers.setdefault("Vary", "Accept-Encoding")
response.headers["Cache-Control"] = f"public, max-age={cache_ttl}"
return response return response
@@ -232,6 +250,93 @@ def _signal_handler(sig, frame):
sys.exit(0) sys.exit(0)
def _cleanup_stale_task_state() -> None:
logger.info("清理遗留任务状态...")
try:
from services.state import safe_get_active_task_ids, safe_remove_task, safe_remove_task_status
for _, accounts in safe_iter_user_accounts_items():
for acc in accounts.values():
if not getattr(acc, "is_running", False):
continue
acc.is_running = False
acc.should_stop = False
acc.status = "未开始"
for account_id in list(safe_get_active_task_ids()):
safe_remove_task(account_id)
safe_remove_task_status(account_id)
logger.info("[OK] 遗留任务状态已清理")
except Exception as e:
logger.warning(f"清理遗留任务状态失败: {e}")
def _init_optional_email_service() -> None:
try:
email_service.init_email_service()
logger.info("[OK] 邮件服务已初始化")
except Exception as e:
logger.warning(f"警告: 邮件服务初始化失败: {e}")
def _load_and_apply_scheduler_limits() -> None:
try:
system_config = database.get_system_config() or {}
max_concurrent_global = int(system_config.get("max_concurrent_global", config.MAX_CONCURRENT_GLOBAL))
max_concurrent_per_account = int(system_config.get("max_concurrent_per_account", config.MAX_CONCURRENT_PER_ACCOUNT))
get_task_scheduler().update_limits(max_global=max_concurrent_global, max_per_user=max_concurrent_per_account)
logger.info(f"[OK] 已加载并发配置: 全局={max_concurrent_global}, 单账号={max_concurrent_per_account}")
except Exception as e:
logger.warning(f"警告: 加载并发配置失败,使用默认值: {e}")
def _start_background_workers() -> None:
logger.info("启动定时任务调度器...")
threading.Thread(target=scheduled_task_worker, daemon=True, name="scheduled-task-worker").start()
logger.info("[OK] 定时任务调度器已启动")
logger.info("[OK] 状态推送线程已启动默认2秒/次)")
threading.Thread(target=status_push_worker, daemon=True, name="status-push-worker").start()
def _init_screenshot_worker_pool() -> None:
try:
pool_size = int((database.get_system_config() or {}).get("max_screenshot_concurrent", 3))
except Exception:
pool_size = 3
try:
logger.info(f"初始化截图线程池({pool_size}个worker按需启动执行环境空闲5分钟后自动释放...")
init_browser_worker_pool(pool_size=pool_size)
logger.info("[OK] 截图线程池初始化完成")
except Exception as e:
logger.warning(f"警告: 截图线程池初始化失败: {e}")
def _warmup_api_connection() -> None:
logger.info("预热 API 连接...")
try:
from api_browser import warmup_api_connection
threading.Thread(
target=warmup_api_connection,
kwargs={"log_callback": lambda msg: logger.info(msg)},
daemon=True,
name="api-warmup",
).start()
except Exception as e:
logger.warning(f"API 预热失败: {e}")
def _log_startup_urls() -> None:
logger.info("服务器启动中...")
logger.info(f"用户访问地址: http://{config.SERVER_HOST}:{config.SERVER_PORT}")
logger.info(f"后台管理地址: http://{config.SERVER_HOST}:{config.SERVER_PORT}/yuyx")
logger.info("默认管理员: admin (首次运行随机密码见日志)")
logger.info("=" * 60)
if __name__ == "__main__": if __name__ == "__main__":
atexit.register(cleanup_on_exit) atexit.register(cleanup_on_exit)
signal.signal(signal.SIGINT, _signal_handler) signal.signal(signal.SIGINT, _signal_handler)
@@ -245,81 +350,17 @@ if __name__ == "__main__":
init_checkpoint_manager() init_checkpoint_manager()
logger.info("[OK] 任务断点管理器已初始化") logger.info("[OK] 任务断点管理器已初始化")
# 【新增】容器重启时清理遗留的任务状态 _cleanup_stale_task_state()
logger.info("清理遗留任务状态...") _init_optional_email_service()
try:
from services.state import safe_remove_task, safe_get_active_task_ids, safe_remove_task_status
# 重置所有账号的运行状态
for _, accounts in safe_iter_user_accounts_items():
for acc in accounts.values():
if getattr(acc, "is_running", False):
acc.is_running = False
acc.should_stop = False
acc.status = "未开始"
# 清理活跃任务句柄
for account_id in list(safe_get_active_task_ids()):
safe_remove_task(account_id)
safe_remove_task_status(account_id)
logger.info("[OK] 遗留任务状态已清理")
except Exception as e:
logger.warning(f"清理遗留任务状态失败: {e}")
try:
email_service.init_email_service()
logger.info("[OK] 邮件服务已初始化")
except Exception as e:
logger.warning(f"警告: 邮件服务初始化失败: {e}")
start_cleanup_scheduler() start_cleanup_scheduler()
start_kdocs_monitor() start_kdocs_monitor()
try: _load_and_apply_scheduler_limits()
system_config = database.get_system_config() or {} _start_background_workers()
max_concurrent_global = int(system_config.get("max_concurrent_global", config.MAX_CONCURRENT_GLOBAL)) _log_startup_urls()
max_concurrent_per_account = int(system_config.get("max_concurrent_per_account", config.MAX_CONCURRENT_PER_ACCOUNT)) _init_screenshot_worker_pool()
get_task_scheduler().update_limits(max_global=max_concurrent_global, max_per_user=max_concurrent_per_account) _warmup_api_connection()
logger.info(f"[OK] 已加载并发配置: 全局={max_concurrent_global}, 单账号={max_concurrent_per_account}")
except Exception as e:
logger.warning(f"警告: 加载并发配置失败,使用默认值: {e}")
logger.info("启动定时任务调度器...")
threading.Thread(target=scheduled_task_worker, daemon=True, name="scheduled-task-worker").start()
logger.info("[OK] 定时任务调度器已启动")
logger.info("[OK] 状态推送线程已启动默认2秒/次)")
threading.Thread(target=status_push_worker, daemon=True, name="status-push-worker").start()
logger.info("服务器启动中...")
logger.info(f"用户访问地址: http://{config.SERVER_HOST}:{config.SERVER_PORT}")
logger.info(f"后台管理地址: http://{config.SERVER_HOST}:{config.SERVER_PORT}/yuyx")
logger.info("默认管理员: admin (首次运行随机密码见日志)")
logger.info("=" * 60)
try:
pool_size = int((database.get_system_config() or {}).get("max_screenshot_concurrent", 3))
except Exception:
pool_size = 3
try:
logger.info(f"初始化截图线程池({pool_size}个worker按需启动执行环境空闲5分钟后自动释放...")
init_browser_worker_pool(pool_size=pool_size)
logger.info("[OK] 截图线程池初始化完成")
except Exception as e:
logger.warning(f"警告: 截图线程池初始化失败: {e}")
# 预热 API 连接(后台进行,不阻塞启动)
logger.info("预热 API 连接...")
try:
from api_browser import warmup_api_connection
import threading
threading.Thread(
target=warmup_api_connection,
kwargs={"log_callback": lambda msg: logger.info(msg)},
daemon=True,
name="api-warmup",
).start()
except Exception as e:
logger.warning(f"API 预热失败: {e}")
socketio.run( socketio.run(
app, app,

View File

@@ -120,7 +120,7 @@ config = get_config()
DB_FILE = config.DB_FILE DB_FILE = config.DB_FILE
# 数据库版本 (用于迁移管理) # 数据库版本 (用于迁移管理)
DB_VERSION = 17 DB_VERSION = 18
# ==================== 系统配置缓存P1 / O-03 ==================== # ==================== 系统配置缓存P1 / O-03 ====================
@@ -142,6 +142,37 @@ def invalidate_system_config_cache() -> None:
_system_config_cache_loaded_at = 0.0 _system_config_cache_loaded_at = 0.0
def _normalize_system_config_value(value) -> dict:
try:
return dict(value or {})
except Exception:
return {}
def _is_system_config_cache_valid(now_ts: float) -> bool:
if _system_config_cache_value is None:
return False
if _SYSTEM_CONFIG_CACHE_TTL_SECONDS <= 0:
return True
return (now_ts - _system_config_cache_loaded_at) < _SYSTEM_CONFIG_CACHE_TTL_SECONDS
def _read_system_config_cache(now_ts: float, *, ignore_ttl: bool = False) -> Optional[dict]:
with _system_config_cache_lock:
if _system_config_cache_value is None:
return None
if (not ignore_ttl) and (not _is_system_config_cache_valid(now_ts)):
return None
return dict(_system_config_cache_value)
def _write_system_config_cache(value: dict, now_ts: float) -> None:
global _system_config_cache_value, _system_config_cache_loaded_at
with _system_config_cache_lock:
_system_config_cache_value = dict(value)
_system_config_cache_loaded_at = now_ts
def init_database(): def init_database():
"""初始化数据库表结构 + 迁移(入口统一)。""" """初始化数据库表结构 + 迁移(入口统一)。"""
db_pool.init_pool(DB_FILE, pool_size=config.DB_POOL_SIZE) db_pool.init_pool(DB_FILE, pool_size=config.DB_POOL_SIZE)
@@ -165,19 +196,21 @@ def migrate_database():
def get_system_config(): def get_system_config():
"""获取系统配置(带进程内缓存)。""" """获取系统配置(带进程内缓存)。"""
global _system_config_cache_value, _system_config_cache_loaded_at
now_ts = time.time() now_ts = time.time()
with _system_config_cache_lock:
if _system_config_cache_value is not None:
if _SYSTEM_CONFIG_CACHE_TTL_SECONDS <= 0 or (now_ts - _system_config_cache_loaded_at) < _SYSTEM_CONFIG_CACHE_TTL_SECONDS:
return dict(_system_config_cache_value)
value = _get_system_config_raw() cached_value = _read_system_config_cache(now_ts)
if cached_value is not None:
return cached_value
with _system_config_cache_lock: try:
_system_config_cache_value = dict(value) value = _normalize_system_config_value(_get_system_config_raw())
_system_config_cache_loaded_at = now_ts except Exception:
fallback_value = _read_system_config_cache(now_ts, ignore_ttl=True)
if fallback_value is not None:
return fallback_value
raise
_write_system_config_cache(value, now_ts)
return dict(value) return dict(value)

View File

@@ -6,19 +6,51 @@ import db_pool
from crypto_utils import decrypt_password, encrypt_password from crypto_utils import decrypt_password, encrypt_password
from db.utils import get_cst_now_str from db.utils import get_cst_now_str
_ACCOUNT_STATUS_QUERY_SQL = """
SELECT status, login_fail_count, last_login_error
FROM accounts
WHERE id = ?
"""
def _decode_account_password(account_dict: dict) -> dict:
account_dict["password"] = decrypt_password(account_dict.get("password", ""))
return account_dict
def _normalize_account_ids(account_ids) -> list[str]:
normalized = []
seen = set()
for account_id in account_ids or []:
if not account_id:
continue
account_key = str(account_id)
if account_key in seen:
continue
seen.add(account_key)
normalized.append(account_key)
return normalized
def create_account(user_id, account_id, username, password, remember=True, remark=""): def create_account(user_id, account_id, username, password, remember=True, remark=""):
"""创建账号(密码加密存储)""" """创建账号(密码加密存储)"""
with db_pool.get_db() as conn: with db_pool.get_db() as conn:
cursor = conn.cursor() cursor = conn.cursor()
cst_time = get_cst_now_str()
encrypted_password = encrypt_password(password) encrypted_password = encrypt_password(password)
cursor.execute( cursor.execute(
""" """
INSERT INTO accounts (id, user_id, username, password, remember, remark, created_at) INSERT INTO accounts (id, user_id, username, password, remember, remark, created_at)
VALUES (?, ?, ?, ?, ?, ?, ?) VALUES (?, ?, ?, ?, ?, ?, ?)
""", """,
(account_id, user_id, username, encrypted_password, 1 if remember else 0, remark, cst_time), (
account_id,
user_id,
username,
encrypted_password,
1 if remember else 0,
remark,
get_cst_now_str(),
),
) )
conn.commit() conn.commit()
return cursor.lastrowid return cursor.lastrowid
@@ -29,12 +61,7 @@ def get_user_accounts(user_id):
with db_pool.get_db() as conn: with db_pool.get_db() as conn:
cursor = conn.cursor() cursor = conn.cursor()
cursor.execute("SELECT * FROM accounts WHERE user_id = ? ORDER BY created_at DESC", (user_id,)) cursor.execute("SELECT * FROM accounts WHERE user_id = ? ORDER BY created_at DESC", (user_id,))
accounts = [] return [_decode_account_password(dict(row)) for row in cursor.fetchall()]
for row in cursor.fetchall():
account = dict(row)
account["password"] = decrypt_password(account.get("password", ""))
accounts.append(account)
return accounts
def get_account(account_id): def get_account(account_id):
@@ -43,11 +70,9 @@ def get_account(account_id):
cursor = conn.cursor() cursor = conn.cursor()
cursor.execute("SELECT * FROM accounts WHERE id = ?", (account_id,)) cursor.execute("SELECT * FROM accounts WHERE id = ?", (account_id,))
row = cursor.fetchone() row = cursor.fetchone()
if row: if not row:
account = dict(row) return None
account["password"] = decrypt_password(account.get("password", "")) return _decode_account_password(dict(row))
return account
return None
def update_account_remark(account_id, remark): def update_account_remark(account_id, remark):
@@ -78,33 +103,21 @@ def increment_account_login_fail(account_id, error_message):
if not row: if not row:
return False return False
fail_count = (row["login_fail_count"] or 0) + 1 fail_count = int(row["login_fail_count"] or 0) + 1
is_suspended = fail_count >= 3
if fail_count >= 3:
cursor.execute(
"""
UPDATE accounts
SET login_fail_count = ?,
last_login_error = ?,
status = 'suspended'
WHERE id = ?
""",
(fail_count, error_message, account_id),
)
conn.commit()
return True
cursor.execute( cursor.execute(
""" """
UPDATE accounts UPDATE accounts
SET login_fail_count = ?, SET login_fail_count = ?,
last_login_error = ? last_login_error = ?,
status = CASE WHEN ? = 1 THEN 'suspended' ELSE status END
WHERE id = ? WHERE id = ?
""", """,
(fail_count, error_message, account_id), (fail_count, error_message, 1 if is_suspended else 0, account_id),
) )
conn.commit() conn.commit()
return False return is_suspended
def reset_account_login_status(account_id): def reset_account_login_status(account_id):
@@ -129,29 +142,22 @@ def get_account_status(account_id):
"""获取账号状态信息""" """获取账号状态信息"""
with db_pool.get_db() as conn: with db_pool.get_db() as conn:
cursor = conn.cursor() cursor = conn.cursor()
cursor.execute( cursor.execute(_ACCOUNT_STATUS_QUERY_SQL, (account_id,))
"""
SELECT status, login_fail_count, last_login_error
FROM accounts
WHERE id = ?
""",
(account_id,),
)
return cursor.fetchone() return cursor.fetchone()
def get_account_status_batch(account_ids): def get_account_status_batch(account_ids):
"""批量获取账号状态信息""" """批量获取账号状态信息"""
account_ids = [str(account_id) for account_id in (account_ids or []) if account_id] normalized_ids = _normalize_account_ids(account_ids)
if not account_ids: if not normalized_ids:
return {} return {}
results = {} results = {}
chunk_size = 900 # 避免触发 SQLite 绑定参数上限 chunk_size = 900 # 避免触发 SQLite 绑定参数上限
with db_pool.get_db() as conn: with db_pool.get_db() as conn:
cursor = conn.cursor() cursor = conn.cursor()
for idx in range(0, len(account_ids), chunk_size): for idx in range(0, len(normalized_ids), chunk_size):
chunk = account_ids[idx : idx + chunk_size] chunk = normalized_ids[idx : idx + chunk_size]
placeholders = ",".join("?" for _ in chunk) placeholders = ",".join("?" for _ in chunk)
cursor.execute( cursor.execute(
f""" f"""

View File

@@ -3,9 +3,6 @@
from __future__ import annotations from __future__ import annotations
import sqlite3 import sqlite3
from datetime import datetime, timedelta
import pytz
import db_pool import db_pool
from db.utils import get_cst_now_str from db.utils import get_cst_now_str
@@ -16,6 +13,99 @@ from password_utils import (
verify_password_sha256, verify_password_sha256,
) )
_DEFAULT_SYSTEM_CONFIG = {
"max_concurrent_global": 2,
"max_concurrent_per_account": 1,
"max_screenshot_concurrent": 3,
"schedule_enabled": 0,
"schedule_time": "02:00",
"schedule_browse_type": "应读",
"schedule_weekdays": "1,2,3,4,5,6,7",
"proxy_enabled": 0,
"proxy_api_url": "",
"proxy_expire_minutes": 3,
"enable_screenshot": 1,
"auto_approve_enabled": 0,
"auto_approve_hourly_limit": 10,
"auto_approve_vip_days": 7,
"kdocs_enabled": 0,
"kdocs_doc_url": "",
"kdocs_default_unit": "",
"kdocs_sheet_name": "",
"kdocs_sheet_index": 0,
"kdocs_unit_column": "A",
"kdocs_image_column": "D",
"kdocs_admin_notify_enabled": 0,
"kdocs_admin_notify_email": "",
"kdocs_row_start": 0,
"kdocs_row_end": 0,
}
_SYSTEM_CONFIG_UPDATERS = (
("max_concurrent_global", "max_concurrent"),
("schedule_enabled", "schedule_enabled"),
("schedule_time", "schedule_time"),
("schedule_browse_type", "schedule_browse_type"),
("schedule_weekdays", "schedule_weekdays"),
("max_concurrent_per_account", "max_concurrent_per_account"),
("max_screenshot_concurrent", "max_screenshot_concurrent"),
("enable_screenshot", "enable_screenshot"),
("proxy_enabled", "proxy_enabled"),
("proxy_api_url", "proxy_api_url"),
("proxy_expire_minutes", "proxy_expire_minutes"),
("auto_approve_enabled", "auto_approve_enabled"),
("auto_approve_hourly_limit", "auto_approve_hourly_limit"),
("auto_approve_vip_days", "auto_approve_vip_days"),
("kdocs_enabled", "kdocs_enabled"),
("kdocs_doc_url", "kdocs_doc_url"),
("kdocs_default_unit", "kdocs_default_unit"),
("kdocs_sheet_name", "kdocs_sheet_name"),
("kdocs_sheet_index", "kdocs_sheet_index"),
("kdocs_unit_column", "kdocs_unit_column"),
("kdocs_image_column", "kdocs_image_column"),
("kdocs_admin_notify_enabled", "kdocs_admin_notify_enabled"),
("kdocs_admin_notify_email", "kdocs_admin_notify_email"),
("kdocs_row_start", "kdocs_row_start"),
("kdocs_row_end", "kdocs_row_end"),
)
def _count_scalar(cursor, sql: str, params=()) -> int:
cursor.execute(sql, params)
row = cursor.fetchone()
if not row:
return 0
try:
if "count" in row.keys():
return int(row["count"] or 0)
except Exception:
pass
try:
return int(row[0] or 0)
except Exception:
return 0
def _table_exists(cursor, table_name: str) -> bool:
cursor.execute(
"""
SELECT name FROM sqlite_master
WHERE type='table' AND name=?
""",
(table_name,),
)
return bool(cursor.fetchone())
def _normalize_days(days, default: int = 30) -> int:
try:
value = int(days)
except Exception:
value = default
if value < 0:
return 0
return value
def ensure_default_admin() -> bool: def ensure_default_admin() -> bool:
"""确保存在默认管理员账号(行为保持不变)。""" """确保存在默认管理员账号(行为保持不变)。"""
@@ -24,10 +114,9 @@ def ensure_default_admin() -> bool:
with db_pool.get_db() as conn: with db_pool.get_db() as conn:
cursor = conn.cursor() cursor = conn.cursor()
cursor.execute("SELECT COUNT(*) as count FROM admins") count = _count_scalar(cursor, "SELECT COUNT(*) as count FROM admins")
result = cursor.fetchone()
if result["count"] == 0: if count == 0:
alphabet = string.ascii_letters + string.digits alphabet = string.ascii_letters + string.digits
random_password = "".join(secrets.choice(alphabet) for _ in range(12)) random_password = "".join(secrets.choice(alphabet) for _ in range(12))
@@ -101,41 +190,33 @@ def get_system_stats() -> dict:
with db_pool.get_db() as conn: with db_pool.get_db() as conn:
cursor = conn.cursor() cursor = conn.cursor()
cursor.execute("SELECT COUNT(*) as count FROM users") total_users = _count_scalar(cursor, "SELECT COUNT(*) as count FROM users")
total_users = cursor.fetchone()["count"] approved_users = _count_scalar(cursor, "SELECT COUNT(*) as count FROM users WHERE status = 'approved'")
new_users_today = _count_scalar(
cursor.execute("SELECT COUNT(*) as count FROM users WHERE status = 'approved'") cursor,
approved_users = cursor.fetchone()["count"]
cursor.execute(
""" """
SELECT COUNT(*) as count SELECT COUNT(*) as count
FROM users FROM users
WHERE date(created_at) = date('now', 'localtime') WHERE date(created_at) = date('now', 'localtime')
""" """,
) )
new_users_today = cursor.fetchone()["count"] new_users_7d = _count_scalar(
cursor,
cursor.execute(
""" """
SELECT COUNT(*) as count SELECT COUNT(*) as count
FROM users FROM users
WHERE datetime(created_at) >= datetime('now', 'localtime', '-7 days') WHERE datetime(created_at) >= datetime('now', 'localtime', '-7 days')
""" """,
) )
new_users_7d = cursor.fetchone()["count"] total_accounts = _count_scalar(cursor, "SELECT COUNT(*) as count FROM accounts")
vip_users = _count_scalar(
cursor.execute("SELECT COUNT(*) as count FROM accounts") cursor,
total_accounts = cursor.fetchone()["count"]
cursor.execute(
""" """
SELECT COUNT(*) as count FROM users SELECT COUNT(*) as count FROM users
WHERE vip_expire_time IS NOT NULL WHERE vip_expire_time IS NOT NULL
AND datetime(vip_expire_time) > datetime('now', 'localtime') AND datetime(vip_expire_time) > datetime('now', 'localtime')
""" """,
) )
vip_users = cursor.fetchone()["count"]
return { return {
"total_users": total_users, "total_users": total_users,
@@ -153,37 +234,9 @@ def get_system_config_raw() -> dict:
cursor = conn.cursor() cursor = conn.cursor()
cursor.execute("SELECT * FROM system_config WHERE id = 1") cursor.execute("SELECT * FROM system_config WHERE id = 1")
row = cursor.fetchone() row = cursor.fetchone()
if row: if row:
return dict(row) return dict(row)
return dict(_DEFAULT_SYSTEM_CONFIG)
return {
"max_concurrent_global": 2,
"max_concurrent_per_account": 1,
"max_screenshot_concurrent": 3,
"schedule_enabled": 0,
"schedule_time": "02:00",
"schedule_browse_type": "应读",
"schedule_weekdays": "1,2,3,4,5,6,7",
"proxy_enabled": 0,
"proxy_api_url": "",
"proxy_expire_minutes": 3,
"enable_screenshot": 1,
"auto_approve_enabled": 0,
"auto_approve_hourly_limit": 10,
"auto_approve_vip_days": 7,
"kdocs_enabled": 0,
"kdocs_doc_url": "",
"kdocs_default_unit": "",
"kdocs_sheet_name": "",
"kdocs_sheet_index": 0,
"kdocs_unit_column": "A",
"kdocs_image_column": "D",
"kdocs_admin_notify_enabled": 0,
"kdocs_admin_notify_email": "",
"kdocs_row_start": 0,
"kdocs_row_end": 0,
}
def update_system_config( def update_system_config(
@@ -215,127 +268,51 @@ def update_system_config(
kdocs_row_end=None, kdocs_row_end=None,
) -> bool: ) -> bool:
"""更新系统配置仅更新DB不做缓存处理""" """更新系统配置仅更新DB不做缓存处理"""
allowed_fields = { arg_values = {
"max_concurrent_global", "max_concurrent": max_concurrent,
"schedule_enabled", "schedule_enabled": schedule_enabled,
"schedule_time", "schedule_time": schedule_time,
"schedule_browse_type", "schedule_browse_type": schedule_browse_type,
"schedule_weekdays", "schedule_weekdays": schedule_weekdays,
"max_concurrent_per_account", "max_concurrent_per_account": max_concurrent_per_account,
"max_screenshot_concurrent", "max_screenshot_concurrent": max_screenshot_concurrent,
"enable_screenshot", "enable_screenshot": enable_screenshot,
"proxy_enabled", "proxy_enabled": proxy_enabled,
"proxy_api_url", "proxy_api_url": proxy_api_url,
"proxy_expire_minutes", "proxy_expire_minutes": proxy_expire_minutes,
"auto_approve_enabled", "auto_approve_enabled": auto_approve_enabled,
"auto_approve_hourly_limit", "auto_approve_hourly_limit": auto_approve_hourly_limit,
"auto_approve_vip_days", "auto_approve_vip_days": auto_approve_vip_days,
"kdocs_enabled", "kdocs_enabled": kdocs_enabled,
"kdocs_doc_url", "kdocs_doc_url": kdocs_doc_url,
"kdocs_default_unit", "kdocs_default_unit": kdocs_default_unit,
"kdocs_sheet_name", "kdocs_sheet_name": kdocs_sheet_name,
"kdocs_sheet_index", "kdocs_sheet_index": kdocs_sheet_index,
"kdocs_unit_column", "kdocs_unit_column": kdocs_unit_column,
"kdocs_image_column", "kdocs_image_column": kdocs_image_column,
"kdocs_admin_notify_enabled", "kdocs_admin_notify_enabled": kdocs_admin_notify_enabled,
"kdocs_admin_notify_email", "kdocs_admin_notify_email": kdocs_admin_notify_email,
"kdocs_row_start", "kdocs_row_start": kdocs_row_start,
"kdocs_row_end", "kdocs_row_end": kdocs_row_end,
"updated_at",
} }
updates = []
params = []
for db_field, arg_name in _SYSTEM_CONFIG_UPDATERS:
value = arg_values.get(arg_name)
if value is None:
continue
updates.append(f"{db_field} = ?")
params.append(value)
if not updates:
return False
updates.append("updated_at = ?")
params.append(get_cst_now_str())
with db_pool.get_db() as conn: with db_pool.get_db() as conn:
cursor = conn.cursor() cursor = conn.cursor()
updates = []
params = []
if max_concurrent is not None:
updates.append("max_concurrent_global = ?")
params.append(max_concurrent)
if schedule_enabled is not None:
updates.append("schedule_enabled = ?")
params.append(schedule_enabled)
if schedule_time is not None:
updates.append("schedule_time = ?")
params.append(schedule_time)
if schedule_browse_type is not None:
updates.append("schedule_browse_type = ?")
params.append(schedule_browse_type)
if max_concurrent_per_account is not None:
updates.append("max_concurrent_per_account = ?")
params.append(max_concurrent_per_account)
if max_screenshot_concurrent is not None:
updates.append("max_screenshot_concurrent = ?")
params.append(max_screenshot_concurrent)
if enable_screenshot is not None:
updates.append("enable_screenshot = ?")
params.append(enable_screenshot)
if schedule_weekdays is not None:
updates.append("schedule_weekdays = ?")
params.append(schedule_weekdays)
if proxy_enabled is not None:
updates.append("proxy_enabled = ?")
params.append(proxy_enabled)
if proxy_api_url is not None:
updates.append("proxy_api_url = ?")
params.append(proxy_api_url)
if proxy_expire_minutes is not None:
updates.append("proxy_expire_minutes = ?")
params.append(proxy_expire_minutes)
if auto_approve_enabled is not None:
updates.append("auto_approve_enabled = ?")
params.append(auto_approve_enabled)
if auto_approve_hourly_limit is not None:
updates.append("auto_approve_hourly_limit = ?")
params.append(auto_approve_hourly_limit)
if auto_approve_vip_days is not None:
updates.append("auto_approve_vip_days = ?")
params.append(auto_approve_vip_days)
if kdocs_enabled is not None:
updates.append("kdocs_enabled = ?")
params.append(kdocs_enabled)
if kdocs_doc_url is not None:
updates.append("kdocs_doc_url = ?")
params.append(kdocs_doc_url)
if kdocs_default_unit is not None:
updates.append("kdocs_default_unit = ?")
params.append(kdocs_default_unit)
if kdocs_sheet_name is not None:
updates.append("kdocs_sheet_name = ?")
params.append(kdocs_sheet_name)
if kdocs_sheet_index is not None:
updates.append("kdocs_sheet_index = ?")
params.append(kdocs_sheet_index)
if kdocs_unit_column is not None:
updates.append("kdocs_unit_column = ?")
params.append(kdocs_unit_column)
if kdocs_image_column is not None:
updates.append("kdocs_image_column = ?")
params.append(kdocs_image_column)
if kdocs_admin_notify_enabled is not None:
updates.append("kdocs_admin_notify_enabled = ?")
params.append(kdocs_admin_notify_enabled)
if kdocs_admin_notify_email is not None:
updates.append("kdocs_admin_notify_email = ?")
params.append(kdocs_admin_notify_email)
if kdocs_row_start is not None:
updates.append("kdocs_row_start = ?")
params.append(kdocs_row_start)
if kdocs_row_end is not None:
updates.append("kdocs_row_end = ?")
params.append(kdocs_row_end)
if not updates:
return False
updates.append("updated_at = ?")
params.append(get_cst_now_str())
for update_clause in updates:
field_name = update_clause.split("=")[0].strip()
if field_name not in allowed_fields:
raise ValueError(f"非法字段名: {field_name}")
sql = f"UPDATE system_config SET {', '.join(updates)} WHERE id = 1" sql = f"UPDATE system_config SET {', '.join(updates)} WHERE id = 1"
cursor.execute(sql, params) cursor.execute(sql, params)
conn.commit() conn.commit()
@@ -346,13 +323,13 @@ def get_hourly_registration_count() -> int:
"""获取最近一小时内的注册用户数""" """获取最近一小时内的注册用户数"""
with db_pool.get_db() as conn: with db_pool.get_db() as conn:
cursor = conn.cursor() cursor = conn.cursor()
cursor.execute( return _count_scalar(
cursor,
""" """
SELECT COUNT(*) FROM users SELECT COUNT(*) as count FROM users
WHERE created_at >= datetime('now', 'localtime', '-1 hour') WHERE created_at >= datetime('now', 'localtime', '-1 hour')
""" """,
) )
return cursor.fetchone()[0]
# ==================== 密码重置(管理员) ==================== # ==================== 密码重置(管理员) ====================
@@ -374,17 +351,12 @@ def admin_reset_user_password(user_id: int, new_password: str) -> bool:
def clean_old_operation_logs(days: int = 30) -> int: def clean_old_operation_logs(days: int = 30) -> int:
"""清理指定天数前的操作日志如果存在operation_logs表""" """清理指定天数前的操作日志如果存在operation_logs表"""
safe_days = _normalize_days(days, default=30)
with db_pool.get_db() as conn: with db_pool.get_db() as conn:
cursor = conn.cursor() cursor = conn.cursor()
cursor.execute( if not _table_exists(cursor, "operation_logs"):
"""
SELECT name FROM sqlite_master
WHERE type='table' AND name='operation_logs'
"""
)
if not cursor.fetchone():
return 0 return 0
try: try:
@@ -393,11 +365,11 @@ def clean_old_operation_logs(days: int = 30) -> int:
DELETE FROM operation_logs DELETE FROM operation_logs
WHERE created_at < datetime('now', 'localtime', '-' || ? || ' days') WHERE created_at < datetime('now', 'localtime', '-' || ? || ' days')
""", """,
(days,), (safe_days,),
) )
deleted_count = cursor.rowcount deleted_count = cursor.rowcount
conn.commit() conn.commit()
print(f"已清理 {deleted_count} 条旧操作日志 (>{days}天)") print(f"已清理 {deleted_count} 条旧操作日志 (>{safe_days}天)")
return deleted_count return deleted_count
except Exception as e: except Exception as e:
print(f"清理旧操作日志失败: {e}") print(f"清理旧操作日志失败: {e}")

View File

@@ -6,12 +6,38 @@ import db_pool
from db.utils import get_cst_now_str from db.utils import get_cst_now_str
def _normalize_limit(value, default: int, *, minimum: int = 1, maximum: int = 500) -> int:
try:
parsed = int(value)
except Exception:
parsed = default
parsed = max(minimum, parsed)
parsed = min(maximum, parsed)
return parsed
def _normalize_offset(value, default: int = 0) -> int:
try:
parsed = int(value)
except Exception:
parsed = default
return max(0, parsed)
def _normalize_announcement_payload(title, content, image_url):
normalized_title = str(title or "").strip()
normalized_content = str(content or "").strip()
normalized_image = str(image_url or "").strip() or None
return normalized_title, normalized_content, normalized_image
def _deactivate_all_active_announcements(cursor, cst_time: str) -> None:
cursor.execute("UPDATE announcements SET is_active = 0, updated_at = ? WHERE is_active = 1", (cst_time,))
def create_announcement(title, content, image_url=None, is_active=True): def create_announcement(title, content, image_url=None, is_active=True):
"""创建公告(默认启用;启用时会自动停用其他公告)""" """创建公告(默认启用;启用时会自动停用其他公告)"""
title = (title or "").strip() title, content, image_url = _normalize_announcement_payload(title, content, image_url)
content = (content or "").strip()
image_url = (image_url or "").strip()
image_url = image_url or None
if not title or not content: if not title or not content:
return None return None
@@ -20,7 +46,7 @@ def create_announcement(title, content, image_url=None, is_active=True):
cst_time = get_cst_now_str() cst_time = get_cst_now_str()
if is_active: if is_active:
cursor.execute("UPDATE announcements SET is_active = 0, updated_at = ? WHERE is_active = 1", (cst_time,)) _deactivate_all_active_announcements(cursor, cst_time)
cursor.execute( cursor.execute(
""" """
@@ -44,6 +70,9 @@ def get_announcement_by_id(announcement_id):
def get_announcements(limit=50, offset=0): def get_announcements(limit=50, offset=0):
"""获取公告列表(管理员用)""" """获取公告列表(管理员用)"""
safe_limit = _normalize_limit(limit, 50)
safe_offset = _normalize_offset(offset, 0)
with db_pool.get_db() as conn: with db_pool.get_db() as conn:
cursor = conn.cursor() cursor = conn.cursor()
cursor.execute( cursor.execute(
@@ -52,7 +81,7 @@ def get_announcements(limit=50, offset=0):
ORDER BY created_at DESC, id DESC ORDER BY created_at DESC, id DESC
LIMIT ? OFFSET ? LIMIT ? OFFSET ?
""", """,
(limit, offset), (safe_limit, safe_offset),
) )
return [dict(row) for row in cursor.fetchall()] return [dict(row) for row in cursor.fetchall()]
@@ -64,7 +93,7 @@ def set_announcement_active(announcement_id, is_active):
cst_time = get_cst_now_str() cst_time = get_cst_now_str()
if is_active: if is_active:
cursor.execute("UPDATE announcements SET is_active = 0, updated_at = ? WHERE is_active = 1", (cst_time,)) _deactivate_all_active_announcements(cursor, cst_time)
cursor.execute( cursor.execute(
""" """
UPDATE announcements UPDATE announcements
@@ -121,13 +150,12 @@ def dismiss_announcement_for_user(user_id, announcement_id):
"""用户永久关闭某条公告(幂等)""" """用户永久关闭某条公告(幂等)"""
with db_pool.get_db() as conn: with db_pool.get_db() as conn:
cursor = conn.cursor() cursor = conn.cursor()
cst_time = get_cst_now_str()
cursor.execute( cursor.execute(
""" """
INSERT OR IGNORE INTO announcement_dismissals (user_id, announcement_id, dismissed_at) INSERT OR IGNORE INTO announcement_dismissals (user_id, announcement_id, dismissed_at)
VALUES (?, ?, ?) VALUES (?, ?, ?)
""", """,
(user_id, announcement_id, cst_time), (user_id, announcement_id, get_cst_now_str()),
) )
conn.commit() conn.commit()
return cursor.rowcount >= 0 return cursor.rowcount >= 0

View File

@@ -5,6 +5,27 @@ from __future__ import annotations
import db_pool import db_pool
def _to_bool_with_default(value, default: bool = True) -> bool:
if value is None:
return default
try:
return bool(int(value))
except Exception:
try:
return bool(value)
except Exception:
return default
def _normalize_notify_enabled(enabled) -> int:
if isinstance(enabled, bool):
return 1 if enabled else 0
try:
return 1 if int(enabled) else 0
except Exception:
return 1
def get_user_by_email(email): def get_user_by_email(email):
"""根据邮箱获取用户""" """根据邮箱获取用户"""
with db_pool.get_db() as conn: with db_pool.get_db() as conn:
@@ -25,7 +46,7 @@ def update_user_email(user_id, email, verified=False):
SET email = ?, email_verified = ? SET email = ?, email_verified = ?
WHERE id = ? WHERE id = ?
""", """,
(email, int(verified), user_id), (email, 1 if verified else 0, user_id),
) )
conn.commit() conn.commit()
return cursor.rowcount > 0 return cursor.rowcount > 0
@@ -42,7 +63,7 @@ def update_user_email_notify(user_id, enabled):
SET email_notify_enabled = ? SET email_notify_enabled = ?
WHERE id = ? WHERE id = ?
""", """,
(int(enabled), user_id), (_normalize_notify_enabled(enabled), user_id),
) )
conn.commit() conn.commit()
return cursor.rowcount > 0 return cursor.rowcount > 0
@@ -57,6 +78,6 @@ def get_user_email_notify(user_id):
row = cursor.fetchone() row = cursor.fetchone()
if row is None: if row is None:
return True return True
return bool(row[0]) if row[0] is not None else True return _to_bool_with_default(row[0], default=True)
except Exception: except Exception:
return True return True

View File

@@ -2,32 +2,73 @@
# -*- coding: utf-8 -*- # -*- coding: utf-8 -*-
from __future__ import annotations from __future__ import annotations
from datetime import datetime
import pytz
import db_pool import db_pool
from db.utils import escape_html from db.utils import escape_html, get_cst_now_str
def _normalize_limit(value, default: int, *, minimum: int = 1, maximum: int = 500) -> int:
try:
parsed = int(value)
except Exception:
parsed = default
parsed = max(minimum, parsed)
parsed = min(maximum, parsed)
return parsed
def _normalize_offset(value, default: int = 0) -> int:
try:
parsed = int(value)
except Exception:
parsed = default
return max(0, parsed)
def _safe_text(value) -> str:
if value is None:
return ""
text = str(value)
return escape_html(text) if text else ""
def _build_feedback_filter_sql(status_filter=None) -> tuple[str, list]:
where_clauses = ["1=1"]
params = []
if status_filter:
where_clauses.append("status = ?")
params.append(status_filter)
return " AND ".join(where_clauses), params
def _normalize_feedback_stats_row(row) -> dict:
row_dict = dict(row) if row else {}
return {
"total": int(row_dict.get("total") or 0),
"pending": int(row_dict.get("pending") or 0),
"replied": int(row_dict.get("replied") or 0),
"closed": int(row_dict.get("closed") or 0),
}
def create_bug_feedback(user_id, username, title, description, contact=""): def create_bug_feedback(user_id, username, title, description, contact=""):
"""创建Bug反馈带XSS防护""" """创建Bug反馈带XSS防护"""
with db_pool.get_db() as conn: with db_pool.get_db() as conn:
cursor = conn.cursor() cursor = conn.cursor()
cst_tz = pytz.timezone("Asia/Shanghai")
cst_time = datetime.now(cst_tz).strftime("%Y-%m-%d %H:%M:%S")
safe_title = escape_html(title) if title else ""
safe_description = escape_html(description) if description else ""
safe_contact = escape_html(contact) if contact else ""
safe_username = escape_html(username) if username else ""
cursor.execute( cursor.execute(
""" """
INSERT INTO bug_feedbacks (user_id, username, title, description, contact, created_at) INSERT INTO bug_feedbacks (user_id, username, title, description, contact, created_at)
VALUES (?, ?, ?, ?, ?, ?) VALUES (?, ?, ?, ?, ?, ?)
""", """,
(user_id, safe_username, safe_title, safe_description, safe_contact, cst_time), (
user_id,
_safe_text(username),
_safe_text(title),
_safe_text(description),
_safe_text(contact),
get_cst_now_str(),
),
) )
conn.commit() conn.commit()
@@ -36,25 +77,25 @@ def create_bug_feedback(user_id, username, title, description, contact=""):
def get_bug_feedbacks(limit=100, offset=0, status_filter=None): def get_bug_feedbacks(limit=100, offset=0, status_filter=None):
"""获取Bug反馈列表管理员用""" """获取Bug反馈列表管理员用"""
safe_limit = _normalize_limit(limit, 100, minimum=1, maximum=1000)
safe_offset = _normalize_offset(offset, 0)
with db_pool.get_db() as conn: with db_pool.get_db() as conn:
cursor = conn.cursor() cursor = conn.cursor()
where_sql, params = _build_feedback_filter_sql(status_filter=status_filter)
sql = "SELECT * FROM bug_feedbacks WHERE 1=1" sql = f"""
params = [] SELECT * FROM bug_feedbacks
WHERE {where_sql}
if status_filter: ORDER BY created_at DESC
sql += " AND status = ?" LIMIT ? OFFSET ?
params.append(status_filter) """
cursor.execute(sql, params + [safe_limit, safe_offset])
sql += " ORDER BY created_at DESC LIMIT ? OFFSET ?"
params.extend([limit, offset])
cursor.execute(sql, params)
return [dict(row) for row in cursor.fetchall()] return [dict(row) for row in cursor.fetchall()]
def get_user_feedbacks(user_id, limit=50): def get_user_feedbacks(user_id, limit=50):
"""获取用户自己的反馈列表""" """获取用户自己的反馈列表"""
safe_limit = _normalize_limit(limit, 50, minimum=1, maximum=1000)
with db_pool.get_db() as conn: with db_pool.get_db() as conn:
cursor = conn.cursor() cursor = conn.cursor()
cursor.execute( cursor.execute(
@@ -64,7 +105,7 @@ def get_user_feedbacks(user_id, limit=50):
ORDER BY created_at DESC ORDER BY created_at DESC
LIMIT ? LIMIT ?
""", """,
(user_id, limit), (user_id, safe_limit),
) )
return [dict(row) for row in cursor.fetchall()] return [dict(row) for row in cursor.fetchall()]
@@ -82,18 +123,13 @@ def reply_feedback(feedback_id, admin_reply):
"""管理员回复反馈带XSS防护""" """管理员回复反馈带XSS防护"""
with db_pool.get_db() as conn: with db_pool.get_db() as conn:
cursor = conn.cursor() cursor = conn.cursor()
cst_tz = pytz.timezone("Asia/Shanghai")
cst_time = datetime.now(cst_tz).strftime("%Y-%m-%d %H:%M:%S")
safe_reply = escape_html(admin_reply) if admin_reply else ""
cursor.execute( cursor.execute(
""" """
UPDATE bug_feedbacks UPDATE bug_feedbacks
SET admin_reply = ?, status = 'replied', replied_at = ? SET admin_reply = ?, status = 'replied', replied_at = ?
WHERE id = ? WHERE id = ?
""", """,
(safe_reply, cst_time, feedback_id), (_safe_text(admin_reply), get_cst_now_str(), feedback_id),
) )
conn.commit() conn.commit()
@@ -139,6 +175,4 @@ def get_feedback_stats():
FROM bug_feedbacks FROM bug_feedbacks
""" """
) )
row = cursor.fetchone() return _normalize_feedback_stats_row(cursor.fetchone())
return dict(row) if row else {"total": 0, "pending": 0, "replied": 0, "closed": 0}

View File

@@ -28,105 +28,136 @@ def set_current_version(conn, version: int) -> None:
conn.commit() conn.commit()
def _table_exists(cursor, table_name: str) -> bool:
cursor.execute("SELECT name FROM sqlite_master WHERE type='table' AND name=?", (str(table_name),))
return cursor.fetchone() is not None
def _get_table_columns(cursor, table_name: str) -> set[str]:
cursor.execute(f"PRAGMA table_info({table_name})")
return {col[1] for col in cursor.fetchall()}
def _add_column_if_missing(cursor, table_name: str, columns: set[str], column_name: str, column_ddl: str, *, ok_message: str) -> bool:
if column_name in columns:
return False
cursor.execute(f"ALTER TABLE {table_name} ADD COLUMN {column_name} {column_ddl}")
columns.add(column_name)
print(ok_message)
return True
def _read_row_value(row, key: str, index: int):
if isinstance(row, sqlite3.Row):
return row[key]
return row[index]
def _get_migration_steps():
return [
(1, _migrate_to_v1),
(2, _migrate_to_v2),
(3, _migrate_to_v3),
(4, _migrate_to_v4),
(5, _migrate_to_v5),
(6, _migrate_to_v6),
(7, _migrate_to_v7),
(8, _migrate_to_v8),
(9, _migrate_to_v9),
(10, _migrate_to_v10),
(11, _migrate_to_v11),
(12, _migrate_to_v12),
(13, _migrate_to_v13),
(14, _migrate_to_v14),
(15, _migrate_to_v15),
(16, _migrate_to_v16),
(17, _migrate_to_v17),
(18, _migrate_to_v18),
]
def migrate_database(conn, target_version: int) -> None: def migrate_database(conn, target_version: int) -> None:
"""数据库迁移:按版本增量升级(向前兼容)。""" """数据库迁移:按版本增量升级(向前兼容)。"""
cursor = conn.cursor() cursor = conn.cursor()
cursor.execute("INSERT OR IGNORE INTO db_version (id, version, updated_at) VALUES (1, 0, ?)", (get_cst_now_str(),)) cursor.execute("INSERT OR IGNORE INTO db_version (id, version, updated_at) VALUES (1, 0, ?)", (get_cst_now_str(),))
conn.commit() conn.commit()
target_version = int(target_version)
current_version = get_current_version(conn) current_version = get_current_version(conn)
if current_version < 1: for version, migrate_fn in _get_migration_steps():
_migrate_to_v1(conn) if version > target_version or current_version >= version:
current_version = 1 continue
if current_version < 2: migrate_fn(conn)
_migrate_to_v2(conn) current_version = version
current_version = 2
if current_version < 3:
_migrate_to_v3(conn)
current_version = 3
if current_version < 4:
_migrate_to_v4(conn)
current_version = 4
if current_version < 5:
_migrate_to_v5(conn)
current_version = 5
if current_version < 6:
_migrate_to_v6(conn)
current_version = 6
if current_version < 7:
_migrate_to_v7(conn)
current_version = 7
if current_version < 8:
_migrate_to_v8(conn)
current_version = 8
if current_version < 9:
_migrate_to_v9(conn)
current_version = 9
if current_version < 10:
_migrate_to_v10(conn)
current_version = 10
if current_version < 11:
_migrate_to_v11(conn)
current_version = 11
if current_version < 12:
_migrate_to_v12(conn)
current_version = 12
if current_version < 13:
_migrate_to_v13(conn)
current_version = 13
if current_version < 14:
_migrate_to_v14(conn)
current_version = 14
if current_version < 15:
_migrate_to_v15(conn)
current_version = 15
if current_version < 16:
_migrate_to_v16(conn)
current_version = 16
if current_version < 17:
_migrate_to_v17(conn)
current_version = 17
if current_version < 18:
_migrate_to_v18(conn)
current_version = 18
if current_version != int(target_version): if current_version != target_version:
set_current_version(conn, int(target_version)) set_current_version(conn, target_version)
def _migrate_to_v1(conn): def _migrate_to_v1(conn):
"""迁移到版本1 - 添加缺失字段""" """迁移到版本1 - 添加缺失字段"""
cursor = conn.cursor() cursor = conn.cursor()
cursor.execute("PRAGMA table_info(system_config)") system_columns = _get_table_columns(cursor, "system_config")
columns = [col[1] for col in cursor.fetchall()] _add_column_if_missing(
cursor,
"system_config",
system_columns,
"schedule_weekdays",
'TEXT DEFAULT "1,2,3,4,5,6,7"',
ok_message=" [OK] 添加 schedule_weekdays 字段",
)
_add_column_if_missing(
cursor,
"system_config",
system_columns,
"max_screenshot_concurrent",
"INTEGER DEFAULT 3",
ok_message=" [OK] 添加 max_screenshot_concurrent 字段",
)
_add_column_if_missing(
cursor,
"system_config",
system_columns,
"max_concurrent_per_account",
"INTEGER DEFAULT 1",
ok_message=" [OK] 添加 max_concurrent_per_account 字段",
)
_add_column_if_missing(
cursor,
"system_config",
system_columns,
"auto_approve_enabled",
"INTEGER DEFAULT 0",
ok_message=" [OK] 添加 auto_approve_enabled 字段",
)
_add_column_if_missing(
cursor,
"system_config",
system_columns,
"auto_approve_hourly_limit",
"INTEGER DEFAULT 10",
ok_message=" [OK] 添加 auto_approve_hourly_limit 字段",
)
_add_column_if_missing(
cursor,
"system_config",
system_columns,
"auto_approve_vip_days",
"INTEGER DEFAULT 7",
ok_message=" [OK] 添加 auto_approve_vip_days 字段",
)
if "schedule_weekdays" not in columns: task_log_columns = _get_table_columns(cursor, "task_logs")
cursor.execute('ALTER TABLE system_config ADD COLUMN schedule_weekdays TEXT DEFAULT "1,2,3,4,5,6,7"') _add_column_if_missing(
print(" [OK] 添加 schedule_weekdays 字段") cursor,
"task_logs",
if "max_screenshot_concurrent" not in columns: task_log_columns,
cursor.execute("ALTER TABLE system_config ADD COLUMN max_screenshot_concurrent INTEGER DEFAULT 3") "duration",
print(" [OK] 添加 max_screenshot_concurrent 字段") "INTEGER",
if "max_concurrent_per_account" not in columns: ok_message=" [OK] 添加 duration 字段到 task_logs",
cursor.execute("ALTER TABLE system_config ADD COLUMN max_concurrent_per_account INTEGER DEFAULT 1") )
print(" [OK] 添加 max_concurrent_per_account 字段")
if "auto_approve_enabled" not in columns:
cursor.execute("ALTER TABLE system_config ADD COLUMN auto_approve_enabled INTEGER DEFAULT 0")
print(" [OK] 添加 auto_approve_enabled 字段")
if "auto_approve_hourly_limit" not in columns:
cursor.execute("ALTER TABLE system_config ADD COLUMN auto_approve_hourly_limit INTEGER DEFAULT 10")
print(" [OK] 添加 auto_approve_hourly_limit 字段")
if "auto_approve_vip_days" not in columns:
cursor.execute("ALTER TABLE system_config ADD COLUMN auto_approve_vip_days INTEGER DEFAULT 7")
print(" [OK] 添加 auto_approve_vip_days 字段")
cursor.execute("PRAGMA table_info(task_logs)")
columns = [col[1] for col in cursor.fetchall()]
if "duration" not in columns:
cursor.execute("ALTER TABLE task_logs ADD COLUMN duration INTEGER")
print(" [OK] 添加 duration 字段到 task_logs")
conn.commit() conn.commit()
@@ -135,24 +166,39 @@ def _migrate_to_v2(conn):
"""迁移到版本2 - 添加代理配置字段""" """迁移到版本2 - 添加代理配置字段"""
cursor = conn.cursor() cursor = conn.cursor()
cursor.execute("PRAGMA table_info(system_config)") columns = _get_table_columns(cursor, "system_config")
columns = [col[1] for col in cursor.fetchall()] _add_column_if_missing(
cursor,
if "proxy_enabled" not in columns: "system_config",
cursor.execute("ALTER TABLE system_config ADD COLUMN proxy_enabled INTEGER DEFAULT 0") columns,
print(" [OK] 添加 proxy_enabled 字段") "proxy_enabled",
"INTEGER DEFAULT 0",
if "proxy_api_url" not in columns: ok_message=" [OK] 添加 proxy_enabled 字段",
cursor.execute('ALTER TABLE system_config ADD COLUMN proxy_api_url TEXT DEFAULT ""') )
print(" [OK] 添加 proxy_api_url 字段") _add_column_if_missing(
cursor,
if "proxy_expire_minutes" not in columns: "system_config",
cursor.execute("ALTER TABLE system_config ADD COLUMN proxy_expire_minutes INTEGER DEFAULT 3") columns,
print(" [OK] 添加 proxy_expire_minutes 字段") "proxy_api_url",
'TEXT DEFAULT ""',
if "enable_screenshot" not in columns: ok_message=" [OK] 添加 proxy_api_url 字段",
cursor.execute("ALTER TABLE system_config ADD COLUMN enable_screenshot INTEGER DEFAULT 1") )
print(" [OK] 添加 enable_screenshot 字段") _add_column_if_missing(
cursor,
"system_config",
columns,
"proxy_expire_minutes",
"INTEGER DEFAULT 3",
ok_message=" [OK] 添加 proxy_expire_minutes 字段",
)
_add_column_if_missing(
cursor,
"system_config",
columns,
"enable_screenshot",
"INTEGER DEFAULT 1",
ok_message=" [OK] 添加 enable_screenshot 字段",
)
conn.commit() conn.commit()
@@ -161,20 +207,31 @@ def _migrate_to_v3(conn):
"""迁移到版本3 - 添加账号状态和登录失败计数字段""" """迁移到版本3 - 添加账号状态和登录失败计数字段"""
cursor = conn.cursor() cursor = conn.cursor()
cursor.execute("PRAGMA table_info(accounts)") columns = _get_table_columns(cursor, "accounts")
columns = [col[1] for col in cursor.fetchall()] _add_column_if_missing(
cursor,
if "status" not in columns: "accounts",
cursor.execute('ALTER TABLE accounts ADD COLUMN status TEXT DEFAULT "active"') columns,
print(" [OK] 添加 accounts.status 字段 (账号状态)") "status",
'TEXT DEFAULT "active"',
if "login_fail_count" not in columns: ok_message=" [OK] 添加 accounts.status 字段 (账号状态)",
cursor.execute("ALTER TABLE accounts ADD COLUMN login_fail_count INTEGER DEFAULT 0") )
print(" [OK] 添加 accounts.login_fail_count 字段 (登录失败计数)") _add_column_if_missing(
cursor,
if "last_login_error" not in columns: "accounts",
cursor.execute("ALTER TABLE accounts ADD COLUMN last_login_error TEXT") columns,
print(" [OK] 添加 accounts.last_login_error 字段 (最后登录错误)") "login_fail_count",
"INTEGER DEFAULT 0",
ok_message=" [OK] 添加 accounts.login_fail_count 字段 (登录失败计数)",
)
_add_column_if_missing(
cursor,
"accounts",
columns,
"last_login_error",
"TEXT",
ok_message=" [OK] 添加 accounts.last_login_error 字段 (最后登录错误)",
)
conn.commit() conn.commit()
@@ -183,12 +240,15 @@ def _migrate_to_v4(conn):
"""迁移到版本4 - 添加任务来源字段""" """迁移到版本4 - 添加任务来源字段"""
cursor = conn.cursor() cursor = conn.cursor()
cursor.execute("PRAGMA table_info(task_logs)") columns = _get_table_columns(cursor, "task_logs")
columns = [col[1] for col in cursor.fetchall()] _add_column_if_missing(
cursor,
if "source" not in columns: "task_logs",
cursor.execute('ALTER TABLE task_logs ADD COLUMN source TEXT DEFAULT "manual"') columns,
print(" [OK] 添加 task_logs.source 字段 (任务来源: manual/scheduled/immediate)") "source",
'TEXT DEFAULT "manual"',
ok_message=" [OK] 添加 task_logs.source 字段 (任务来源: manual/scheduled/immediate)",
)
conn.commit() conn.commit()
@@ -300,20 +360,17 @@ def _migrate_to_v6(conn):
def _migrate_to_v7(conn): def _migrate_to_v7(conn):
"""迁移到版本7 - 统一存储北京时间将历史UTC时间字段整体+8小时""" """迁移到版本7 - 统一存储北京时间将历史UTC时间字段整体+8小时"""
cursor = conn.cursor() cursor = conn.cursor()
columns_cache: dict[str, set[str]] = {}
def table_exists(table_name: str) -> bool:
cursor.execute("SELECT name FROM sqlite_master WHERE type='table' AND name=?", (table_name,))
return cursor.fetchone() is not None
def column_exists(table_name: str, column_name: str) -> bool:
cursor.execute(f"PRAGMA table_info({table_name})")
return any(row[1] == column_name for row in cursor.fetchall())
def shift_utc_to_cst(table_name: str, column_name: str) -> None: def shift_utc_to_cst(table_name: str, column_name: str) -> None:
if not table_exists(table_name): if not _table_exists(cursor, table_name):
return return
if not column_exists(table_name, column_name):
if table_name not in columns_cache:
columns_cache[table_name] = _get_table_columns(cursor, table_name)
if column_name not in columns_cache[table_name]:
return return
cursor.execute( cursor.execute(
f""" f"""
UPDATE {table_name} UPDATE {table_name}
@@ -329,10 +386,6 @@ def _migrate_to_v7(conn):
("accounts", "created_at"), ("accounts", "created_at"),
("password_reset_requests", "created_at"), ("password_reset_requests", "created_at"),
("password_reset_requests", "processed_at"), ("password_reset_requests", "processed_at"),
]:
shift_utc_to_cst(table, col)
for table, col in [
("smtp_configs", "created_at"), ("smtp_configs", "created_at"),
("smtp_configs", "updated_at"), ("smtp_configs", "updated_at"),
("smtp_configs", "last_success_at"), ("smtp_configs", "last_success_at"),
@@ -340,10 +393,6 @@ def _migrate_to_v7(conn):
("email_tokens", "created_at"), ("email_tokens", "created_at"),
("email_logs", "created_at"), ("email_logs", "created_at"),
("email_stats", "last_updated"), ("email_stats", "last_updated"),
]:
shift_utc_to_cst(table, col)
for table, col in [
("task_checkpoints", "created_at"), ("task_checkpoints", "created_at"),
("task_checkpoints", "updated_at"), ("task_checkpoints", "updated_at"),
("task_checkpoints", "completed_at"), ("task_checkpoints", "completed_at"),
@@ -359,15 +408,23 @@ def _migrate_to_v8(conn):
cursor = conn.cursor() cursor = conn.cursor()
# 1) 增量字段random_delay旧库可能不存在 # 1) 增量字段random_delay旧库可能不存在
cursor.execute("PRAGMA table_info(user_schedules)") columns = _get_table_columns(cursor, "user_schedules")
columns = [col[1] for col in cursor.fetchall()] _add_column_if_missing(
if "random_delay" not in columns: cursor,
cursor.execute("ALTER TABLE user_schedules ADD COLUMN random_delay INTEGER DEFAULT 0") "user_schedules",
print(" [OK] 添加 user_schedules.random_delay 字段") columns,
"random_delay",
if "next_run_at" not in columns: "INTEGER DEFAULT 0",
cursor.execute("ALTER TABLE user_schedules ADD COLUMN next_run_at TIMESTAMP") ok_message=" [OK] 添加 user_schedules.random_delay 字段",
print(" [OK] 添加 user_schedules.next_run_at 字段") )
_add_column_if_missing(
cursor,
"user_schedules",
columns,
"next_run_at",
"TIMESTAMP",
ok_message=" [OK] 添加 user_schedules.next_run_at 字段",
)
cursor.execute("CREATE INDEX IF NOT EXISTS idx_user_schedules_next_run ON user_schedules(next_run_at)") cursor.execute("CREATE INDEX IF NOT EXISTS idx_user_schedules_next_run ON user_schedules(next_run_at)")
conn.commit() conn.commit()
@@ -392,12 +449,12 @@ def _migrate_to_v8(conn):
fixed = 0 fixed = 0
for row in rows: for row in rows:
try: try:
schedule_id = row["id"] if isinstance(row, sqlite3.Row) else row[0] schedule_id = _read_row_value(row, "id", 0)
schedule_time = row["schedule_time"] if isinstance(row, sqlite3.Row) else row[1] schedule_time = _read_row_value(row, "schedule_time", 1)
weekdays = row["weekdays"] if isinstance(row, sqlite3.Row) else row[2] weekdays = _read_row_value(row, "weekdays", 2)
random_delay = row["random_delay"] if isinstance(row, sqlite3.Row) else row[3] random_delay = _read_row_value(row, "random_delay", 3)
last_run_at = row["last_run_at"] if isinstance(row, sqlite3.Row) else row[4] last_run_at = _read_row_value(row, "last_run_at", 4)
next_run_at = row["next_run_at"] if isinstance(row, sqlite3.Row) else row[5] next_run_at = _read_row_value(row, "next_run_at", 5)
except Exception: except Exception:
continue continue
@@ -430,27 +487,46 @@ def _migrate_to_v9(conn):
"""迁移到版本9 - 邮件设置字段迁移(清理 email_service scattered ALTER TABLE""" """迁移到版本9 - 邮件设置字段迁移(清理 email_service scattered ALTER TABLE"""
cursor = conn.cursor() cursor = conn.cursor()
cursor.execute("SELECT name FROM sqlite_master WHERE type='table' AND name='email_settings'") if not _table_exists(cursor, "email_settings"):
if not cursor.fetchone():
# 邮件表由 email_service.init_email_tables 创建;此处仅做增量字段迁移 # 邮件表由 email_service.init_email_tables 创建;此处仅做增量字段迁移
return return
cursor.execute("PRAGMA table_info(email_settings)") columns = _get_table_columns(cursor, "email_settings")
columns = [col[1] for col in cursor.fetchall()]
changed = False changed = False
if "register_verify_enabled" not in columns: changed = (
cursor.execute("ALTER TABLE email_settings ADD COLUMN register_verify_enabled INTEGER DEFAULT 0") _add_column_if_missing(
print(" [OK] 添加 email_settings.register_verify_enabled 字段") cursor,
changed = True "email_settings",
if "base_url" not in columns: columns,
cursor.execute("ALTER TABLE email_settings ADD COLUMN base_url TEXT DEFAULT ''") "register_verify_enabled",
print(" [OK] 添加 email_settings.base_url 字段") "INTEGER DEFAULT 0",
changed = True ok_message=" [OK] 添加 email_settings.register_verify_enabled 字段",
if "task_notify_enabled" not in columns: )
cursor.execute("ALTER TABLE email_settings ADD COLUMN task_notify_enabled INTEGER DEFAULT 0") or changed
print(" [OK] 添加 email_settings.task_notify_enabled 字段") )
changed = True changed = (
_add_column_if_missing(
cursor,
"email_settings",
columns,
"base_url",
"TEXT DEFAULT ''",
ok_message=" [OK] 添加 email_settings.base_url 字段",
)
or changed
)
changed = (
_add_column_if_missing(
cursor,
"email_settings",
columns,
"task_notify_enabled",
"INTEGER DEFAULT 0",
ok_message=" [OK] 添加 email_settings.task_notify_enabled 字段",
)
or changed
)
if changed: if changed:
conn.commit() conn.commit()
@@ -459,18 +535,31 @@ def _migrate_to_v9(conn):
def _migrate_to_v10(conn): def _migrate_to_v10(conn):
"""迁移到版本10 - users 邮箱字段迁移(避免运行时 ALTER TABLE""" """迁移到版本10 - users 邮箱字段迁移(避免运行时 ALTER TABLE"""
cursor = conn.cursor() cursor = conn.cursor()
cursor.execute("PRAGMA table_info(users)") columns = _get_table_columns(cursor, "users")
columns = [col[1] for col in cursor.fetchall()]
changed = False changed = False
if "email_verified" not in columns: changed = (
cursor.execute("ALTER TABLE users ADD COLUMN email_verified INTEGER DEFAULT 0") _add_column_if_missing(
print(" [OK] 添加 users.email_verified 字段") cursor,
changed = True "users",
if "email_notify_enabled" not in columns: columns,
cursor.execute("ALTER TABLE users ADD COLUMN email_notify_enabled INTEGER DEFAULT 1") "email_verified",
print(" [OK] 添加 users.email_notify_enabled 字段") "INTEGER DEFAULT 0",
changed = True ok_message=" [OK] 添加 users.email_verified 字段",
)
or changed
)
changed = (
_add_column_if_missing(
cursor,
"users",
columns,
"email_notify_enabled",
"INTEGER DEFAULT 1",
ok_message=" [OK] 添加 users.email_notify_enabled 字段",
)
or changed
)
if changed: if changed:
conn.commit() conn.commit()
@@ -657,19 +746,24 @@ def _migrate_to_v15(conn):
"""迁移到版本15 - 邮件设置:新设备登录提醒全局开关""" """迁移到版本15 - 邮件设置:新设备登录提醒全局开关"""
cursor = conn.cursor() cursor = conn.cursor()
cursor.execute("SELECT name FROM sqlite_master WHERE type='table' AND name='email_settings'") if not _table_exists(cursor, "email_settings"):
if not cursor.fetchone():
# 邮件表由 email_service.init_email_tables 创建;此处仅做增量字段迁移 # 邮件表由 email_service.init_email_tables 创建;此处仅做增量字段迁移
return return
cursor.execute("PRAGMA table_info(email_settings)") columns = _get_table_columns(cursor, "email_settings")
columns = [col[1] for col in cursor.fetchall()]
changed = False changed = False
if "login_alert_enabled" not in columns: changed = (
cursor.execute("ALTER TABLE email_settings ADD COLUMN login_alert_enabled INTEGER DEFAULT 1") _add_column_if_missing(
print(" [OK] 添加 email_settings.login_alert_enabled 字段") cursor,
changed = True "email_settings",
columns,
"login_alert_enabled",
"INTEGER DEFAULT 1",
ok_message=" [OK] 添加 email_settings.login_alert_enabled 字段",
)
or changed
)
try: try:
cursor.execute("UPDATE email_settings SET login_alert_enabled = 1 WHERE login_alert_enabled IS NULL") cursor.execute("UPDATE email_settings SET login_alert_enabled = 1 WHERE login_alert_enabled IS NULL")
@@ -686,22 +780,24 @@ def _migrate_to_v15(conn):
def _migrate_to_v16(conn): def _migrate_to_v16(conn):
"""迁移到版本16 - 公告支持图片字段""" """迁移到版本16 - 公告支持图片字段"""
cursor = conn.cursor() cursor = conn.cursor()
cursor.execute("PRAGMA table_info(announcements)") columns = _get_table_columns(cursor, "announcements")
columns = [col[1] for col in cursor.fetchall()]
if "image_url" not in columns: if _add_column_if_missing(
cursor.execute("ALTER TABLE announcements ADD COLUMN image_url TEXT") cursor,
"announcements",
columns,
"image_url",
"TEXT",
ok_message=" [OK] 添加 announcements.image_url 字段",
):
conn.commit() conn.commit()
print(" [OK] 添加 announcements.image_url 字段")
def _migrate_to_v17(conn): def _migrate_to_v17(conn):
"""迁移到版本17 - 金山文档上传配置与用户开关""" """迁移到版本17 - 金山文档上传配置与用户开关"""
cursor = conn.cursor() cursor = conn.cursor()
cursor.execute("PRAGMA table_info(system_config)") system_columns = _get_table_columns(cursor, "system_config")
columns = [col[1] for col in cursor.fetchall()]
system_fields = [ system_fields = [
("kdocs_enabled", "INTEGER DEFAULT 0"), ("kdocs_enabled", "INTEGER DEFAULT 0"),
("kdocs_doc_url", "TEXT DEFAULT ''"), ("kdocs_doc_url", "TEXT DEFAULT ''"),
@@ -714,21 +810,29 @@ def _migrate_to_v17(conn):
("kdocs_admin_notify_email", "TEXT DEFAULT ''"), ("kdocs_admin_notify_email", "TEXT DEFAULT ''"),
] ]
for field, ddl in system_fields: for field, ddl in system_fields:
if field not in columns: _add_column_if_missing(
cursor.execute(f"ALTER TABLE system_config ADD COLUMN {field} {ddl}") cursor,
print(f" [OK] 添加 system_config.{field} 字段") "system_config",
system_columns,
cursor.execute("PRAGMA table_info(users)") field,
columns = [col[1] for col in cursor.fetchall()] ddl,
ok_message=f" [OK] 添加 system_config.{field} 字段",
)
user_columns = _get_table_columns(cursor, "users")
user_fields = [ user_fields = [
("kdocs_unit", "TEXT DEFAULT ''"), ("kdocs_unit", "TEXT DEFAULT ''"),
("kdocs_auto_upload", "INTEGER DEFAULT 0"), ("kdocs_auto_upload", "INTEGER DEFAULT 0"),
] ]
for field, ddl in user_fields: for field, ddl in user_fields:
if field not in columns: _add_column_if_missing(
cursor.execute(f"ALTER TABLE users ADD COLUMN {field} {ddl}") cursor,
print(f" [OK] 添加 users.{field} 字段") "users",
user_columns,
field,
ddl,
ok_message=f" [OK] 添加 users.{field} 字段",
)
conn.commit() conn.commit()
@@ -737,15 +841,22 @@ def _migrate_to_v18(conn):
"""迁移到版本18 - 金山文档上传:有效行范围配置""" """迁移到版本18 - 金山文档上传:有效行范围配置"""
cursor = conn.cursor() cursor = conn.cursor()
cursor.execute("PRAGMA table_info(system_config)") columns = _get_table_columns(cursor, "system_config")
columns = [col[1] for col in cursor.fetchall()] _add_column_if_missing(
cursor,
if "kdocs_row_start" not in columns: "system_config",
cursor.execute("ALTER TABLE system_config ADD COLUMN kdocs_row_start INTEGER DEFAULT 0") columns,
print(" [OK] 添加 system_config.kdocs_row_start 字段") "kdocs_row_start",
"INTEGER DEFAULT 0",
if "kdocs_row_end" not in columns: ok_message=" [OK] 添加 system_config.kdocs_row_start 字段",
cursor.execute("ALTER TABLE system_config ADD COLUMN kdocs_row_end INTEGER DEFAULT 0") )
print(" [OK] 添加 system_config.kdocs_row_end 字段") _add_column_if_missing(
cursor,
"system_config",
columns,
"kdocs_row_end",
"INTEGER DEFAULT 0",
ok_message=" [OK] 添加 system_config.kdocs_row_end 字段",
)
conn.commit() conn.commit()

View File

@@ -2,12 +2,93 @@
# -*- coding: utf-8 -*- # -*- coding: utf-8 -*-
from __future__ import annotations from __future__ import annotations
from datetime import datetime import json
from datetime import datetime, timedelta
import db_pool import db_pool
from services.schedule_utils import compute_next_run_at, format_cst from services.schedule_utils import compute_next_run_at, format_cst
from services.time_utils import get_beijing_now from services.time_utils import get_beijing_now
_SCHEDULE_DEFAULT_TIME = "08:00"
_SCHEDULE_DEFAULT_WEEKDAYS = "1,2,3,4,5"
_ALLOWED_SCHEDULE_UPDATE_FIELDS = (
"name",
"enabled",
"schedule_time",
"weekdays",
"browse_type",
"enable_screenshot",
"random_delay",
"account_ids",
)
_ALLOWED_EXEC_LOG_UPDATE_FIELDS = (
"total_accounts",
"success_accounts",
"failed_accounts",
"total_items",
"total_attachments",
"total_screenshots",
"duration_seconds",
"status",
"error_message",
)
def _normalize_limit(limit, default: int, *, minimum: int = 1) -> int:
try:
parsed = int(limit)
except Exception:
parsed = default
if parsed < minimum:
return minimum
return parsed
def _to_int(value, default: int = 0) -> int:
try:
return int(value)
except Exception:
return default
def _format_optional_datetime(dt: datetime | None) -> str | None:
if dt is None:
return None
return format_cst(dt)
def _serialize_account_ids(account_ids) -> str:
return json.dumps(account_ids) if account_ids else "[]"
def _compute_schedule_next_run_str(
*,
now_dt,
schedule_time,
weekdays,
random_delay,
last_run_at,
) -> str:
next_dt = compute_next_run_at(
now=now_dt,
schedule_time=str(schedule_time or _SCHEDULE_DEFAULT_TIME),
weekdays=str(weekdays or _SCHEDULE_DEFAULT_WEEKDAYS),
random_delay=_to_int(random_delay, 0),
last_run_at=str(last_run_at or "") if last_run_at else None,
)
return format_cst(next_dt)
def _map_schedule_log_row(row) -> dict:
log = dict(row)
log["created_at"] = log.get("execute_time")
log["success_count"] = log.get("success_accounts", 0)
log["failed_count"] = log.get("failed_accounts", 0)
log["duration"] = log.get("duration_seconds", 0)
return log
def get_user_schedules(user_id): def get_user_schedules(user_id):
"""获取用户的所有定时任务""" """获取用户的所有定时任务"""
@@ -44,14 +125,10 @@ def create_user_schedule(
account_ids=None, account_ids=None,
): ):
"""创建用户定时任务""" """创建用户定时任务"""
import json
with db_pool.get_db() as conn: with db_pool.get_db() as conn:
cursor = conn.cursor() cursor = conn.cursor()
cst_time = format_cst(get_beijing_now()) cst_time = format_cst(get_beijing_now())
account_ids_str = json.dumps(account_ids) if account_ids else "[]"
cursor.execute( cursor.execute(
""" """
INSERT INTO user_schedules ( INSERT INTO user_schedules (
@@ -66,8 +143,8 @@ def create_user_schedule(
weekdays, weekdays,
browse_type, browse_type,
enable_screenshot, enable_screenshot,
int(random_delay or 0), _to_int(random_delay, 0),
account_ids_str, _serialize_account_ids(account_ids),
cst_time, cst_time,
cst_time, cst_time,
), ),
@@ -79,28 +156,11 @@ def create_user_schedule(
def update_user_schedule(schedule_id, **kwargs): def update_user_schedule(schedule_id, **kwargs):
"""更新用户定时任务""" """更新用户定时任务"""
import json
with db_pool.get_db() as conn: with db_pool.get_db() as conn:
cursor = conn.cursor() cursor = conn.cursor()
now_dt = get_beijing_now() now_dt = get_beijing_now()
now_str = format_cst(now_dt) now_str = format_cst(now_dt)
updates = []
params = []
allowed_fields = [
"name",
"enabled",
"schedule_time",
"weekdays",
"browse_type",
"enable_screenshot",
"random_delay",
"account_ids",
]
# 读取旧值,用于决定是否需要重算 next_run_at
cursor.execute( cursor.execute(
""" """
SELECT enabled, schedule_time, weekdays, random_delay, last_run_at SELECT enabled, schedule_time, weekdays, random_delay, last_run_at
@@ -112,10 +172,11 @@ def update_user_schedule(schedule_id, **kwargs):
current = cursor.fetchone() current = cursor.fetchone()
if not current: if not current:
return False return False
current_enabled = int(current[0] or 0)
current_enabled = _to_int(current[0], 0)
current_time = current[1] current_time = current[1]
current_weekdays = current[2] current_weekdays = current[2]
current_random_delay = int(current[3] or 0) current_random_delay = _to_int(current[3], 0)
current_last_run_at = current[4] current_last_run_at = current[4]
will_enabled = current_enabled will_enabled = current_enabled
@@ -123,21 +184,28 @@ def update_user_schedule(schedule_id, **kwargs):
next_weekdays = current_weekdays next_weekdays = current_weekdays
next_random_delay = current_random_delay next_random_delay = current_random_delay
for field in allowed_fields: updates = []
if field in kwargs: params = []
value = kwargs[field]
if field == "account_ids" and isinstance(value, list): for field in _ALLOWED_SCHEDULE_UPDATE_FIELDS:
value = json.dumps(value) if field not in kwargs:
if field == "enabled": continue
will_enabled = 1 if value else 0
if field == "schedule_time": value = kwargs[field]
next_time = value if field == "account_ids" and isinstance(value, list):
if field == "weekdays": value = json.dumps(value)
next_weekdays = value
if field == "random_delay": if field == "enabled":
next_random_delay = int(value or 0) will_enabled = 1 if value else 0
updates.append(f"{field} = ?") if field == "schedule_time":
params.append(value) next_time = value
if field == "weekdays":
next_weekdays = value
if field == "random_delay":
next_random_delay = int(value or 0)
updates.append(f"{field} = ?")
params.append(value)
if not updates: if not updates:
return False return False
@@ -145,30 +213,26 @@ def update_user_schedule(schedule_id, **kwargs):
updates.append("updated_at = ?") updates.append("updated_at = ?")
params.append(now_str) params.append(now_str)
# 关键字段变更后重算 next_run_at确保索引驱动不会跑偏 config_changed = any(key in kwargs for key in ("schedule_time", "weekdays", "random_delay"))
#
# 需求:当用户修改“执行时间/执行日期/随机±15分钟”后即使今天已经执行过也允许按新配置在今天再次触发。
# 做法:这些关键字段发生变更时,重算 next_run_at 时忽略 last_run_at 的“同日仅一次”限制。
config_changed = any(key in kwargs for key in ["schedule_time", "weekdays", "random_delay"])
enabled_toggled = "enabled" in kwargs enabled_toggled = "enabled" in kwargs
should_recompute_next = config_changed or (enabled_toggled and will_enabled == 1) should_recompute_next = config_changed or (enabled_toggled and will_enabled == 1)
if should_recompute_next: if should_recompute_next:
next_dt = compute_next_run_at( next_run_at = _compute_schedule_next_run_str(
now=now_dt, now_dt=now_dt,
schedule_time=str(next_time or "08:00"), schedule_time=next_time,
weekdays=str(next_weekdays or "1,2,3,4,5"), weekdays=next_weekdays,
random_delay=int(next_random_delay or 0), random_delay=next_random_delay,
last_run_at=None if config_changed else (str(current_last_run_at or "") if current_last_run_at else None), last_run_at=None if config_changed else current_last_run_at,
) )
updates.append("next_run_at = ?") updates.append("next_run_at = ?")
params.append(format_cst(next_dt)) params.append(next_run_at)
# 若本次显式禁用任务,则 next_run_at 清空(与 toggle 行为保持一致)
if enabled_toggled and will_enabled == 0: if enabled_toggled and will_enabled == 0:
updates.append("next_run_at = ?") updates.append("next_run_at = ?")
params.append(None) params.append(None)
params.append(schedule_id)
params.append(schedule_id)
sql = f"UPDATE user_schedules SET {', '.join(updates)} WHERE id = ?" sql = f"UPDATE user_schedules SET {', '.join(updates)} WHERE id = ?"
cursor.execute(sql, params) cursor.execute(sql, params)
conn.commit() conn.commit()
@@ -203,28 +267,19 @@ def toggle_user_schedule(schedule_id, enabled):
) )
row = cursor.fetchone() row = cursor.fetchone()
if row: if row:
schedule_time, weekdays, random_delay, last_run_at, existing_next_run_at = ( schedule_time, weekdays, random_delay, last_run_at, existing_next_run_at = row
row[0],
row[1],
row[2],
row[3],
row[4],
)
existing_next_run_at = str(existing_next_run_at or "").strip() or None existing_next_run_at = str(existing_next_run_at or "").strip() or None
# 若 next_run_at 已经被“修改配置”逻辑预先计算好且仍在未来,则优先沿用,
# 避免 last_run_at 的“同日仅一次”限制阻塞用户把任务调整到今天再次触发。
if existing_next_run_at and existing_next_run_at > now_str: if existing_next_run_at and existing_next_run_at > now_str:
next_run_at = existing_next_run_at next_run_at = existing_next_run_at
else: else:
next_dt = compute_next_run_at( next_run_at = _compute_schedule_next_run_str(
now=now_dt, now_dt=now_dt,
schedule_time=str(schedule_time or "08:00"), schedule_time=schedule_time,
weekdays=str(weekdays or "1,2,3,4,5"), weekdays=weekdays,
random_delay=int(random_delay or 0), random_delay=random_delay,
last_run_at=str(last_run_at or "") if last_run_at else None, last_run_at=last_run_at,
) )
next_run_at = format_cst(next_dt)
cursor.execute( cursor.execute(
""" """
@@ -272,16 +327,15 @@ def update_schedule_last_run(schedule_id):
row = cursor.fetchone() row = cursor.fetchone()
if not row: if not row:
return False return False
schedule_time, weekdays, random_delay = row[0], row[1], row[2]
next_dt = compute_next_run_at( schedule_time, weekdays, random_delay = row
now=now_dt, next_run_at = _compute_schedule_next_run_str(
schedule_time=str(schedule_time or "08:00"), now_dt=now_dt,
weekdays=str(weekdays or "1,2,3,4,5"), schedule_time=schedule_time,
random_delay=int(random_delay or 0), weekdays=weekdays,
random_delay=random_delay,
last_run_at=now_str, last_run_at=now_str,
) )
next_run_at = format_cst(next_dt)
cursor.execute( cursor.execute(
""" """
@@ -305,7 +359,11 @@ def update_schedule_next_run(schedule_id: int, next_run_at: str) -> bool:
SET next_run_at = ?, updated_at = ? SET next_run_at = ?, updated_at = ?
WHERE id = ? WHERE id = ?
""", """,
(str(next_run_at or "").strip() or None, format_cst(get_beijing_now()), int(schedule_id)), (
str(next_run_at or "").strip() or None,
format_cst(get_beijing_now()),
int(schedule_id),
),
) )
conn.commit() conn.commit()
return cursor.rowcount > 0 return cursor.rowcount > 0
@@ -328,15 +386,15 @@ def recompute_schedule_next_run(schedule_id: int, *, now_dt=None) -> bool:
if not row: if not row:
return False return False
schedule_time, weekdays, random_delay, last_run_at = row[0], row[1], row[2], row[3] schedule_time, weekdays, random_delay, last_run_at = row
next_dt = compute_next_run_at( next_run_at = _compute_schedule_next_run_str(
now=now_dt, now_dt=now_dt,
schedule_time=str(schedule_time or "08:00"), schedule_time=schedule_time,
weekdays=str(weekdays or "1,2,3,4,5"), weekdays=weekdays,
random_delay=int(random_delay or 0), random_delay=random_delay,
last_run_at=str(last_run_at or "") if last_run_at else None, last_run_at=last_run_at,
) )
return update_schedule_next_run(int(schedule_id), format_cst(next_dt)) return update_schedule_next_run(int(schedule_id), next_run_at)
def get_due_user_schedules(now_cst: str, limit: int = 50): def get_due_user_schedules(now_cst: str, limit: int = 50):
@@ -345,6 +403,8 @@ def get_due_user_schedules(now_cst: str, limit: int = 50):
if not now_cst: if not now_cst:
now_cst = format_cst(get_beijing_now()) now_cst = format_cst(get_beijing_now())
safe_limit = _normalize_limit(limit, 50, minimum=1)
with db_pool.get_db() as conn: with db_pool.get_db() as conn:
cursor = conn.cursor() cursor = conn.cursor()
cursor.execute( cursor.execute(
@@ -358,7 +418,7 @@ def get_due_user_schedules(now_cst: str, limit: int = 50):
ORDER BY us.next_run_at ASC ORDER BY us.next_run_at ASC
LIMIT ? LIMIT ?
""", """,
(now_cst, int(limit)), (now_cst, safe_limit),
) )
return [dict(row) for row in cursor.fetchall()] return [dict(row) for row in cursor.fetchall()]
@@ -370,15 +430,13 @@ def create_schedule_execution_log(schedule_id, user_id, schedule_name):
"""创建定时任务执行日志""" """创建定时任务执行日志"""
with db_pool.get_db() as conn: with db_pool.get_db() as conn:
cursor = conn.cursor() cursor = conn.cursor()
execute_time = format_cst(get_beijing_now())
cursor.execute( cursor.execute(
""" """
INSERT INTO schedule_execution_logs ( INSERT INTO schedule_execution_logs (
schedule_id, user_id, schedule_name, execute_time, status schedule_id, user_id, schedule_name, execute_time, status
) VALUES (?, ?, ?, ?, 'running') ) VALUES (?, ?, ?, ?, 'running')
""", """,
(schedule_id, user_id, schedule_name, execute_time), (schedule_id, user_id, schedule_name, format_cst(get_beijing_now())),
) )
conn.commit() conn.commit()
@@ -393,22 +451,11 @@ def update_schedule_execution_log(log_id, **kwargs):
updates = [] updates = []
params = [] params = []
allowed_fields = [ for field in _ALLOWED_EXEC_LOG_UPDATE_FIELDS:
"total_accounts", if field not in kwargs:
"success_accounts", continue
"failed_accounts", updates.append(f"{field} = ?")
"total_items", params.append(kwargs[field])
"total_attachments",
"total_screenshots",
"duration_seconds",
"status",
"error_message",
]
for field in allowed_fields:
if field in kwargs:
updates.append(f"{field} = ?")
params.append(kwargs[field])
if not updates: if not updates:
return False return False
@@ -424,6 +471,7 @@ def update_schedule_execution_log(log_id, **kwargs):
def get_schedule_execution_logs(schedule_id, limit=10): def get_schedule_execution_logs(schedule_id, limit=10):
"""获取定时任务执行日志""" """获取定时任务执行日志"""
try: try:
safe_limit = _normalize_limit(limit, 10, minimum=1)
with db_pool.get_db() as conn: with db_pool.get_db() as conn:
cursor = conn.cursor() cursor = conn.cursor()
cursor.execute( cursor.execute(
@@ -433,24 +481,16 @@ def get_schedule_execution_logs(schedule_id, limit=10):
ORDER BY execute_time DESC ORDER BY execute_time DESC
LIMIT ? LIMIT ?
""", """,
(schedule_id, limit), (schedule_id, safe_limit),
) )
logs = [] logs = []
rows = cursor.fetchall() for row in cursor.fetchall():
for row in rows:
try: try:
log = dict(row) logs.append(_map_schedule_log_row(row))
log["created_at"] = log.get("execute_time")
log["success_count"] = log.get("success_accounts", 0)
log["failed_count"] = log.get("failed_accounts", 0)
log["duration"] = log.get("duration_seconds", 0)
logs.append(log)
except Exception as e: except Exception as e:
print(f"[数据库] 处理日志行时出错: {e}") print(f"[数据库] 处理日志行时出错: {e}")
continue continue
return logs return logs
except Exception as e: except Exception as e:
print(f"[数据库] 查询定时任务日志时出错: {e}") print(f"[数据库] 查询定时任务日志时出错: {e}")
@@ -462,6 +502,7 @@ def get_schedule_execution_logs(schedule_id, limit=10):
def get_user_all_schedule_logs(user_id, limit=50): def get_user_all_schedule_logs(user_id, limit=50):
"""获取用户所有定时任务的执行日志""" """获取用户所有定时任务的执行日志"""
safe_limit = _normalize_limit(limit, 50, minimum=1)
with db_pool.get_db() as conn: with db_pool.get_db() as conn:
cursor = conn.cursor() cursor = conn.cursor()
cursor.execute( cursor.execute(
@@ -471,7 +512,7 @@ def get_user_all_schedule_logs(user_id, limit=50):
ORDER BY execute_time DESC ORDER BY execute_time DESC
LIMIT ? LIMIT ?
""", """,
(user_id, limit), (user_id, safe_limit),
) )
return [dict(row) for row in cursor.fetchall()] return [dict(row) for row in cursor.fetchall()]
@@ -493,14 +534,21 @@ def delete_schedule_logs(schedule_id, user_id):
def clean_old_schedule_logs(days=30): def clean_old_schedule_logs(days=30):
"""清理指定天数前的定时任务执行日志""" """清理指定天数前的定时任务执行日志"""
safe_days = _to_int(days, 30)
if safe_days < 0:
safe_days = 0
cutoff_dt = get_beijing_now() - timedelta(days=safe_days)
cutoff_str = format_cst(cutoff_dt)
with db_pool.get_db() as conn: with db_pool.get_db() as conn:
cursor = conn.cursor() cursor = conn.cursor()
cursor.execute( cursor.execute(
""" """
DELETE FROM schedule_execution_logs DELETE FROM schedule_execution_logs
WHERE execute_time < datetime('now', 'localtime', '-' || ? || ' days') WHERE execute_time < ?
""", """,
(days,), (cutoff_str,),
) )
conn.commit() conn.commit()
return cursor.rowcount return cursor.rowcount

View File

@@ -362,6 +362,8 @@ def ensure_schema(conn) -> None:
cursor.execute("CREATE INDEX IF NOT EXISTS idx_users_username ON users(username)") cursor.execute("CREATE INDEX IF NOT EXISTS idx_users_username ON users(username)")
cursor.execute("CREATE INDEX IF NOT EXISTS idx_users_status ON users(status)") cursor.execute("CREATE INDEX IF NOT EXISTS idx_users_status ON users(status)")
cursor.execute("CREATE INDEX IF NOT EXISTS idx_users_vip_expire ON users(vip_expire_time)") cursor.execute("CREATE INDEX IF NOT EXISTS idx_users_vip_expire ON users(vip_expire_time)")
cursor.execute("CREATE INDEX IF NOT EXISTS idx_users_email ON users(email)")
cursor.execute("CREATE INDEX IF NOT EXISTS idx_users_created_at ON users(created_at)")
cursor.execute("CREATE INDEX IF NOT EXISTS idx_login_fingerprints_user ON login_fingerprints(user_id)") cursor.execute("CREATE INDEX IF NOT EXISTS idx_login_fingerprints_user ON login_fingerprints(user_id)")
cursor.execute("CREATE INDEX IF NOT EXISTS idx_login_ips_user ON login_ips(user_id)") cursor.execute("CREATE INDEX IF NOT EXISTS idx_login_ips_user ON login_ips(user_id)")
@@ -391,6 +393,8 @@ def ensure_schema(conn) -> None:
cursor.execute("CREATE INDEX IF NOT EXISTS idx_task_logs_user_id ON task_logs(user_id)") cursor.execute("CREATE INDEX IF NOT EXISTS idx_task_logs_user_id ON task_logs(user_id)")
cursor.execute("CREATE INDEX IF NOT EXISTS idx_task_logs_status ON task_logs(status)") cursor.execute("CREATE INDEX IF NOT EXISTS idx_task_logs_status ON task_logs(status)")
cursor.execute("CREATE INDEX IF NOT EXISTS idx_task_logs_created_at ON task_logs(created_at)") cursor.execute("CREATE INDEX IF NOT EXISTS idx_task_logs_created_at ON task_logs(created_at)")
cursor.execute("CREATE INDEX IF NOT EXISTS idx_task_logs_source ON task_logs(source)")
cursor.execute("CREATE INDEX IF NOT EXISTS idx_task_logs_source_created_at ON task_logs(source, created_at)")
cursor.execute("CREATE INDEX IF NOT EXISTS idx_task_logs_user_date ON task_logs(user_id, created_at)") cursor.execute("CREATE INDEX IF NOT EXISTS idx_task_logs_user_date ON task_logs(user_id, created_at)")
cursor.execute("CREATE INDEX IF NOT EXISTS idx_bug_feedbacks_user_id ON bug_feedbacks(user_id)") cursor.execute("CREATE INDEX IF NOT EXISTS idx_bug_feedbacks_user_id ON bug_feedbacks(user_id)")
@@ -409,6 +413,9 @@ def ensure_schema(conn) -> None:
cursor.execute("CREATE INDEX IF NOT EXISTS idx_schedule_execution_logs_schedule_id ON schedule_execution_logs(schedule_id)") cursor.execute("CREATE INDEX IF NOT EXISTS idx_schedule_execution_logs_schedule_id ON schedule_execution_logs(schedule_id)")
cursor.execute("CREATE INDEX IF NOT EXISTS idx_schedule_execution_logs_user_id ON schedule_execution_logs(user_id)") cursor.execute("CREATE INDEX IF NOT EXISTS idx_schedule_execution_logs_user_id ON schedule_execution_logs(user_id)")
cursor.execute("CREATE INDEX IF NOT EXISTS idx_schedule_execution_logs_status ON schedule_execution_logs(status)") cursor.execute("CREATE INDEX IF NOT EXISTS idx_schedule_execution_logs_status ON schedule_execution_logs(status)")
cursor.execute("CREATE INDEX IF NOT EXISTS idx_schedule_execution_logs_execute_time ON schedule_execution_logs(execute_time)")
cursor.execute("CREATE INDEX IF NOT EXISTS idx_schedule_execution_logs_schedule_time ON schedule_execution_logs(schedule_id, execute_time)")
cursor.execute("CREATE INDEX IF NOT EXISTS idx_schedule_execution_logs_user_time ON schedule_execution_logs(user_id, execute_time)")
# 初始化VIP配置幂等 # 初始化VIP配置幂等
try: try:

View File

@@ -3,13 +3,82 @@
from __future__ import annotations from __future__ import annotations
from datetime import timedelta from datetime import timedelta
from typing import Any, Optional from typing import Any, Dict, Optional
from typing import Dict
import db_pool import db_pool
from db.utils import get_cst_now, get_cst_now_str from db.utils import get_cst_now, get_cst_now_str
_THREAT_EVENT_SELECT_COLUMNS = """
id,
threat_type,
score,
rule,
field_name,
matched,
value_preview,
ip,
user_id,
request_method,
request_path,
user_agent,
created_at
"""
def _normalize_page(page: int) -> int:
try:
page_i = int(page)
except Exception:
page_i = 1
return max(1, page_i)
def _normalize_per_page(per_page: int, default: int = 20) -> int:
try:
value = int(per_page)
except Exception:
value = default
return max(1, min(200, value))
def _normalize_limit(limit: int, default: int = 50) -> int:
try:
value = int(limit)
except Exception:
value = default
return max(1, min(200, value))
def _row_value(row, key: str, index: int = 0, default=None):
if row is None:
return default
try:
return row[key]
except Exception:
try:
return row[index]
except Exception:
return default
def _fetch_threat_events_history(where_clause: str, params: tuple[Any, ...], limit_i: int) -> list[dict]:
with db_pool.get_db() as conn:
cursor = conn.cursor()
cursor.execute(
f"""
SELECT
{_THREAT_EVENT_SELECT_COLUMNS}
FROM threat_events
WHERE {where_clause}
ORDER BY created_at DESC, id DESC
LIMIT ?
""",
tuple(params) + (limit_i,),
)
return [dict(r) for r in cursor.fetchall()]
def record_login_context(user_id: int, ip_address: str, user_agent: str) -> Dict[str, bool]: def record_login_context(user_id: int, ip_address: str, user_agent: str) -> Dict[str, bool]:
"""记录登录环境信息,返回是否新设备/新IP。""" """记录登录环境信息,返回是否新设备/新IP。"""
user_id = int(user_id) user_id = int(user_id)
@@ -36,7 +105,7 @@ def record_login_context(user_id: int, ip_address: str, user_agent: str) -> Dict
SET last_seen = ?, last_ip = ? SET last_seen = ?, last_ip = ?
WHERE id = ? WHERE id = ?
""", """,
(now_str, ip_text, row["id"] if isinstance(row, dict) else row[0]), (now_str, ip_text, _row_value(row, "id", 0)),
) )
else: else:
cursor.execute( cursor.execute(
@@ -61,7 +130,7 @@ def record_login_context(user_id: int, ip_address: str, user_agent: str) -> Dict
SET last_seen = ? SET last_seen = ?
WHERE id = ? WHERE id = ?
""", """,
(now_str, row["id"] if isinstance(row, dict) else row[0]), (now_str, _row_value(row, "id", 0)),
) )
else: else:
cursor.execute( cursor.execute(
@@ -166,15 +235,8 @@ def _build_threat_events_where_clause(filters: Optional[dict]) -> tuple[str, lis
def get_threat_events_list(page: int, per_page: int, filters: Optional[dict] = None) -> dict: def get_threat_events_list(page: int, per_page: int, filters: Optional[dict] = None) -> dict:
"""分页获取威胁事件。""" """分页获取威胁事件。"""
try: page_i = _normalize_page(page)
page_i = max(1, int(page)) per_page_i = _normalize_per_page(per_page, default=20)
except Exception:
page_i = 1
try:
per_page_i = int(per_page)
except Exception:
per_page_i = 20
per_page_i = max(1, min(200, per_page_i))
where_sql, params = _build_threat_events_where_clause(filters) where_sql, params = _build_threat_events_where_clause(filters)
offset = (page_i - 1) * per_page_i offset = (page_i - 1) * per_page_i
@@ -188,19 +250,7 @@ def get_threat_events_list(page: int, per_page: int, filters: Optional[dict] = N
cursor.execute( cursor.execute(
f""" f"""
SELECT SELECT
id, {_THREAT_EVENT_SELECT_COLUMNS}
threat_type,
score,
rule,
field_name,
matched,
value_preview,
ip,
user_id,
request_method,
request_path,
user_agent,
created_at
FROM threat_events FROM threat_events
{where_sql} {where_sql}
ORDER BY created_at DESC, id DESC ORDER BY created_at DESC, id DESC
@@ -218,75 +268,20 @@ def get_ip_threat_history(ip: str, limit: int = 50) -> list[dict]:
ip_text = str(ip or "").strip()[:64] ip_text = str(ip or "").strip()[:64]
if not ip_text: if not ip_text:
return [] return []
try:
limit_i = max(1, min(200, int(limit)))
except Exception:
limit_i = 50
with db_pool.get_db() as conn: limit_i = _normalize_limit(limit, default=50)
cursor = conn.cursor() return _fetch_threat_events_history("ip = ?", (ip_text,), limit_i)
cursor.execute(
"""
SELECT
id,
threat_type,
score,
rule,
field_name,
matched,
value_preview,
ip,
user_id,
request_method,
request_path,
user_agent,
created_at
FROM threat_events
WHERE ip = ?
ORDER BY created_at DESC, id DESC
LIMIT ?
""",
(ip_text, limit_i),
)
return [dict(r) for r in cursor.fetchall()]
def get_user_threat_history(user_id: int, limit: int = 50) -> list[dict]: def get_user_threat_history(user_id: int, limit: int = 50) -> list[dict]:
"""获取用户的威胁历史最近limit条""" """获取用户的威胁历史最近limit条"""
if user_id is None: if user_id is None:
return [] return []
try: try:
user_id_int = int(user_id) user_id_int = int(user_id)
except Exception: except Exception:
return [] return []
try:
limit_i = max(1, min(200, int(limit)))
except Exception:
limit_i = 50
with db_pool.get_db() as conn: limit_i = _normalize_limit(limit, default=50)
cursor = conn.cursor() return _fetch_threat_events_history("user_id = ?", (user_id_int,), limit_i)
cursor.execute(
"""
SELECT
id,
threat_type,
score,
rule,
field_name,
matched,
value_preview,
ip,
user_id,
request_method,
request_path,
user_agent,
created_at
FROM threat_events
WHERE user_id = ?
ORDER BY created_at DESC, id DESC
LIMIT ?
""",
(user_id_int, limit_i),
)
return [dict(r) for r in cursor.fetchall()]

View File

@@ -2,12 +2,135 @@
# -*- coding: utf-8 -*- # -*- coding: utf-8 -*-
from __future__ import annotations from __future__ import annotations
from datetime import datetime from datetime import datetime, timedelta
import pytz
import db_pool import db_pool
from db.utils import sanitize_sql_like_pattern from db.utils import get_cst_now, get_cst_now_str, sanitize_sql_like_pattern
_TASK_STATS_SELECT_SQL = """
SELECT
COUNT(*) as total_tasks,
SUM(CASE WHEN status = 'success' THEN 1 ELSE 0 END) as success_tasks,
SUM(CASE WHEN status = 'failed' THEN 1 ELSE 0 END) as failed_tasks,
SUM(total_items) as total_items,
SUM(total_attachments) as total_attachments
FROM task_logs
"""
_USER_RUN_STATS_SELECT_SQL = """
SELECT
SUM(CASE WHEN status = 'success' THEN 1 ELSE 0 END) as completed,
SUM(CASE WHEN status = 'failed' THEN 1 ELSE 0 END) as failed,
SUM(total_items) as total_items,
SUM(total_attachments) as total_attachments
FROM task_logs
"""
def _build_day_bounds(date_filter: str) -> tuple[str | None, str | None]:
"""将 YYYY-MM-DD 转换为 [day_start, day_end) 区间。"""
try:
day_start = datetime.strptime(str(date_filter), "%Y-%m-%d")
except Exception:
return None, None
day_end = day_start + timedelta(days=1)
return day_start.strftime("%Y-%m-%d %H:%M:%S"), day_end.strftime("%Y-%m-%d %H:%M:%S")
def _normalize_int(value, default: int, *, minimum: int | None = None) -> int:
try:
parsed = int(value)
except Exception:
parsed = default
if minimum is not None and parsed < minimum:
return minimum
return parsed
def _stat_value(row, key: str) -> int:
try:
value = row[key] if row else 0
except Exception:
value = 0
return int(value or 0)
def _build_task_logs_where_sql(
*,
date_filter=None,
status_filter=None,
source_filter=None,
user_id_filter=None,
account_filter=None,
) -> tuple[str, list]:
where_clauses = ["1=1"]
params = []
if date_filter:
day_start, day_end = _build_day_bounds(date_filter)
if day_start and day_end:
where_clauses.append("tl.created_at >= ? AND tl.created_at < ?")
params.extend([day_start, day_end])
else:
where_clauses.append("date(tl.created_at) = ?")
params.append(date_filter)
if status_filter:
where_clauses.append("tl.status = ?")
params.append(status_filter)
if source_filter:
source_filter = str(source_filter or "").strip()
if source_filter == "user_scheduled":
where_clauses.append("tl.source LIKE ? ESCAPE '\\\\'")
params.append("user_scheduled:%")
elif source_filter.endswith("*"):
prefix = source_filter[:-1]
safe_prefix = sanitize_sql_like_pattern(prefix)
where_clauses.append("tl.source LIKE ? ESCAPE '\\\\'")
params.append(f"{safe_prefix}%")
else:
where_clauses.append("tl.source = ?")
params.append(source_filter)
if user_id_filter:
where_clauses.append("tl.user_id = ?")
params.append(user_id_filter)
if account_filter:
safe_filter = sanitize_sql_like_pattern(account_filter)
where_clauses.append("tl.username LIKE ? ESCAPE '\\\\'")
params.append(f"%{safe_filter}%")
return " AND ".join(where_clauses), params
def _fetch_task_stats_row(cursor, *, where_clause: str = "", params: tuple | list = ()) -> dict:
sql = _TASK_STATS_SELECT_SQL
if where_clause:
sql = f"{sql}\nWHERE {where_clause}"
cursor.execute(sql, params)
row = cursor.fetchone()
return {
"total_tasks": _stat_value(row, "total_tasks"),
"success_tasks": _stat_value(row, "success_tasks"),
"failed_tasks": _stat_value(row, "failed_tasks"),
"total_items": _stat_value(row, "total_items"),
"total_attachments": _stat_value(row, "total_attachments"),
}
def _fetch_user_run_stats_row(cursor, *, where_clause: str, params: tuple | list) -> dict:
sql = f"{_USER_RUN_STATS_SELECT_SQL}\nWHERE {where_clause}"
cursor.execute(sql, params)
row = cursor.fetchone()
return {
"completed": _stat_value(row, "completed"),
"failed": _stat_value(row, "failed"),
"total_items": _stat_value(row, "total_items"),
"total_attachments": _stat_value(row, "total_attachments"),
}
def create_task_log( def create_task_log(
@@ -25,8 +148,6 @@ def create_task_log(
"""创建任务日志记录""" """创建任务日志记录"""
with db_pool.get_db() as conn: with db_pool.get_db() as conn:
cursor = conn.cursor() cursor = conn.cursor()
cst_tz = pytz.timezone("Asia/Shanghai")
cst_time = datetime.now(cst_tz).strftime("%Y-%m-%d %H:%M:%S")
cursor.execute( cursor.execute(
""" """
@@ -45,7 +166,7 @@ def create_task_log(
total_attachments, total_attachments,
error_message, error_message,
duration, duration,
cst_time, get_cst_now_str(),
source, source,
), ),
) )
@@ -64,54 +185,27 @@ def get_task_logs(
account_filter=None, account_filter=None,
): ):
"""获取任务日志列表(支持分页和多种筛选)""" """获取任务日志列表(支持分页和多种筛选)"""
limit = _normalize_int(limit, 100, minimum=1)
offset = _normalize_int(offset, 0, minimum=0)
with db_pool.get_db() as conn: with db_pool.get_db() as conn:
cursor = conn.cursor() cursor = conn.cursor()
where_clauses = ["1=1"] where_sql, params = _build_task_logs_where_sql(
params = [] date_filter=date_filter,
status_filter=status_filter,
if date_filter: source_filter=source_filter,
where_clauses.append("date(tl.created_at) = ?") user_id_filter=user_id_filter,
params.append(date_filter) account_filter=account_filter,
)
if status_filter:
where_clauses.append("tl.status = ?")
params.append(status_filter)
if source_filter:
source_filter = str(source_filter or "").strip()
# 兼容“虚拟来源”:用于筛选 user_scheduled:batch_xxx 这类动态值
if source_filter == "user_scheduled":
where_clauses.append("tl.source LIKE ? ESCAPE '\\\\'")
params.append("user_scheduled:%")
elif source_filter.endswith("*"):
prefix = source_filter[:-1]
safe_prefix = sanitize_sql_like_pattern(prefix)
where_clauses.append("tl.source LIKE ? ESCAPE '\\\\'")
params.append(f"{safe_prefix}%")
else:
where_clauses.append("tl.source = ?")
params.append(source_filter)
if user_id_filter:
where_clauses.append("tl.user_id = ?")
params.append(user_id_filter)
if account_filter:
safe_filter = sanitize_sql_like_pattern(account_filter)
where_clauses.append("tl.username LIKE ? ESCAPE '\\\\'")
params.append(f"%{safe_filter}%")
where_sql = " AND ".join(where_clauses)
count_sql = f""" count_sql = f"""
SELECT COUNT(*) as total SELECT COUNT(*) as total
FROM task_logs tl FROM task_logs tl
LEFT JOIN users u ON tl.user_id = u.id
WHERE {where_sql} WHERE {where_sql}
""" """
cursor.execute(count_sql, params) cursor.execute(count_sql, params)
total = cursor.fetchone()["total"] total = _stat_value(cursor.fetchone(), "total")
data_sql = f""" data_sql = f"""
SELECT SELECT
@@ -123,9 +217,10 @@ def get_task_logs(
ORDER BY tl.created_at DESC ORDER BY tl.created_at DESC
LIMIT ? OFFSET ? LIMIT ? OFFSET ?
""" """
params.extend([limit, offset]) data_params = list(params)
data_params.extend([limit, offset])
cursor.execute(data_sql, params) cursor.execute(data_sql, data_params)
logs = [dict(row) for row in cursor.fetchall()] logs = [dict(row) for row in cursor.fetchall()]
return {"logs": logs, "total": total} return {"logs": logs, "total": total}
@@ -133,61 +228,39 @@ def get_task_logs(
def get_task_stats(date_filter=None): def get_task_stats(date_filter=None):
"""获取任务统计信息""" """获取任务统计信息"""
if date_filter is None:
date_filter = get_cst_now().strftime("%Y-%m-%d")
day_start, day_end = _build_day_bounds(date_filter)
with db_pool.get_db() as conn: with db_pool.get_db() as conn:
cursor = conn.cursor() cursor = conn.cursor()
cst_tz = pytz.timezone("Asia/Shanghai")
if date_filter is None: if day_start and day_end:
date_filter = datetime.now(cst_tz).strftime("%Y-%m-%d") today_stats = _fetch_task_stats_row(
cursor,
where_clause="created_at >= ? AND created_at < ?",
params=(day_start, day_end),
)
else:
today_stats = _fetch_task_stats_row(
cursor,
where_clause="date(created_at) = ?",
params=(date_filter,),
)
cursor.execute( total_stats = _fetch_task_stats_row(cursor)
"""
SELECT
COUNT(*) as total_tasks,
SUM(CASE WHEN status = 'success' THEN 1 ELSE 0 END) as success_tasks,
SUM(CASE WHEN status = 'failed' THEN 1 ELSE 0 END) as failed_tasks,
SUM(total_items) as total_items,
SUM(total_attachments) as total_attachments
FROM task_logs
WHERE date(created_at) = ?
""",
(date_filter,),
)
today_stats = cursor.fetchone()
cursor.execute( return {"today": today_stats, "total": total_stats}
"""
SELECT
COUNT(*) as total_tasks,
SUM(CASE WHEN status = 'success' THEN 1 ELSE 0 END) as success_tasks,
SUM(CASE WHEN status = 'failed' THEN 1 ELSE 0 END) as failed_tasks,
SUM(total_items) as total_items,
SUM(total_attachments) as total_attachments
FROM task_logs
"""
)
total_stats = cursor.fetchone()
return {
"today": {
"total_tasks": today_stats["total_tasks"] or 0,
"success_tasks": today_stats["success_tasks"] or 0,
"failed_tasks": today_stats["failed_tasks"] or 0,
"total_items": today_stats["total_items"] or 0,
"total_attachments": today_stats["total_attachments"] or 0,
},
"total": {
"total_tasks": total_stats["total_tasks"] or 0,
"success_tasks": total_stats["success_tasks"] or 0,
"failed_tasks": total_stats["failed_tasks"] or 0,
"total_items": total_stats["total_items"] or 0,
"total_attachments": total_stats["total_attachments"] or 0,
},
}
def delete_old_task_logs(days=30, batch_size=1000): def delete_old_task_logs(days=30, batch_size=1000):
"""删除N天前的任务日志分批删除避免长时间锁表""" """删除N天前的任务日志分批删除避免长时间锁表"""
days = _normalize_int(days, 30, minimum=0)
batch_size = _normalize_int(batch_size, 1000, minimum=1)
cutoff = (get_cst_now() - timedelta(days=days)).strftime("%Y-%m-%d %H:%M:%S")
total_deleted = 0 total_deleted = 0
while True: while True:
with db_pool.get_db() as conn: with db_pool.get_db() as conn:
@@ -197,16 +270,16 @@ def delete_old_task_logs(days=30, batch_size=1000):
DELETE FROM task_logs DELETE FROM task_logs
WHERE rowid IN ( WHERE rowid IN (
SELECT rowid FROM task_logs SELECT rowid FROM task_logs
WHERE created_at < datetime('now', 'localtime', '-' || ? || ' days') WHERE created_at < ?
LIMIT ? LIMIT ?
) )
""", """,
(days, batch_size), (cutoff, batch_size),
) )
deleted = cursor.rowcount deleted = cursor.rowcount
conn.commit() conn.commit()
if deleted == 0: if deleted <= 0:
break break
total_deleted += deleted total_deleted += deleted
@@ -215,31 +288,23 @@ def delete_old_task_logs(days=30, batch_size=1000):
def get_user_run_stats(user_id, date_filter=None): def get_user_run_stats(user_id, date_filter=None):
"""获取用户的运行统计信息""" """获取用户的运行统计信息"""
if date_filter is None:
date_filter = get_cst_now().strftime("%Y-%m-%d")
day_start, day_end = _build_day_bounds(date_filter)
with db_pool.get_db() as conn: with db_pool.get_db() as conn:
cst_tz = pytz.timezone("Asia/Shanghai")
cursor = conn.cursor() cursor = conn.cursor()
if date_filter is None: if day_start and day_end:
date_filter = datetime.now(cst_tz).strftime("%Y-%m-%d") return _fetch_user_run_stats_row(
cursor,
where_clause="user_id = ? AND created_at >= ? AND created_at < ?",
params=(user_id, day_start, day_end),
)
cursor.execute( return _fetch_user_run_stats_row(
""" cursor,
SELECT where_clause="user_id = ? AND date(created_at) = ?",
SUM(CASE WHEN status = 'success' THEN 1 ELSE 0 END) as completed, params=(user_id, date_filter),
SUM(CASE WHEN status = 'failed' THEN 1 ELSE 0 END) as failed,
SUM(total_items) as total_items,
SUM(total_attachments) as total_attachments
FROM task_logs
WHERE user_id = ? AND date(created_at) = ?
""",
(user_id, date_filter),
) )
stats = cursor.fetchone()
return {
"completed": stats["completed"] or 0,
"failed": stats["failed"] or 0,
"total_items": stats["total_items"] or 0,
"total_attachments": stats["total_attachments"] or 0,
}

View File

@@ -16,8 +16,41 @@ from password_utils import (
verify_password_bcrypt, verify_password_bcrypt,
verify_password_sha256, verify_password_sha256,
) )
logger = get_logger(__name__) logger = get_logger(__name__)
_CST_TZ = pytz.timezone("Asia/Shanghai")
_PERMANENT_VIP_EXPIRE = "2099-12-31 23:59:59"
def _row_to_dict(row):
return dict(row) if row else None
def _get_user_by_field(field_name: str, field_value):
with db_pool.get_db() as conn:
cursor = conn.cursor()
cursor.execute(f"SELECT * FROM users WHERE {field_name} = ?", (field_value,))
return _row_to_dict(cursor.fetchone())
def _parse_cst_datetime(datetime_str: str | None):
if not datetime_str:
return None
try:
naive_dt = datetime.strptime(str(datetime_str), "%Y-%m-%d %H:%M:%S")
return _CST_TZ.localize(naive_dt)
except Exception:
return None
def _format_vip_expire(days: int, *, base_dt: datetime | None = None) -> str:
if int(days) == 999999:
return _PERMANENT_VIP_EXPIRE
if base_dt is None:
base_dt = datetime.now(_CST_TZ)
return (base_dt + timedelta(days=int(days))).strftime("%Y-%m-%d %H:%M:%S")
def get_vip_config(): def get_vip_config():
"""获取VIP配置""" """获取VIP配置"""
@@ -32,13 +65,12 @@ def set_default_vip_days(days):
"""设置默认VIP天数""" """设置默认VIP天数"""
with db_pool.get_db() as conn: with db_pool.get_db() as conn:
cursor = conn.cursor() cursor = conn.cursor()
cst_time = get_cst_now_str()
cursor.execute( cursor.execute(
""" """
INSERT OR REPLACE INTO vip_config (id, default_vip_days, updated_at) INSERT OR REPLACE INTO vip_config (id, default_vip_days, updated_at)
VALUES (1, ?, ?) VALUES (1, ?, ?)
""", """,
(days, cst_time), (days, get_cst_now_str()),
) )
conn.commit() conn.commit()
return True return True
@@ -47,14 +79,8 @@ def set_default_vip_days(days):
def set_user_vip(user_id, days): def set_user_vip(user_id, days):
"""设置用户VIP - days: 7=一周, 30=一个月, 365=一年, 999999=永久""" """设置用户VIP - days: 7=一周, 30=一个月, 365=一年, 999999=永久"""
with db_pool.get_db() as conn: with db_pool.get_db() as conn:
cst_tz = pytz.timezone("Asia/Shanghai")
cursor = conn.cursor() cursor = conn.cursor()
expire_time = _format_vip_expire(days)
if days == 999999:
expire_time = "2099-12-31 23:59:59"
else:
expire_time = (datetime.now(cst_tz) + timedelta(days=days)).strftime("%Y-%m-%d %H:%M:%S")
cursor.execute("UPDATE users SET vip_expire_time = ? WHERE id = ?", (expire_time, user_id)) cursor.execute("UPDATE users SET vip_expire_time = ? WHERE id = ?", (expire_time, user_id))
conn.commit() conn.commit()
return cursor.rowcount > 0 return cursor.rowcount > 0
@@ -63,29 +89,26 @@ def set_user_vip(user_id, days):
def extend_user_vip(user_id, days): def extend_user_vip(user_id, days):
"""延长用户VIP时间""" """延长用户VIP时间"""
user = get_user_by_id(user_id) user = get_user_by_id(user_id)
cst_tz = pytz.timezone("Asia/Shanghai")
if not user: if not user:
return False return False
current_expire = user.get("vip_expire_time")
now_dt = datetime.now(_CST_TZ)
if current_expire and current_expire != _PERMANENT_VIP_EXPIRE:
expire_time = _parse_cst_datetime(current_expire)
if expire_time is not None:
if expire_time < now_dt:
expire_time = now_dt
new_expire = _format_vip_expire(days, base_dt=expire_time)
else:
logger.warning("解析VIP过期时间失败使用当前时间")
new_expire = _format_vip_expire(days, base_dt=now_dt)
else:
new_expire = _format_vip_expire(days, base_dt=now_dt)
with db_pool.get_db() as conn: with db_pool.get_db() as conn:
cursor = conn.cursor() cursor = conn.cursor()
current_expire = user.get("vip_expire_time")
if current_expire and current_expire != "2099-12-31 23:59:59":
try:
expire_time_naive = datetime.strptime(current_expire, "%Y-%m-%d %H:%M:%S")
expire_time = cst_tz.localize(expire_time_naive)
now = datetime.now(cst_tz)
if expire_time < now:
expire_time = now
new_expire = (expire_time + timedelta(days=days)).strftime("%Y-%m-%d %H:%M:%S")
except (ValueError, AttributeError) as e:
logger.warning(f"解析VIP过期时间失败: {e}, 使用当前时间")
new_expire = (datetime.now(cst_tz) + timedelta(days=days)).strftime("%Y-%m-%d %H:%M:%S")
else:
new_expire = (datetime.now(cst_tz) + timedelta(days=days)).strftime("%Y-%m-%d %H:%M:%S")
cursor.execute("UPDATE users SET vip_expire_time = ? WHERE id = ?", (new_expire, user_id)) cursor.execute("UPDATE users SET vip_expire_time = ? WHERE id = ?", (new_expire, user_id))
conn.commit() conn.commit()
return cursor.rowcount > 0 return cursor.rowcount > 0
@@ -105,45 +128,49 @@ def is_user_vip(user_id):
注意数据库中存储的时间统一使用CSTAsia/Shanghai时区 注意数据库中存储的时间统一使用CSTAsia/Shanghai时区
""" """
cst_tz = pytz.timezone("Asia/Shanghai")
user = get_user_by_id(user_id) user = get_user_by_id(user_id)
if not user:
if not user or not user.get("vip_expire_time"):
return False return False
try: vip_expire_time = user.get("vip_expire_time")
expire_time_naive = datetime.strptime(user["vip_expire_time"], "%Y-%m-%d %H:%M:%S") if not vip_expire_time:
expire_time = cst_tz.localize(expire_time_naive)
now = datetime.now(cst_tz)
return now < expire_time
except (ValueError, AttributeError) as e:
logger.warning(f"检查VIP状态失败 (user_id={user_id}): {e}")
return False return False
expire_time = _parse_cst_datetime(vip_expire_time)
if expire_time is None:
logger.warning(f"检查VIP状态失败 (user_id={user_id}): 无法解析时间")
return False
return datetime.now(_CST_TZ) < expire_time
def get_user_vip_info(user_id): def get_user_vip_info(user_id):
"""获取用户VIP信息""" """获取用户VIP信息"""
cst_tz = pytz.timezone("Asia/Shanghai")
user = get_user_by_id(user_id) user = get_user_by_id(user_id)
if not user: if not user:
return {"is_vip": False, "expire_time": None, "days_left": 0, "username": ""} return {"is_vip": False, "expire_time": None, "days_left": 0, "username": ""}
vip_expire_time = user.get("vip_expire_time") vip_expire_time = user.get("vip_expire_time")
username = user.get("username", "")
if not vip_expire_time: if not vip_expire_time:
return {"is_vip": False, "expire_time": None, "days_left": 0, "username": user.get("username", "")} return {"is_vip": False, "expire_time": None, "days_left": 0, "username": username}
try: expire_time = _parse_cst_datetime(vip_expire_time)
expire_time_naive = datetime.strptime(vip_expire_time, "%Y-%m-%d %H:%M:%S") if expire_time is None:
expire_time = cst_tz.localize(expire_time_naive) logger.warning("VIP信息获取错误: 无法解析过期时间")
now = datetime.now(cst_tz) return {"is_vip": False, "expire_time": None, "days_left": 0, "username": username}
is_vip = now < expire_time
days_left = (expire_time - now).days if is_vip else 0
return {"username": user.get("username", ""), "is_vip": is_vip, "expire_time": vip_expire_time, "days_left": max(0, days_left)} now_dt = datetime.now(_CST_TZ)
except Exception as e: is_vip = now_dt < expire_time
logger.warning(f"VIP信息获取错误: {e}") days_left = (expire_time - now_dt).days if is_vip else 0
return {"is_vip": False, "expire_time": None, "days_left": 0, "username": user.get("username", "")}
return {
"username": username,
"is_vip": is_vip,
"expire_time": vip_expire_time,
"days_left": max(0, days_left),
}
# ==================== 用户相关 ==================== # ==================== 用户相关 ====================
@@ -151,8 +178,6 @@ def get_user_vip_info(user_id):
def create_user(username, password, email=""): def create_user(username, password, email=""):
"""创建新用户(默认直接通过,赠送默认VIP)""" """创建新用户(默认直接通过,赠送默认VIP)"""
cst_tz = pytz.timezone("Asia/Shanghai")
with db_pool.get_db() as conn: with db_pool.get_db() as conn:
cursor = conn.cursor() cursor = conn.cursor()
password_hash = hash_password_bcrypt(password) password_hash = hash_password_bcrypt(password)
@@ -160,12 +185,8 @@ def create_user(username, password, email=""):
default_vip_days = get_vip_config()["default_vip_days"] default_vip_days = get_vip_config()["default_vip_days"]
vip_expire_time = None vip_expire_time = None
if int(default_vip_days or 0) > 0:
if default_vip_days > 0: vip_expire_time = _format_vip_expire(int(default_vip_days))
if default_vip_days == 999999:
vip_expire_time = "2099-12-31 23:59:59"
else:
vip_expire_time = (datetime.now(cst_tz) + timedelta(days=default_vip_days)).strftime("%Y-%m-%d %H:%M:%S")
try: try:
cursor.execute( cursor.execute(
@@ -210,28 +231,28 @@ def verify_user(username, password):
def get_user_by_id(user_id): def get_user_by_id(user_id):
"""根据ID获取用户""" """根据ID获取用户"""
with db_pool.get_db() as conn: return _get_user_by_field("id", user_id)
cursor = conn.cursor()
cursor.execute("SELECT * FROM users WHERE id = ?", (user_id,))
user = cursor.fetchone()
return dict(user) if user else None
def get_user_kdocs_settings(user_id): def get_user_kdocs_settings(user_id):
"""获取用户的金山文档配置""" """获取用户的金山文档配置"""
user = get_user_by_id(user_id) with db_pool.get_db() as conn:
if not user: cursor = conn.cursor()
return None cursor.execute("SELECT kdocs_unit, kdocs_auto_upload FROM users WHERE id = ?", (user_id,))
return { row = cursor.fetchone()
"kdocs_unit": user.get("kdocs_unit") or "", if not row:
"kdocs_auto_upload": 1 if user.get("kdocs_auto_upload") else 0, return None
} return {
"kdocs_unit": (row["kdocs_unit"] or "") if isinstance(row, dict) else (row[0] or ""),
"kdocs_auto_upload": 1 if ((row["kdocs_auto_upload"] if isinstance(row, dict) else row[1]) or 0) else 0,
}
def update_user_kdocs_settings(user_id, *, kdocs_unit=None, kdocs_auto_upload=None) -> bool: def update_user_kdocs_settings(user_id, *, kdocs_unit=None, kdocs_auto_upload=None) -> bool:
"""更新用户的金山文档配置""" """更新用户的金山文档配置"""
updates = [] updates = []
params = [] params = []
if kdocs_unit is not None: if kdocs_unit is not None:
updates.append("kdocs_unit = ?") updates.append("kdocs_unit = ?")
params.append(kdocs_unit) params.append(kdocs_unit)
@@ -252,11 +273,7 @@ def update_user_kdocs_settings(user_id, *, kdocs_unit=None, kdocs_auto_upload=No
def get_user_by_username(username): def get_user_by_username(username):
"""根据用户名获取用户""" """根据用户名获取用户"""
with db_pool.get_db() as conn: return _get_user_by_field("username", username)
cursor = conn.cursor()
cursor.execute("SELECT * FROM users WHERE username = ?", (username,))
user = cursor.fetchone()
return dict(user) if user else None
def get_all_users(): def get_all_users():
@@ -279,14 +296,13 @@ def approve_user(user_id):
"""审核通过用户""" """审核通过用户"""
with db_pool.get_db() as conn: with db_pool.get_db() as conn:
cursor = conn.cursor() cursor = conn.cursor()
cst_time = get_cst_now_str()
cursor.execute( cursor.execute(
""" """
UPDATE users UPDATE users
SET status = 'approved', approved_at = ? SET status = 'approved', approved_at = ?
WHERE id = ? WHERE id = ?
""", """,
(cst_time, user_id), (get_cst_now_str(), user_id),
) )
conn.commit() conn.commit()
return cursor.rowcount > 0 return cursor.rowcount > 0
@@ -315,5 +331,5 @@ def get_user_stats(user_id):
with db_pool.get_db() as conn: with db_pool.get_db() as conn:
cursor = conn.cursor() cursor = conn.cursor()
cursor.execute("SELECT COUNT(*) as count FROM accounts WHERE user_id = ?", (user_id,)) cursor.execute("SELECT COUNT(*) as count FROM accounts WHERE user_id = ?", (user_id,))
account_count = cursor.fetchone()["count"] row = cursor.fetchone()
return {"account_count": account_count} return {"account_count": int((row["count"] if row else 0) or 0)}

View File

@@ -7,8 +7,12 @@
import sqlite3 import sqlite3
import threading import threading
from queue import Queue, Empty from queue import Empty, Full, Queue
import time
from app_logger import get_logger
logger = get_logger("database")
class ConnectionPool: class ConnectionPool:
@@ -44,12 +48,55 @@ class ConnectionPool:
"""创建新的数据库连接""" """创建新的数据库连接"""
conn = sqlite3.connect(self.database, check_same_thread=False) conn = sqlite3.connect(self.database, check_same_thread=False)
conn.row_factory = sqlite3.Row conn.row_factory = sqlite3.Row
# 启用外键约束,确保 ON DELETE CASCADE 等约束生效
conn.execute("PRAGMA foreign_keys=ON")
# 设置WAL模式提高并发性能 # 设置WAL模式提高并发性能
conn.execute("PRAGMA journal_mode=WAL") conn.execute("PRAGMA journal_mode=WAL")
# 在WAL模式下使用NORMAL同步兼顾性能与可靠性
conn.execute("PRAGMA synchronous=NORMAL")
# 设置合理的超时时间 # 设置合理的超时时间
conn.execute("PRAGMA busy_timeout=5000") conn.execute("PRAGMA busy_timeout=5000")
return conn return conn
def _close_connection(self, conn) -> None:
if conn is None:
return
try:
conn.close()
except Exception as e:
logger.warning(f"关闭连接失败: {e}")
def _is_connection_healthy(self, conn) -> bool:
if conn is None:
return False
try:
conn.rollback()
conn.execute("SELECT 1")
return True
except sqlite3.Error as e:
logger.warning(f"连接健康检查失败(数据库错误): {e}")
except Exception as e:
logger.warning(f"连接健康检查失败(未知错误): {e}")
return False
def _replenish_pool_if_needed(self) -> None:
with self._lock:
if self._pool.qsize() >= self.pool_size:
return
new_conn = None
try:
new_conn = self._create_connection()
self._pool.put(new_conn, block=False)
self._created_connections += 1
except Full:
if new_conn:
self._close_connection(new_conn)
except Exception as e:
if new_conn:
self._close_connection(new_conn)
logger.warning(f"重建连接失败: {e}")
def get_connection(self): def get_connection(self):
""" """
从连接池获取连接 从连接池获取连接
@@ -70,66 +117,20 @@ class ConnectionPool:
Args: Args:
conn: 要归还的连接 conn: 要归还的连接
""" """
import sqlite3
from queue import Full
if conn is None: if conn is None:
return return
connection_healthy = False if self._is_connection_healthy(conn):
try:
# 回滚任何未提交的事务
conn.rollback()
# 安全修复:验证连接是否健康,防止损坏的连接污染连接池
conn.execute("SELECT 1")
connection_healthy = True
except sqlite3.Error as e:
# 数据库相关错误,连接可能损坏
print(f"连接健康检查失败(数据库错误): {e}")
except Exception as e:
print(f"连接健康检查失败(未知错误): {e}")
if connection_healthy:
try: try:
self._pool.put(conn, block=False) self._pool.put(conn, block=False)
return # 成功归还 return
except Full: except Full:
# 队列已满(不应该发生,但处理它) logger.warning("连接池已满,关闭多余连接")
print(f"警告: 连接池已满,关闭多余连接") self._close_connection(conn)
connection_healthy = False # 标记为需要关闭 return
# 连接不健康或队列已满,关闭它 self._close_connection(conn)
try: self._replenish_pool_if_needed()
conn.close()
except Exception as close_error:
print(f"关闭连接失败: {close_error}")
# 如果连接不健康,尝试创建新连接补充池
if not connection_healthy:
with self._lock:
# 双重检查:确保池确实需要补充
if self._pool.qsize() < self.pool_size:
new_conn = None
try:
new_conn = self._create_connection()
self._pool.put(new_conn, block=False)
# 只有成功放入池后才增加计数
self._created_connections += 1
except Full:
# 在获取锁期间池被填满了,关闭新建的连接
if new_conn:
try:
new_conn.close()
except Exception:
pass
except Exception as create_error:
# 创建连接失败,确保关闭已创建的连接
if new_conn:
try:
new_conn.close()
except Exception:
pass
print(f"重建连接失败: {create_error}")
def close_all(self): def close_all(self):
"""关闭所有连接""" """关闭所有连接"""
@@ -138,7 +139,7 @@ class ConnectionPool:
conn = self._pool.get(block=False) conn = self._pool.get(block=False)
conn.close() conn.close()
except Exception as e: except Exception as e:
print(f"关闭连接失败: {e}") logger.warning(f"关闭连接失败: {e}")
def get_stats(self): def get_stats(self):
"""获取连接池统计信息""" """获取连接池统计信息"""
@@ -175,14 +176,14 @@ class PooledConnection:
if exc_type is not None: if exc_type is not None:
# 发生异常,回滚事务 # 发生异常,回滚事务
self._conn.rollback() self._conn.rollback()
print(f"数据库事务已回滚: {exc_type.__name__}") logger.warning(f"数据库事务已回滚: {exc_type.__name__}")
# 注意: 不自动commit要求用户显式调用conn.commit() # 注意: 不自动commit要求用户显式调用conn.commit()
if self._cursor: if self._cursor:
self._cursor.close() self._cursor.close()
self._cursor = None self._cursor = None
except Exception as e: except Exception as e:
print(f"关闭游标失败: {e}") logger.warning(f"关闭游标失败: {e}")
finally: finally:
# 归还连接 # 归还连接
self._pool.return_connection(self._conn) self._pool.return_connection(self._conn)
@@ -254,7 +255,7 @@ def init_pool(database, pool_size=5):
with _pool_lock: with _pool_lock:
if _pool is None: if _pool is None:
_pool = ConnectionPool(database, pool_size) _pool = ConnectionPool(database, pool_size)
print(f"[OK] 数据库连接池已初始化 (大小: {pool_size})") logger.info(f"[OK] 数据库连接池已初始化 (大小: {pool_size})")
def get_db(): def get_db():

File diff suppressed because it is too large Load Diff

View File

@@ -2,6 +2,7 @@
# -*- coding: utf-8 -*- # -*- coding: utf-8 -*-
from __future__ import annotations from __future__ import annotations
import json
import os import os
import time import time
@@ -9,8 +10,40 @@ from services.runtime import get_logger, get_socketio
from services.state import safe_get_account, safe_iter_task_status_items from services.state import safe_get_account, safe_iter_task_status_items
def _to_int(value, default: int = 0) -> int:
try:
return int(value)
except Exception:
return int(default)
def _payload_signature(payload: dict) -> str:
try:
return json.dumps(payload, ensure_ascii=False, sort_keys=True, separators=(",", ":"), default=str)
except Exception:
return repr(payload)
def _should_emit(
*,
last_sig: str | None,
last_ts: float,
new_sig: str,
now_ts: float,
min_interval: float,
force_interval: float,
) -> bool:
if last_sig is None:
return True
if (now_ts - last_ts) >= force_interval:
return True
if new_sig != last_sig and (now_ts - last_ts) >= min_interval:
return True
return False
def status_push_worker() -> None: def status_push_worker() -> None:
"""后台线程:按间隔推送排队/运行中任务状态更新(可节流)。""" """后台线程:按间隔推送排队/运行中任务状态(变更驱动+心跳兜底)。"""
logger = get_logger() logger = get_logger()
try: try:
push_interval = float(os.environ.get("STATUS_PUSH_INTERVAL_SECONDS", "1")) push_interval = float(os.environ.get("STATUS_PUSH_INTERVAL_SECONDS", "1"))
@@ -18,18 +51,41 @@ def status_push_worker() -> None:
push_interval = 1.0 push_interval = 1.0
push_interval = max(0.5, push_interval) push_interval = max(0.5, push_interval)
try:
queue_min_interval = float(os.environ.get("STATUS_PUSH_MIN_QUEUE_INTERVAL_SECONDS", str(push_interval)))
except Exception:
queue_min_interval = push_interval
queue_min_interval = max(push_interval, queue_min_interval)
try:
progress_min_interval = float(
os.environ.get("STATUS_PUSH_MIN_PROGRESS_INTERVAL_SECONDS", str(push_interval))
)
except Exception:
progress_min_interval = push_interval
progress_min_interval = max(push_interval, progress_min_interval)
try:
force_interval = float(os.environ.get("STATUS_PUSH_FORCE_INTERVAL_SECONDS", "10"))
except Exception:
force_interval = 10.0
force_interval = max(push_interval, force_interval)
socketio = get_socketio() socketio = get_socketio()
from services.tasks import get_task_scheduler from services.tasks import get_task_scheduler
scheduler = get_task_scheduler() scheduler = get_task_scheduler()
emitted_state: dict[str, dict] = {}
while True: while True:
try: try:
now_ts = time.time()
queue_snapshot = scheduler.get_queue_state_snapshot() queue_snapshot = scheduler.get_queue_state_snapshot()
pending_total = int(queue_snapshot.get("pending_total", 0) or 0) pending_total = int(queue_snapshot.get("pending_total", 0) or 0)
running_total = int(queue_snapshot.get("running_total", 0) or 0) running_total = int(queue_snapshot.get("running_total", 0) or 0)
running_by_user = queue_snapshot.get("running_by_user") or {} running_by_user = queue_snapshot.get("running_by_user") or {}
positions = queue_snapshot.get("positions") or {} positions = queue_snapshot.get("positions") or {}
active_account_ids = set()
status_items = safe_iter_task_status_items() status_items = safe_iter_task_status_items()
for account_id, status_info in status_items: for account_id, status_info in status_items:
@@ -39,11 +95,15 @@ def status_push_worker() -> None:
user_id = status_info.get("user_id") user_id = status_info.get("user_id")
if not user_id: if not user_id:
continue continue
active_account_ids.add(str(account_id))
account = safe_get_account(user_id, account_id) account = safe_get_account(user_id, account_id)
if not account: if not account:
continue continue
user_id_int = _to_int(user_id)
account_data = account.to_dict() account_data = account.to_dict()
pos = positions.get(account_id) or {} pos = positions.get(account_id) or positions.get(str(account_id)) or {}
account_data.update( account_data.update(
{ {
"queue_pending_total": pending_total, "queue_pending_total": pending_total,
@@ -51,10 +111,23 @@ def status_push_worker() -> None:
"queue_ahead": pos.get("queue_ahead"), "queue_ahead": pos.get("queue_ahead"),
"queue_position": pos.get("queue_position"), "queue_position": pos.get("queue_position"),
"queue_is_vip": pos.get("is_vip"), "queue_is_vip": pos.get("is_vip"),
"queue_running_user": int(running_by_user.get(int(user_id), 0) or 0), "queue_running_user": _to_int(running_by_user.get(user_id_int, running_by_user.get(str(user_id_int), 0))),
} }
) )
socketio.emit("account_update", account_data, room=f"user_{user_id}")
cache_entry = emitted_state.setdefault(str(account_id), {})
account_sig = _payload_signature(account_data)
if _should_emit(
last_sig=cache_entry.get("account_sig"),
last_ts=float(cache_entry.get("account_ts", 0) or 0),
new_sig=account_sig,
now_ts=now_ts,
min_interval=queue_min_interval,
force_interval=force_interval,
):
socketio.emit("account_update", account_data, room=f"user_{user_id}")
cache_entry["account_sig"] = account_sig
cache_entry["account_ts"] = now_ts
if status != "运行中": if status != "运行中":
continue continue
@@ -74,9 +147,26 @@ def status_push_worker() -> None:
"queue_running_total": running_total, "queue_running_total": running_total,
"queue_ahead": pos.get("queue_ahead"), "queue_ahead": pos.get("queue_ahead"),
"queue_position": pos.get("queue_position"), "queue_position": pos.get("queue_position"),
"queue_running_user": int(running_by_user.get(int(user_id), 0) or 0), "queue_running_user": _to_int(running_by_user.get(user_id_int, running_by_user.get(str(user_id_int), 0))),
} }
socketio.emit("task_progress", progress_data, room=f"user_{user_id}")
progress_sig = _payload_signature(progress_data)
if _should_emit(
last_sig=cache_entry.get("progress_sig"),
last_ts=float(cache_entry.get("progress_ts", 0) or 0),
new_sig=progress_sig,
now_ts=now_ts,
min_interval=progress_min_interval,
force_interval=force_interval,
):
socketio.emit("task_progress", progress_data, room=f"user_{user_id}")
cache_entry["progress_sig"] = progress_sig
cache_entry["progress_ts"] = now_ts
if emitted_state:
stale_ids = [account_id for account_id in emitted_state.keys() if account_id not in active_account_ids]
for account_id in stale_ids:
emitted_state.pop(account_id, None)
time.sleep(push_interval) time.sleep(push_interval)
except Exception as e: except Exception as e:

View File

@@ -8,6 +8,15 @@ admin_api_bp = Blueprint("admin_api", __name__, url_prefix="/yuyx/api")
# Import side effects: register routes on blueprint # Import side effects: register routes on blueprint
from routes.admin_api import core as _core # noqa: F401 from routes.admin_api import core as _core # noqa: F401
from routes.admin_api import system_config_api as _system_config_api # noqa: F401
from routes.admin_api import operations_api as _operations_api # noqa: F401
from routes.admin_api import announcements_api as _announcements_api # noqa: F401
from routes.admin_api import users_api as _users_api # noqa: F401
from routes.admin_api import account_api as _account_api # noqa: F401
from routes.admin_api import feedback_api as _feedback_api # noqa: F401
from routes.admin_api import infra_api as _infra_api # noqa: F401
from routes.admin_api import tasks_api as _tasks_api # noqa: F401
from routes.admin_api import email_api as _email_api # noqa: F401
# Export security blueprint for app registration # Export security blueprint for app registration
from routes.admin_api.security import security_bp # noqa: F401 from routes.admin_api.security import security_bp # noqa: F401

View File

@@ -0,0 +1,63 @@
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
from __future__ import annotations
import database
from app_security import validate_password
from flask import jsonify, request, session
from routes.admin_api import admin_api_bp
from routes.decorators import admin_required
# ==================== 密码重置 / 反馈(管理员) ====================
@admin_api_bp.route("/admin/password", methods=["PUT"])
@admin_required
def update_admin_password():
"""修改管理员密码"""
data = request.json or {}
new_password = (data.get("new_password") or "").strip()
if not new_password:
return jsonify({"error": "密码不能为空"}), 400
username = session.get("admin_username")
if database.update_admin_password(username, new_password):
return jsonify({"success": True})
return jsonify({"error": "修改失败"}), 400
@admin_api_bp.route("/admin/username", methods=["PUT"])
@admin_required
def update_admin_username():
"""修改管理员用户名"""
data = request.json or {}
new_username = (data.get("new_username") or "").strip()
if not new_username:
return jsonify({"error": "用户名不能为空"}), 400
old_username = session.get("admin_username")
if database.update_admin_username(old_username, new_username):
session["admin_username"] = new_username
return jsonify({"success": True})
return jsonify({"error": "修改失败,用户名可能已存在"}), 400
@admin_api_bp.route("/users/<int:user_id>/reset_password", methods=["POST"])
@admin_required
def admin_reset_password_route(user_id):
"""管理员直接重置用户密码(无需审核)"""
data = request.json or {}
new_password = (data.get("new_password") or "").strip()
if not new_password:
return jsonify({"error": "新密码不能为空"}), 400
is_valid, error_msg = validate_password(new_password)
if not is_valid:
return jsonify({"error": error_msg}), 400
if database.admin_reset_user_password(user_id, new_password):
return jsonify({"message": "密码重置成功"})
return jsonify({"error": "重置失败,用户不存在"}), 400

View File

@@ -0,0 +1,144 @@
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
from __future__ import annotations
import os
import posixpath
import secrets
import time
import database
from app_config import get_config
from app_logger import get_logger
from app_security import is_safe_path, sanitize_filename
from flask import current_app, jsonify, request, url_for
from routes.admin_api import admin_api_bp
from routes.decorators import admin_required
logger = get_logger("app")
config = get_config()
def _get_upload_dir():
rel_dir = getattr(config, "ANNOUNCEMENT_IMAGE_DIR", "static/announcements")
if not is_safe_path(current_app.root_path, rel_dir):
rel_dir = "static/announcements"
abs_dir = os.path.join(current_app.root_path, rel_dir)
os.makedirs(abs_dir, exist_ok=True)
return abs_dir, rel_dir
def _get_file_size(file_storage):
try:
file_storage.stream.seek(0, os.SEEK_END)
size = file_storage.stream.tell()
file_storage.stream.seek(0)
return size
except Exception:
return None
# ==================== 公告管理API管理员 ====================
@admin_api_bp.route("/announcements/upload_image", methods=["POST"])
@admin_required
def admin_upload_announcement_image():
"""上传公告图片返回可访问URL"""
file = request.files.get("file")
if not file or not file.filename:
return jsonify({"error": "请选择图片"}), 400
filename = sanitize_filename(file.filename)
ext = os.path.splitext(filename)[1].lower()
allowed_exts = getattr(config, "ALLOWED_ANNOUNCEMENT_IMAGE_EXTENSIONS", {".png", ".jpg", ".jpeg"})
if not ext or ext not in allowed_exts:
return jsonify({"error": "不支持的图片格式"}), 400
if file.mimetype and not str(file.mimetype).startswith("image/"):
return jsonify({"error": "文件类型无效"}), 400
size = _get_file_size(file)
max_size = int(getattr(config, "MAX_ANNOUNCEMENT_IMAGE_SIZE", 5 * 1024 * 1024))
if size is not None and size > max_size:
max_mb = max_size // 1024 // 1024
return jsonify({"error": f"图片大小不能超过{max_mb}MB"}), 400
abs_dir, rel_dir = _get_upload_dir()
token = secrets.token_hex(6)
name = f"announcement_{int(time.time())}_{token}{ext}"
save_path = os.path.join(abs_dir, name)
file.save(save_path)
static_root = os.path.join(current_app.root_path, "static")
rel_to_static = os.path.relpath(abs_dir, static_root)
if rel_to_static.startswith(".."):
rel_to_static = "announcements"
url_path = posixpath.join(rel_to_static.replace(os.sep, "/"), name)
return jsonify({"success": True, "url": url_for("serve_static", filename=url_path)})
@admin_api_bp.route("/announcements", methods=["GET"])
@admin_required
def admin_get_announcements():
"""获取公告列表"""
try:
limit = int(request.args.get("limit", 50))
offset = int(request.args.get("offset", 0))
except (TypeError, ValueError):
limit, offset = 50, 0
limit = max(1, min(200, limit))
offset = max(0, offset)
return jsonify(database.get_announcements(limit=limit, offset=offset))
@admin_api_bp.route("/announcements", methods=["POST"])
@admin_required
def admin_create_announcement():
"""创建公告(默认启用并替换旧公告)"""
data = request.json or {}
title = (data.get("title") or "").strip()
content = (data.get("content") or "").strip()
image_url = (data.get("image_url") or "").strip()
is_active = bool(data.get("is_active", True))
if image_url and len(image_url) > 1000:
return jsonify({"error": "图片地址过长"}), 400
announcement_id = database.create_announcement(title, content, image_url=image_url, is_active=is_active)
if not announcement_id:
return jsonify({"error": "标题和内容不能为空"}), 400
return jsonify({"success": True, "id": announcement_id})
@admin_api_bp.route("/announcements/<int:announcement_id>/activate", methods=["POST"])
@admin_required
def admin_activate_announcement(announcement_id):
"""启用公告(会自动停用其他公告)"""
if not database.get_announcement_by_id(announcement_id):
return jsonify({"error": "公告不存在"}), 404
ok = database.set_announcement_active(announcement_id, True)
return jsonify({"success": ok})
@admin_api_bp.route("/announcements/<int:announcement_id>/deactivate", methods=["POST"])
@admin_required
def admin_deactivate_announcement(announcement_id):
"""停用公告"""
if not database.get_announcement_by_id(announcement_id):
return jsonify({"error": "公告不存在"}), 404
ok = database.set_announcement_active(announcement_id, False)
return jsonify({"success": ok})
@admin_api_bp.route("/announcements/<int:announcement_id>", methods=["DELETE"])
@admin_required
def admin_delete_announcement(announcement_id):
"""删除公告"""
if not database.get_announcement_by_id(announcement_id):
return jsonify({"error": "公告不存在"}), 404
ok = database.delete_announcement(announcement_id)
return jsonify({"success": ok})

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,214 @@
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
from __future__ import annotations
import email_service
from app_logger import get_logger
from app_security import validate_email
from flask import jsonify, request
from routes.admin_api import admin_api_bp
from routes.decorators import admin_required
logger = get_logger("app")
@admin_api_bp.route("/email/settings", methods=["GET"])
@admin_required
def get_email_settings_api():
"""获取全局邮件设置"""
try:
settings = email_service.get_email_settings()
return jsonify(settings)
except Exception as e:
logger.error(f"获取邮件设置失败: {e}")
return jsonify({"error": str(e)}), 500
@admin_api_bp.route("/email/settings", methods=["POST"])
@admin_required
def update_email_settings_api():
"""更新全局邮件设置"""
try:
data = request.json or {}
enabled = data.get("enabled", False)
failover_enabled = data.get("failover_enabled", True)
register_verify_enabled = data.get("register_verify_enabled")
login_alert_enabled = data.get("login_alert_enabled")
base_url = data.get("base_url")
task_notify_enabled = data.get("task_notify_enabled")
email_service.update_email_settings(
enabled=enabled,
failover_enabled=failover_enabled,
register_verify_enabled=register_verify_enabled,
login_alert_enabled=login_alert_enabled,
base_url=base_url,
task_notify_enabled=task_notify_enabled,
)
return jsonify({"success": True})
except Exception as e:
logger.error(f"更新邮件设置失败: {e}")
return jsonify({"error": str(e)}), 500
@admin_api_bp.route("/smtp/configs", methods=["GET"])
@admin_required
def get_smtp_configs_api():
"""获取所有SMTP配置列表"""
try:
configs = email_service.get_smtp_configs(include_password=False)
return jsonify(configs)
except Exception as e:
logger.error(f"获取SMTP配置失败: {e}")
return jsonify({"error": str(e)}), 500
@admin_api_bp.route("/smtp/configs", methods=["POST"])
@admin_required
def create_smtp_config_api():
"""创建SMTP配置"""
try:
data = request.json or {}
if not data.get("host"):
return jsonify({"error": "SMTP服务器地址不能为空"}), 400
if not data.get("username"):
return jsonify({"error": "SMTP用户名不能为空"}), 400
config_id = email_service.create_smtp_config(data)
return jsonify({"success": True, "id": config_id})
except Exception as e:
logger.error(f"创建SMTP配置失败: {e}")
return jsonify({"error": str(e)}), 500
@admin_api_bp.route("/smtp/configs/<int:config_id>", methods=["GET"])
@admin_required
def get_smtp_config_api(config_id):
"""获取单个SMTP配置详情"""
try:
config_data = email_service.get_smtp_config(config_id, include_password=False)
if not config_data:
return jsonify({"error": "配置不存在"}), 404
return jsonify(config_data)
except Exception as e:
logger.error(f"获取SMTP配置失败: {e}")
return jsonify({"error": str(e)}), 500
@admin_api_bp.route("/smtp/configs/<int:config_id>", methods=["PUT"])
@admin_required
def update_smtp_config_api(config_id):
"""更新SMTP配置"""
try:
data = request.json or {}
if email_service.update_smtp_config(config_id, data):
return jsonify({"success": True})
return jsonify({"error": "更新失败"}), 400
except Exception as e:
logger.error(f"更新SMTP配置失败: {e}")
return jsonify({"error": str(e)}), 500
@admin_api_bp.route("/smtp/configs/<int:config_id>", methods=["DELETE"])
@admin_required
def delete_smtp_config_api(config_id):
"""删除SMTP配置"""
try:
if email_service.delete_smtp_config(config_id):
return jsonify({"success": True})
return jsonify({"error": "删除失败"}), 400
except Exception as e:
logger.error(f"删除SMTP配置失败: {e}")
return jsonify({"error": str(e)}), 500
@admin_api_bp.route("/smtp/configs/<int:config_id>/test", methods=["POST"])
@admin_required
def test_smtp_config_api(config_id):
"""测试SMTP配置"""
try:
data = request.json or {}
test_email = str(data.get("email", "") or "").strip()
if not test_email:
return jsonify({"error": "请提供测试邮箱"}), 400
is_valid, error_msg = validate_email(test_email)
if not is_valid:
return jsonify({"error": error_msg}), 400
result = email_service.test_smtp_config(config_id, test_email)
return jsonify(result)
except Exception as e:
logger.error(f"测试SMTP配置失败: {e}")
return jsonify({"success": False, "error": str(e)}), 500
@admin_api_bp.route("/smtp/configs/<int:config_id>/primary", methods=["POST"])
@admin_required
def set_primary_smtp_config_api(config_id):
"""设置主SMTP配置"""
try:
if email_service.set_primary_smtp_config(config_id):
return jsonify({"success": True})
return jsonify({"error": "设置失败"}), 400
except Exception as e:
logger.error(f"设置主SMTP配置失败: {e}")
return jsonify({"error": str(e)}), 500
@admin_api_bp.route("/smtp/configs/primary/clear", methods=["POST"])
@admin_required
def clear_primary_smtp_config_api():
"""取消主SMTP配置"""
try:
email_service.clear_primary_smtp_config()
return jsonify({"success": True})
except Exception as e:
logger.error(f"取消主SMTP配置失败: {e}")
return jsonify({"error": str(e)}), 500
@admin_api_bp.route("/email/stats", methods=["GET"])
@admin_required
def get_email_stats_api():
"""获取邮件发送统计"""
try:
stats = email_service.get_email_stats()
return jsonify(stats)
except Exception as e:
logger.error(f"获取邮件统计失败: {e}")
return jsonify({"error": str(e)}), 500
@admin_api_bp.route("/email/logs", methods=["GET"])
@admin_required
def get_email_logs_api():
"""获取邮件发送日志"""
try:
page = request.args.get("page", 1, type=int)
page_size = request.args.get("page_size", 20, type=int)
email_type = request.args.get("type", None)
status = request.args.get("status", None)
page_size = min(max(page_size, 10), 100)
result = email_service.get_email_logs(page, page_size, email_type, status)
return jsonify(result)
except Exception as e:
logger.error(f"获取邮件日志失败: {e}")
return jsonify({"error": str(e)}), 500
@admin_api_bp.route("/email/logs/cleanup", methods=["POST"])
@admin_required
def cleanup_email_logs_api():
"""清理过期邮件日志"""
try:
data = request.json or {}
days = data.get("days", 30)
days = min(max(days, 7), 365)
deleted = email_service.cleanup_email_logs(days)
return jsonify({"success": True, "deleted": deleted})
except Exception as e:
logger.error(f"清理邮件日志失败: {e}")
return jsonify({"error": str(e)}), 500

View File

@@ -0,0 +1,58 @@
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
from __future__ import annotations
import database
from flask import jsonify, request
from routes.admin_api import admin_api_bp
from routes.decorators import admin_required
@admin_api_bp.route("/feedbacks", methods=["GET"])
@admin_required
def get_all_feedbacks():
"""管理员获取所有反馈"""
status = request.args.get("status")
try:
limit = int(request.args.get("limit", 100))
offset = int(request.args.get("offset", 0))
limit = min(max(1, limit), 1000)
offset = max(0, offset)
except (ValueError, TypeError):
return jsonify({"error": "无效的分页参数"}), 400
feedbacks = database.get_bug_feedbacks(limit=limit, offset=offset, status_filter=status)
stats = database.get_feedback_stats()
return jsonify({"feedbacks": feedbacks, "stats": stats})
@admin_api_bp.route("/feedbacks/<int:feedback_id>/reply", methods=["POST"])
@admin_required
def reply_to_feedback(feedback_id):
"""管理员回复反馈"""
data = request.get_json() or {}
reply = (data.get("reply") or "").strip()
if not reply:
return jsonify({"error": "回复内容不能为空"}), 400
if database.reply_feedback(feedback_id, reply):
return jsonify({"message": "回复成功"})
return jsonify({"error": "反馈不存在"}), 404
@admin_api_bp.route("/feedbacks/<int:feedback_id>/close", methods=["POST"])
@admin_required
def close_feedback_api(feedback_id):
"""管理员关闭反馈"""
if database.close_feedback(feedback_id):
return jsonify({"message": "已关闭"})
return jsonify({"error": "反馈不存在"}), 404
@admin_api_bp.route("/feedbacks/<int:feedback_id>", methods=["DELETE"])
@admin_required
def delete_feedback_api(feedback_id):
"""管理员删除反馈"""
if database.delete_feedback(feedback_id):
return jsonify({"message": "已删除"})
return jsonify({"error": "反馈不存在"}), 404

View File

@@ -0,0 +1,226 @@
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
from __future__ import annotations
import os
import time
from datetime import datetime
import database
from app_logger import get_logger
from flask import jsonify, session
from routes.admin_api import admin_api_bp
from routes.decorators import admin_required
from services.time_utils import BEIJING_TZ, get_beijing_now
logger = get_logger("app")
@admin_api_bp.route("/stats", methods=["GET"])
@admin_required
def get_system_stats():
"""获取系统统计"""
stats = database.get_system_stats()
stats["admin_username"] = session.get("admin_username", "admin")
return jsonify(stats)
@admin_api_bp.route("/browser_pool/stats", methods=["GET"])
@admin_required
def get_browser_pool_stats():
"""获取截图线程池状态"""
try:
from browser_pool_worker import get_browser_worker_pool
pool = get_browser_worker_pool()
stats = pool.get_stats() or {}
worker_details = []
for w in stats.get("workers") or []:
last_ts = float(w.get("last_active_ts") or 0)
last_active_at = None
if last_ts > 0:
try:
last_active_at = datetime.fromtimestamp(last_ts, tz=BEIJING_TZ).strftime("%Y-%m-%d %H:%M:%S")
except Exception:
last_active_at = None
created_ts = w.get("browser_created_at")
created_at = None
if created_ts:
try:
created_at = datetime.fromtimestamp(float(created_ts), tz=BEIJING_TZ).strftime("%Y-%m-%d %H:%M:%S")
except Exception:
created_at = None
worker_details.append(
{
"worker_id": w.get("worker_id"),
"idle": bool(w.get("idle")),
"has_browser": bool(w.get("has_browser")),
"total_tasks": int(w.get("total_tasks") or 0),
"failed_tasks": int(w.get("failed_tasks") or 0),
"browser_use_count": int(w.get("browser_use_count") or 0),
"browser_created_at": created_at,
"browser_created_ts": created_ts,
"last_active_at": last_active_at,
"last_active_ts": last_ts,
"thread_alive": bool(w.get("thread_alive")),
}
)
total_workers = len(worker_details) if worker_details else int(stats.get("pool_size") or 0)
return jsonify(
{
"total_workers": total_workers,
"active_workers": int(stats.get("busy_workers") or 0),
"idle_workers": int(stats.get("idle_workers") or 0),
"queue_size": int(stats.get("queue_size") or 0),
"workers": worker_details,
"summary": {
"total_tasks": int(stats.get("total_tasks") or 0),
"failed_tasks": int(stats.get("failed_tasks") or 0),
"success_rate": stats.get("success_rate"),
},
"server_time_cst": get_beijing_now().strftime("%Y-%m-%d %H:%M:%S"),
}
)
except Exception as e:
logger.exception(f"[AdminAPI] 获取截图线程池状态失败: {e}")
return jsonify({"error": "获取截图线程池状态失败"}), 500
@admin_api_bp.route("/docker_stats", methods=["GET"])
@admin_required
def get_docker_stats():
"""获取Docker容器运行状态"""
import subprocess
docker_status = {
"running": False,
"container_name": "N/A",
"uptime": "N/A",
"memory_usage": "N/A",
"memory_limit": "N/A",
"memory_percent": "N/A",
"cpu_percent": "N/A",
"status": "Unknown",
}
try:
if os.path.exists("/.dockerenv"):
docker_status["running"] = True
try:
with open("/etc/hostname", "r") as f:
docker_status["container_name"] = f.read().strip()
except Exception as e:
logger.debug(f"读取容器名称失败: {e}")
try:
if os.path.exists("/sys/fs/cgroup/memory.current"):
with open("/sys/fs/cgroup/memory.current", "r") as f:
mem_total = int(f.read().strip())
cache = 0
if os.path.exists("/sys/fs/cgroup/memory.stat"):
with open("/sys/fs/cgroup/memory.stat", "r") as f:
for line in f:
if line.startswith("inactive_file "):
cache = int(line.split()[1])
break
mem_bytes = mem_total - cache
docker_status["memory_usage"] = "{:.2f} MB".format(mem_bytes / 1024 / 1024)
if os.path.exists("/sys/fs/cgroup/memory.max"):
with open("/sys/fs/cgroup/memory.max", "r") as f:
limit_str = f.read().strip()
if limit_str != "max":
limit_bytes = int(limit_str)
docker_status["memory_limit"] = "{:.2f} GB".format(limit_bytes / 1024 / 1024 / 1024)
docker_status["memory_percent"] = "{:.2f}%".format(mem_bytes / limit_bytes * 100)
elif os.path.exists("/sys/fs/cgroup/memory/memory.usage_in_bytes"):
with open("/sys/fs/cgroup/memory/memory.usage_in_bytes", "r") as f:
mem_bytes = int(f.read().strip())
docker_status["memory_usage"] = "{:.2f} MB".format(mem_bytes / 1024 / 1024)
with open("/sys/fs/cgroup/memory/memory.limit_in_bytes", "r") as f:
limit_bytes = int(f.read().strip())
if limit_bytes < 1e18:
docker_status["memory_limit"] = "{:.2f} GB".format(limit_bytes / 1024 / 1024 / 1024)
docker_status["memory_percent"] = "{:.2f}%".format(mem_bytes / limit_bytes * 100)
except Exception as e:
logger.debug(f"读取内存信息失败: {e}")
try:
if os.path.exists("/sys/fs/cgroup/cpu.stat"):
cpu_usage = 0
with open("/sys/fs/cgroup/cpu.stat", "r") as f:
for line in f:
if line.startswith("usage_usec"):
cpu_usage = int(line.split()[1])
break
time.sleep(0.1)
cpu_usage2 = 0
with open("/sys/fs/cgroup/cpu.stat", "r") as f:
for line in f:
if line.startswith("usage_usec"):
cpu_usage2 = int(line.split()[1])
break
cpu_percent = (cpu_usage2 - cpu_usage) / 0.1 / 1e6 * 100
docker_status["cpu_percent"] = "{:.2f}%".format(cpu_percent)
elif os.path.exists("/sys/fs/cgroup/cpu/cpuacct.usage"):
with open("/sys/fs/cgroup/cpu/cpuacct.usage", "r") as f:
cpu_usage = int(f.read().strip())
time.sleep(0.1)
with open("/sys/fs/cgroup/cpu/cpuacct.usage", "r") as f:
cpu_usage2 = int(f.read().strip())
cpu_percent = (cpu_usage2 - cpu_usage) / 0.1 / 1e9 * 100
docker_status["cpu_percent"] = "{:.2f}%".format(cpu_percent)
except Exception as e:
logger.debug(f"读取CPU信息失败: {e}")
try:
# 读取系统运行时间
with open('/proc/uptime', 'r') as f:
system_uptime = float(f.read().split()[0])
# 读取 PID 1 的启动时间 (jiffies)
with open('/proc/1/stat', 'r') as f:
stat = f.read().split()
starttime_jiffies = int(stat[21])
# 获取 CLK_TCK (通常是 100)
clk_tck = os.sysconf(os.sysconf_names['SC_CLK_TCK'])
# 计算容器运行时长(秒)
container_uptime_seconds = system_uptime - (starttime_jiffies / clk_tck)
# 格式化为可读字符串
days = int(container_uptime_seconds // 86400)
hours = int((container_uptime_seconds % 86400) // 3600)
minutes = int((container_uptime_seconds % 3600) // 60)
if days > 0:
docker_status["uptime"] = f"{days}{hours}小时{minutes}分钟"
elif hours > 0:
docker_status["uptime"] = f"{hours}小时{minutes}分钟"
else:
docker_status["uptime"] = f"{minutes}分钟"
except Exception as e:
logger.debug(f"获取容器运行时间失败: {e}")
docker_status["status"] = "Running"
else:
docker_status["status"] = "Not in Docker"
except Exception as e:
docker_status["status"] = f"Error: {str(e)}"
return jsonify(docker_status)

View File

@@ -0,0 +1,228 @@
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
from __future__ import annotations
import threading
import time
from datetime import datetime
import database
import requests
from app_logger import get_logger
from app_security import is_safe_outbound_url
from flask import jsonify, request
from routes.admin_api import admin_api_bp
from routes.decorators import admin_required
from services.scheduler import run_scheduled_task
from services.time_utils import BEIJING_TZ, get_beijing_now
logger = get_logger("app")
_server_cpu_percent_lock = threading.Lock()
_server_cpu_percent_last: float | None = None
_server_cpu_percent_last_ts = 0.0
def _get_server_cpu_percent() -> float:
import psutil
global _server_cpu_percent_last, _server_cpu_percent_last_ts
now = time.time()
with _server_cpu_percent_lock:
if _server_cpu_percent_last is not None and (now - _server_cpu_percent_last_ts) < 0.5:
return _server_cpu_percent_last
try:
if _server_cpu_percent_last is None:
cpu_percent = float(psutil.cpu_percent(interval=0.1))
else:
cpu_percent = float(psutil.cpu_percent(interval=None))
except Exception:
cpu_percent = float(_server_cpu_percent_last or 0.0)
if cpu_percent < 0:
cpu_percent = 0.0
_server_cpu_percent_last = cpu_percent
_server_cpu_percent_last_ts = now
return cpu_percent
@admin_api_bp.route("/kdocs/status", methods=["GET"])
@admin_required
def get_kdocs_status_api():
"""获取金山文档上传状态"""
try:
from services.kdocs_uploader import get_kdocs_uploader
uploader = get_kdocs_uploader()
status = uploader.get_status()
live = str(request.args.get("live", "")).lower() in ("1", "true", "yes")
if live:
live_status = uploader.refresh_login_status()
if live_status.get("success"):
logged_in = bool(live_status.get("logged_in"))
status["logged_in"] = logged_in
status["last_login_ok"] = logged_in
status["login_required"] = not logged_in
if live_status.get("error"):
status["last_error"] = live_status.get("error")
else:
status["logged_in"] = True if status.get("last_login_ok") else False if status.get("last_login_ok") is False else None
if status.get("last_login_ok") is True and status.get("last_error") == "操作超时":
status["last_error"] = None
return jsonify(status)
except Exception as e:
return jsonify({"error": f"获取状态失败: {e}"}), 500
@admin_api_bp.route("/kdocs/qr", methods=["POST"])
@admin_required
def get_kdocs_qr_api():
"""获取金山文档登录二维码"""
try:
from services.kdocs_uploader import get_kdocs_uploader
uploader = get_kdocs_uploader()
data = request.get_json(silent=True) or {}
force = bool(data.get("force"))
if not force:
force = str(request.args.get("force", "")).lower() in ("1", "true", "yes")
result = uploader.request_qr(force=force)
if not result.get("success"):
return jsonify({"error": result.get("error", "获取二维码失败")}), 400
return jsonify(result)
except Exception as e:
return jsonify({"error": f"获取二维码失败: {e}"}), 500
@admin_api_bp.route("/kdocs/clear-login", methods=["POST"])
@admin_required
def clear_kdocs_login_api():
"""清除金山文档登录态"""
try:
from services.kdocs_uploader import get_kdocs_uploader
uploader = get_kdocs_uploader()
result = uploader.clear_login()
if not result.get("success"):
return jsonify({"error": result.get("error", "清除失败")}), 400
return jsonify({"success": True})
except Exception as e:
return jsonify({"error": f"清除失败: {e}"}), 500
@admin_api_bp.route("/schedule/execute", methods=["POST"])
@admin_required
def execute_schedule_now():
"""立即执行定时任务(无视定时时间和星期限制)"""
try:
threading.Thread(target=run_scheduled_task, args=(True,), daemon=True).start()
logger.info("[立即执行定时任务] 管理员手动触发定时任务执行(跳过星期检查)")
return jsonify({"message": "定时任务已开始执行,请查看任务列表获取进度"})
except Exception as e:
logger.error(f"[立即执行定时任务] 启动失败: {str(e)}")
return jsonify({"error": f"启动失败: {str(e)}"}), 500
@admin_api_bp.route("/proxy/config", methods=["GET"])
@admin_required
def get_proxy_config_api():
"""获取代理配置"""
config_data = database.get_system_config()
return jsonify(
{
"proxy_enabled": config_data.get("proxy_enabled", 0),
"proxy_api_url": config_data.get("proxy_api_url", ""),
"proxy_expire_minutes": config_data.get("proxy_expire_minutes", 3),
}
)
@admin_api_bp.route("/proxy/config", methods=["POST"])
@admin_required
def update_proxy_config_api():
"""更新代理配置"""
data = request.json or {}
proxy_enabled = data.get("proxy_enabled")
proxy_api_url = (data.get("proxy_api_url", "") or "").strip()
proxy_expire_minutes = data.get("proxy_expire_minutes")
if proxy_enabled is not None and proxy_enabled not in [0, 1]:
return jsonify({"error": "proxy_enabled必须是0或1"}), 400
if proxy_expire_minutes is not None:
if not isinstance(proxy_expire_minutes, int) or proxy_expire_minutes < 1:
return jsonify({"error": "代理有效期必须是大于0的整数"}), 400
if database.update_system_config(
proxy_enabled=proxy_enabled,
proxy_api_url=proxy_api_url,
proxy_expire_minutes=proxy_expire_minutes,
):
return jsonify({"message": "代理配置已更新"})
return jsonify({"error": "更新失败"}), 400
@admin_api_bp.route("/proxy/test", methods=["POST"])
@admin_required
def test_proxy_api():
"""测试代理连接"""
data = request.json or {}
api_url = (data.get("api_url") or "").strip()
if not api_url:
return jsonify({"error": "请提供API地址"}), 400
if not is_safe_outbound_url(api_url):
return jsonify({"error": "API地址不可用或不安全"}), 400
try:
response = requests.get(api_url, timeout=10)
if response.status_code == 200:
ip_port = response.text.strip()
if ip_port and ":" in ip_port:
return jsonify({"success": True, "proxy": ip_port, "message": f"代理获取成功: {ip_port}"})
return jsonify({"success": False, "message": f"代理格式错误: {ip_port}"}), 400
return jsonify({"success": False, "message": f"HTTP错误: {response.status_code}"}), 400
except Exception as e:
return jsonify({"success": False, "message": f"连接失败: {str(e)}"}), 500
@admin_api_bp.route("/server/info", methods=["GET"])
@admin_required
def get_server_info_api():
"""获取服务器信息"""
import psutil
cpu_percent = _get_server_cpu_percent()
memory = psutil.virtual_memory()
memory_total = f"{memory.total / (1024**3):.1f}GB"
memory_used = f"{memory.used / (1024**3):.1f}GB"
memory_percent = memory.percent
disk = psutil.disk_usage("/")
disk_total = f"{disk.total / (1024**3):.1f}GB"
disk_used = f"{disk.used / (1024**3):.1f}GB"
disk_percent = disk.percent
boot_time = datetime.fromtimestamp(psutil.boot_time(), tz=BEIJING_TZ)
uptime_delta = get_beijing_now() - boot_time
days = uptime_delta.days
hours = uptime_delta.seconds // 3600
uptime = f"{days}{hours}小时"
return jsonify(
{
"cpu_percent": cpu_percent,
"memory_total": memory_total,
"memory_used": memory_used,
"memory_percent": memory_percent,
"disk_total": disk_total,
"disk_used": disk_used,
"disk_percent": disk_percent,
"uptime": uptime,
}
)

View File

@@ -62,6 +62,19 @@ def _parse_bool(value: Any) -> bool:
return text in {"1", "true", "yes", "y", "on"} return text in {"1", "true", "yes", "y", "on"}
def _parse_int(value: Any, *, default: int | None = None, min_value: int | None = None) -> int | None:
try:
parsed = int(value)
except Exception:
parsed = default
if parsed is None:
return None
if min_value is not None:
parsed = max(int(min_value), parsed)
return parsed
def _sanitize_threat_event(event: dict) -> dict: def _sanitize_threat_event(event: dict) -> dict:
return { return {
"id": event.get("id"), "id": event.get("id"),
@@ -199,10 +212,7 @@ def ban_ip():
if not reason: if not reason:
return jsonify({"error": "reason不能为空"}), 400 return jsonify({"error": "reason不能为空"}), 400
try: duration_hours = _parse_int(duration_hours_raw, default=24, min_value=1) or 24
duration_hours = max(1, int(duration_hours_raw))
except Exception:
duration_hours = 24
ok = blacklist.ban_ip(ip, reason, duration_hours=duration_hours, permanent=permanent) ok = blacklist.ban_ip(ip, reason, duration_hours=duration_hours, permanent=permanent)
if not ok: if not ok:
@@ -235,20 +245,14 @@ def ban_user():
duration_hours_raw = data.get("duration_hours", 24) duration_hours_raw = data.get("duration_hours", 24)
permanent = _parse_bool(data.get("permanent", False)) permanent = _parse_bool(data.get("permanent", False))
try: user_id = _parse_int(user_id_raw)
user_id = int(user_id_raw)
except Exception:
user_id = None
if user_id is None: if user_id is None:
return jsonify({"error": "user_id不能为空"}), 400 return jsonify({"error": "user_id不能为空"}), 400
if not reason: if not reason:
return jsonify({"error": "reason不能为空"}), 400 return jsonify({"error": "reason不能为空"}), 400
try: duration_hours = _parse_int(duration_hours_raw, default=24, min_value=1) or 24
duration_hours = max(1, int(duration_hours_raw))
except Exception:
duration_hours = 24
ok = blacklist._ban_user_internal(user_id, reason=reason, duration_hours=duration_hours, permanent=permanent) ok = blacklist._ban_user_internal(user_id, reason=reason, duration_hours=duration_hours, permanent=permanent)
if not ok: if not ok:
@@ -262,10 +266,7 @@ def unban_user():
"""解除用户封禁""" """解除用户封禁"""
data = _parse_json() data = _parse_json()
user_id_raw = data.get("user_id") user_id_raw = data.get("user_id")
try: user_id = _parse_int(user_id_raw)
user_id = int(user_id_raw)
except Exception:
user_id = None
if user_id is None: if user_id is None:
return jsonify({"error": "user_id不能为空"}), 400 return jsonify({"error": "user_id不能为空"}), 400

View File

@@ -0,0 +1,228 @@
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
from __future__ import annotations
import database
from app_logger import get_logger
from app_security import is_safe_outbound_url, validate_email
from flask import jsonify, request
from routes.admin_api import admin_api_bp
from routes.decorators import admin_required
from services.browse_types import BROWSE_TYPE_SHOULD_READ, validate_browse_type
from services.tasks import get_task_scheduler
logger = get_logger("app")
@admin_api_bp.route("/system/config", methods=["GET"])
@admin_required
def get_system_config_api():
"""获取系统配置"""
return jsonify(database.get_system_config())
@admin_api_bp.route("/system/config", methods=["POST"])
@admin_required
def update_system_config_api():
"""更新系统配置"""
data = request.json or {}
max_concurrent = data.get("max_concurrent_global")
schedule_enabled = data.get("schedule_enabled")
schedule_time = data.get("schedule_time")
schedule_browse_type = data.get("schedule_browse_type")
schedule_weekdays = data.get("schedule_weekdays")
new_max_concurrent_per_account = data.get("max_concurrent_per_account")
new_max_screenshot_concurrent = data.get("max_screenshot_concurrent")
enable_screenshot = data.get("enable_screenshot")
auto_approve_enabled = data.get("auto_approve_enabled")
auto_approve_hourly_limit = data.get("auto_approve_hourly_limit")
auto_approve_vip_days = data.get("auto_approve_vip_days")
kdocs_enabled = data.get("kdocs_enabled")
kdocs_doc_url = data.get("kdocs_doc_url")
kdocs_default_unit = data.get("kdocs_default_unit")
kdocs_sheet_name = data.get("kdocs_sheet_name")
kdocs_sheet_index = data.get("kdocs_sheet_index")
kdocs_unit_column = data.get("kdocs_unit_column")
kdocs_image_column = data.get("kdocs_image_column")
kdocs_admin_notify_enabled = data.get("kdocs_admin_notify_enabled")
kdocs_admin_notify_email = data.get("kdocs_admin_notify_email")
kdocs_row_start = data.get("kdocs_row_start")
kdocs_row_end = data.get("kdocs_row_end")
if max_concurrent is not None:
if not isinstance(max_concurrent, int) or max_concurrent < 1:
return jsonify({"error": "全局并发数必须大于0建议小型服务器2-5中型5-10大型10-20"}), 400
if new_max_concurrent_per_account is not None:
if not isinstance(new_max_concurrent_per_account, int) or new_max_concurrent_per_account < 1:
return jsonify({"error": "单账号并发数必须大于0建议设为1避免同一用户任务相互影响"}), 400
if new_max_screenshot_concurrent is not None:
if not isinstance(new_max_screenshot_concurrent, int) or new_max_screenshot_concurrent < 1:
return jsonify({"error": "截图并发数必须大于0建议根据服务器配置设置wkhtmltoimage 资源占用较低)"}), 400
if enable_screenshot is not None:
if isinstance(enable_screenshot, bool):
enable_screenshot = 1 if enable_screenshot else 0
if enable_screenshot not in (0, 1):
return jsonify({"error": "截图开关必须是0或1"}), 400
if schedule_time is not None:
import re
if not re.match(r"^([01]\\d|2[0-3]):([0-5]\\d)$", schedule_time):
return jsonify({"error": "时间格式错误,应为 HH:MM"}), 400
if schedule_browse_type is not None:
normalized = validate_browse_type(schedule_browse_type, default=BROWSE_TYPE_SHOULD_READ)
if not normalized:
return jsonify({"error": "浏览类型无效"}), 400
schedule_browse_type = normalized
if schedule_weekdays is not None:
try:
days = [int(d.strip()) for d in schedule_weekdays.split(",") if d.strip()]
if not all(1 <= d <= 7 for d in days):
return jsonify({"error": "星期数字必须在1-7之间"}), 400
except (ValueError, AttributeError):
return jsonify({"error": "星期格式错误"}), 400
if auto_approve_hourly_limit is not None:
if not isinstance(auto_approve_hourly_limit, int) or auto_approve_hourly_limit < 1:
return jsonify({"error": "每小时注册限制必须大于0"}), 400
if auto_approve_vip_days is not None:
if not isinstance(auto_approve_vip_days, int) or auto_approve_vip_days < 0:
return jsonify({"error": "注册赠送VIP天数不能为负数"}), 400
if kdocs_enabled is not None:
if isinstance(kdocs_enabled, bool):
kdocs_enabled = 1 if kdocs_enabled else 0
if kdocs_enabled not in (0, 1):
return jsonify({"error": "表格上传开关必须是0或1"}), 400
if kdocs_doc_url is not None:
kdocs_doc_url = str(kdocs_doc_url or "").strip()
if kdocs_doc_url and not is_safe_outbound_url(kdocs_doc_url):
return jsonify({"error": "文档链接格式不正确"}), 400
if kdocs_default_unit is not None:
kdocs_default_unit = str(kdocs_default_unit or "").strip()
if len(kdocs_default_unit) > 50:
return jsonify({"error": "默认县区长度不能超过50"}), 400
if kdocs_sheet_name is not None:
kdocs_sheet_name = str(kdocs_sheet_name or "").strip()
if len(kdocs_sheet_name) > 50:
return jsonify({"error": "Sheet名称长度不能超过50"}), 400
if kdocs_sheet_index is not None:
try:
kdocs_sheet_index = int(kdocs_sheet_index)
except Exception:
return jsonify({"error": "Sheet序号必须是数字"}), 400
if kdocs_sheet_index < 0:
return jsonify({"error": "Sheet序号不能为负数"}), 400
if kdocs_unit_column is not None:
kdocs_unit_column = str(kdocs_unit_column or "").strip().upper()
if not kdocs_unit_column:
return jsonify({"error": "县区列不能为空"}), 400
import re
if not re.match(r"^[A-Z]{1,3}$", kdocs_unit_column):
return jsonify({"error": "县区列格式错误"}), 400
if kdocs_image_column is not None:
kdocs_image_column = str(kdocs_image_column or "").strip().upper()
if not kdocs_image_column:
return jsonify({"error": "图片列不能为空"}), 400
import re
if not re.match(r"^[A-Z]{1,3}$", kdocs_image_column):
return jsonify({"error": "图片列格式错误"}), 400
if kdocs_admin_notify_enabled is not None:
if isinstance(kdocs_admin_notify_enabled, bool):
kdocs_admin_notify_enabled = 1 if kdocs_admin_notify_enabled else 0
if kdocs_admin_notify_enabled not in (0, 1):
return jsonify({"error": "管理员通知开关必须是0或1"}), 400
if kdocs_admin_notify_email is not None:
kdocs_admin_notify_email = str(kdocs_admin_notify_email or "").strip()
if kdocs_admin_notify_email:
is_valid, error_msg = validate_email(kdocs_admin_notify_email)
if not is_valid:
return jsonify({"error": error_msg}), 400
if kdocs_row_start is not None:
try:
kdocs_row_start = int(kdocs_row_start)
except (ValueError, TypeError):
return jsonify({"error": "起始行必须是数字"}), 400
if kdocs_row_start < 0:
return jsonify({"error": "起始行不能为负数"}), 400
if kdocs_row_end is not None:
try:
kdocs_row_end = int(kdocs_row_end)
except (ValueError, TypeError):
return jsonify({"error": "结束行必须是数字"}), 400
if kdocs_row_end < 0:
return jsonify({"error": "结束行不能为负数"}), 400
old_config = database.get_system_config() or {}
if not database.update_system_config(
max_concurrent=max_concurrent,
schedule_enabled=schedule_enabled,
schedule_time=schedule_time,
schedule_browse_type=schedule_browse_type,
schedule_weekdays=schedule_weekdays,
max_concurrent_per_account=new_max_concurrent_per_account,
max_screenshot_concurrent=new_max_screenshot_concurrent,
enable_screenshot=enable_screenshot,
auto_approve_enabled=auto_approve_enabled,
auto_approve_hourly_limit=auto_approve_hourly_limit,
auto_approve_vip_days=auto_approve_vip_days,
kdocs_enabled=kdocs_enabled,
kdocs_doc_url=kdocs_doc_url,
kdocs_default_unit=kdocs_default_unit,
kdocs_sheet_name=kdocs_sheet_name,
kdocs_sheet_index=kdocs_sheet_index,
kdocs_unit_column=kdocs_unit_column,
kdocs_image_column=kdocs_image_column,
kdocs_admin_notify_enabled=kdocs_admin_notify_enabled,
kdocs_admin_notify_email=kdocs_admin_notify_email,
kdocs_row_start=kdocs_row_start,
kdocs_row_end=kdocs_row_end,
):
return jsonify({"error": "更新失败"}), 400
try:
new_config = database.get_system_config() or {}
scheduler = get_task_scheduler()
scheduler.update_limits(
max_global=int(new_config.get("max_concurrent_global", old_config.get("max_concurrent_global", 2))),
max_per_user=int(new_config.get("max_concurrent_per_account", old_config.get("max_concurrent_per_account", 1))),
)
if new_max_screenshot_concurrent is not None:
try:
from browser_pool_worker import resize_browser_worker_pool
if resize_browser_worker_pool(int(new_config.get("max_screenshot_concurrent", new_max_screenshot_concurrent))):
logger.info(f"截图线程池并发已更新为: {new_config.get('max_screenshot_concurrent')}")
except Exception as pool_error:
logger.warning(f"截图线程池并发更新失败: {pool_error}")
except Exception:
pass
if max_concurrent is not None and max_concurrent != old_config.get("max_concurrent_global"):
logger.info(f"全局并发数已更新为: {max_concurrent}")
if new_max_concurrent_per_account is not None and new_max_concurrent_per_account != old_config.get("max_concurrent_per_account"):
logger.info(f"单用户并发数已更新为: {new_max_concurrent_per_account}")
if new_max_screenshot_concurrent is not None:
logger.info(f"截图并发数已更新为: {new_max_screenshot_concurrent}")
return jsonify({"message": "系统配置已更新"})

View File

@@ -0,0 +1,138 @@
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
from __future__ import annotations
import database
from app_logger import get_logger
from flask import jsonify, request
from routes.admin_api import admin_api_bp
from routes.decorators import admin_required
from services.state import safe_iter_task_status_items
from services.tasks import get_task_scheduler
logger = get_logger("app")
def _parse_page_int(name: str, default: int, *, minimum: int, maximum: int) -> int:
try:
value = int(request.args.get(name, default))
return max(minimum, min(value, maximum))
except (ValueError, TypeError):
return default
@admin_api_bp.route("/task/stats", methods=["GET"])
@admin_required
def get_task_stats_api():
"""获取任务统计数据"""
date_filter = request.args.get("date")
stats = database.get_task_stats(date_filter)
return jsonify(stats)
@admin_api_bp.route("/task/running", methods=["GET"])
@admin_required
def get_running_tasks_api():
"""获取当前运行中和排队中的任务"""
import time as time_mod
current_time = time_mod.time()
running = []
queuing = []
user_cache = {}
for account_id, info in safe_iter_task_status_items():
elapsed = int(current_time - info.get("start_time", current_time))
info_user_id = info.get("user_id")
if info_user_id not in user_cache:
user_cache[info_user_id] = database.get_user_by_id(info_user_id)
user = user_cache.get(info_user_id)
user_username = user["username"] if user else "N/A"
progress = info.get("progress", {"items": 0, "attachments": 0})
task_info = {
"account_id": account_id,
"user_id": info.get("user_id"),
"user_username": user_username,
"username": info.get("username"),
"browse_type": info.get("browse_type"),
"source": info.get("source", "manual"),
"detail_status": info.get("detail_status", "未知"),
"progress_items": progress.get("items", 0),
"progress_attachments": progress.get("attachments", 0),
"elapsed_seconds": elapsed,
"elapsed_display": f"{elapsed // 60}{elapsed % 60}" if elapsed >= 60 else f"{elapsed}",
}
if info.get("status") == "运行中":
running.append(task_info)
else:
queuing.append(task_info)
running.sort(key=lambda x: x["elapsed_seconds"], reverse=True)
queuing.sort(key=lambda x: x["elapsed_seconds"], reverse=True)
try:
max_concurrent = int(get_task_scheduler().max_global)
except Exception:
max_concurrent = int((database.get_system_config() or {}).get("max_concurrent_global", 2))
return jsonify(
{
"running": running,
"queuing": queuing,
"running_count": len(running),
"queuing_count": len(queuing),
"max_concurrent": max_concurrent,
}
)
@admin_api_bp.route("/task/logs", methods=["GET"])
@admin_required
def get_task_logs_api():
"""获取任务日志列表(支持分页和多种筛选)"""
limit = _parse_page_int("limit", 20, minimum=1, maximum=200)
offset = _parse_page_int("offset", 0, minimum=0, maximum=10**9)
date_filter = request.args.get("date")
status_filter = request.args.get("status")
source_filter = request.args.get("source")
user_id_filter = request.args.get("user_id")
account_filter = (request.args.get("account") or "").strip()
if user_id_filter:
try:
user_id_filter = int(user_id_filter)
except (ValueError, TypeError):
user_id_filter = None
try:
result = database.get_task_logs(
limit=limit,
offset=offset,
date_filter=date_filter,
status_filter=status_filter,
source_filter=source_filter,
user_id_filter=user_id_filter,
account_filter=account_filter if account_filter else None,
)
return jsonify(result)
except Exception as e:
logger.error(f"获取任务日志失败: {e}")
return jsonify({"logs": [], "total": 0, "error": "查询失败"})
@admin_api_bp.route("/task/logs/clear", methods=["POST"])
@admin_required
def clear_old_task_logs_api():
"""清理旧的任务日志"""
data = request.json or {}
days = data.get("days", 30)
if not isinstance(days, int) or days < 1:
return jsonify({"error": "天数必须是大于0的整数"}), 400
deleted_count = database.delete_old_task_logs(days)
return jsonify({"message": f"已删除{days}天前的{deleted_count}条日志"})

View File

@@ -0,0 +1,117 @@
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
from __future__ import annotations
import database
from flask import jsonify, request
from routes.admin_api import admin_api_bp
from routes.decorators import admin_required
from services.state import safe_clear_user_logs, safe_remove_user_accounts
# ==================== 用户管理/统计(管理员) ====================
@admin_api_bp.route("/users", methods=["GET"])
@admin_required
def get_all_users():
"""获取所有用户"""
users = database.get_all_users()
return jsonify(users)
@admin_api_bp.route("/users/pending", methods=["GET"])
@admin_required
def get_pending_users():
"""获取待审核用户"""
users = database.get_pending_users()
return jsonify(users)
@admin_api_bp.route("/users/<int:user_id>/approve", methods=["POST"])
@admin_required
def approve_user_route(user_id):
"""审核通过用户"""
if database.approve_user(user_id):
return jsonify({"success": True})
return jsonify({"error": "审核失败"}), 400
@admin_api_bp.route("/users/<int:user_id>/reject", methods=["POST"])
@admin_required
def reject_user_route(user_id):
"""拒绝用户"""
if database.reject_user(user_id):
return jsonify({"success": True})
return jsonify({"error": "拒绝失败"}), 400
@admin_api_bp.route("/users/<int:user_id>", methods=["DELETE"])
@admin_required
def delete_user_route(user_id):
"""删除用户"""
if database.delete_user(user_id):
safe_remove_user_accounts(user_id)
safe_clear_user_logs(user_id)
return jsonify({"success": True})
return jsonify({"error": "删除失败"}), 400
# ==================== VIP 管理(管理员) ====================
@admin_api_bp.route("/vip/config", methods=["GET"])
@admin_required
def get_vip_config_api():
"""获取VIP配置"""
config = database.get_vip_config()
return jsonify(config)
@admin_api_bp.route("/vip/config", methods=["POST"])
@admin_required
def set_vip_config_api():
"""设置默认VIP天数"""
data = request.json or {}
days = data.get("default_vip_days", 0)
if not isinstance(days, int) or days < 0:
return jsonify({"error": "VIP天数必须是非负整数"}), 400
database.set_default_vip_days(days)
return jsonify({"message": "VIP配置已更新", "default_vip_days": days})
@admin_api_bp.route("/users/<int:user_id>/vip", methods=["POST"])
@admin_required
def set_user_vip_api(user_id):
"""设置用户VIP"""
data = request.json or {}
days = data.get("days", 30)
valid_days = [7, 30, 365, 999999]
if days not in valid_days:
return jsonify({"error": "VIP天数必须是 7/30/365/999999 之一"}), 400
if database.set_user_vip(user_id, days):
vip_type = {7: "一周", 30: "一个月", 365: "一年", 999999: "永久"}[days]
return jsonify({"message": f"VIP设置成功: {vip_type}"})
return jsonify({"error": "设置失败,用户不存在"}), 400
@admin_api_bp.route("/users/<int:user_id>/vip", methods=["DELETE"])
@admin_required
def remove_user_vip_api(user_id):
"""移除用户VIP"""
if database.remove_user_vip(user_id):
return jsonify({"message": "VIP已移除"})
return jsonify({"error": "移除失败"}), 400
@admin_api_bp.route("/users/<int:user_id>/vip", methods=["GET"])
@admin_required
def get_user_vip_info_api(user_id):
"""获取用户VIP信息(管理员)"""
vip_info = database.get_user_vip_info(user_id)
return jsonify(vip_info)

View File

@@ -40,6 +40,48 @@ def _emit(event: str, data: object, *, room: str | None = None) -> None:
pass pass
def _emit_account_update(user_id: int, account) -> None:
_emit("account_update", account.to_dict(), room=f"user_{user_id}")
def _request_json(default=None):
if default is None:
default = {}
data = request.get_json(silent=True)
return data if isinstance(data, dict) else default
def _ensure_accounts_loaded(user_id: int) -> dict:
accounts = safe_get_user_accounts_snapshot(user_id)
if accounts:
return accounts
load_user_accounts(user_id)
return safe_get_user_accounts_snapshot(user_id)
def _get_user_account(user_id: int, account_id: str, *, refresh_if_missing: bool = False):
account = safe_get_account(user_id, account_id)
if account or (not refresh_if_missing):
return account
load_user_accounts(user_id)
return safe_get_account(user_id, account_id)
def _validate_browse_type_input(raw_browse_type, *, default=BROWSE_TYPE_SHOULD_READ):
browse_type = validate_browse_type(raw_browse_type, default=default)
if not browse_type:
return None, (jsonify({"error": "浏览类型无效"}), 400)
return browse_type, None
def _cancel_pending_account_task(user_id: int, account_id: str) -> bool:
try:
scheduler = get_task_scheduler()
return bool(scheduler.cancel_pending_task(user_id=user_id, account_id=account_id))
except Exception:
return False
@api_accounts_bp.route("/api/accounts", methods=["GET"]) @api_accounts_bp.route("/api/accounts", methods=["GET"])
@login_required @login_required
def get_accounts(): def get_accounts():
@@ -49,8 +91,7 @@ def get_accounts():
accounts = safe_get_user_accounts_snapshot(user_id) accounts = safe_get_user_accounts_snapshot(user_id)
if refresh or not accounts: if refresh or not accounts:
load_user_accounts(user_id) accounts = _ensure_accounts_loaded(user_id)
accounts = safe_get_user_accounts_snapshot(user_id)
return jsonify([acc.to_dict() for acc in accounts.values()]) return jsonify([acc.to_dict() for acc in accounts.values()])
@@ -63,20 +104,18 @@ def add_account():
current_count = len(database.get_user_accounts(user_id)) current_count = len(database.get_user_accounts(user_id))
is_vip = database.is_user_vip(user_id) is_vip = database.is_user_vip(user_id)
if not is_vip and current_count >= 3: if (not is_vip) and current_count >= 3:
return jsonify({"error": "普通用户最多添加3个账号升级VIP可无限添加"}), 403 return jsonify({"error": "普通用户最多添加3个账号升级VIP可无限添加"}), 403
data = request.json
username = data.get("username", "").strip() data = _request_json()
password = data.get("password", "").strip() username = str(data.get("username", "")).strip()
remark = data.get("remark", "").strip()[:200] password = str(data.get("password", "")).strip()
remark = str(data.get("remark", "")).strip()[:200]
if not username or not password: if not username or not password:
return jsonify({"error": "用户名和密码不能为空"}), 400 return jsonify({"error": "用户名和密码不能为空"}), 400
accounts = safe_get_user_accounts_snapshot(user_id) accounts = _ensure_accounts_loaded(user_id)
if not accounts:
load_user_accounts(user_id)
accounts = safe_get_user_accounts_snapshot(user_id)
for acc in accounts.values(): for acc in accounts.values():
if acc.username == username: if acc.username == username:
return jsonify({"error": f"账号 '{username}' 已存在"}), 400 return jsonify({"error": f"账号 '{username}' 已存在"}), 400
@@ -92,7 +131,7 @@ def add_account():
safe_set_account(user_id, account_id, account) safe_set_account(user_id, account_id, account)
log_to_client(f"添加账号: {username}", user_id) log_to_client(f"添加账号: {username}", user_id)
_emit("account_update", account.to_dict(), room=f"user_{user_id}") _emit_account_update(user_id, account)
return jsonify(account.to_dict()) return jsonify(account.to_dict())
@@ -103,15 +142,15 @@ def update_account(account_id):
"""更新账号信息(密码等)""" """更新账号信息(密码等)"""
user_id = current_user.id user_id = current_user.id
account = safe_get_account(user_id, account_id) account = _get_user_account(user_id, account_id)
if not account: if not account:
return jsonify({"error": "账号不存在"}), 404 return jsonify({"error": "账号不存在"}), 404
if account.is_running: if account.is_running:
return jsonify({"error": "账号正在运行中,请先停止"}), 400 return jsonify({"error": "账号正在运行中,请先停止"}), 400
data = request.json data = _request_json()
new_password = data.get("password", "").strip() new_password = str(data.get("password", "")).strip()
new_remember = data.get("remember", account.remember) new_remember = data.get("remember", account.remember)
if not new_password: if not new_password:
@@ -147,7 +186,7 @@ def delete_account(account_id):
"""删除账号""" """删除账号"""
user_id = current_user.id user_id = current_user.id
account = safe_get_account(user_id, account_id) account = _get_user_account(user_id, account_id)
if not account: if not account:
return jsonify({"error": "账号不存在"}), 404 return jsonify({"error": "账号不存在"}), 404
@@ -159,7 +198,6 @@ def delete_account(account_id):
username = account.username username = account.username
database.delete_account(account_id) database.delete_account(account_id)
safe_remove_account(user_id, account_id) safe_remove_account(user_id, account_id)
log_to_client(f"删除账号: {username}", user_id) log_to_client(f"删除账号: {username}", user_id)
@@ -196,12 +234,12 @@ def update_remark(account_id):
"""更新备注""" """更新备注"""
user_id = current_user.id user_id = current_user.id
account = safe_get_account(user_id, account_id) account = _get_user_account(user_id, account_id)
if not account: if not account:
return jsonify({"error": "账号不存在"}), 404 return jsonify({"error": "账号不存在"}), 404
data = request.json data = _request_json()
remark = data.get("remark", "").strip()[:200] remark = str(data.get("remark", "")).strip()[:200]
database.update_account_remark(account_id, remark) database.update_account_remark(account_id, remark)
@@ -217,17 +255,18 @@ def start_account(account_id):
"""启动账号任务""" """启动账号任务"""
user_id = current_user.id user_id = current_user.id
account = safe_get_account(user_id, account_id) account = _get_user_account(user_id, account_id)
if not account: if not account:
return jsonify({"error": "账号不存在"}), 404 return jsonify({"error": "账号不存在"}), 404
if account.is_running: if account.is_running:
return jsonify({"error": "任务已在运行中"}), 400 return jsonify({"error": "任务已在运行中"}), 400
data = request.json or {} data = _request_json()
browse_type = validate_browse_type(data.get("browse_type"), default=BROWSE_TYPE_SHOULD_READ) browse_type, browse_error = _validate_browse_type_input(data.get("browse_type"), default=BROWSE_TYPE_SHOULD_READ)
if not browse_type: if browse_error:
return jsonify({"error": "浏览类型无效"}), 400 return browse_error
enable_screenshot = data.get("enable_screenshot", True) enable_screenshot = data.get("enable_screenshot", True)
ok, message = submit_account_task( ok, message = submit_account_task(
user_id=user_id, user_id=user_id,
@@ -249,7 +288,7 @@ def stop_account(account_id):
"""停止账号任务""" """停止账号任务"""
user_id = current_user.id user_id = current_user.id
account = safe_get_account(user_id, account_id) account = _get_user_account(user_id, account_id)
if not account: if not account:
return jsonify({"error": "账号不存在"}), 404 return jsonify({"error": "账号不存在"}), 404
@@ -259,20 +298,16 @@ def stop_account(account_id):
account.should_stop = True account.should_stop = True
account.status = "正在停止" account.status = "正在停止"
try: if _cancel_pending_account_task(user_id, account_id):
scheduler = get_task_scheduler() account.status = "已停止"
if scheduler.cancel_pending_task(user_id=user_id, account_id=account_id): account.is_running = False
account.status = "已停止" safe_remove_task_status(account_id)
account.is_running = False _emit_account_update(user_id, account)
safe_remove_task_status(account_id) log_to_client(f"任务已取消: {account.username}", user_id)
_emit("account_update", account.to_dict(), room=f"user_{user_id}") return jsonify({"success": True, "canceled": True})
log_to_client(f"任务已取消: {account.username}", user_id)
return jsonify({"success": True, "canceled": True})
except Exception:
pass
log_to_client(f"停止任务: {account.username}", user_id) log_to_client(f"停止任务: {account.username}", user_id)
_emit("account_update", account.to_dict(), room=f"user_{user_id}") _emit_account_update(user_id, account)
return jsonify({"success": True}) return jsonify({"success": True})
@@ -283,23 +318,20 @@ def manual_screenshot(account_id):
"""手动为指定账号截图""" """手动为指定账号截图"""
user_id = current_user.id user_id = current_user.id
account = safe_get_account(user_id, account_id) account = _get_user_account(user_id, account_id, refresh_if_missing=True)
if not account:
load_user_accounts(user_id)
account = safe_get_account(user_id, account_id)
if not account: if not account:
return jsonify({"error": "账号不存在"}), 404 return jsonify({"error": "账号不存在"}), 404
if account.is_running: if account.is_running:
return jsonify({"error": "任务运行中,无法截图"}), 400 return jsonify({"error": "任务运行中,无法截图"}), 400
data = request.json or {} data = _request_json()
requested_browse_type = data.get("browse_type", None) requested_browse_type = data.get("browse_type", None)
if requested_browse_type is None: if requested_browse_type is None:
browse_type = normalize_browse_type(account.last_browse_type) browse_type = normalize_browse_type(account.last_browse_type)
else: else:
browse_type = validate_browse_type(requested_browse_type, default=BROWSE_TYPE_SHOULD_READ) browse_type, browse_error = _validate_browse_type_input(requested_browse_type, default=BROWSE_TYPE_SHOULD_READ)
if not browse_type: if browse_error:
return jsonify({"error": "浏览类型无效"}), 400 return browse_error
account.last_browse_type = browse_type account.last_browse_type = browse_type
@@ -317,12 +349,16 @@ def manual_screenshot(account_id):
def batch_start_accounts(): def batch_start_accounts():
"""批量启动账号""" """批量启动账号"""
user_id = current_user.id user_id = current_user.id
data = request.json or {} data = _request_json()
account_ids = data.get("account_ids", []) account_ids = data.get("account_ids", [])
browse_type = validate_browse_type(data.get("browse_type", BROWSE_TYPE_SHOULD_READ), default=BROWSE_TYPE_SHOULD_READ) browse_type, browse_error = _validate_browse_type_input(
if not browse_type: data.get("browse_type", BROWSE_TYPE_SHOULD_READ),
return jsonify({"error": "浏览类型无效"}), 400 default=BROWSE_TYPE_SHOULD_READ,
)
if browse_error:
return browse_error
enable_screenshot = data.get("enable_screenshot", True) enable_screenshot = data.get("enable_screenshot", True)
if not account_ids: if not account_ids:
@@ -331,11 +367,10 @@ def batch_start_accounts():
started = [] started = []
failed = [] failed = []
if not safe_get_user_accounts_snapshot(user_id): _ensure_accounts_loaded(user_id)
load_user_accounts(user_id)
for account_id in account_ids: for account_id in account_ids:
account = safe_get_account(user_id, account_id) account = _get_user_account(user_id, account_id)
if not account: if not account:
failed.append({"id": account_id, "reason": "账号不存在"}) failed.append({"id": account_id, "reason": "账号不存在"})
continue continue
@@ -357,7 +392,13 @@ def batch_start_accounts():
failed.append({"id": account_id, "reason": msg}) failed.append({"id": account_id, "reason": msg})
return jsonify( return jsonify(
{"success": True, "started_count": len(started), "failed_count": len(failed), "started": started, "failed": failed} {
"success": True,
"started_count": len(started),
"failed_count": len(failed),
"started": started,
"failed": failed,
}
) )
@@ -366,39 +407,29 @@ def batch_start_accounts():
def batch_stop_accounts(): def batch_stop_accounts():
"""批量停止账号""" """批量停止账号"""
user_id = current_user.id user_id = current_user.id
data = request.json data = _request_json()
account_ids = data.get("account_ids", []) account_ids = data.get("account_ids", [])
if not account_ids: if not account_ids:
return jsonify({"error": "请选择要停止的账号"}), 400 return jsonify({"error": "请选择要停止的账号"}), 400
stopped = [] stopped = []
_ensure_accounts_loaded(user_id)
if not safe_get_user_accounts_snapshot(user_id):
load_user_accounts(user_id)
for account_id in account_ids: for account_id in account_ids:
account = safe_get_account(user_id, account_id) account = _get_user_account(user_id, account_id)
if not account: if (not account) or (not account.is_running):
continue
if not account.is_running:
continue continue
account.should_stop = True account.should_stop = True
account.status = "正在停止" account.status = "正在停止"
stopped.append(account_id) stopped.append(account_id)
try: if _cancel_pending_account_task(user_id, account_id):
scheduler = get_task_scheduler() account.status = "已停止"
if scheduler.cancel_pending_task(user_id=user_id, account_id=account_id): account.is_running = False
account.status = "已停止" safe_remove_task_status(account_id)
account.is_running = False
safe_remove_task_status(account_id)
except Exception:
pass
_emit("account_update", account.to_dict(), room=f"user_{user_id}") _emit_account_update(user_id, account)
return jsonify({"success": True, "stopped_count": len(stopped), "stopped": stopped}) return jsonify({"success": True, "stopped_count": len(stopped), "stopped": stopped})

View File

@@ -2,16 +2,20 @@
# -*- coding: utf-8 -*- # -*- coding: utf-8 -*-
from __future__ import annotations from __future__ import annotations
import base64
import random import random
import secrets import secrets
import threading
import time import time
import uuid
from io import BytesIO
import database import database
import email_service import email_service
from app_config import get_config from app_config import get_config
from app_logger import get_logger from app_logger import get_logger
from app_security import get_rate_limit_ip, require_ip_not_locked, validate_email, validate_password, validate_username from app_security import get_rate_limit_ip, require_ip_not_locked, validate_email, validate_password, validate_username
from flask import Blueprint, jsonify, redirect, render_template, request, url_for from flask import Blueprint, jsonify, request
from flask_login import login_required, login_user, logout_user from flask_login import login_required, login_user, logout_user
from routes.pages import render_app_spa_or_legacy from routes.pages import render_app_spa_or_legacy
from services.accounts_service import load_user_accounts from services.accounts_service import load_user_accounts
@@ -39,12 +43,162 @@ config = get_config()
api_auth_bp = Blueprint("api_auth", __name__) api_auth_bp = Blueprint("api_auth", __name__)
_CAPTCHA_FONT_PATHS = [
"/usr/share/fonts/truetype/liberation/LiberationSans-Bold.ttf",
"/usr/share/fonts/truetype/dejavu/DejaVuSans-Bold.ttf",
"/usr/share/fonts/truetype/freefont/FreeSansBold.ttf",
]
_CAPTCHA_FONT = None
_CAPTCHA_FONT_LOCK = threading.Lock()
def _get_json_payload() -> dict:
data = request.get_json(silent=True)
return data if isinstance(data, dict) else {}
def _load_captcha_font(image_font_module):
global _CAPTCHA_FONT
if _CAPTCHA_FONT is not None:
return _CAPTCHA_FONT
with _CAPTCHA_FONT_LOCK:
if _CAPTCHA_FONT is not None:
return _CAPTCHA_FONT
for font_path in _CAPTCHA_FONT_PATHS:
try:
_CAPTCHA_FONT = image_font_module.truetype(font_path, 42)
break
except Exception:
continue
if _CAPTCHA_FONT is None:
_CAPTCHA_FONT = image_font_module.load_default()
return _CAPTCHA_FONT
def _generate_captcha_image_data_uri(code: str) -> str:
from PIL import Image, ImageDraw, ImageFont
width, height = 160, 60
image = Image.new("RGB", (width, height), color=(255, 255, 255))
draw = ImageDraw.Draw(image)
for _ in range(6):
x1 = random.randint(0, width)
y1 = random.randint(0, height)
x2 = random.randint(0, width)
y2 = random.randint(0, height)
draw.line(
[(x1, y1), (x2, y2)],
fill=(random.randint(0, 200), random.randint(0, 200), random.randint(0, 200)),
width=1,
)
for _ in range(80):
x = random.randint(0, width)
y = random.randint(0, height)
draw.point((x, y), fill=(random.randint(0, 200), random.randint(0, 200), random.randint(0, 200)))
font = _load_captcha_font(ImageFont)
for i, char in enumerate(code):
x = 12 + i * 35 + random.randint(-3, 3)
y = random.randint(5, 12)
color = (random.randint(0, 150), random.randint(0, 150), random.randint(0, 150))
draw.text((x, y), char, font=font, fill=color)
buffer = BytesIO()
image.save(buffer, format="PNG")
img_base64 = base64.b64encode(buffer.getvalue()).decode("utf-8")
return f"data:image/png;base64,{img_base64}"
def _with_vip_suffix(message: str, auto_approve_enabled: bool, auto_approve_vip_days: int) -> str:
if auto_approve_enabled and auto_approve_vip_days > 0:
return f"{message},赠送{auto_approve_vip_days}天VIP"
return message
def _verify_common_captcha(client_ip: str, captcha_session: str, captcha_code: str):
success, message = safe_verify_and_consume_captcha(captcha_session, captcha_code)
if success:
return True, None
is_locked = record_failed_captcha(client_ip)
if is_locked:
return False, (jsonify({"error": "验证码错误次数过多,IP已被锁定1小时"}), 429)
return False, (jsonify({"error": message}), 400)
def _verify_login_captcha_if_needed(
*,
captcha_required: bool,
captcha_session: str,
captcha_code: str,
client_ip: str,
username_key: str,
):
if not captcha_required:
return True, None
if not captcha_session or not captcha_code:
return False, (jsonify({"error": "请填写验证码", "need_captcha": True}), 400)
success, message = safe_verify_and_consume_captcha(captcha_session, captcha_code)
if success:
return True, None
record_login_failure(client_ip, username_key)
return False, (jsonify({"error": message, "need_captcha": True}), 400)
def _send_password_reset_email_if_possible(email: str, username: str, user_id: int) -> None:
result = email_service.send_password_reset_email(email=email, username=username, user_id=user_id)
if not result["success"]:
logger.error(f"密码重置邮件发送失败: {result['error']}")
def _send_login_security_alert_if_needed(user: dict, username: str, client_ip: str) -> None:
try:
user_agent = request.headers.get("User-Agent", "")
context = database.record_login_context(user["id"], client_ip, user_agent)
if not context or (not context.get("new_ip") and not context.get("new_device")):
return
if not config.LOGIN_ALERT_ENABLED:
return
if not should_send_login_alert(user["id"], client_ip):
return
if not email_service.get_email_settings().get("login_alert_enabled", True):
return
user_info = database.get_user_by_id(user["id"]) or {}
if (not user_info.get("email")) or (not user_info.get("email_verified")):
return
if not database.get_user_email_notify(user["id"]):
return
email_service.send_security_alert_email(
email=user_info.get("email"),
username=user_info.get("username") or username,
ip_address=client_ip,
user_agent=user_agent,
new_ip=context.get("new_ip", False),
new_device=context.get("new_device", False),
user_id=user["id"],
)
except Exception:
pass
@api_auth_bp.route("/api/register", methods=["POST"]) @api_auth_bp.route("/api/register", methods=["POST"])
@require_ip_not_locked @require_ip_not_locked
def register(): def register():
"""用户注册""" """用户注册"""
data = request.json or {} data = _get_json_payload()
username = data.get("username", "").strip() username = data.get("username", "").strip()
password = data.get("password", "").strip() password = data.get("password", "").strip()
email = data.get("email", "").strip().lower() email = data.get("email", "").strip().lower()
@@ -67,12 +221,9 @@ def register():
if not allowed: if not allowed:
return jsonify({"error": error_msg}), 429 return jsonify({"error": error_msg}), 429
success, message = safe_verify_and_consume_captcha(captcha_session, captcha_code) captcha_ok, captcha_error_response = _verify_common_captcha(client_ip, captcha_session, captcha_code)
if not success: if not captcha_ok:
is_locked = record_failed_captcha(client_ip) return captcha_error_response
if is_locked:
return jsonify({"error": "验证码错误次数过多,IP已被锁定1小时"}), 429
return jsonify({"error": message}), 400
email_settings = email_service.get_email_settings() email_settings = email_service.get_email_settings()
email_verify_enabled = email_settings.get("register_verify_enabled", False) and email_settings.get("enabled", False) email_verify_enabled = email_settings.get("register_verify_enabled", False) and email_settings.get("enabled", False)
@@ -105,20 +256,22 @@ def register():
if email_verify_enabled and email: if email_verify_enabled and email:
result = email_service.send_register_verification_email(email=email, username=username, user_id=user_id) result = email_service.send_register_verification_email(email=email, username=username, user_id=user_id)
if result["success"]: if result["success"]:
message = "注册成功!验证邮件已发送(可直接登录,建议完成邮箱验证)" message = _with_vip_suffix(
if auto_approve_enabled and auto_approve_vip_days > 0: "注册成功!验证邮件已发送(可直接登录,建议完成邮箱验证)",
message += f",赠送{auto_approve_vip_days}天VIP" auto_approve_enabled,
auto_approve_vip_days,
)
return jsonify({"success": True, "message": message, "need_verify": True}) return jsonify({"success": True, "message": message, "need_verify": True})
logger.error(f"注册验证邮件发送失败: {result['error']}") logger.error(f"注册验证邮件发送失败: {result['error']}")
message = f"注册成功,但验证邮件发送失败({result['error']})。你仍可直接登录" message = _with_vip_suffix(
if auto_approve_enabled and auto_approve_vip_days > 0: f"注册成功,但验证邮件发送失败({result['error']})。你仍可直接登录",
message += f",赠送{auto_approve_vip_days}天VIP" auto_approve_enabled,
auto_approve_vip_days,
)
return jsonify({"success": True, "message": message, "need_verify": True}) return jsonify({"success": True, "message": message, "need_verify": True})
message = "注册成功!可直接登录" message = _with_vip_suffix("注册成功!可直接登录", auto_approve_enabled, auto_approve_vip_days)
if auto_approve_enabled and auto_approve_vip_days > 0:
message += f",赠送{auto_approve_vip_days}天VIP"
return jsonify({"success": True, "message": message}) return jsonify({"success": True, "message": message})
return jsonify({"error": "用户名已存在"}), 400 return jsonify({"error": "用户名已存在"}), 400
@@ -175,7 +328,7 @@ def verify_email(token):
@require_ip_not_locked @require_ip_not_locked
def resend_verify_email(): def resend_verify_email():
"""重发验证邮件""" """重发验证邮件"""
data = request.json or {} data = _get_json_payload()
email = data.get("email", "").strip().lower() email = data.get("email", "").strip().lower()
captcha_session = data.get("captcha_session", "") captcha_session = data.get("captcha_session", "")
captcha_code = data.get("captcha", "").strip() captcha_code = data.get("captcha", "").strip()
@@ -195,12 +348,9 @@ def resend_verify_email():
if not allowed: if not allowed:
return jsonify({"error": error_msg}), 429 return jsonify({"error": error_msg}), 429
success, message = safe_verify_and_consume_captcha(captcha_session, captcha_code) captcha_ok, captcha_error_response = _verify_common_captcha(client_ip, captcha_session, captcha_code)
if not success: if not captcha_ok:
is_locked = record_failed_captcha(client_ip) return captcha_error_response
if is_locked:
return jsonify({"error": "验证码错误次数过多,IP已被锁定1小时"}), 429
return jsonify({"error": message}), 400
user = database.get_user_by_email(email) user = database.get_user_by_email(email)
if not user: if not user:
@@ -235,7 +385,7 @@ def get_email_verify_status():
@require_ip_not_locked @require_ip_not_locked
def forgot_password(): def forgot_password():
"""发送密码重置邮件""" """发送密码重置邮件"""
data = request.json or {} data = _get_json_payload()
email = data.get("email", "").strip().lower() email = data.get("email", "").strip().lower()
username = data.get("username", "").strip() username = data.get("username", "").strip()
captcha_session = data.get("captcha_session", "") captcha_session = data.get("captcha_session", "")
@@ -263,12 +413,9 @@ def forgot_password():
if not allowed: if not allowed:
return jsonify({"error": error_msg}), 429 return jsonify({"error": error_msg}), 429
success, message = safe_verify_and_consume_captcha(captcha_session, captcha_code) captcha_ok, captcha_error_response = _verify_common_captcha(client_ip, captcha_session, captcha_code)
if not success: if not captcha_ok:
is_locked = record_failed_captcha(client_ip) return captcha_error_response
if is_locked:
return jsonify({"error": "验证码错误次数过多,IP已被锁定1小时"}), 429
return jsonify({"error": message}), 400
email_settings = email_service.get_email_settings() email_settings = email_service.get_email_settings()
if not email_settings.get("enabled", False): if not email_settings.get("enabled", False):
@@ -293,20 +440,16 @@ def forgot_password():
if not allowed: if not allowed:
return jsonify({"error": error_msg}), 429 return jsonify({"error": error_msg}), 429
result = email_service.send_password_reset_email( _send_password_reset_email_if_possible(
email=bound_email, email=bound_email,
username=user["username"], username=user["username"],
user_id=user["id"], user_id=user["id"],
) )
if not result["success"]:
logger.error(f"密码重置邮件发送失败: {result['error']}")
return jsonify({"success": True, "message": "如果该账号已绑定邮箱,您将收到密码重置邮件"}) return jsonify({"success": True, "message": "如果该账号已绑定邮箱,您将收到密码重置邮件"})
user = database.get_user_by_email(email) user = database.get_user_by_email(email)
if user and user.get("status") == "approved": if user and user.get("status") == "approved":
result = email_service.send_password_reset_email(email=email, username=user["username"], user_id=user["id"]) _send_password_reset_email_if_possible(email=email, username=user["username"], user_id=user["id"])
if not result["success"]:
logger.error(f"密码重置邮件发送失败: {result['error']}")
return jsonify({"success": True, "message": "如果该邮箱已注册,您将收到密码重置邮件"}) return jsonify({"success": True, "message": "如果该邮箱已注册,您将收到密码重置邮件"})
@@ -331,7 +474,7 @@ def reset_password_page(token):
@api_auth_bp.route("/api/reset-password-confirm", methods=["POST"]) @api_auth_bp.route("/api/reset-password-confirm", methods=["POST"])
def reset_password_confirm(): def reset_password_confirm():
"""确认密码重置""" """确认密码重置"""
data = request.json or {} data = _get_json_payload()
token = data.get("token", "").strip() token = data.get("token", "").strip()
new_password = data.get("new_password", "").strip() new_password = data.get("new_password", "").strip()
@@ -356,67 +499,15 @@ def reset_password_confirm():
@api_auth_bp.route("/api/generate_captcha", methods=["POST"]) @api_auth_bp.route("/api/generate_captcha", methods=["POST"])
def generate_captcha(): def generate_captcha():
"""生成4位数字验证码图片""" """生成4位数字验证码图片"""
import base64
import uuid
from io import BytesIO
session_id = str(uuid.uuid4()) session_id = str(uuid.uuid4())
code = "".join(str(secrets.randbelow(10)) for _ in range(4))
code = "".join([str(secrets.randbelow(10)) for _ in range(4)])
safe_set_captcha(session_id, {"code": code, "expire_time": time.time() + 300, "failed_attempts": 0}) safe_set_captcha(session_id, {"code": code, "expire_time": time.time() + 300, "failed_attempts": 0})
safe_cleanup_expired_captcha() safe_cleanup_expired_captcha()
try: try:
from PIL import Image, ImageDraw, ImageFont captcha_image = _generate_captcha_image_data_uri(code)
import io return jsonify({"session_id": session_id, "captcha_image": captcha_image})
width, height = 160, 60
image = Image.new("RGB", (width, height), color=(255, 255, 255))
draw = ImageDraw.Draw(image)
for _ in range(6):
x1 = random.randint(0, width)
y1 = random.randint(0, height)
x2 = random.randint(0, width)
y2 = random.randint(0, height)
draw.line(
[(x1, y1), (x2, y2)],
fill=(random.randint(0, 200), random.randint(0, 200), random.randint(0, 200)),
width=1,
)
for _ in range(80):
x = random.randint(0, width)
y = random.randint(0, height)
draw.point((x, y), fill=(random.randint(0, 200), random.randint(0, 200), random.randint(0, 200)))
font = None
font_paths = [
"/usr/share/fonts/truetype/liberation/LiberationSans-Bold.ttf",
"/usr/share/fonts/truetype/dejavu/DejaVuSans-Bold.ttf",
"/usr/share/fonts/truetype/freefont/FreeSansBold.ttf",
]
for font_path in font_paths:
try:
font = ImageFont.truetype(font_path, 42)
break
except Exception:
continue
if font is None:
font = ImageFont.load_default()
for i, char in enumerate(code):
x = 12 + i * 35 + random.randint(-3, 3)
y = random.randint(5, 12)
color = (random.randint(0, 150), random.randint(0, 150), random.randint(0, 150))
draw.text((x, y), char, font=font, fill=color)
buffer = io.BytesIO()
image.save(buffer, format="PNG")
img_base64 = base64.b64encode(buffer.getvalue()).decode("utf-8")
return jsonify({"session_id": session_id, "captcha_image": f"data:image/png;base64,{img_base64}"})
except ImportError as e: except ImportError as e:
logger.error(f"PIL库未安装验证码功能不可用: {e}") logger.error(f"PIL库未安装验证码功能不可用: {e}")
safe_delete_captcha(session_id) safe_delete_captcha(session_id)
@@ -427,7 +518,7 @@ def generate_captcha():
@require_ip_not_locked @require_ip_not_locked
def login(): def login():
"""用户登录""" """用户登录"""
data = request.json or {} data = _get_json_payload()
username = data.get("username", "").strip() username = data.get("username", "").strip()
password = data.get("password", "").strip() password = data.get("password", "").strip()
captcha_session = data.get("captcha_session", "") captcha_session = data.get("captcha_session", "")
@@ -452,13 +543,15 @@ def login():
return jsonify({"error": error_msg, "need_captcha": True}), 429 return jsonify({"error": error_msg, "need_captcha": True}), 429
captcha_required = check_login_captcha_required(client_ip, username_key) or scan_locked or bool(need_captcha) captcha_required = check_login_captcha_required(client_ip, username_key) or scan_locked or bool(need_captcha)
if captcha_required: captcha_ok, captcha_error_response = _verify_login_captcha_if_needed(
if not captcha_session or not captcha_code: captcha_required=captcha_required,
return jsonify({"error": "请填写验证码", "need_captcha": True}), 400 captcha_session=captcha_session,
success, message = safe_verify_and_consume_captcha(captcha_session, captcha_code) captcha_code=captcha_code,
if not success: client_ip=client_ip,
record_login_failure(client_ip, username_key) username_key=username_key,
return jsonify({"error": message, "need_captcha": True}), 400 )
if not captcha_ok:
return captcha_error_response
user = database.verify_user(username, password) user = database.verify_user(username, password)
if not user: if not user:
@@ -476,29 +569,7 @@ def login():
login_user(user_obj) login_user(user_obj)
load_user_accounts(user["id"]) load_user_accounts(user["id"])
try: _send_login_security_alert_if_needed(user=user, username=username, client_ip=client_ip)
user_agent = request.headers.get("User-Agent", "")
context = database.record_login_context(user["id"], client_ip, user_agent)
if context and (context.get("new_ip") or context.get("new_device")):
if (
config.LOGIN_ALERT_ENABLED
and should_send_login_alert(user["id"], client_ip)
and email_service.get_email_settings().get("login_alert_enabled", True)
):
user_info = database.get_user_by_id(user["id"]) or {}
if user_info.get("email") and user_info.get("email_verified"):
if database.get_user_email_notify(user["id"]):
email_service.send_security_alert_email(
email=user_info.get("email"),
username=user_info.get("username") or username,
ip_address=client_ip,
user_agent=user_agent,
new_ip=context.get("new_ip", False),
new_device=context.get("new_device", False),
user_id=user["id"],
)
except Exception:
pass
return jsonify({"success": True}) return jsonify({"success": True})

View File

@@ -2,7 +2,11 @@
# -*- coding: utf-8 -*- # -*- coding: utf-8 -*-
from __future__ import annotations from __future__ import annotations
import json
import re import re
import threading
import time as time_mod
import uuid
import database import database
from flask import Blueprint, jsonify, request from flask import Blueprint, jsonify, request
@@ -17,6 +21,13 @@ api_schedules_bp = Blueprint("api_schedules", __name__)
_HHMM_RE = re.compile(r"^(\d{1,2}):(\d{2})$") _HHMM_RE = re.compile(r"^(\d{1,2}):(\d{2})$")
def _request_json(default=None):
if default is None:
default = {}
data = request.get_json(silent=True)
return data if isinstance(data, dict) else default
def _normalize_hhmm(value: object) -> str | None: def _normalize_hhmm(value: object) -> str | None:
match = _HHMM_RE.match(str(value or "").strip()) match = _HHMM_RE.match(str(value or "").strip())
if not match: if not match:
@@ -28,18 +39,53 @@ def _normalize_hhmm(value: object) -> str | None:
return f"{hour:02d}:{minute:02d}" return f"{hour:02d}:{minute:02d}"
def _normalize_random_delay(value) -> tuple[int | None, str | None]:
try:
normalized = int(value or 0)
except Exception:
return None, "random_delay必须是0或1"
if normalized not in (0, 1):
return None, "random_delay必须是0或1"
return normalized, None
def _parse_schedule_account_ids(raw_value) -> list:
try:
parsed = json.loads(raw_value or "[]")
except (json.JSONDecodeError, TypeError):
return []
return parsed if isinstance(parsed, list) else []
def _get_owned_schedule_or_error(schedule_id: int):
schedule = database.get_schedule_by_id(schedule_id)
if not schedule:
return None, (jsonify({"error": "定时任务不存在"}), 404)
if schedule.get("user_id") != current_user.id:
return None, (jsonify({"error": "无权访问"}), 403)
return schedule, None
def _ensure_user_accounts_loaded(user_id: int) -> None:
if safe_get_user_accounts_snapshot(user_id):
return
load_user_accounts(user_id)
def _parse_browse_type_or_error(raw_value, *, default=BROWSE_TYPE_SHOULD_READ):
browse_type = validate_browse_type(raw_value, default=default)
if not browse_type:
return None, (jsonify({"error": "浏览类型无效"}), 400)
return browse_type, None
@api_schedules_bp.route("/api/schedules", methods=["GET"]) @api_schedules_bp.route("/api/schedules", methods=["GET"])
@login_required @login_required
def get_user_schedules_api(): def get_user_schedules_api():
"""获取当前用户的所有定时任务""" """获取当前用户的所有定时任务"""
schedules = database.get_user_schedules(current_user.id) schedules = database.get_user_schedules(current_user.id)
import json for schedule in schedules:
schedule["account_ids"] = _parse_schedule_account_ids(schedule.get("account_ids"))
for s in schedules:
try:
s["account_ids"] = json.loads(s.get("account_ids", "[]") or "[]")
except (json.JSONDecodeError, TypeError):
s["account_ids"] = []
return jsonify(schedules) return jsonify(schedules)
@@ -47,23 +93,26 @@ def get_user_schedules_api():
@login_required @login_required
def create_user_schedule_api(): def create_user_schedule_api():
"""创建用户定时任务""" """创建用户定时任务"""
data = request.json or {} data = _request_json()
name = data.get("name", "我的定时任务") name = data.get("name", "我的定时任务")
schedule_time = data.get("schedule_time", "08:00") schedule_time = data.get("schedule_time", "08:00")
weekdays = data.get("weekdays", "1,2,3,4,5") weekdays = data.get("weekdays", "1,2,3,4,5")
browse_type = validate_browse_type(data.get("browse_type", BROWSE_TYPE_SHOULD_READ), default=BROWSE_TYPE_SHOULD_READ)
if not browse_type: browse_type, browse_error = _parse_browse_type_or_error(data.get("browse_type", BROWSE_TYPE_SHOULD_READ))
return jsonify({"error": "浏览类型无效"}), 400 if browse_error:
return browse_error
enable_screenshot = data.get("enable_screenshot", 1) enable_screenshot = data.get("enable_screenshot", 1)
random_delay = int(data.get("random_delay", 0) or 0) random_delay, delay_error = _normalize_random_delay(data.get("random_delay", 0))
if delay_error:
return jsonify({"error": delay_error}), 400
account_ids = data.get("account_ids", []) account_ids = data.get("account_ids", [])
normalized_time = _normalize_hhmm(schedule_time) normalized_time = _normalize_hhmm(schedule_time)
if not normalized_time: if not normalized_time:
return jsonify({"error": "时间格式不正确,应为 HH:MM"}), 400 return jsonify({"error": "时间格式不正确,应为 HH:MM"}), 400
if random_delay not in (0, 1):
return jsonify({"error": "random_delay必须是0或1"}), 400
schedule_id = database.create_user_schedule( schedule_id = database.create_user_schedule(
user_id=current_user.id, user_id=current_user.id,
@@ -85,18 +134,11 @@ def create_user_schedule_api():
@login_required @login_required
def get_schedule_detail_api(schedule_id): def get_schedule_detail_api(schedule_id):
"""获取定时任务详情""" """获取定时任务详情"""
schedule = database.get_schedule_by_id(schedule_id) schedule, error_response = _get_owned_schedule_or_error(schedule_id)
if not schedule: if error_response:
return jsonify({"error": "定时任务不存在"}), 404 return error_response
if schedule["user_id"] != current_user.id:
return jsonify({"error": "无权访问"}), 403
import json schedule["account_ids"] = _parse_schedule_account_ids(schedule.get("account_ids"))
try:
schedule["account_ids"] = json.loads(schedule.get("account_ids", "[]") or "[]")
except (json.JSONDecodeError, TypeError):
schedule["account_ids"] = []
return jsonify(schedule) return jsonify(schedule)
@@ -104,14 +146,12 @@ def get_schedule_detail_api(schedule_id):
@login_required @login_required
def update_schedule_api(schedule_id): def update_schedule_api(schedule_id):
"""更新定时任务""" """更新定时任务"""
schedule = database.get_schedule_by_id(schedule_id) _, error_response = _get_owned_schedule_or_error(schedule_id)
if not schedule: if error_response:
return jsonify({"error": "定时任务不存在"}), 404 return error_response
if schedule["user_id"] != current_user.id:
return jsonify({"error": "无权访问"}), 403
data = request.json or {} data = _request_json()
allowed_fields = [ allowed_fields = {
"name", "name",
"schedule_time", "schedule_time",
"weekdays", "weekdays",
@@ -120,27 +160,26 @@ def update_schedule_api(schedule_id):
"random_delay", "random_delay",
"account_ids", "account_ids",
"enabled", "enabled",
] }
update_data = {key: value for key, value in data.items() if key in allowed_fields}
update_data = {k: v for k, v in data.items() if k in allowed_fields}
if "schedule_time" in update_data: if "schedule_time" in update_data:
normalized_time = _normalize_hhmm(update_data["schedule_time"]) normalized_time = _normalize_hhmm(update_data["schedule_time"])
if not normalized_time: if not normalized_time:
return jsonify({"error": "时间格式不正确,应为 HH:MM"}), 400 return jsonify({"error": "时间格式不正确,应为 HH:MM"}), 400
update_data["schedule_time"] = normalized_time update_data["schedule_time"] = normalized_time
if "random_delay" in update_data: if "random_delay" in update_data:
try: random_delay, delay_error = _normalize_random_delay(update_data.get("random_delay"))
update_data["random_delay"] = int(update_data.get("random_delay") or 0) if delay_error:
except Exception: return jsonify({"error": delay_error}), 400
return jsonify({"error": "random_delay必须是0或1"}), 400 update_data["random_delay"] = random_delay
if update_data["random_delay"] not in (0, 1):
return jsonify({"error": "random_delay必须是0或1"}), 400
if "browse_type" in update_data: if "browse_type" in update_data:
normalized = validate_browse_type(update_data.get("browse_type"), default=BROWSE_TYPE_SHOULD_READ) normalized_browse_type, browse_error = _parse_browse_type_or_error(update_data.get("browse_type"))
if not normalized: if browse_error:
return jsonify({"error": "浏览类型无效"}), 400 return browse_error
update_data["browse_type"] = normalized update_data["browse_type"] = normalized_browse_type
success = database.update_user_schedule(schedule_id, **update_data) success = database.update_user_schedule(schedule_id, **update_data)
if success: if success:
@@ -152,11 +191,9 @@ def update_schedule_api(schedule_id):
@login_required @login_required
def delete_schedule_api(schedule_id): def delete_schedule_api(schedule_id):
"""删除定时任务""" """删除定时任务"""
schedule = database.get_schedule_by_id(schedule_id) _, error_response = _get_owned_schedule_or_error(schedule_id)
if not schedule: if error_response:
return jsonify({"error": "定时任务不存在"}), 404 return error_response
if schedule["user_id"] != current_user.id:
return jsonify({"error": "无权访问"}), 403
success = database.delete_user_schedule(schedule_id) success = database.delete_user_schedule(schedule_id)
if success: if success:
@@ -168,13 +205,11 @@ def delete_schedule_api(schedule_id):
@login_required @login_required
def toggle_schedule_api(schedule_id): def toggle_schedule_api(schedule_id):
"""启用/禁用定时任务""" """启用/禁用定时任务"""
schedule = database.get_schedule_by_id(schedule_id) schedule, error_response = _get_owned_schedule_or_error(schedule_id)
if not schedule: if error_response:
return jsonify({"error": "定时任务不存在"}), 404 return error_response
if schedule["user_id"] != current_user.id:
return jsonify({"error": "无权访问"}), 403
data = request.json data = _request_json()
enabled = data.get("enabled", not schedule["enabled"]) enabled = data.get("enabled", not schedule["enabled"])
success = database.toggle_user_schedule(schedule_id, enabled) success = database.toggle_user_schedule(schedule_id, enabled)
@@ -187,22 +222,11 @@ def toggle_schedule_api(schedule_id):
@login_required @login_required
def run_schedule_now_api(schedule_id): def run_schedule_now_api(schedule_id):
"""立即执行定时任务""" """立即执行定时任务"""
import json schedule, error_response = _get_owned_schedule_or_error(schedule_id)
import threading if error_response:
import time as time_mod return error_response
import uuid
schedule = database.get_schedule_by_id(schedule_id)
if not schedule:
return jsonify({"error": "定时任务不存在"}), 404
if schedule["user_id"] != current_user.id:
return jsonify({"error": "无权访问"}), 403
try:
account_ids = json.loads(schedule.get("account_ids", "[]") or "[]")
except (json.JSONDecodeError, TypeError):
account_ids = []
account_ids = _parse_schedule_account_ids(schedule.get("account_ids"))
if not account_ids: if not account_ids:
return jsonify({"error": "没有配置账号"}), 400 return jsonify({"error": "没有配置账号"}), 400
@@ -210,8 +234,7 @@ def run_schedule_now_api(schedule_id):
browse_type = normalize_browse_type(schedule.get("browse_type", BROWSE_TYPE_SHOULD_READ)) browse_type = normalize_browse_type(schedule.get("browse_type", BROWSE_TYPE_SHOULD_READ))
enable_screenshot = schedule["enable_screenshot"] enable_screenshot = schedule["enable_screenshot"]
if not safe_get_user_accounts_snapshot(user_id): _ensure_user_accounts_loaded(user_id)
load_user_accounts(user_id)
from services.state import safe_create_batch, safe_finalize_batch_after_dispatch from services.state import safe_create_batch, safe_finalize_batch_after_dispatch
from services.task_batches import _send_batch_task_email_if_configured from services.task_batches import _send_batch_task_email_if_configured
@@ -250,6 +273,7 @@ def run_schedule_now_api(schedule_id):
if remaining["done"] or remaining["count"] > 0: if remaining["done"] or remaining["count"] > 0:
return return
remaining["done"] = True remaining["done"] = True
execution_duration = int(time_mod.time() - execution_start_time) execution_duration = int(time_mod.time() - execution_start_time)
database.update_schedule_execution_log( database.update_schedule_execution_log(
log_id, log_id,
@@ -260,19 +284,17 @@ def run_schedule_now_api(schedule_id):
status="completed", status="completed",
) )
task_source = f"user_scheduled:{batch_id}"
for account_id in account_ids: for account_id in account_ids:
account = safe_get_account(user_id, account_id) account = safe_get_account(user_id, account_id)
if not account: if (not account) or account.is_running:
skipped_count += 1
continue
if account.is_running:
skipped_count += 1 skipped_count += 1
continue continue
task_source = f"user_scheduled:{batch_id}"
with completion_lock: with completion_lock:
remaining["count"] += 1 remaining["count"] += 1
ok, msg = submit_account_task(
ok, _ = submit_account_task(
user_id=user_id, user_id=user_id,
account_id=account_id, account_id=account_id,
browse_type=browse_type, browse_type=browse_type,

View File

@@ -4,6 +4,7 @@ from __future__ import annotations
import os import os
from datetime import datetime from datetime import datetime
from typing import Iterator
import database import database
from app_config import get_config from app_config import get_config
@@ -15,41 +16,67 @@ from services.time_utils import BEIJING_TZ
config = get_config() config = get_config()
SCREENSHOTS_DIR = config.SCREENSHOTS_DIR SCREENSHOTS_DIR = config.SCREENSHOTS_DIR
_IMAGE_EXTENSIONS = (".png", ".jpg", ".jpeg")
api_screenshots_bp = Blueprint("api_screenshots", __name__) api_screenshots_bp = Blueprint("api_screenshots", __name__)
def _get_user_prefix(user_id: int) -> str:
user_info = database.get_user_by_id(user_id)
return user_info["username"] if user_info else f"user{user_id}"
def _is_user_screenshot(filename: str, username_prefix: str) -> bool:
return filename.startswith(username_prefix + "_") and filename.lower().endswith(_IMAGE_EXTENSIONS)
def _iter_user_screenshot_entries(username_prefix: str) -> Iterator[os.DirEntry]:
if not os.path.exists(SCREENSHOTS_DIR):
return
with os.scandir(SCREENSHOTS_DIR) as entries:
for entry in entries:
if (not entry.is_file()) or (not _is_user_screenshot(entry.name, username_prefix)):
continue
yield entry
def _build_display_name(filename: str) -> str:
base_name, ext = filename.rsplit(".", 1)
parts = base_name.split("_", 1)
if len(parts) > 1:
return f"{parts[1]}.{ext}"
return filename
@api_screenshots_bp.route("/api/screenshots", methods=["GET"]) @api_screenshots_bp.route("/api/screenshots", methods=["GET"])
@login_required @login_required
def get_screenshots(): def get_screenshots():
"""获取当前用户的截图列表""" """获取当前用户的截图列表"""
user_id = current_user.id user_id = current_user.id
user_info = database.get_user_by_id(user_id) username_prefix = _get_user_prefix(user_id)
username_prefix = user_info["username"] if user_info else f"user{user_id}"
try: try:
screenshots = [] screenshots = []
if os.path.exists(SCREENSHOTS_DIR): for entry in _iter_user_screenshot_entries(username_prefix):
for filename in os.listdir(SCREENSHOTS_DIR): filename = entry.name
if filename.lower().endswith((".png", ".jpg", ".jpeg")) and filename.startswith(username_prefix + "_"): stat = entry.stat()
filepath = os.path.join(SCREENSHOTS_DIR, filename) created_time = datetime.fromtimestamp(stat.st_mtime, tz=BEIJING_TZ)
stat = os.stat(filepath)
created_time = datetime.fromtimestamp(stat.st_mtime, tz=BEIJING_TZ) screenshots.append(
parts = filename.rsplit(".", 1)[0].split("_", 1) {
if len(parts) > 1: "filename": filename,
display_name = parts[1] + "." + filename.rsplit(".", 1)[1] "display_name": _build_display_name(filename),
else: "size": stat.st_size,
display_name = filename "created": created_time.strftime("%Y-%m-%d %H:%M:%S"),
"_created_ts": stat.st_mtime,
}
)
screenshots.sort(key=lambda item: item.get("_created_ts", 0), reverse=True)
for item in screenshots:
item.pop("_created_ts", None)
screenshots.append(
{
"filename": filename,
"display_name": display_name,
"size": stat.st_size,
"created": created_time.strftime("%Y-%m-%d %H:%M:%S"),
}
)
screenshots.sort(key=lambda x: x["created"], reverse=True)
return jsonify(screenshots) return jsonify(screenshots)
except Exception as e: except Exception as e:
return jsonify({"error": str(e)}), 500 return jsonify({"error": str(e)}), 500
@@ -60,10 +87,9 @@ def get_screenshots():
def serve_screenshot(filename): def serve_screenshot(filename):
"""提供截图文件访问""" """提供截图文件访问"""
user_id = current_user.id user_id = current_user.id
user_info = database.get_user_by_id(user_id) username_prefix = _get_user_prefix(user_id)
username_prefix = user_info["username"] if user_info else f"user{user_id}"
if not filename.startswith(username_prefix + "_"): if not _is_user_screenshot(filename, username_prefix):
return jsonify({"error": "无权访问"}), 403 return jsonify({"error": "无权访问"}), 403
if not is_safe_path(SCREENSHOTS_DIR, filename): if not is_safe_path(SCREENSHOTS_DIR, filename):
@@ -77,12 +103,14 @@ def serve_screenshot(filename):
def delete_screenshot(filename): def delete_screenshot(filename):
"""删除指定截图""" """删除指定截图"""
user_id = current_user.id user_id = current_user.id
user_info = database.get_user_by_id(user_id) username_prefix = _get_user_prefix(user_id)
username_prefix = user_info["username"] if user_info else f"user{user_id}"
if not filename.startswith(username_prefix + "_"): if not _is_user_screenshot(filename, username_prefix):
return jsonify({"error": "无权删除"}), 403 return jsonify({"error": "无权删除"}), 403
if not is_safe_path(SCREENSHOTS_DIR, filename):
return jsonify({"error": "非法路径"}), 403
try: try:
filepath = os.path.join(SCREENSHOTS_DIR, filename) filepath = os.path.join(SCREENSHOTS_DIR, filename)
if os.path.exists(filepath): if os.path.exists(filepath):
@@ -99,19 +127,15 @@ def delete_screenshot(filename):
def clear_all_screenshots(): def clear_all_screenshots():
"""清空当前用户的所有截图""" """清空当前用户的所有截图"""
user_id = current_user.id user_id = current_user.id
user_info = database.get_user_by_id(user_id) username_prefix = _get_user_prefix(user_id)
username_prefix = user_info["username"] if user_info else f"user{user_id}"
try: try:
deleted_count = 0 deleted_count = 0
if os.path.exists(SCREENSHOTS_DIR): for entry in _iter_user_screenshot_entries(username_prefix):
for filename in os.listdir(SCREENSHOTS_DIR): os.remove(entry.path)
if filename.lower().endswith((".png", ".jpg", ".jpeg")) and filename.startswith(username_prefix + "_"): deleted_count += 1
filepath = os.path.join(SCREENSHOTS_DIR, filename)
os.remove(filepath)
deleted_count += 1
log_to_client(f"清理了 {deleted_count} 个截图文件", user_id) log_to_client(f"清理了 {deleted_count} 个截图文件", user_id)
return jsonify({"success": True, "deleted": deleted_count}) return jsonify({"success": True, "deleted": deleted_count})
except Exception as e: except Exception as e:
return jsonify({"error": str(e)}), 500 return jsonify({"error": str(e)}), 500

View File

@@ -10,12 +10,96 @@ from flask import Blueprint, jsonify, request
from flask_login import current_user, login_required from flask_login import current_user, login_required
from routes.pages import render_app_spa_or_legacy from routes.pages import render_app_spa_or_legacy
from services.state import check_email_rate_limit, check_ip_request_rate, safe_iter_task_status_items from services.state import check_email_rate_limit, check_ip_request_rate, safe_iter_task_status_items
from services.tasks import get_task_scheduler
logger = get_logger("app") logger = get_logger("app")
api_user_bp = Blueprint("api_user", __name__) api_user_bp = Blueprint("api_user", __name__)
def _get_current_user_record():
return database.get_user_by_id(current_user.id)
def _get_current_user_or_404():
user = _get_current_user_record()
if user:
return user, None
return None, (jsonify({"error": "用户不存在"}), 404)
def _get_current_username(*, fallback: str) -> str:
user = _get_current_user_record()
username = (user or {}).get("username", "")
return username if username else fallback
def _coerce_binary_flag(value, *, field_label: str):
if isinstance(value, bool):
value = 1 if value else 0
try:
value = int(value)
except Exception:
return None, f"{field_label}必须是0或1"
if value not in (0, 1):
return None, f"{field_label}必须是0或1"
return value, None
def _check_bind_email_rate_limits(email: str):
client_ip = get_rate_limit_ip()
allowed, error_msg = check_ip_request_rate(client_ip, "email")
if not allowed:
return False, error_msg, 429
allowed, error_msg = check_email_rate_limit(email, "bind_email")
if not allowed:
return False, error_msg, 429
return True, "", 200
def _render_verify_bind_failed(*, title: str, error_message: str):
spa_initial_state = {
"page": "verify_result",
"success": False,
"title": title,
"error_message": error_message,
"primary_label": "返回登录",
"primary_url": "/login",
}
return render_app_spa_or_legacy(
"verify_failed.html",
legacy_context={"error_message": error_message},
spa_initial_state=spa_initial_state,
)
def _render_verify_bind_success(email: str):
spa_initial_state = {
"page": "verify_result",
"success": True,
"title": "邮箱绑定成功",
"message": f"邮箱 {email} 已成功绑定到您的账号!",
"primary_label": "返回登录",
"primary_url": "/login",
"redirect_url": "/login",
"redirect_seconds": 5,
}
return render_app_spa_or_legacy("verify_success.html", spa_initial_state=spa_initial_state)
def _get_current_running_count(user_id: int) -> int:
try:
queue_snapshot = get_task_scheduler().get_queue_state_snapshot() or {}
running_by_user = queue_snapshot.get("running_by_user") or {}
return int(running_by_user.get(int(user_id), running_by_user.get(str(user_id), 0)) or 0)
except Exception:
current_running = 0
for _, info in safe_iter_task_status_items():
if info.get("user_id") == user_id and info.get("status") == "运行中":
current_running += 1
return current_running
@api_user_bp.route("/api/announcements/active", methods=["GET"]) @api_user_bp.route("/api/announcements/active", methods=["GET"])
@login_required @login_required
def get_active_announcement(): def get_active_announcement():
@@ -77,8 +161,7 @@ def submit_feedback():
if len(description) > 2000: if len(description) > 2000:
return jsonify({"error": "描述不能超过2000个字符"}), 400 return jsonify({"error": "描述不能超过2000个字符"}), 400
user_info = database.get_user_by_id(current_user.id) username = _get_current_username(fallback=f"用户{current_user.id}")
username = user_info["username"] if user_info else f"用户{current_user.id}"
feedback_id = database.create_bug_feedback( feedback_id = database.create_bug_feedback(
user_id=current_user.id, user_id=current_user.id,
@@ -104,8 +187,7 @@ def get_my_feedbacks():
def get_current_user_vip(): def get_current_user_vip():
"""获取当前用户VIP信息""" """获取当前用户VIP信息"""
vip_info = database.get_user_vip_info(current_user.id) vip_info = database.get_user_vip_info(current_user.id)
user_info = database.get_user_by_id(current_user.id) vip_info["username"] = _get_current_username(fallback="Unknown")
vip_info["username"] = user_info["username"] if user_info else "Unknown"
return jsonify(vip_info) return jsonify(vip_info)
@@ -124,9 +206,9 @@ def change_user_password():
if not is_valid: if not is_valid:
return jsonify({"error": error_msg}), 400 return jsonify({"error": error_msg}), 400
user = database.get_user_by_id(current_user.id) user, error_response = _get_current_user_or_404()
if not user: if error_response:
return jsonify({"error": "用户不存在"}), 404 return error_response
username = user.get("username", "") username = user.get("username", "")
if not username or not database.verify_user(username, current_password): if not username or not database.verify_user(username, current_password):
@@ -141,9 +223,9 @@ def change_user_password():
@login_required @login_required
def get_user_email(): def get_user_email():
"""获取当前用户的邮箱信息""" """获取当前用户的邮箱信息"""
user = database.get_user_by_id(current_user.id) user, error_response = _get_current_user_or_404()
if not user: if error_response:
return jsonify({"error": "用户不存在"}), 404 return error_response
return jsonify({"email": user.get("email", ""), "email_verified": user.get("email_verified", False)}) return jsonify({"email": user.get("email", ""), "email_verified": user.get("email_verified", False)})
@@ -172,14 +254,9 @@ def update_user_kdocs_settings():
return jsonify({"error": "县区长度不能超过50"}), 400 return jsonify({"error": "县区长度不能超过50"}), 400
if kdocs_auto_upload is not None: if kdocs_auto_upload is not None:
if isinstance(kdocs_auto_upload, bool): kdocs_auto_upload, parse_error = _coerce_binary_flag(kdocs_auto_upload, field_label="自动上传开关")
kdocs_auto_upload = 1 if kdocs_auto_upload else 0 if parse_error:
try: return jsonify({"error": parse_error}), 400
kdocs_auto_upload = int(kdocs_auto_upload)
except Exception:
return jsonify({"error": "自动上传开关必须是0或1"}), 400
if kdocs_auto_upload not in (0, 1):
return jsonify({"error": "自动上传开关必须是0或1"}), 400
if not database.update_user_kdocs_settings( if not database.update_user_kdocs_settings(
current_user.id, current_user.id,
@@ -207,13 +284,9 @@ def bind_user_email():
if not is_valid: if not is_valid:
return jsonify({"error": error_msg}), 400 return jsonify({"error": error_msg}), 400
client_ip = get_rate_limit_ip() allowed, error_msg, status_code = _check_bind_email_rate_limits(email)
allowed, error_msg = check_ip_request_rate(client_ip, "email")
if not allowed: if not allowed:
return jsonify({"error": error_msg}), 429 return jsonify({"error": error_msg}), status_code
allowed, error_msg = check_email_rate_limit(email, "bind_email")
if not allowed:
return jsonify({"error": error_msg}), 429
settings = email_service.get_email_settings() settings = email_service.get_email_settings()
if not settings.get("enabled", False): if not settings.get("enabled", False):
@@ -223,9 +296,9 @@ def bind_user_email():
if existing_user and existing_user["id"] != current_user.id: if existing_user and existing_user["id"] != current_user.id:
return jsonify({"error": "该邮箱已被其他用户绑定"}), 400 return jsonify({"error": "该邮箱已被其他用户绑定"}), 400
user = database.get_user_by_id(current_user.id) user, error_response = _get_current_user_or_404()
if not user: if error_response:
return jsonify({"error": "用户不存在"}), 404 return error_response
if user.get("email") == email and user.get("email_verified"): if user.get("email") == email and user.get("email_verified"):
return jsonify({"error": "该邮箱已绑定并验证"}), 400 return jsonify({"error": "该邮箱已绑定并验证"}), 400
@@ -247,56 +320,20 @@ def verify_bind_email(token):
email = result["email"] email = result["email"]
if database.update_user_email(user_id, email, verified=True): if database.update_user_email(user_id, email, verified=True):
spa_initial_state = { return _render_verify_bind_success(email)
"page": "verify_result",
"success": True,
"title": "邮箱绑定成功",
"message": f"邮箱 {email} 已成功绑定到您的账号!",
"primary_label": "返回登录",
"primary_url": "/login",
"redirect_url": "/login",
"redirect_seconds": 5,
}
return render_app_spa_or_legacy("verify_success.html", spa_initial_state=spa_initial_state)
error_message = "邮箱绑定失败,请重试" return _render_verify_bind_failed(title="绑定失败", error_message="邮箱绑定失败,请重试")
spa_initial_state = {
"page": "verify_result",
"success": False,
"title": "绑定失败",
"error_message": error_message,
"primary_label": "返回登录",
"primary_url": "/login",
}
return render_app_spa_or_legacy(
"verify_failed.html",
legacy_context={"error_message": error_message},
spa_initial_state=spa_initial_state,
)
error_message = "验证链接已过期或无效,请重新发送验证邮件" return _render_verify_bind_failed(title="链接无效", error_message="验证链接已过期或无效,请重新发送验证邮件")
spa_initial_state = {
"page": "verify_result",
"success": False,
"title": "链接无效",
"error_message": error_message,
"primary_label": "返回登录",
"primary_url": "/login",
}
return render_app_spa_or_legacy(
"verify_failed.html",
legacy_context={"error_message": error_message},
spa_initial_state=spa_initial_state,
)
@api_user_bp.route("/api/user/unbind-email", methods=["POST"]) @api_user_bp.route("/api/user/unbind-email", methods=["POST"])
@login_required @login_required
def unbind_user_email(): def unbind_user_email():
"""解绑用户邮箱""" """解绑用户邮箱"""
user = database.get_user_by_id(current_user.id) user, error_response = _get_current_user_or_404()
if not user: if error_response:
return jsonify({"error": "用户不存在"}), 404 return error_response
if not user.get("email"): if not user.get("email"):
return jsonify({"error": "当前未绑定邮箱"}), 400 return jsonify({"error": "当前未绑定邮箱"}), 400
@@ -334,10 +371,7 @@ def get_run_stats():
stats = database.get_user_run_stats(user_id) stats = database.get_user_run_stats(user_id)
current_running = 0 current_running = _get_current_running_count(user_id)
for _, info in safe_iter_task_status_items():
if info.get("user_id") == user_id and info.get("status") == "运行中":
current_running += 1
return jsonify( return jsonify(
{ {

View File

@@ -31,7 +31,7 @@ def admin_required(f):
if is_api: if is_api:
return jsonify({"error": "需要管理员权限"}), 403 return jsonify({"error": "需要管理员权限"}), 403
return redirect(url_for("pages.admin_login_page")) return redirect(url_for("pages.admin_login_page"))
logger.info(f"[admin_required] 管理员 {session.get('admin_username')} 访问 {request.path}") logger.debug(f"[admin_required] 管理员 {session.get('admin_username')} 访问 {request.path}")
return f(*args, **kwargs) return f(*args, **kwargs)
return decorated_function return decorated_function

View File

@@ -2,12 +2,62 @@
# -*- coding: utf-8 -*- # -*- coding: utf-8 -*-
from __future__ import annotations from __future__ import annotations
import os
import time
from flask import Blueprint, jsonify from flask import Blueprint, jsonify
import database import database
import db_pool
from services.time_utils import get_beijing_now from services.time_utils import get_beijing_now
health_bp = Blueprint("health", __name__) health_bp = Blueprint("health", __name__)
_PROCESS_START_TS = time.time()
def _build_runtime_metrics() -> dict:
metrics = {
"uptime_seconds": max(0, int(time.time() - _PROCESS_START_TS)),
}
try:
pool_stats = db_pool.get_pool_stats() or {}
metrics["db_pool"] = {
"pool_size": int(pool_stats.get("pool_size", 0) or 0),
"available": int(pool_stats.get("available", 0) or 0),
"in_use": int(pool_stats.get("in_use", 0) or 0),
}
except Exception:
pass
try:
import psutil
proc = psutil.Process(os.getpid())
with proc.oneshot():
mem_info = proc.memory_info()
metrics["process"] = {
"rss_mb": round(float(mem_info.rss) / 1024 / 1024, 2),
"cpu_percent": round(float(proc.cpu_percent(interval=None)), 2),
"threads": int(proc.num_threads()),
}
except Exception:
pass
try:
from services import tasks as tasks_module
scheduler = getattr(tasks_module, "_task_scheduler", None)
if scheduler is not None:
queue_snapshot = scheduler.get_queue_state_snapshot() or {}
metrics["task_queue"] = {
"pending_total": int(queue_snapshot.get("pending_total", 0) or 0),
"running_total": int(queue_snapshot.get("running_total", 0) or 0),
}
except Exception:
pass
return metrics
@health_bp.route("/health", methods=["GET"]) @health_bp.route("/health", methods=["GET"])
@@ -26,6 +76,6 @@ def health_check():
"time": get_beijing_now().strftime("%Y-%m-%d %H:%M:%S"), "time": get_beijing_now().strftime("%Y-%m-%d %H:%M:%S"),
"db_ok": db_ok, "db_ok": db_ok,
"db_error": db_error, "db_error": db_error,
"metrics": _build_runtime_metrics(),
} }
return jsonify(payload), (200 if db_ok else 500) return jsonify(payload), (200 if db_ok else 500)

View File

@@ -0,0 +1,60 @@
# 健康监控(邮件版)
本目录提供 `health_email_monitor.py`,通过调用 `/health` 接口并使用**容器内已有邮件配置**发告警邮件。
## 1) 快速试跑
```bash
cd /root/zsglpt
python3 scripts/health_email_monitor.py \
--to 你的告警邮箱@example.com \
--container knowledge-automation-multiuser \
--url http://127.0.0.1:51232/health \
--dry-run
```
去掉 `--dry-run` 即会实际发邮件。
## 2) 建议 cron每分钟
```bash
* * * * * cd /root/zsglpt && /usr/bin/python3 scripts/health_email_monitor.py \
--to 你的告警邮箱@example.com \
--container knowledge-automation-multiuser \
--url http://127.0.0.1:51232/health \
>> /root/zsglpt/logs/health_monitor.log 2>&1
```
## 3) 支持的规则
- `service_down`:健康接口请求失败(立即告警)
- `health_fail`:返回 `ok/db_ok` 异常或 HTTP 5xx立即告警
- `db_pool_exhausted`:连接池耗尽(默认连续 3 次才告警)
- `queue_backlog_high`:任务堆积过高(默认 `pending_total >= 50` 且连续 5 次)
脚本支持恢复通知(规则恢复正常会发“恢复”邮件)。
## 4) 常用参数
- `--to`:收件人(必填)
- `--container`Docker 容器名(默认 `knowledge-automation-multiuser`
- `--url`:健康地址(默认 `http://127.0.0.1:51232/health`
- `--state-file`:状态文件路径(默认 `/tmp/zsglpt_health_monitor_state.json`
- `--remind-seconds`:重复告警间隔(默认 3600 秒)
- `--queue-threshold`:队列告警阈值(默认 50
- `--queue-streak`:队列连续次数阈值(默认 5
- `--db-pool-streak`:连接池连续次数阈值(默认 3
## 5) 环境变量方式(可选)
也可不用命令行参数,改用环境变量:
- `MONITOR_EMAIL_TO`
- `MONITOR_DOCKER_CONTAINER`
- `HEALTH_URL`
- `MONITOR_STATE_FILE`
- `MONITOR_REMIND_SECONDS`
- `MONITOR_QUEUE_THRESHOLD`
- `MONITOR_QUEUE_STREAK`
- `MONITOR_DB_POOL_STREAK`

View File

@@ -0,0 +1,348 @@
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
from __future__ import annotations
import argparse
import json
import os
import subprocess
import time
from datetime import datetime
from typing import Any, Dict, Tuple
from urllib.error import HTTPError, URLError
from urllib.request import Request, urlopen
DEFAULT_STATE_FILE = "/tmp/zsglpt_health_monitor_state.json"
def _now_text() -> str:
return datetime.now().strftime("%Y-%m-%d %H:%M:%S")
def _safe_int(value: Any, default: int = 0) -> int:
try:
return int(value)
except Exception:
return int(default)
def _safe_float(value: Any, default: float = 0.0) -> float:
try:
return float(value)
except Exception:
return float(default)
def _load_state(path: str) -> Dict[str, Any]:
if not path or not os.path.exists(path):
return {
"version": 1,
"rules": {},
"counters": {},
}
try:
with open(path, "r", encoding="utf-8") as f:
raw = json.load(f)
if not isinstance(raw, dict):
raise ValueError("state is not dict")
raw.setdefault("version", 1)
raw.setdefault("rules", {})
raw.setdefault("counters", {})
return raw
except Exception:
return {
"version": 1,
"rules": {},
"counters": {},
}
def _save_state(path: str, state: Dict[str, Any]) -> None:
if not path:
return
state_dir = os.path.dirname(path)
if state_dir:
os.makedirs(state_dir, exist_ok=True)
tmp_path = f"{path}.tmp"
with open(tmp_path, "w", encoding="utf-8") as f:
json.dump(state, f, ensure_ascii=False, indent=2)
os.replace(tmp_path, path)
def _fetch_health(url: str, timeout: int) -> Tuple[int | None, Dict[str, Any], str | None]:
req = Request(
url,
headers={
"User-Agent": "zsglpt-health-email-monitor/1.0",
"Accept": "application/json",
},
method="GET",
)
try:
with urlopen(req, timeout=max(1, int(timeout))) as resp:
status = int(resp.getcode())
body = resp.read().decode("utf-8", errors="ignore")
except HTTPError as e:
status = int(getattr(e, "code", 0) or 0)
body = ""
try:
body = e.read().decode("utf-8", errors="ignore")
except Exception:
pass
data = {}
if body:
try:
data = json.loads(body)
if not isinstance(data, dict):
data = {}
except Exception:
data = {}
return status, data, f"HTTPError: {e}"
except URLError as e:
return None, {}, f"URLError: {e}"
except Exception as e:
return None, {}, f"RequestError: {e}"
data: Dict[str, Any] = {}
if body:
try:
loaded = json.loads(body)
if isinstance(loaded, dict):
data = loaded
except Exception:
data = {}
return status, data, None
def _inc_streak(state: Dict[str, Any], key: str, bad: bool) -> int:
counters = state.setdefault("counters", {})
current = _safe_int(counters.get(key), 0)
current = (current + 1) if bad else 0
counters[key] = current
return current
def _rule_transition(
state: Dict[str, Any],
*,
rule_name: str,
bad: bool,
streak: int,
threshold: int,
remind_seconds: int,
now_ts: float,
) -> str | None:
rules = state.setdefault("rules", {})
rule_state = rules.setdefault(rule_name, {"active": False, "last_sent": 0})
is_active = bool(rule_state.get("active", False))
last_sent = _safe_float(rule_state.get("last_sent", 0), 0.0)
threshold = max(1, int(threshold))
remind_seconds = max(60, int(remind_seconds))
if bad and streak >= threshold:
if not is_active:
rule_state["active"] = True
rule_state["last_sent"] = now_ts
return "alert"
if (now_ts - last_sent) >= remind_seconds:
rule_state["last_sent"] = now_ts
return "alert"
return None
if is_active and (not bad):
rule_state["active"] = False
rule_state["last_sent"] = now_ts
return "recover"
return None
def _send_email_via_container(
*,
container_name: str,
to_email: str,
subject: str,
body: str,
timeout_seconds: int = 45,
) -> Tuple[bool, str]:
code = (
"import sys,email_service;"
"res=email_service.send_email(to_email=sys.argv[1],subject=sys.argv[2],body=sys.argv[3],email_type='health_monitor');"
"ok=bool(res.get('success'));"
"print('ok' if ok else ('error:'+str(res.get('error'))));"
"raise SystemExit(0 if ok else 2)"
)
cmd = [
"docker",
"exec",
container_name,
"python",
"-c",
code,
to_email,
subject,
body,
]
try:
proc = subprocess.run(
cmd,
capture_output=True,
text=True,
timeout=max(5, int(timeout_seconds)),
check=False,
)
except Exception as e:
return False, str(e)
output = (proc.stdout or "") + (proc.stderr or "")
output = output.strip()
return proc.returncode == 0, output
def _build_common_lines(status: int | None, data: Dict[str, Any], fetch_error: str | None) -> list[str]:
metrics = data.get("metrics") if isinstance(data.get("metrics"), dict) else {}
db_pool = metrics.get("db_pool") if isinstance(metrics.get("db_pool"), dict) else {}
process = metrics.get("process") if isinstance(metrics.get("process"), dict) else {}
task_queue = metrics.get("task_queue") if isinstance(metrics.get("task_queue"), dict) else {}
lines = [
f"时间: {_now_text()}",
f"健康地址: {data.get('_monitor_url', '')}",
f"HTTP状态: {status if status is not None else '请求失败'}",
f"ok/db_ok: {data.get('ok')} / {data.get('db_ok')}",
]
if fetch_error:
lines.append(f"请求错误: {fetch_error}")
lines.extend(
[
f"队列: pending={task_queue.get('pending_total', 'N/A')}, running={task_queue.get('running_total', 'N/A')}",
f"连接池: size={db_pool.get('pool_size', 'N/A')}, available={db_pool.get('available', 'N/A')}, in_use={db_pool.get('in_use', 'N/A')}",
f"进程: rss_mb={process.get('rss_mb', 'N/A')}, cpu%={process.get('cpu_percent', 'N/A')}, threads={process.get('threads', 'N/A')}",
f"运行时长: {metrics.get('uptime_seconds', 'N/A')}",
]
)
return lines
def main() -> int:
parser = argparse.ArgumentParser(description="zsglpt 邮件健康监控(基于 /health")
parser.add_argument("--url", default=os.environ.get("HEALTH_URL", "http://127.0.0.1:51232/health"))
parser.add_argument("--to", default=os.environ.get("MONITOR_EMAIL_TO", ""))
parser.add_argument(
"--container",
default=os.environ.get("MONITOR_DOCKER_CONTAINER", "knowledge-automation-multiuser"),
)
parser.add_argument("--state-file", default=os.environ.get("MONITOR_STATE_FILE", DEFAULT_STATE_FILE))
parser.add_argument("--timeout", type=int, default=_safe_int(os.environ.get("MONITOR_TIMEOUT", 8), 8))
parser.add_argument(
"--remind-seconds",
type=int,
default=_safe_int(os.environ.get("MONITOR_REMIND_SECONDS", 3600), 3600),
)
parser.add_argument(
"--queue-threshold",
type=int,
default=_safe_int(os.environ.get("MONITOR_QUEUE_THRESHOLD", 50), 50),
)
parser.add_argument(
"--queue-streak",
type=int,
default=_safe_int(os.environ.get("MONITOR_QUEUE_STREAK", 5), 5),
)
parser.add_argument(
"--db-pool-streak",
type=int,
default=_safe_int(os.environ.get("MONITOR_DB_POOL_STREAK", 3), 3),
)
parser.add_argument("--dry-run", action="store_true", help="仅打印,不实际发邮件")
args = parser.parse_args()
if not args.to:
print("[monitor] 缺少收件人,请设置 --to 或 MONITOR_EMAIL_TO", flush=True)
return 2
state = _load_state(args.state_file)
now_ts = time.time()
status, data, fetch_error = _fetch_health(args.url, args.timeout)
if not isinstance(data, dict):
data = {}
data["_monitor_url"] = args.url
metrics = data.get("metrics") if isinstance(data.get("metrics"), dict) else {}
db_pool = metrics.get("db_pool") if isinstance(metrics.get("db_pool"), dict) else {}
task_queue = metrics.get("task_queue") if isinstance(metrics.get("task_queue"), dict) else {}
service_down = status is None
health_fail = bool(status is not None and (status >= 500 or (not data.get("ok", False)) or (not data.get("db_ok", False))))
db_pool_exhausted = (
_safe_int(db_pool.get("pool_size"), 0) > 0
and _safe_int(db_pool.get("available"), 0) <= 0
and _safe_int(db_pool.get("in_use"), 0) >= _safe_int(db_pool.get("pool_size"), 0)
)
queue_backlog_high = _safe_int(task_queue.get("pending_total"), 0) >= max(1, int(args.queue_threshold))
rule_defs = [
("service_down", service_down, 1),
("health_fail", health_fail, 1),
("db_pool_exhausted", db_pool_exhausted, max(1, int(args.db_pool_streak))),
("queue_backlog_high", queue_backlog_high, max(1, int(args.queue_streak))),
]
pending_notifications: list[tuple[str, str]] = []
for rule_name, bad, threshold in rule_defs:
streak = _inc_streak(state, rule_name, bad)
action = _rule_transition(
state,
rule_name=rule_name,
bad=bad,
streak=streak,
threshold=threshold,
remind_seconds=args.remind_seconds,
now_ts=now_ts,
)
if action:
pending_notifications.append((rule_name, action))
_save_state(args.state_file, state)
if not pending_notifications:
print(f"[monitor] {_now_text()} 正常,无需发送邮件")
return 0
common_lines = _build_common_lines(status, data, fetch_error)
for rule_name, action in pending_notifications:
level = "告警" if action == "alert" else "恢复"
subject = f"[zsglpt健康监控][{level}] {rule_name}"
body_lines = [
f"规则: {rule_name}",
f"状态: {level}",
"",
*common_lines,
]
body = "\n".join(body_lines)
if args.dry_run:
print(f"[monitor][dry-run] subject={subject}\n{body}\n")
continue
ok, msg = _send_email_via_container(
container_name=args.container,
to_email=args.to,
subject=subject,
body=body,
)
if ok:
print(f"[monitor] 邮件已发送: {subject}")
else:
print(f"[monitor] 邮件发送失败: {subject} | {msg}")
return 0
if __name__ == "__main__":
raise SystemExit(main())

View File

@@ -243,6 +243,35 @@ class KDocsUploader:
except queue.Empty: except queue.Empty:
return {"success": False, "error": "操作超时"} return {"success": False, "error": "操作超时"}
def _put_task_response(self, task: Dict[str, Any], result: Dict[str, Any]) -> None:
response_queue = task.get("response")
if not response_queue:
return
try:
response_queue.put(result)
except Exception:
return
def _process_task(self, task: Dict[str, Any]) -> bool:
action = task.get("action")
payload = task.get("payload") or {}
if action == "shutdown":
return False
if action == "upload":
self._handle_upload(payload)
return True
if action == "qr":
self._put_task_response(task, self._handle_qr(payload))
return True
if action == "clear_login":
self._put_task_response(task, self._handle_clear_login())
return True
if action == "status":
self._put_task_response(task, self._handle_status_check())
return True
return True
def _run(self) -> None: def _run(self) -> None:
thread_id = self._thread_id thread_id = self._thread_id
logger.info(f"[KDocs] 上传线程启动 (ID={thread_id})") logger.info(f"[KDocs] 上传线程启动 (ID={thread_id})")
@@ -261,34 +290,17 @@ class KDocsUploader:
# 更新最后活动时间 # 更新最后活动时间
self._last_activity = time.time() self._last_activity = time.time()
action = task.get("action")
if action == "shutdown":
break
try: try:
if action == "upload": should_continue = self._process_task(task)
self._handle_upload(task.get("payload") or {}) if not should_continue:
elif action == "qr": break
result = self._handle_qr(task.get("payload") or {})
task.get("response").put(result)
elif action == "clear_login":
result = self._handle_clear_login()
task.get("response").put(result)
elif action == "status":
result = self._handle_status_check()
task.get("response").put(result)
# 任务处理完成后更新活动时间 # 任务处理完成后更新活动时间
self._last_activity = time.time() self._last_activity = time.time()
except Exception as e: except Exception as e:
logger.warning(f"[KDocs] 处理任务失败: {e}") logger.warning(f"[KDocs] 处理任务失败: {e}")
# 如果有响应队列,返回错误 self._put_task_response(task, {"success": False, "error": str(e)})
if "response" in task and task.get("response"):
try:
task["response"].put({"success": False, "error": str(e)})
except Exception:
pass
except Exception as e: except Exception as e:
logger.warning(f"[KDocs] 线程主循环异常: {e}") logger.warning(f"[KDocs] 线程主循环异常: {e}")
@@ -830,18 +842,180 @@ class KDocsUploader:
except Exception as e: except Exception as e:
logger.warning(f"[KDocs] 保存登录态失败: {e}") logger.warning(f"[KDocs] 保存登录态失败: {e}")
def _resolve_doc_url(self, cfg: Dict[str, Any]) -> str:
return (cfg.get("kdocs_doc_url") or "").strip()
def _ensure_doc_access(
self,
doc_url: str,
*,
fast: bool = False,
use_storage_state: bool = True,
) -> Optional[str]:
if not self._ensure_playwright(use_storage_state=use_storage_state):
return self._last_error or "浏览器不可用"
if not self._open_document(doc_url, fast=fast):
return self._last_error or "打开文档失败"
return None
def _trigger_fast_login_dialog(self, timeout_ms: int) -> None:
self._ensure_login_dialog(
timeout_ms=timeout_ms,
frame_timeout_ms=timeout_ms,
quick=True,
)
def _capture_qr_with_retry(self, fast_login_timeout: int) -> Tuple[Optional[bytes], Optional[bytes]]:
qr_image = None
invalid_qr = None
for attempt in range(10):
if attempt in (3, 7):
self._trigger_fast_login_dialog(fast_login_timeout)
candidate = self._capture_qr_image()
if candidate and self._is_valid_qr_image(candidate):
qr_image = candidate
break
if candidate:
invalid_qr = candidate
time.sleep(0.8) # 优化: 1 -> 0.8
return qr_image, invalid_qr
def _save_qr_debug_artifacts(self, invalid_qr: Optional[bytes]) -> None:
try:
pages = self._iter_pages()
page_urls = [getattr(p, "url", "") for p in pages]
logger.warning(f"[KDocs] 二维码未捕获,页面: {page_urls}")
ts = int(time.time())
saved = []
for idx, page in enumerate(pages[:3]):
try:
path = f"data/kdocs_debug_{ts}_{idx}.png"
page.screenshot(path=path, full_page=True)
saved.append(path)
except Exception:
continue
if saved:
logger.warning(f"[KDocs] 已保存调试截图: {saved}")
if invalid_qr:
try:
path = f"data/kdocs_invalid_qr_{ts}.png"
with open(path, "wb") as handle:
handle.write(invalid_qr)
logger.warning(f"[KDocs] 已保存无效二维码截图: {path}")
except Exception:
pass
except Exception:
pass
def _log_upload_failure(self, message: str, user_id: Any, account_id: Any) -> None:
try:
log_to_client(f"表格上传失败: {message}", user_id, account_id)
except Exception:
pass
def _mark_upload_tracking(self, user_id: Any, account_id: Any) -> Tuple[Any, Optional[str], bool]:
account = None
prev_status = None
status_tracked = False
try:
account = safe_get_account(user_id, account_id)
if account and self._should_mark_upload(account):
prev_status = getattr(account, "status", None)
account.status = "上传截图"
self._emit_account_update(user_id, account)
status_tracked = True
except Exception:
prev_status = None
return account, prev_status, status_tracked
def _parse_upload_payload(self, payload: Dict[str, Any]) -> Optional[Dict[str, Any]]:
unit = (payload.get("unit") or "").strip()
name = (payload.get("name") or "").strip()
image_path = payload.get("image_path")
if not unit or not name:
return None
if not image_path or not os.path.exists(image_path):
return None
return {
"unit": unit,
"name": name,
"image_path": image_path,
"user_id": payload.get("user_id"),
"account_id": payload.get("account_id"),
}
def _resolve_upload_sheet_config(self, cfg: Dict[str, Any]) -> Dict[str, Any]:
return {
"sheet_name": (cfg.get("kdocs_sheet_name") or "").strip(),
"sheet_index": int(cfg.get("kdocs_sheet_index") or 0),
"unit_col": (cfg.get("kdocs_unit_column") or "A").strip().upper(),
"image_col": (cfg.get("kdocs_image_column") or "D").strip().upper(),
"row_start": int(cfg.get("kdocs_row_start") or 0),
"row_end": int(cfg.get("kdocs_row_end") or 0),
}
def _try_upload_to_sheet(self, cfg: Dict[str, Any], unit: str, name: str, image_path: str) -> Tuple[bool, str]:
sheet_cfg = self._resolve_upload_sheet_config(cfg)
success = False
error_msg = ""
for _ in range(2):
try:
if sheet_cfg["sheet_name"] or sheet_cfg["sheet_index"]:
self._select_sheet(sheet_cfg["sheet_name"], sheet_cfg["sheet_index"])
row_num = self._find_person_with_unit(
unit,
name,
sheet_cfg["unit_col"],
row_start=sheet_cfg["row_start"],
row_end=sheet_cfg["row_end"],
)
if row_num < 0:
error_msg = f"未找到人员: {unit}-{name}"
break
success = self._upload_image_to_cell(row_num, image_path, sheet_cfg["image_col"])
if success:
break
except Exception as e:
error_msg = str(e)
return success, error_msg
def _handle_upload_login_invalid(
self,
*,
unit: str,
name: str,
image_path: str,
user_id: Any,
account_id: Any,
) -> None:
error_msg = "登录已失效,请管理员重新扫码登录"
self._login_required = True
self._last_login_ok = False
self._notify_admin(unit, name, image_path, error_msg)
self._log_upload_failure(error_msg, user_id, account_id)
def _handle_qr(self, payload: Dict[str, Any]) -> Dict[str, Any]: def _handle_qr(self, payload: Dict[str, Any]) -> Dict[str, Any]:
cfg = self._load_system_config() cfg = self._load_system_config()
doc_url = (cfg.get("kdocs_doc_url") or "").strip() doc_url = self._resolve_doc_url(cfg)
if not doc_url: if not doc_url:
return {"success": False, "error": "未配置金山文档链接"} return {"success": False, "error": "未配置金山文档链接"}
force = bool(payload.get("force")) force = bool(payload.get("force"))
if force: if force:
self._handle_clear_login() self._handle_clear_login()
if not self._ensure_playwright(use_storage_state=not force):
return {"success": False, "error": self._last_error or "浏览器不可用"} doc_error = self._ensure_doc_access(doc_url, fast=True, use_storage_state=not force)
if not self._open_document(doc_url, fast=True): if doc_error:
return {"success": False, "error": self._last_error or "打开文档失败"} return {"success": False, "error": doc_error}
if not force and self._has_saved_login_state() and self._is_logged_in(): if not force and self._has_saved_login_state() and self._is_logged_in():
self._login_required = False self._login_required = False
@@ -850,54 +1024,12 @@ class KDocsUploader:
return {"success": True, "logged_in": True, "qr_image": ""} return {"success": True, "logged_in": True, "qr_image": ""}
fast_login_timeout = int(os.environ.get("KDOCS_FAST_LOGIN_TIMEOUT_MS", "300")) fast_login_timeout = int(os.environ.get("KDOCS_FAST_LOGIN_TIMEOUT_MS", "300"))
self._ensure_login_dialog( self._trigger_fast_login_dialog(fast_login_timeout)
timeout_ms=fast_login_timeout, qr_image, invalid_qr = self._capture_qr_with_retry(fast_login_timeout)
frame_timeout_ms=fast_login_timeout,
quick=True,
)
qr_image = None
invalid_qr = None
for attempt in range(10):
if attempt in (3, 7):
self._ensure_login_dialog(
timeout_ms=fast_login_timeout,
frame_timeout_ms=fast_login_timeout,
quick=True,
)
candidate = self._capture_qr_image()
if candidate and self._is_valid_qr_image(candidate):
qr_image = candidate
break
if candidate:
invalid_qr = candidate
time.sleep(0.8) # 优化: 1 -> 0.8
if not qr_image: if not qr_image:
self._last_error = "二维码识别异常" if invalid_qr else "二维码获取失败" self._last_error = "二维码识别异常" if invalid_qr else "二维码获取失败"
try: self._save_qr_debug_artifacts(invalid_qr)
pages = self._iter_pages()
page_urls = [getattr(p, "url", "") for p in pages]
logger.warning(f"[KDocs] 二维码未捕获,页面: {page_urls}")
ts = int(time.time())
saved = []
for idx, page in enumerate(pages[:3]):
try:
path = f"data/kdocs_debug_{ts}_{idx}.png"
page.screenshot(path=path, full_page=True)
saved.append(path)
except Exception:
continue
if saved:
logger.warning(f"[KDocs] 已保存调试截图: {saved}")
if invalid_qr:
try:
path = f"data/kdocs_invalid_qr_{ts}.png"
with open(path, "wb") as handle:
handle.write(invalid_qr)
logger.warning(f"[KDocs] 已保存无效二维码截图: {path}")
except Exception:
pass
except Exception:
pass
return {"success": False, "error": self._last_error} return {"success": False, "error": self._last_error}
try: try:
@@ -933,24 +1065,22 @@ class KDocsUploader:
def _handle_status_check(self) -> Dict[str, Any]: def _handle_status_check(self) -> Dict[str, Any]:
cfg = self._load_system_config() cfg = self._load_system_config()
doc_url = (cfg.get("kdocs_doc_url") or "").strip() doc_url = self._resolve_doc_url(cfg)
if not doc_url: if not doc_url:
return {"success": True, "logged_in": False, "error": "未配置文档链接"} return {"success": True, "logged_in": False, "error": "未配置文档链接"}
if not self._ensure_playwright():
return {"success": False, "logged_in": False, "error": self._last_error or "浏览器不可用"} doc_error = self._ensure_doc_access(doc_url, fast=True)
if not self._open_document(doc_url, fast=True): if doc_error:
return {"success": False, "logged_in": False, "error": self._last_error or "打开文档失败"} return {"success": False, "logged_in": False, "error": doc_error}
fast_login_timeout = int(os.environ.get("KDOCS_FAST_LOGIN_TIMEOUT_MS", "300")) fast_login_timeout = int(os.environ.get("KDOCS_FAST_LOGIN_TIMEOUT_MS", "300"))
self._ensure_login_dialog( self._trigger_fast_login_dialog(fast_login_timeout)
timeout_ms=fast_login_timeout,
frame_timeout_ms=fast_login_timeout,
quick=True,
)
self._try_confirm_login( self._try_confirm_login(
timeout_ms=fast_login_timeout, timeout_ms=fast_login_timeout,
frame_timeout_ms=fast_login_timeout, frame_timeout_ms=fast_login_timeout,
quick=True, quick=True,
) )
logged_in = self._is_logged_in() logged_in = self._is_logged_in()
self._last_login_ok = logged_in self._last_login_ok = logged_in
self._login_required = not logged_in self._login_required = not logged_in
@@ -962,79 +1092,43 @@ class KDocsUploader:
cfg = self._load_system_config() cfg = self._load_system_config()
if int(cfg.get("kdocs_enabled", 0) or 0) != 1: if int(cfg.get("kdocs_enabled", 0) or 0) != 1:
return return
doc_url = (cfg.get("kdocs_doc_url") or "").strip()
doc_url = self._resolve_doc_url(cfg)
if not doc_url: if not doc_url:
return return
unit = (payload.get("unit") or "").strip() upload_data = self._parse_upload_payload(payload)
name = (payload.get("name") or "").strip() if not upload_data:
image_path = payload.get("image_path")
user_id = payload.get("user_id")
account_id = payload.get("account_id")
if not unit or not name:
return
if not image_path or not os.path.exists(image_path):
return return
account = None unit = upload_data["unit"]
prev_status = None name = upload_data["name"]
status_tracked = False image_path = upload_data["image_path"]
user_id = upload_data["user_id"]
account_id = upload_data["account_id"]
account, prev_status, status_tracked = self._mark_upload_tracking(user_id, account_id)
try: try:
try: doc_error = self._ensure_doc_access(doc_url)
account = safe_get_account(user_id, account_id) if doc_error:
if account and self._should_mark_upload(account): self._notify_admin(unit, name, image_path, doc_error)
prev_status = getattr(account, "status", None)
account.status = "上传截图"
self._emit_account_update(user_id, account)
status_tracked = True
except Exception:
prev_status = None
if not self._ensure_playwright():
self._notify_admin(unit, name, image_path, self._last_error or "浏览器不可用")
return
if not self._open_document(doc_url):
self._notify_admin(unit, name, image_path, self._last_error or "打开文档失败")
return return
if not self._is_logged_in(): if not self._is_logged_in():
self._login_required = True self._handle_upload_login_invalid(
self._last_login_ok = False unit=unit,
self._notify_admin(unit, name, image_path, "登录已失效,请管理员重新扫码登录") name=name,
try: image_path=image_path,
log_to_client("表格上传失败: 登录已失效,请管理员重新扫码登录", user_id, account_id) user_id=user_id,
except Exception: account_id=account_id,
pass )
return return
self._login_required = False self._login_required = False
self._last_login_ok = True self._last_login_ok = True
sheet_name = (cfg.get("kdocs_sheet_name") or "").strip() success, error_msg = self._try_upload_to_sheet(cfg, unit, name, image_path)
sheet_index = int(cfg.get("kdocs_sheet_index") or 0)
unit_col = (cfg.get("kdocs_unit_column") or "A").strip().upper()
image_col = (cfg.get("kdocs_image_column") or "D").strip().upper()
row_start = int(cfg.get("kdocs_row_start") or 0)
row_end = int(cfg.get("kdocs_row_end") or 0)
success = False
error_msg = ""
for attempt in range(2):
try:
if sheet_name or sheet_index:
self._select_sheet(sheet_name, sheet_index)
row_num = self._find_person_with_unit(unit, name, unit_col, row_start=row_start, row_end=row_end)
if row_num < 0:
error_msg = f"未找到人员: {unit}-{name}"
break
success = self._upload_image_to_cell(row_num, image_path, image_col)
if success:
break
except Exception as e:
error_msg = str(e)
if success: if success:
self._last_success_at = time.time() self._last_success_at = time.time()
self._last_error = None self._last_error = None
@@ -1048,10 +1142,7 @@ class KDocsUploader:
error_msg = "上传失败" error_msg = "上传失败"
self._last_error = error_msg self._last_error = error_msg
self._notify_admin(unit, name, image_path, error_msg) self._notify_admin(unit, name, image_path, error_msg)
try: self._log_upload_failure(error_msg, user_id, account_id)
log_to_client(f"表格上传失败: {error_msg}", user_id, account_id)
except Exception:
pass
finally: finally:
if status_tracked: if status_tracked:
self._restore_account_status(user_id, account, prev_status) self._restore_account_status(user_id, account, prev_status)

View File

@@ -2,6 +2,7 @@
# -*- coding: utf-8 -*- # -*- coding: utf-8 -*-
from __future__ import annotations from __future__ import annotations
import os
import threading import threading
import time import time
from datetime import datetime from datetime import datetime
@@ -10,6 +11,8 @@ from app_config import get_config
from app_logger import get_logger from app_logger import get_logger
from services.state import ( from services.state import (
cleanup_expired_ip_rate_limits, cleanup_expired_ip_rate_limits,
cleanup_expired_ip_request_rates,
cleanup_expired_login_security_state,
safe_cleanup_expired_batches, safe_cleanup_expired_batches,
safe_cleanup_expired_captcha, safe_cleanup_expired_captcha,
safe_cleanup_expired_pending_random, safe_cleanup_expired_pending_random,
@@ -31,6 +34,69 @@ PENDING_RANDOM_EXPIRE_SECONDS = int(getattr(config, "PENDING_RANDOM_EXPIRE_SECON
_kdocs_offline_notified: bool = False _kdocs_offline_notified: bool = False
def _to_int(value, default: int = 0) -> int:
try:
return int(value)
except Exception:
return int(default)
def _collect_active_user_ids() -> set[int]:
active_user_ids: set[int] = set()
for _, info in safe_iter_task_status_items():
user_id = info.get("user_id") if isinstance(info, dict) else None
if user_id is None:
continue
try:
active_user_ids.add(int(user_id))
except Exception:
continue
return active_user_ids
def _find_expired_user_cache_ids(current_time: float, active_user_ids: set[int]) -> list[int]:
expired_users = []
for user_id, last_access in (safe_get_user_accounts_last_access_items() or []):
try:
user_id_int = int(user_id)
last_access_ts = float(last_access)
except Exception:
continue
if (current_time - last_access_ts) <= USER_ACCOUNTS_EXPIRE_SECONDS:
continue
if user_id_int in active_user_ids:
continue
if safe_has_user(user_id_int):
expired_users.append(user_id_int)
return expired_users
def _find_completed_task_status_ids(current_time: float) -> list[str]:
completed_task_ids = []
for account_id, status_data in safe_iter_task_status_items():
status = status_data.get("status") if isinstance(status_data, dict) else None
if status not in ["已完成", "失败", "已停止"]:
continue
start_time = float(status_data.get("start_time", 0) or 0)
if (current_time - start_time) > 600: # 10分钟
completed_task_ids.append(account_id)
return completed_task_ids
def _reap_zombie_processes() -> None:
while True:
try:
pid, _ = os.waitpid(-1, os.WNOHANG)
if pid == 0:
break
logger.debug(f"已回收僵尸进程: PID={pid}")
except ChildProcessError:
break
except Exception:
break
def cleanup_expired_data() -> None: def cleanup_expired_data() -> None:
"""定期清理过期数据,防止内存泄漏(逻辑保持不变)。""" """定期清理过期数据,防止内存泄漏(逻辑保持不变)。"""
current_time = time.time() current_time = time.time()
@@ -43,48 +109,36 @@ def cleanup_expired_data() -> None:
if deleted_ips: if deleted_ips:
logger.debug(f"已清理 {deleted_ips} 个过期IP限流记录") logger.debug(f"已清理 {deleted_ips} 个过期IP限流记录")
expired_users = [] deleted_ip_requests = cleanup_expired_ip_request_rates(current_time)
last_access_items = safe_get_user_accounts_last_access_items() if deleted_ip_requests:
if last_access_items: logger.debug(f"已清理 {deleted_ip_requests} 个过期IP请求频率记录")
task_items = safe_iter_task_status_items()
active_user_ids = {int(info.get("user_id")) for _, info in task_items if info.get("user_id")}
for user_id, last_access in last_access_items:
if (current_time - float(last_access)) <= USER_ACCOUNTS_EXPIRE_SECONDS:
continue
if int(user_id) in active_user_ids:
continue
if safe_has_user(user_id):
expired_users.append(int(user_id))
login_cleanup_stats = cleanup_expired_login_security_state(current_time)
login_cleanup_total = sum(int(v or 0) for v in login_cleanup_stats.values())
if login_cleanup_total:
logger.debug(
"已清理登录风控缓存: "
f"失败计数={login_cleanup_stats.get('failures', 0)}, "
f"限流桶={login_cleanup_stats.get('rate_limits', 0)}, "
f"扫描状态={login_cleanup_stats.get('scan_states', 0)}, "
f"短时锁={login_cleanup_stats.get('ip_user_locks', 0)}, "
f"告警状态={login_cleanup_stats.get('alerts', 0)}"
)
active_user_ids = _collect_active_user_ids()
expired_users = _find_expired_user_cache_ids(current_time, active_user_ids)
for user_id in expired_users: for user_id in expired_users:
safe_remove_user_accounts(user_id) safe_remove_user_accounts(user_id)
if expired_users: if expired_users:
logger.debug(f"已清理 {len(expired_users)} 个过期用户账号缓存") logger.debug(f"已清理 {len(expired_users)} 个过期用户账号缓存")
completed_tasks = [] completed_task_ids = _find_completed_task_status_ids(current_time)
for account_id, status_data in safe_iter_task_status_items(): for account_id in completed_task_ids:
if status_data.get("status") in ["已完成", "失败", "已停止"]:
start_time = float(status_data.get("start_time", 0) or 0)
if (current_time - start_time) > 600: # 10分钟
completed_tasks.append(account_id)
for account_id in completed_tasks:
safe_remove_task_status(account_id) safe_remove_task_status(account_id)
if completed_tasks: if completed_task_ids:
logger.debug(f"已清理 {len(completed_tasks)} 个已完成任务状态") logger.debug(f"已清理 {len(completed_task_ids)} 个已完成任务状态")
try: _reap_zombie_processes()
import os
while True:
try:
pid, status = os.waitpid(-1, os.WNOHANG)
if pid == 0:
break
logger.debug(f"已回收僵尸进程: PID={pid}")
except ChildProcessError:
break
except Exception:
pass
deleted_batches = safe_cleanup_expired_batches(BATCH_TASK_EXPIRE_SECONDS, current_time) deleted_batches = safe_cleanup_expired_batches(BATCH_TASK_EXPIRE_SECONDS, current_time)
if deleted_batches: if deleted_batches:
@@ -95,52 +149,39 @@ def cleanup_expired_data() -> None:
logger.debug(f"已清理 {deleted_random} 个过期随机延迟任务") logger.debug(f"已清理 {deleted_random} 个过期随机延迟任务")
def check_kdocs_online_status() -> None: def _load_kdocs_monitor_config():
"""检测金山文档登录状态,如果离线则发送邮件通知管理员(每次掉线只通知一次)""" import database
global _kdocs_offline_notified
cfg = database.get_system_config()
if not cfg:
return None
kdocs_enabled = _to_int(cfg.get("kdocs_enabled"), 0)
if not kdocs_enabled:
return None
admin_notify_enabled = _to_int(cfg.get("kdocs_admin_notify_enabled"), 0)
admin_notify_email = str(cfg.get("kdocs_admin_notify_email") or "").strip()
if (not admin_notify_enabled) or (not admin_notify_email):
return None
return admin_notify_email
def _is_kdocs_offline(status: dict) -> tuple[bool, bool, bool | None]:
login_required = bool(status.get("login_required", False))
last_login_ok = status.get("last_login_ok")
is_offline = login_required or (last_login_ok is False)
return is_offline, login_required, last_login_ok
def _send_kdocs_offline_alert(admin_notify_email: str, *, login_required: bool, last_login_ok) -> bool:
try: try:
import database import email_service
from services.kdocs_uploader import get_kdocs_uploader
# 获取系统配置 now_str = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
cfg = database.get_system_config() subject = "【金山文档离线告警】需要重新登录"
if not cfg: body = f"""
return
# 检查是否启用了金山文档功能
kdocs_enabled = int(cfg.get("kdocs_enabled") or 0)
if not kdocs_enabled:
return
# 检查是否启用了管理员通知
admin_notify_enabled = int(cfg.get("kdocs_admin_notify_enabled") or 0)
admin_notify_email = (cfg.get("kdocs_admin_notify_email") or "").strip()
if not admin_notify_enabled or not admin_notify_email:
return
# 获取金山文档状态
kdocs = get_kdocs_uploader()
status = kdocs.get_status()
login_required = status.get("login_required", False)
last_login_ok = status.get("last_login_ok")
# 如果需要登录或最后登录状态不是成功
is_offline = login_required or (last_login_ok is False)
if is_offline:
# 已经通知过了,不再重复通知
if _kdocs_offline_notified:
logger.debug("[KDocs监控] 金山文档离线,已通知过,跳过重复通知")
return
# 发送邮件通知
try:
import email_service
now_str = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
subject = "【金山文档离线告警】需要重新登录"
body = f"""
您好, 您好,
系统检测到金山文档上传功能已离线,需要重新扫码登录。 系统检测到金山文档上传功能已离线,需要重新扫码登录。
@@ -155,58 +196,92 @@ def check_kdocs_online_status() -> None:
--- ---
此邮件由系统自动发送,请勿直接回复。 此邮件由系统自动发送,请勿直接回复。
""" """
email_service.send_email_async( email_service.send_email_async(
to_email=admin_notify_email, to_email=admin_notify_email,
subject=subject, subject=subject,
body=body, body=body,
email_type="kdocs_offline_alert", email_type="kdocs_offline_alert",
) )
_kdocs_offline_notified = True # 标记为已通知 logger.warning(f"[KDocs监控] 金山文档离线,已发送通知邮件到 {admin_notify_email}")
logger.warning(f"[KDocs监控] 金山文档离线,已发送通知邮件到 {admin_notify_email}") return True
except Exception as e: except Exception as e:
logger.error(f"[KDocs监控] 发送离线通知邮件失败: {e}") logger.error(f"[KDocs监控] 发送离线通知邮件失败: {e}")
else: return False
# 恢复在线,重置通知状态
def check_kdocs_online_status() -> None:
"""检测金山文档登录状态,如果离线则发送邮件通知管理员(每次掉线只通知一次)"""
global _kdocs_offline_notified
try:
admin_notify_email = _load_kdocs_monitor_config()
if not admin_notify_email:
return
from services.kdocs_uploader import get_kdocs_uploader
kdocs = get_kdocs_uploader()
status = kdocs.get_status() or {}
is_offline, login_required, last_login_ok = _is_kdocs_offline(status)
if is_offline:
if _kdocs_offline_notified: if _kdocs_offline_notified:
logger.info("[KDocs监控] 金山文档已恢复在线,重置通知状态") logger.debug("[KDocs监控] 金山文档离线,已通知过,跳过重复通知")
_kdocs_offline_notified = False return
logger.debug("[KDocs监控] 金山文档状态正常")
if _send_kdocs_offline_alert(
admin_notify_email,
login_required=login_required,
last_login_ok=last_login_ok,
):
_kdocs_offline_notified = True
return
if _kdocs_offline_notified:
logger.info("[KDocs监控] 金山文档已恢复在线,重置通知状态")
_kdocs_offline_notified = False
logger.debug("[KDocs监控] 金山文档状态正常")
except Exception as e: except Exception as e:
logger.error(f"[KDocs监控] 检测失败: {e}") logger.error(f"[KDocs监控] 检测失败: {e}")
def start_cleanup_scheduler() -> None: def _start_daemon_loop(name: str, *, startup_delay: float, interval_seconds: float, job, error_tag: str):
"""启动定期清理调度器""" def loop():
if startup_delay > 0:
def cleanup_loop(): time.sleep(startup_delay)
while True: while True:
try: try:
time.sleep(300) # 每5分钟执行一次清理 job()
cleanup_expired_data() time.sleep(interval_seconds)
except Exception as e: except Exception as e:
logger.error(f"清理任务执行失败: {e}") logger.error(f"{error_tag}: {e}")
time.sleep(min(60.0, max(1.0, interval_seconds / 5.0)))
cleanup_thread = threading.Thread(target=cleanup_loop, daemon=True, name="cleanup-scheduler") thread = threading.Thread(target=loop, daemon=True, name=name)
cleanup_thread.start() thread.start()
return thread
def start_cleanup_scheduler() -> None:
"""启动定期清理调度器"""
_start_daemon_loop(
"cleanup-scheduler",
startup_delay=300,
interval_seconds=300,
job=cleanup_expired_data,
error_tag="清理任务执行失败",
)
logger.info("内存清理调度器已启动") logger.info("内存清理调度器已启动")
def start_kdocs_monitor() -> None: def start_kdocs_monitor() -> None:
"""启动金山文档状态监控""" """启动金山文档状态监控"""
_start_daemon_loop(
def monitor_loop(): "kdocs-monitor",
# 启动后等待 60 秒再开始检测(给系统初始化的时间) startup_delay=60,
time.sleep(60) interval_seconds=300,
while True: job=check_kdocs_online_status,
try: error_tag="[KDocs监控] 监控任务执行失败",
check_kdocs_online_status() )
time.sleep(300) # 每5分钟检测一次
except Exception as e:
logger.error(f"[KDocs监控] 监控任务执行失败: {e}")
time.sleep(60)
monitor_thread = threading.Thread(target=monitor_loop, daemon=True, name="kdocs-monitor")
monitor_thread.start()
logger.info("[KDocs监控] 金山文档状态监控已启动每5分钟检测一次") logger.info("[KDocs监控] 金山文档状态监控已启动每5分钟检测一次")

View File

@@ -27,6 +27,12 @@ from services.time_utils import get_beijing_now
logger = get_logger("app") logger = get_logger("app")
config = get_config() config = get_config()
try:
_SCHEDULE_SUBMIT_DELAY_SECONDS = float(os.environ.get("SCHEDULE_SUBMIT_DELAY_SECONDS", "0.2"))
except Exception:
_SCHEDULE_SUBMIT_DELAY_SECONDS = 0.2
_SCHEDULE_SUBMIT_DELAY_SECONDS = max(0.0, _SCHEDULE_SUBMIT_DELAY_SECONDS)
SCREENSHOTS_DIR = config.SCREENSHOTS_DIR SCREENSHOTS_DIR = config.SCREENSHOTS_DIR
os.makedirs(SCREENSHOTS_DIR, exist_ok=True) os.makedirs(SCREENSHOTS_DIR, exist_ok=True)
@@ -55,6 +61,150 @@ def _normalize_hhmm(value: object, *, default: str) -> str:
return f"{hour:02d}:{minute:02d}" return f"{hour:02d}:{minute:02d}"
def _safe_recompute_schedule_next_run(schedule_id: int) -> None:
try:
database.recompute_schedule_next_run(schedule_id)
except Exception:
pass
def _load_accounts_for_users(approved_users: list[dict]) -> tuple[dict[int, dict], list[str]]:
"""批量加载用户账号快照。"""
user_accounts: dict[int, dict] = {}
account_ids: list[str] = []
for user in approved_users:
user_id = user["id"]
accounts = safe_get_user_accounts_snapshot(user_id)
if not accounts:
load_user_accounts(user_id)
accounts = safe_get_user_accounts_snapshot(user_id)
if accounts:
user_accounts[user_id] = accounts
account_ids.extend(list(accounts.keys()))
return user_accounts, account_ids
def _should_skip_suspended_account(account_status_info, account, username: str) -> bool:
"""判断是否应跳过暂停账号,并输出日志。"""
if not account_status_info:
return False
status = account_status_info["status"] if "status" in account_status_info.keys() else "active"
if status != "suspended":
return False
fail_count = account_status_info["login_fail_count"] if "login_fail_count" in account_status_info.keys() else 0
logger.info(
f"[定时任务] 跳过暂停账号: {account.username} (用户:{username}) - 连续{fail_count}次密码错误,需修改密码"
)
return True
def _parse_schedule_account_ids(schedule_config: dict, schedule_id: int):
import json
try:
account_ids_raw = schedule_config.get("account_ids", "[]") or "[]"
account_ids = json.loads(account_ids_raw)
except Exception as e:
logger.warning(f"[定时任务] 任务#{schedule_id} 解析account_ids失败: {e}")
return []
if isinstance(account_ids, list):
return account_ids
return []
def _create_user_schedule_batch(*, batch_id: str, user_id: int, browse_type: str, schedule_name: str, now_ts: float) -> None:
safe_create_batch(
batch_id,
{
"user_id": user_id,
"browse_type": browse_type,
"schedule_name": schedule_name,
"screenshots": [],
"total_accounts": 0,
"completed": 0,
"created_at": now_ts,
"updated_at": now_ts,
},
)
def _build_user_schedule_done_callback(
*,
completion_lock: threading.Lock,
remaining: dict,
counters: dict,
execution_start_time: float,
log_id: int,
schedule_id: int,
total_accounts: int,
):
def on_browse_done():
with completion_lock:
remaining["count"] -= 1
if remaining["done"] or remaining["count"] > 0:
return
remaining["done"] = True
execution_duration = int(time.time() - execution_start_time)
started_count = int(counters.get("started", 0) or 0)
database.update_schedule_execution_log(
log_id,
total_accounts=total_accounts,
success_accounts=started_count,
failed_accounts=total_accounts - started_count,
duration_seconds=execution_duration,
status="completed",
)
logger.info(f"[用户定时任务] 任务#{schedule_id}浏览阶段完成,耗时{execution_duration}秒,等待截图完成后发送邮件")
return on_browse_done
def _submit_user_schedule_accounts(
*,
user_id: int,
account_ids: list,
browse_type: str,
enable_screenshot,
task_source: str,
done_callback,
completion_lock: threading.Lock,
remaining: dict,
counters: dict,
) -> tuple[int, int]:
started_count = 0
skipped_count = 0
for account_id in account_ids:
account = safe_get_account(user_id, account_id)
if (not account) or account.is_running:
skipped_count += 1
continue
with completion_lock:
remaining["count"] += 1
ok, msg = submit_account_task(
user_id=user_id,
account_id=account_id,
browse_type=browse_type,
enable_screenshot=enable_screenshot,
source=task_source,
done_callback=done_callback,
)
if ok:
started_count += 1
counters["started"] = started_count
else:
with completion_lock:
remaining["count"] -= 1
skipped_count += 1
logger.warning(f"[用户定时任务] 账号 {account.username} 启动失败: {msg}")
return started_count, skipped_count
def run_scheduled_task(skip_weekday_check: bool = False) -> None: def run_scheduled_task(skip_weekday_check: bool = False) -> None:
"""执行所有账号的浏览任务(可被手动调用,过滤重复账号)""" """执行所有账号的浏览任务(可被手动调用,过滤重复账号)"""
try: try:
@@ -87,17 +237,7 @@ def run_scheduled_task(skip_weekday_check: bool = False) -> None:
cfg = database.get_system_config() cfg = database.get_system_config()
enable_screenshot_scheduled = cfg.get("enable_screenshot", 0) == 1 enable_screenshot_scheduled = cfg.get("enable_screenshot", 0) == 1
user_accounts = {} user_accounts, account_ids = _load_accounts_for_users(approved_users)
account_ids = []
for user in approved_users:
user_id = user["id"]
accounts = safe_get_user_accounts_snapshot(user_id)
if not accounts:
load_user_accounts(user_id)
accounts = safe_get_user_accounts_snapshot(user_id)
if accounts:
user_accounts[user_id] = accounts
account_ids.extend(list(accounts.keys()))
account_statuses = database.get_account_status_batch(account_ids) account_statuses = database.get_account_status_batch(account_ids)
@@ -113,18 +253,8 @@ def run_scheduled_task(skip_weekday_check: bool = False) -> None:
continue continue
account_status_info = account_statuses.get(str(account_id)) account_status_info = account_statuses.get(str(account_id))
if account_status_info: if _should_skip_suspended_account(account_status_info, account, user["username"]):
status = account_status_info["status"] if "status" in account_status_info.keys() else "active" continue
if status == "suspended":
fail_count = (
account_status_info["login_fail_count"]
if "login_fail_count" in account_status_info.keys()
else 0
)
logger.info(
f"[定时任务] 跳过暂停账号: {account.username} (用户:{user['username']}) - 连续{fail_count}次密码错误,需修改密码"
)
continue
if account.username in executed_usernames: if account.username in executed_usernames:
skipped_duplicates += 1 skipped_duplicates += 1
@@ -149,7 +279,8 @@ def run_scheduled_task(skip_weekday_check: bool = False) -> None:
else: else:
logger.warning(f"[定时任务] 启动失败({account.username}): {msg}") logger.warning(f"[定时任务] 启动失败({account.username}): {msg}")
time.sleep(2) if _SCHEDULE_SUBMIT_DELAY_SECONDS > 0:
time.sleep(_SCHEDULE_SUBMIT_DELAY_SECONDS)
logger.info( logger.info(
f"[定时任务] 执行完成 - 总账号数:{total_accounts}, 已执行:{executed_accounts}, 跳过重复:{skipped_duplicates}" f"[定时任务] 执行完成 - 总账号数:{total_accounts}, 已执行:{executed_accounts}, 跳过重复:{skipped_duplicates}"
@@ -198,15 +329,16 @@ def scheduled_task_worker() -> None:
deleted_screenshots = 0 deleted_screenshots = 0
if os.path.exists(SCREENSHOTS_DIR): if os.path.exists(SCREENSHOTS_DIR):
cutoff_time = time.time() - (7 * 24 * 60 * 60) cutoff_time = time.time() - (7 * 24 * 60 * 60)
for filename in os.listdir(SCREENSHOTS_DIR): with os.scandir(SCREENSHOTS_DIR) as entries:
if filename.lower().endswith((".png", ".jpg", ".jpeg")): for entry in entries:
filepath = os.path.join(SCREENSHOTS_DIR, filename) if (not entry.is_file()) or (not entry.name.lower().endswith((".png", ".jpg", ".jpeg"))):
continue
try: try:
if os.path.getmtime(filepath) < cutoff_time: if entry.stat().st_mtime < cutoff_time:
os.remove(filepath) os.remove(entry.path)
deleted_screenshots += 1 deleted_screenshots += 1
except Exception as e: except Exception as e:
logger.warning(f"[定时清理] 删除截图失败 {filename}: {str(e)}") logger.warning(f"[定时清理] 删除截图失败 {entry.name}: {str(e)}")
logger.info(f"[定时清理] 已删除 {deleted_screenshots} 个截图文件") logger.info(f"[定时清理] 已删除 {deleted_screenshots} 个截图文件")
logger.info("[定时清理] 清理完成!") logger.info("[定时清理] 清理完成!")
@@ -214,10 +346,97 @@ def scheduled_task_worker() -> None:
except Exception as e: except Exception as e:
logger.exception(f"[定时清理] 清理任务出错: {str(e)}") logger.exception(f"[定时清理] 清理任务出错: {str(e)}")
def _parse_due_schedule_weekdays(schedule_config: dict, schedule_id: int):
weekdays_str = schedule_config.get("weekdays", "1,2,3,4,5")
try:
return [int(d) for d in weekdays_str.split(",") if d.strip()]
except Exception as e:
logger.warning(f"[定时任务] 任务#{schedule_id} 解析weekdays失败: {e}")
_safe_recompute_schedule_next_run(schedule_id)
return None
def _execute_due_user_schedule(schedule_config: dict) -> None:
schedule_name = schedule_config.get("name", "未命名任务")
schedule_id = schedule_config["id"]
user_id = schedule_config["user_id"]
browse_type = normalize_browse_type(schedule_config.get("browse_type", BROWSE_TYPE_SHOULD_READ))
enable_screenshot = schedule_config.get("enable_screenshot", 1)
account_ids = _parse_schedule_account_ids(schedule_config, schedule_id)
if not account_ids:
_safe_recompute_schedule_next_run(schedule_id)
return
if not safe_get_user_accounts_snapshot(user_id):
load_user_accounts(user_id)
import uuid
execution_start_time = time.time()
log_id = database.create_schedule_execution_log(
schedule_id=schedule_id,
user_id=user_id,
schedule_name=schedule_name,
)
batch_id = f"batch_{uuid.uuid4().hex[:12]}"
now_ts = time.time()
_create_user_schedule_batch(
batch_id=batch_id,
user_id=user_id,
browse_type=browse_type,
schedule_name=schedule_name,
now_ts=now_ts,
)
completion_lock = threading.Lock()
remaining = {"count": 0, "done": False}
counters = {"started": 0}
on_browse_done = _build_user_schedule_done_callback(
completion_lock=completion_lock,
remaining=remaining,
counters=counters,
execution_start_time=execution_start_time,
log_id=log_id,
schedule_id=schedule_id,
total_accounts=len(account_ids),
)
task_source = f"user_scheduled:{batch_id}"
started_count, skipped_count = _submit_user_schedule_accounts(
user_id=user_id,
account_ids=account_ids,
browse_type=browse_type,
enable_screenshot=enable_screenshot,
task_source=task_source,
done_callback=on_browse_done,
completion_lock=completion_lock,
remaining=remaining,
counters=counters,
)
batch_info = safe_finalize_batch_after_dispatch(batch_id, started_count, now_ts=time.time())
if batch_info:
_send_batch_task_email_if_configured(batch_info)
database.update_schedule_last_run(schedule_id)
logger.info(f"[用户定时任务] 已启动 {started_count} 个账号,跳过 {skipped_count} 个账号批次ID: {batch_id}")
if started_count <= 0:
database.update_schedule_execution_log(
log_id,
total_accounts=len(account_ids),
success_accounts=0,
failed_accounts=len(account_ids),
duration_seconds=0,
status="completed",
)
if started_count == 0 and len(account_ids) > 0:
logger.warning("[用户定时任务] ⚠️ 警告所有账号都被跳过了请检查user_accounts状态")
def check_user_schedules(): def check_user_schedules():
"""检查并执行用户定时任务O-08next_run_at 索引驱动)。""" """检查并执行用户定时任务O-08next_run_at 索引驱动)。"""
import json
try: try:
now = get_beijing_now() now = get_beijing_now()
now_str = now.strftime("%Y-%m-%d %H:%M:%S") now_str = now.strftime("%Y-%m-%d %H:%M:%S")
@@ -226,145 +445,22 @@ def scheduled_task_worker() -> None:
due_schedules = database.get_due_user_schedules(now_str, limit=50) or [] due_schedules = database.get_due_user_schedules(now_str, limit=50) or []
for schedule_config in due_schedules: for schedule_config in due_schedules:
schedule_name = schedule_config.get("name", "未命名任务")
schedule_id = schedule_config["id"] schedule_id = schedule_config["id"]
schedule_name = schedule_config.get("name", "未命名任务")
weekdays_str = schedule_config.get("weekdays", "1,2,3,4,5") allowed_weekdays = _parse_due_schedule_weekdays(schedule_config, schedule_id)
try: if allowed_weekdays is None:
allowed_weekdays = [int(d) for d in weekdays_str.split(",") if d.strip()]
except Exception as e:
logger.warning(f"[定时任务] 任务#{schedule_id} 解析weekdays失败: {e}")
try:
database.recompute_schedule_next_run(schedule_id)
except Exception:
pass
continue continue
if current_weekday not in allowed_weekdays: if current_weekday not in allowed_weekdays:
try: _safe_recompute_schedule_next_run(schedule_id)
database.recompute_schedule_next_run(schedule_id)
except Exception:
pass
continue continue
logger.info(f"[用户定时任务] 任务#{schedule_id} '{schedule_name}' 到期,开始执行 (next_run_at={schedule_config.get('next_run_at')})") logger.info(
f"[用户定时任务] 任务#{schedule_id} '{schedule_name}' 到期,开始执行 "
user_id = schedule_config["user_id"] f"(next_run_at={schedule_config.get('next_run_at')})"
schedule_id = schedule_config["id"]
browse_type = normalize_browse_type(schedule_config.get("browse_type", BROWSE_TYPE_SHOULD_READ))
enable_screenshot = schedule_config.get("enable_screenshot", 1)
try:
account_ids_raw = schedule_config.get("account_ids", "[]") or "[]"
account_ids = json.loads(account_ids_raw)
except Exception as e:
logger.warning(f"[定时任务] 任务#{schedule_id} 解析account_ids失败: {e}")
account_ids = []
if not account_ids:
try:
database.recompute_schedule_next_run(schedule_id)
except Exception:
pass
continue
if not safe_get_user_accounts_snapshot(user_id):
load_user_accounts(user_id)
import time as time_mod
import uuid
execution_start_time = time_mod.time()
log_id = database.create_schedule_execution_log(
schedule_id=schedule_id, user_id=user_id, schedule_name=schedule_config.get("name", "未命名任务")
) )
_execute_due_user_schedule(schedule_config)
batch_id = f"batch_{uuid.uuid4().hex[:12]}"
now_ts = time_mod.time()
safe_create_batch(
batch_id,
{
"user_id": user_id,
"browse_type": browse_type,
"schedule_name": schedule_config.get("name", "未命名任务"),
"screenshots": [],
"total_accounts": 0,
"completed": 0,
"created_at": now_ts,
"updated_at": now_ts,
},
)
started_count = 0
skipped_count = 0
completion_lock = threading.Lock()
remaining = {"count": 0, "done": False}
def on_browse_done():
with completion_lock:
remaining["count"] -= 1
if remaining["done"] or remaining["count"] > 0:
return
remaining["done"] = True
execution_duration = int(time_mod.time() - execution_start_time)
database.update_schedule_execution_log(
log_id,
total_accounts=len(account_ids),
success_accounts=started_count,
failed_accounts=len(account_ids) - started_count,
duration_seconds=execution_duration,
status="completed",
)
logger.info(
f"[用户定时任务] 任务#{schedule_id}浏览阶段完成,耗时{execution_duration}秒,等待截图完成后发送邮件"
)
for account_id in account_ids:
account = safe_get_account(user_id, account_id)
if not account:
skipped_count += 1
continue
if account.is_running:
skipped_count += 1
continue
task_source = f"user_scheduled:{batch_id}"
with completion_lock:
remaining["count"] += 1
ok, msg = submit_account_task(
user_id=user_id,
account_id=account_id,
browse_type=browse_type,
enable_screenshot=enable_screenshot,
source=task_source,
done_callback=on_browse_done,
)
if ok:
started_count += 1
else:
with completion_lock:
remaining["count"] -= 1
skipped_count += 1
logger.warning(f"[用户定时任务] 账号 {account.username} 启动失败: {msg}")
batch_info = safe_finalize_batch_after_dispatch(batch_id, started_count, now_ts=time_mod.time())
if batch_info:
_send_batch_task_email_if_configured(batch_info)
database.update_schedule_last_run(schedule_id)
logger.info(f"[用户定时任务] 已启动 {started_count} 个账号,跳过 {skipped_count} 个账号批次ID: {batch_id}")
if started_count <= 0:
database.update_schedule_execution_log(
log_id,
total_accounts=len(account_ids),
success_accounts=0,
failed_accounts=len(account_ids),
duration_seconds=0,
status="completed",
)
if started_count == 0 and len(account_ids) > 0:
logger.warning("[用户定时任务] ⚠️ 警告所有账号都被跳过了请检查user_accounts状态")
except Exception as e: except Exception as e:
logger.exception(f"[用户定时任务] 检查出错: {str(e)}") logger.exception(f"[用户定时任务] 检查出错: {str(e)}")

View File

@@ -6,12 +6,14 @@ import os
import shutil import shutil
import subprocess import subprocess
import time import time
from urllib.parse import urlsplit
import database import database
import email_service import email_service
from api_browser import APIBrowser, get_cookie_jar_path, is_cookie_jar_fresh from api_browser import APIBrowser, get_cookie_jar_path, is_cookie_jar_fresh
from app_config import get_config from app_config import get_config
from app_logger import get_logger from app_logger import get_logger
from app_security import sanitize_filename
from browser_pool_worker import get_browser_worker_pool from browser_pool_worker import get_browser_worker_pool
from services.client_log import log_to_client from services.client_log import log_to_client
from services.runtime import get_socketio from services.runtime import get_socketio
@@ -194,6 +196,293 @@ def _emit(event: str, data: object, *, room: str | None = None) -> None:
pass pass
def _set_screenshot_running_status(user_id: int, account_id: str) -> None:
"""更新账号状态为截图中。"""
acc = safe_get_account(user_id, account_id)
if not acc:
return
acc.status = "截图中"
safe_update_task_status(account_id, {"status": "运行中", "detail_status": "正在截图"})
_emit("account_update", acc.to_dict(), room=f"user_{user_id}")
def _get_worker_display_info(browser_instance) -> tuple[str, int]:
"""获取截图 worker 的展示信息。"""
if isinstance(browser_instance, dict):
return str(browser_instance.get("worker_id", "?")), int(browser_instance.get("use_count", 0) or 0)
return "?", 0
def _get_proxy_context(account) -> tuple[dict | None, str | None]:
"""提取截图阶段代理配置。"""
proxy_config = account.proxy_config if hasattr(account, "proxy_config") else None
proxy_server = proxy_config.get("server") if proxy_config else None
return proxy_config, proxy_server
def _build_screenshot_targets(browse_type: str) -> tuple[str, str, str]:
"""构建截图目标 URL 与页面脚本。"""
parsed = urlsplit(config.ZSGL_LOGIN_URL)
base = f"{parsed.scheme}://{parsed.netloc}"
if "注册前" in str(browse_type):
bz = 0
else:
bz = 0
target_url = f"{base}/admin/center.aspx?bz={bz}"
index_url = config.ZSGL_INDEX_URL or f"{base}/admin/index.aspx"
run_script = (
"(function(){"
"function done(){window.status='ready';}"
"function ensureNav(){try{if(typeof loadMenuTree==='function'){loadMenuTree(true);}}catch(e){}}"
"function expandMenu(){"
"try{var body=document.body;if(body&&body.classList.contains('lay-mini')){body.classList.remove('lay-mini');}}catch(e){}"
"try{if(typeof mainPageResize==='function'){mainPageResize();}}catch(e){}"
"try{if(typeof toggleMainMenu==='function' && document.body && document.body.classList.contains('lay-mini')){toggleMainMenu();}}catch(e){}"
"try{var navRight=document.querySelector('.nav-right');if(navRight){navRight.style.display='block';}}catch(e){}"
"try{var mainNav=document.getElementById('main-nav');if(mainNav){mainNav.style.display='block';}}catch(e){}"
"}"
"function navReady(){"
"try{var nav=document.getElementById('sidebar-nav');return nav && nav.querySelectorAll('a').length>0;}catch(e){return false;}"
"}"
"function frameReady(){"
"try{var f=document.getElementById('mainframe');return f && f.contentDocument && f.contentDocument.readyState==='complete';}catch(e){return false;}"
"}"
"function check(){"
"if(navReady() && frameReady()){done();return;}"
"setTimeout(check,300);"
"}"
"var f=document.getElementById('mainframe');"
"ensureNav();"
"expandMenu();"
"if(!f){done();return;}"
f"f.src='{target_url}';"
"f.onload=function(){ensureNav();expandMenu();setTimeout(check,300);};"
"setTimeout(check,5000);"
"})();"
)
return index_url, target_url, run_script
def _build_screenshot_output_path(username_prefix: str, account, browse_type: str) -> tuple[str, str]:
"""构建截图输出文件名与路径。"""
timestamp = get_beijing_now().strftime("%Y%m%d_%H%M%S")
login_account = account.remark if account.remark else account.username
raw_filename = f"{username_prefix}_{login_account}_{browse_type}_{timestamp}.jpg"
screenshot_filename = sanitize_filename(raw_filename)
return screenshot_filename, os.path.join(SCREENSHOTS_DIR, screenshot_filename)
def _ensure_screenshot_login_state(
*,
account,
proxy_config,
cookie_path: str,
attempt: int,
max_retries: int,
user_id: int,
account_id: str,
custom_log,
) -> str:
"""确保截图前登录态有效。返回: ok/retry/fail。"""
should_refresh_login = not is_cookie_jar_fresh(cookie_path)
if not should_refresh_login:
return "ok"
log_to_client("正在刷新登录态...", user_id, account_id)
if _ensure_login_cookies(account, proxy_config, custom_log):
return "ok"
if attempt > 1:
log_to_client("截图登录失败", user_id, account_id)
if attempt < max_retries:
log_to_client("将重试...", user_id, account_id)
time.sleep(2)
return "retry"
log_to_client("❌ 截图失败: 登录失败", user_id, account_id)
return "fail"
def _take_screenshot_once(
*,
index_url: str,
target_url: str,
screenshot_path: str,
cookie_path: str,
proxy_server: str | None,
run_script: str,
log_callback,
) -> str:
"""执行一次截图尝试并验证输出文件。返回: success/invalid/failed。"""
cookies_for_shot = cookie_path if is_cookie_jar_fresh(cookie_path) else None
attempts = [
{
"url": index_url,
"run_script": run_script,
"window_status": "ready",
},
{
"url": target_url,
"run_script": None,
"window_status": None,
},
]
ok = False
for shot in attempts:
ok = take_screenshot_wkhtmltoimage(
shot["url"],
screenshot_path,
cookies_path=cookies_for_shot,
proxy_server=proxy_server,
run_script=shot["run_script"],
window_status=shot["window_status"],
log_callback=log_callback,
)
if ok:
break
if not ok:
return "failed"
if os.path.exists(screenshot_path) and os.path.getsize(screenshot_path) > 1000:
return "success"
if os.path.exists(screenshot_path):
os.remove(screenshot_path)
return "invalid"
def _get_result_screenshot_path(result) -> str | None:
"""从截图结果中提取截图文件绝对路径。"""
if result and result.get("success") and result.get("filename"):
return os.path.join(SCREENSHOTS_DIR, result["filename"])
return None
def _enqueue_kdocs_upload_if_needed(user_id: int, account_id: str, account, screenshot_path: str | None) -> None:
"""按配置提交金山文档上传任务。"""
if not screenshot_path:
return
cfg = database.get_system_config() or {}
if int(cfg.get("kdocs_enabled", 0) or 0) != 1:
return
doc_url = (cfg.get("kdocs_doc_url") or "").strip()
if not doc_url:
return
user_cfg = database.get_user_kdocs_settings(user_id) or {}
if int(user_cfg.get("kdocs_auto_upload", 0) or 0) != 1:
return
unit = (user_cfg.get("kdocs_unit") or cfg.get("kdocs_default_unit") or "").strip()
name = (account.remark or "").strip()
if not unit:
log_to_client("表格上传跳过: 未配置县区", user_id, account_id)
return
if not name:
log_to_client("表格上传跳过: 账号备注为空", user_id, account_id)
return
from services.kdocs_uploader import get_kdocs_uploader
ok = get_kdocs_uploader().enqueue_upload(
user_id=user_id,
account_id=account_id,
unit=unit,
name=name,
image_path=screenshot_path,
)
if not ok:
log_to_client("表格上传排队失败: 队列已满", user_id, account_id)
def _dispatch_screenshot_result(
*,
user_id: int,
account_id: str,
source: str,
browse_type: str,
browse_result: dict,
result,
account,
user_info,
) -> None:
"""将截图结果发送到批次统计/邮件通知链路。"""
batch_id = _get_batch_id_from_source(source)
screenshot_path = _get_result_screenshot_path(result)
account_name = account.remark if account.remark else account.username
try:
if result and result.get("success") and screenshot_path:
_enqueue_kdocs_upload_if_needed(user_id, account_id, account, screenshot_path)
except Exception as kdocs_error:
logger.warning(f"表格上传任务提交失败: {kdocs_error}")
if batch_id:
_batch_task_record_result(
batch_id=batch_id,
account_name=account_name,
screenshot_path=screenshot_path,
total_items=browse_result.get("total_items", 0),
total_attachments=browse_result.get("total_attachments", 0),
)
return
if source and source.startswith("user_scheduled"):
if user_info and user_info.get("email") and database.get_user_email_notify(user_id):
email_service.send_task_complete_email_async(
user_id=user_id,
email=user_info["email"],
username=user_info["username"],
account_name=account_name,
browse_type=browse_type,
total_items=browse_result.get("total_items", 0),
total_attachments=browse_result.get("total_attachments", 0),
screenshot_path=screenshot_path,
log_callback=lambda msg: log_to_client(msg, user_id, account_id),
)
def _finalize_screenshot_callback_state(user_id: int, account_id: str, account) -> None:
"""截图回调的通用收尾状态变更。"""
account.is_running = False
account.status = "未开始"
safe_remove_task_status(account_id)
_emit("account_update", account.to_dict(), room=f"user_{user_id}")
def _persist_browse_log_after_screenshot(
*,
user_id: int,
account_id: str,
account,
browse_type: str,
source: str,
task_start_time,
browse_result,
) -> None:
"""截图完成后写入任务日志(浏览完成日志)。"""
import time as time_module
total_elapsed = int(time_module.time() - task_start_time)
database.create_task_log(
user_id=user_id,
account_id=account_id,
username=account.username,
browse_type=browse_type,
status="success",
total_items=browse_result.get("total_items", 0),
total_attachments=browse_result.get("total_attachments", 0),
duration=total_elapsed,
source=source,
)
def take_screenshot_for_account( def take_screenshot_for_account(
user_id, user_id,
account_id, account_id,
@@ -213,21 +502,21 @@ def take_screenshot_for_account(
# 标记账号正在截图(防止重复提交截图任务) # 标记账号正在截图(防止重复提交截图任务)
account.is_running = True account.is_running = True
user_info = database.get_user_by_id(user_id)
username_prefix = user_info["username"] if user_info else f"user{user_id}"
def screenshot_task( def screenshot_task(
browser_instance, user_id, account_id, account, browse_type, source, task_start_time, browse_result browser_instance, user_id, account_id, account, browse_type, source, task_start_time, browse_result
): ):
"""在worker线程中执行的截图任务""" """在worker线程中执行的截图任务"""
# ✅ 获得worker后立即更新状态为"截图中" # ✅ 获得worker后立即更新状态为"截图中"
acc = safe_get_account(user_id, account_id) _set_screenshot_running_status(user_id, account_id)
if acc:
acc.status = "截图中"
safe_update_task_status(account_id, {"status": "运行中", "detail_status": "正在截图"})
_emit("account_update", acc.to_dict(), room=f"user_{user_id}")
max_retries = 3 max_retries = 3
proxy_config = account.proxy_config if hasattr(account, "proxy_config") else None proxy_config, proxy_server = _get_proxy_context(account)
proxy_server = proxy_config.get("server") if proxy_config else None
cookie_path = get_cookie_jar_path(account.username) cookie_path = get_cookie_jar_path(account.username)
index_url, target_url, run_script = _build_screenshot_targets(browse_type)
for attempt in range(1, max_retries + 1): for attempt in range(1, max_retries + 1):
try: try:
@@ -239,8 +528,7 @@ def take_screenshot_for_account(
if attempt > 1: if attempt > 1:
log_to_client(f"🔄 第 {attempt} 次截图尝试...", user_id, account_id) log_to_client(f"🔄 第 {attempt} 次截图尝试...", user_id, account_id)
worker_id = browser_instance.get("worker_id", "?") if isinstance(browser_instance, dict) else "?" worker_id, use_count = _get_worker_display_info(browser_instance)
use_count = browser_instance.get("use_count", 0) if isinstance(browser_instance, dict) else 0
log_to_client( log_to_client(
f"使用Worker-{worker_id}执行截图(已执行{use_count}次)", f"使用Worker-{worker_id}执行截图(已执行{use_count}次)",
user_id, user_id,
@@ -250,99 +538,39 @@ def take_screenshot_for_account(
def custom_log(message: str): def custom_log(message: str):
log_to_client(message, user_id, account_id) log_to_client(message, user_id, account_id)
# 智能登录状态检查:只在必要时才刷新登录 login_state = _ensure_screenshot_login_state(
should_refresh_login = not is_cookie_jar_fresh(cookie_path) account=account,
if should_refresh_login and attempt > 1: proxy_config=proxy_config,
# 重试时刷新登录attempt > 1 表示第2次及以后的尝试 cookie_path=cookie_path,
log_to_client("正在刷新登录态...", user_id, account_id) attempt=attempt,
if not _ensure_login_cookies(account, proxy_config, custom_log): max_retries=max_retries,
log_to_client("截图登录失败", user_id, account_id) user_id=user_id,
if attempt < max_retries: account_id=account_id,
log_to_client("将重试...", user_id, account_id) custom_log=custom_log,
time.sleep(2) )
continue if login_state == "retry":
log_to_client("❌ 截图失败: 登录失败", user_id, account_id) continue
return {"success": False, "error": "登录失败"} if login_state == "fail":
elif should_refresh_login: return {"success": False, "error": "登录失败"}
# 首次尝试时快速检查登录状态
log_to_client("正在刷新登录态...", user_id, account_id)
if not _ensure_login_cookies(account, proxy_config, custom_log):
log_to_client("❌ 截图失败: 登录失败", user_id, account_id)
return {"success": False, "error": "登录失败"}
log_to_client(f"导航到 '{browse_type}' 页面...", user_id, account_id) log_to_client(f"导航到 '{browse_type}' 页面...", user_id, account_id)
from urllib.parse import urlsplit screenshot_filename, screenshot_path = _build_screenshot_output_path(username_prefix, account, browse_type)
shot_state = _take_screenshot_once(
parsed = urlsplit(config.ZSGL_LOGIN_URL) index_url=index_url,
base = f"{parsed.scheme}://{parsed.netloc}" target_url=target_url,
if "注册前" in str(browse_type): screenshot_path=screenshot_path,
bz = 0 cookie_path=cookie_path,
else:
bz = 0 # 应读(网站更新后 bz=0 为应读)
target_url = f"{base}/admin/center.aspx?bz={bz}"
index_url = config.ZSGL_INDEX_URL or f"{base}/admin/index.aspx"
run_script = (
"(function(){"
"function done(){window.status='ready';}"
"function ensureNav(){try{if(typeof loadMenuTree==='function'){loadMenuTree(true);}}catch(e){}}"
"function expandMenu(){"
"try{var body=document.body;if(body&&body.classList.contains('lay-mini')){body.classList.remove('lay-mini');}}catch(e){}"
"try{if(typeof mainPageResize==='function'){mainPageResize();}}catch(e){}"
"try{if(typeof toggleMainMenu==='function' && document.body && document.body.classList.contains('lay-mini')){toggleMainMenu();}}catch(e){}"
"try{var navRight=document.querySelector('.nav-right');if(navRight){navRight.style.display='block';}}catch(e){}"
"try{var mainNav=document.getElementById('main-nav');if(mainNav){mainNav.style.display='block';}}catch(e){}"
"}"
"function navReady(){"
"try{var nav=document.getElementById('sidebar-nav');return nav && nav.querySelectorAll('a').length>0;}catch(e){return false;}"
"}"
"function frameReady(){"
"try{var f=document.getElementById('mainframe');return f && f.contentDocument && f.contentDocument.readyState==='complete';}catch(e){return false;}"
"}"
"function check(){"
"if(navReady() && frameReady()){done();return;}"
"setTimeout(check,300);"
"}"
"var f=document.getElementById('mainframe');"
"ensureNav();"
"expandMenu();"
"if(!f){done();return;}"
f"f.src='{target_url}';"
"f.onload=function(){ensureNav();expandMenu();setTimeout(check,300);};"
"setTimeout(check,5000);"
"})();"
)
timestamp = get_beijing_now().strftime("%Y%m%d_%H%M%S")
user_info = database.get_user_by_id(user_id)
username_prefix = user_info["username"] if user_info else f"user{user_id}"
login_account = account.remark if account.remark else account.username
screenshot_filename = f"{username_prefix}_{login_account}_{browse_type}_{timestamp}.jpg"
screenshot_path = os.path.join(SCREENSHOTS_DIR, screenshot_filename)
cookies_for_shot = cookie_path if is_cookie_jar_fresh(cookie_path) else None
if take_screenshot_wkhtmltoimage(
index_url,
screenshot_path,
cookies_path=cookies_for_shot,
proxy_server=proxy_server, proxy_server=proxy_server,
run_script=run_script, run_script=run_script,
window_status="ready",
log_callback=custom_log, log_callback=custom_log,
) or take_screenshot_wkhtmltoimage( )
target_url, if shot_state == "success":
screenshot_path, log_to_client(f"[OK] 截图成功: {screenshot_filename}", user_id, account_id)
cookies_path=cookies_for_shot, return {"success": True, "filename": screenshot_filename}
proxy_server=proxy_server,
log_callback=custom_log, if shot_state == "invalid":
):
if os.path.exists(screenshot_path) and os.path.getsize(screenshot_path) > 1000:
log_to_client(f"[OK] 截图成功: {screenshot_filename}", user_id, account_id)
return {"success": True, "filename": screenshot_filename}
log_to_client("截图文件异常,将重试", user_id, account_id) log_to_client("截图文件异常,将重试", user_id, account_id)
if os.path.exists(screenshot_path):
os.remove(screenshot_path)
else: else:
log_to_client("截图保存失败", user_id, account_id) log_to_client("截图保存失败", user_id, account_id)
@@ -361,12 +589,7 @@ def take_screenshot_for_account(
def screenshot_callback(result, error): def screenshot_callback(result, error):
"""截图完成回调""" """截图完成回调"""
try: try:
account.is_running = False _finalize_screenshot_callback_state(user_id, account_id, account)
account.status = "未开始"
safe_remove_task_status(account_id)
_emit("account_update", account.to_dict(), room=f"user_{user_id}")
if error: if error:
log_to_client(f"❌ 截图失败: {error}", user_id, account_id) log_to_client(f"❌ 截图失败: {error}", user_id, account_id)
@@ -375,84 +598,27 @@ def take_screenshot_for_account(
log_to_client(f"❌ 截图失败: {error_msg}", user_id, account_id) log_to_client(f"❌ 截图失败: {error_msg}", user_id, account_id)
if task_start_time and browse_result: if task_start_time and browse_result:
import time as time_module _persist_browse_log_after_screenshot(
total_elapsed = int(time_module.time() - task_start_time)
database.create_task_log(
user_id=user_id, user_id=user_id,
account_id=account_id, account_id=account_id,
username=account.username, account=account,
browse_type=browse_type, browse_type=browse_type,
status="success",
total_items=browse_result.get("total_items", 0),
total_attachments=browse_result.get("total_attachments", 0),
duration=total_elapsed,
source=source, source=source,
task_start_time=task_start_time,
browse_result=browse_result,
) )
try: try:
batch_id = _get_batch_id_from_source(source) _dispatch_screenshot_result(
user_id=user_id,
screenshot_path = None account_id=account_id,
if result and result.get("success") and result.get("filename"): source=source,
screenshot_path = os.path.join(SCREENSHOTS_DIR, result["filename"]) browse_type=browse_type,
browse_result=browse_result,
account_name = account.remark if account.remark else account.username result=result,
account=account,
try: user_info=user_info,
if screenshot_path and result and result.get("success"): )
cfg = database.get_system_config() or {}
if int(cfg.get("kdocs_enabled", 0) or 0) == 1:
doc_url = (cfg.get("kdocs_doc_url") or "").strip()
if doc_url:
user_cfg = database.get_user_kdocs_settings(user_id) or {}
if int(user_cfg.get("kdocs_auto_upload", 0) or 0) == 1:
unit = (
user_cfg.get("kdocs_unit") or cfg.get("kdocs_default_unit") or ""
).strip()
name = (account.remark or "").strip()
if unit and name:
from services.kdocs_uploader import get_kdocs_uploader
ok = get_kdocs_uploader().enqueue_upload(
user_id=user_id,
account_id=account_id,
unit=unit,
name=name,
image_path=screenshot_path,
)
if not ok:
log_to_client("表格上传排队失败: 队列已满", user_id, account_id)
else:
if not unit:
log_to_client("表格上传跳过: 未配置县区", user_id, account_id)
if not name:
log_to_client("表格上传跳过: 账号备注为空", user_id, account_id)
except Exception as kdocs_error:
logger.warning(f"表格上传任务提交失败: {kdocs_error}")
if batch_id:
_batch_task_record_result(
batch_id=batch_id,
account_name=account_name,
screenshot_path=screenshot_path,
total_items=browse_result.get("total_items", 0),
total_attachments=browse_result.get("total_attachments", 0),
)
elif source and source.startswith("user_scheduled"):
user_info = database.get_user_by_id(user_id)
if user_info and user_info.get("email") and database.get_user_email_notify(user_id):
email_service.send_task_complete_email_async(
user_id=user_id,
email=user_info["email"],
username=user_info["username"],
account_name=account_name,
browse_type=browse_type,
total_items=browse_result.get("total_items", 0),
total_attachments=browse_result.get("total_attachments", 0),
screenshot_path=screenshot_path,
log_callback=lambda msg: log_to_client(msg, user_id, account_id),
)
except Exception as email_error: except Exception as email_error:
logger.warning(f"发送任务完成邮件失败: {email_error}") logger.warning(f"发送任务完成邮件失败: {email_error}")
except Exception as e: except Exception as e:

View File

@@ -13,7 +13,7 @@ from __future__ import annotations
import threading import threading
import time import time
import random import random
from typing import Any, Dict, List, Optional, Tuple from typing import Any, Callable, Dict, List, Optional, Tuple
from app_config import get_config from app_config import get_config
@@ -161,6 +161,36 @@ _log_cache_lock = threading.RLock()
_log_cache_total_count = 0 _log_cache_total_count = 0
def _pop_oldest_log_for_user(uid: int) -> bool:
global _log_cache_total_count
logs = _log_cache.get(uid)
if not logs:
_log_cache.pop(uid, None)
return False
logs.pop(0)
_log_cache_total_count = max(0, _log_cache_total_count - 1)
if not logs:
_log_cache.pop(uid, None)
return True
def _pop_oldest_log_from_largest_user() -> bool:
largest_uid = None
largest_size = 0
for uid, logs in _log_cache.items():
size = len(logs)
if size > largest_size:
largest_uid = uid
largest_size = size
if largest_uid is None or largest_size <= 0:
return False
return _pop_oldest_log_for_user(int(largest_uid))
def safe_add_log( def safe_add_log(
user_id: int, user_id: int,
log_entry: Dict[str, Any], log_entry: Dict[str, Any],
@@ -175,24 +205,17 @@ def safe_add_log(
max_total_logs = int(max_total_logs or config.MAX_TOTAL_LOGS) max_total_logs = int(max_total_logs or config.MAX_TOTAL_LOGS)
with _log_cache_lock: with _log_cache_lock:
if uid not in _log_cache: logs = _log_cache.setdefault(uid, [])
_log_cache[uid] = []
if len(_log_cache[uid]) >= max_logs_per_user: if len(logs) >= max_logs_per_user:
_log_cache[uid].pop(0) _pop_oldest_log_for_user(uid)
_log_cache_total_count = max(0, _log_cache_total_count - 1) logs = _log_cache.setdefault(uid, [])
_log_cache[uid].append(dict(log_entry or {})) logs.append(dict(log_entry or {}))
_log_cache_total_count += 1 _log_cache_total_count += 1
while _log_cache_total_count > max_total_logs: while _log_cache_total_count > max_total_logs:
if not _log_cache: if not _pop_oldest_log_from_largest_user():
break
max_user = max(_log_cache.keys(), key=lambda u: len(_log_cache[u]))
if _log_cache.get(max_user):
_log_cache[max_user].pop(0)
_log_cache_total_count -= 1
else:
break break
@@ -378,6 +401,34 @@ def _get_action_rate_limit(action: str) -> Tuple[int, int]:
return int(config.IP_RATE_LIMIT_LOGIN_MAX), int(config.IP_RATE_LIMIT_LOGIN_WINDOW_SECONDS) return int(config.IP_RATE_LIMIT_LOGIN_MAX), int(config.IP_RATE_LIMIT_LOGIN_WINDOW_SECONDS)
def _format_wait_hint(remaining_seconds: int) -> str:
remaining = max(1, int(remaining_seconds or 0))
if remaining >= 60:
return f"{remaining // 60 + 1}分钟"
return f"{remaining}"
def _check_and_increment_rate_bucket(
*,
buckets: Dict[str, Dict[str, Any]],
key: str,
now_ts: float,
max_requests: int,
window_seconds: int,
) -> Tuple[bool, Optional[str]]:
if not key or int(max_requests) <= 0:
return True, None
data = _get_or_reset_bucket(buckets.get(key), now_ts, window_seconds)
if int(data.get("count", 0) or 0) >= int(max_requests):
remaining = max(1, int(window_seconds - (now_ts - float(data.get("window_start", 0) or 0))))
return False, f"请求过于频繁,请{_format_wait_hint(remaining)}后再试"
data["count"] = int(data.get("count", 0) or 0) + 1
buckets[key] = data
return True, None
def check_ip_request_rate( def check_ip_request_rate(
ip_address: str, ip_address: str,
action: str, action: str,
@@ -392,21 +443,13 @@ def check_ip_request_rate(
key = f"{action}:{ip_address}" key = f"{action}:{ip_address}"
with _ip_request_rate_lock: with _ip_request_rate_lock:
data = _ip_request_rate.get(key) return _check_and_increment_rate_bucket(
if not data or (now_ts - float(data.get("window_start", 0) or 0)) >= window_seconds: buckets=_ip_request_rate,
data = {"window_start": now_ts, "count": 0} key=key,
_ip_request_rate[key] = data now_ts=now_ts,
max_requests=max_requests,
if int(data.get("count", 0) or 0) >= max_requests: window_seconds=window_seconds,
remaining = max(1, int(window_seconds - (now_ts - float(data.get("window_start", 0) or 0)))) )
if remaining >= 60:
wait_hint = f"{remaining // 60 + 1}分钟"
else:
wait_hint = f"{remaining}"
return False, f"请求过于频繁,请{wait_hint}后再试"
data["count"] = int(data.get("count", 0) or 0) + 1
return True, None
def cleanup_expired_ip_request_rates(now_ts: Optional[float] = None) -> int: def cleanup_expired_ip_request_rates(now_ts: Optional[float] = None) -> int:
@@ -417,8 +460,7 @@ def cleanup_expired_ip_request_rates(now_ts: Optional[float] = None) -> int:
data = _ip_request_rate.get(key) or {} data = _ip_request_rate.get(key) or {}
action = key.split(":", 1)[0] action = key.split(":", 1)[0]
_, window_seconds = _get_action_rate_limit(action) _, window_seconds = _get_action_rate_limit(action)
window_start = float(data.get("window_start", 0) or 0) if _is_bucket_expired(data, now_ts, window_seconds):
if now_ts - window_start >= window_seconds:
_ip_request_rate.pop(key, None) _ip_request_rate.pop(key, None)
removed += 1 removed += 1
return removed return removed
@@ -487,6 +529,30 @@ def _get_or_reset_bucket(data: Optional[Dict[str, Any]], now_ts: float, window_s
return data return data
def _is_bucket_expired(
data: Optional[Dict[str, Any]],
now_ts: float,
window_seconds: int,
*,
time_field: str = "window_start",
) -> bool:
start_ts = float((data or {}).get(time_field, 0) or 0)
return (now_ts - start_ts) >= max(1, int(window_seconds))
def _cleanup_map_entries(
store: Dict[Any, Dict[str, Any]],
should_remove: Callable[[Dict[str, Any]], bool],
) -> int:
removed = 0
for key, value in list(store.items()):
item = value if isinstance(value, dict) else {}
if should_remove(item):
store.pop(key, None)
removed += 1
return removed
def record_login_username_attempt(ip_address: str, username: str) -> bool: def record_login_username_attempt(ip_address: str, username: str) -> bool:
now_ts = time.time() now_ts = time.time()
threshold, window_seconds, cooldown_seconds = _get_login_scan_config() threshold, window_seconds, cooldown_seconds = _get_login_scan_config()
@@ -527,26 +593,32 @@ def check_login_rate_limits(ip_address: str, username: str) -> Tuple[bool, Optio
user_key = _normalize_login_key("user", "", username) user_key = _normalize_login_key("user", "", username)
ip_user_key = _normalize_login_key("ipuser", ip_address, username) ip_user_key = _normalize_login_key("ipuser", ip_address, username)
def _check(key: str, max_requests: int) -> Tuple[bool, Optional[str]]:
if not key or max_requests <= 0:
return True, None
data = _get_or_reset_bucket(_login_rate_limits.get(key), now_ts, window_seconds)
if int(data.get("count", 0) or 0) >= max_requests:
remaining = max(1, int(window_seconds - (now_ts - float(data.get("window_start", 0) or 0))))
wait_hint = f"{remaining // 60 + 1}分钟" if remaining >= 60 else f"{remaining}"
return False, f"请求过于频繁,请{wait_hint}后再试"
data["count"] = int(data.get("count", 0) or 0) + 1
_login_rate_limits[key] = data
return True, None
with _login_rate_limits_lock: with _login_rate_limits_lock:
allowed, msg = _check(ip_key, ip_max) allowed, msg = _check_and_increment_rate_bucket(
buckets=_login_rate_limits,
key=ip_key,
now_ts=now_ts,
max_requests=ip_max,
window_seconds=window_seconds,
)
if not allowed: if not allowed:
return False, msg return False, msg
allowed, msg = _check(ip_user_key, ip_user_max) allowed, msg = _check_and_increment_rate_bucket(
buckets=_login_rate_limits,
key=ip_user_key,
now_ts=now_ts,
max_requests=ip_user_max,
window_seconds=window_seconds,
)
if not allowed: if not allowed:
return False, msg return False, msg
allowed, msg = _check(user_key, user_max) allowed, msg = _check_and_increment_rate_bucket(
buckets=_login_rate_limits,
key=user_key,
now_ts=now_ts,
max_requests=user_max,
window_seconds=window_seconds,
)
if not allowed: if not allowed:
return False, msg return False, msg
@@ -622,15 +694,18 @@ def check_login_captcha_required(ip_address: str, username: Optional[str] = None
ip_key = _normalize_login_key("ip", ip_address) ip_key = _normalize_login_key("ip", ip_address)
ip_user_key = _normalize_login_key("ipuser", ip_address, username or "") ip_user_key = _normalize_login_key("ipuser", ip_address, username or "")
def _is_over_threshold(data: Optional[Dict[str, Any]]) -> bool:
if not data:
return False
if (now_ts - float(data.get("first_failed", 0) or 0)) > window_seconds:
return False
return int(data.get("count", 0) or 0) >= max_failures
with _login_failures_lock: with _login_failures_lock:
ip_data = _login_failures.get(ip_key) if _is_over_threshold(_login_failures.get(ip_key)):
if ip_data and (now_ts - float(ip_data.get("first_failed", 0) or 0)) <= window_seconds: return True
if int(ip_data.get("count", 0) or 0) >= max_failures: if _is_over_threshold(_login_failures.get(ip_user_key)):
return True return True
ip_user_data = _login_failures.get(ip_user_key)
if ip_user_data and (now_ts - float(ip_user_data.get("first_failed", 0) or 0)) <= window_seconds:
if int(ip_user_data.get("count", 0) or 0) >= max_failures:
return True
if is_login_scan_locked(ip_address): if is_login_scan_locked(ip_address):
return True return True
@@ -685,6 +760,56 @@ def should_send_login_alert(user_id: int, ip_address: str) -> bool:
return False return False
def cleanup_expired_login_security_state(now_ts: Optional[float] = None) -> Dict[str, int]:
now_ts = float(now_ts if now_ts is not None else time.time())
_, captcha_window = _get_login_captcha_config()
_, _, _, rate_window = _get_login_rate_limit_config()
_, lock_window, _ = _get_login_lock_config()
_, scan_window, _ = _get_login_scan_config()
alert_expire_seconds = max(3600, int(config.LOGIN_ALERT_MIN_INTERVAL_SECONDS) * 3)
with _login_failures_lock:
failures_removed = _cleanup_map_entries(
_login_failures,
lambda data: (now_ts - float(data.get("first_failed", 0) or 0)) > max(captcha_window, lock_window),
)
with _login_rate_limits_lock:
rate_removed = _cleanup_map_entries(
_login_rate_limits,
lambda data: _is_bucket_expired(data, now_ts, rate_window),
)
with _login_scan_lock:
scan_removed = _cleanup_map_entries(
_login_scan_state,
lambda data: (
(now_ts - float(data.get("first_seen", 0) or 0)) > scan_window
and now_ts >= float(data.get("scan_until", 0) or 0)
),
)
with _login_ip_user_lock:
ip_user_locks_removed = _cleanup_map_entries(
_login_ip_user_locks,
lambda data: now_ts >= float(data.get("lock_until", 0) or 0),
)
with _login_alert_lock:
alerts_removed = _cleanup_map_entries(
_login_alert_state,
lambda data: (now_ts - float(data.get("last_sent", 0) or 0)) > alert_expire_seconds,
)
return {
"failures": failures_removed,
"rate_limits": rate_removed,
"scan_states": scan_removed,
"ip_user_locks": ip_user_locks_removed,
"alerts": alerts_removed,
}
# ==================== 邮箱维度限流 ==================== # ==================== 邮箱维度限流 ====================
_email_rate_limit: Dict[str, Dict[str, Any]] = {} _email_rate_limit: Dict[str, Dict[str, Any]] = {}
@@ -701,14 +826,13 @@ def check_email_rate_limit(email: str, action: str) -> Tuple[bool, Optional[str]
key = f"{action}:{email_key}" key = f"{action}:{email_key}"
with _email_rate_limit_lock: with _email_rate_limit_lock:
data = _get_or_reset_bucket(_email_rate_limit.get(key), now_ts, window_seconds) return _check_and_increment_rate_bucket(
if int(data.get("count", 0) or 0) >= max_requests: buckets=_email_rate_limit,
remaining = max(1, int(window_seconds - (now_ts - float(data.get("window_start", 0) or 0)))) key=key,
wait_hint = f"{remaining // 60 + 1}分钟" if remaining >= 60 else f"{remaining}" now_ts=now_ts,
return False, f"请求过于频繁,请{wait_hint}后再试" max_requests=max_requests,
data["count"] = int(data.get("count", 0) or 0) + 1 window_seconds=window_seconds,
_email_rate_limit[key] = data )
return True, None
# ==================== Batch screenshots批次任务截图收集 ==================== # ==================== Batch screenshots批次任务截图收集 ====================

365
services/task_scheduler.py Normal file
View File

@@ -0,0 +1,365 @@
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
from __future__ import annotations
import heapq
import threading
import time
from concurrent.futures import ThreadPoolExecutor, wait
from dataclasses import dataclass
import database
from app_logger import get_logger
from services.state import safe_get_account, safe_get_task, safe_remove_task, safe_set_task
from services.task_batches import _batch_task_record_result, _get_batch_id_from_source
logger = get_logger("app")
# VIP优先级队列仅用于可视化/调试)
vip_task_queue = [] # VIP用户任务队列
normal_task_queue = [] # 普通用户任务队列
task_queue_lock = threading.Lock()
@dataclass
class _TaskRequest:
user_id: int
account_id: str
browse_type: str
enable_screenshot: bool
source: str
retry_count: int
submitted_at: float
is_vip: bool
seq: int
canceled: bool = False
done_callback: object = None
class TaskScheduler:
"""全局任务调度器:队列排队,不为每个任务单独创建线程。"""
def __init__(self, max_global: int, max_per_user: int, max_queue_size: int = 1000, run_task_fn=None):
self.max_global = max(1, int(max_global))
self.max_per_user = max(1, int(max_per_user))
self.max_queue_size = max(1, int(max_queue_size))
self._cond = threading.Condition()
self._pending = [] # heap: (priority, submitted_at, seq, task)
self._pending_by_account = {} # {account_id: task}
self._seq = 0
self._known_account_ids = set()
self._running_global = 0
self._running_by_user = {} # {user_id: running_count}
self._executor_max_workers = self.max_global
self._executor = ThreadPoolExecutor(max_workers=self._executor_max_workers, thread_name_prefix="TaskWorker")
self._futures_lock = threading.Lock()
self._active_futures = set()
self._running = True
self._run_task_fn = run_task_fn
self._dispatcher_thread = threading.Thread(target=self._dispatch_loop, daemon=True, name="TaskDispatcher")
self._dispatcher_thread.start()
def _track_future(self, future) -> None:
with self._futures_lock:
self._active_futures.add(future)
try:
future.add_done_callback(self._untrack_future)
except Exception:
pass
def _untrack_future(self, future) -> None:
with self._futures_lock:
self._active_futures.discard(future)
def shutdown(self, timeout: float = 5.0):
"""停止调度器(用于进程退出清理)"""
with self._cond:
self._running = False
self._cond.notify_all()
try:
self._dispatcher_thread.join(timeout=timeout)
except Exception:
pass
# 等待已提交的任务收尾(最多等待 timeout 秒),避免遗留 active_task 干扰后续调度/测试
try:
deadline = time.time() + max(0.0, float(timeout or 0))
while True:
with self._futures_lock:
pending = [f for f in self._active_futures if not f.done()]
if not pending:
break
remaining = deadline - time.time()
if remaining <= 0:
break
wait(pending, timeout=remaining)
except Exception:
pass
try:
self._executor.shutdown(wait=False)
except Exception:
pass
# 最后兜底:清理本调度器提交过的 active_task避免测试/重启时被“任务已在运行中”误拦截
try:
with self._cond:
known_ids = set(self._known_account_ids) | set(self._pending_by_account.keys())
self._pending.clear()
self._pending_by_account.clear()
self._cond.notify_all()
for account_id in known_ids:
safe_remove_task(account_id)
except Exception:
pass
def update_limits(self, max_global: int = None, max_per_user: int = None, max_queue_size: int = None):
"""动态更新并发/队列上限(不影响已在运行的任务)"""
with self._cond:
if max_per_user is not None:
self.max_per_user = max(1, int(max_per_user))
if max_queue_size is not None:
self.max_queue_size = max(1, int(max_queue_size))
if max_global is not None:
new_max_global = max(1, int(max_global))
self.max_global = new_max_global
if new_max_global > self._executor_max_workers:
# 立即关闭旧线程池,防止资源泄漏
old_executor = self._executor
self._executor_max_workers = new_max_global
self._executor = ThreadPoolExecutor(
max_workers=self._executor_max_workers, thread_name_prefix="TaskWorker"
)
# 立即关闭旧线程池
try:
old_executor.shutdown(wait=False)
logger.info(f"线程池已扩容:{old_executor._max_workers} -> {self._executor_max_workers}")
except Exception as e:
logger.warning(f"关闭旧线程池失败: {e}")
self._cond.notify_all()
def get_queue_state_snapshot(self) -> dict:
"""获取调度器队列/运行状态快照(用于前端展示/监控)。"""
with self._cond:
pending_tasks = [t for t in self._pending_by_account.values() if t and not t.canceled]
pending_tasks.sort(key=lambda t: (0 if t.is_vip else 1, t.submitted_at, t.seq))
positions = {}
for idx, t in enumerate(pending_tasks):
positions[t.account_id] = {"queue_position": idx + 1, "queue_ahead": idx, "is_vip": bool(t.is_vip)}
return {
"pending_total": len(pending_tasks),
"running_total": int(self._running_global),
"running_by_user": dict(self._running_by_user),
"positions": positions,
}
def submit_task(
self,
user_id: int,
account_id: str,
browse_type: str,
enable_screenshot: bool = True,
source: str = "manual",
retry_count: int = 0,
is_vip: bool = None,
done_callback=None,
):
"""提交任务进入队列(返回: (ok, message)"""
if not user_id or not account_id:
return False, "参数错误"
submitted_at = time.time()
if is_vip is None:
try:
is_vip = bool(database.is_user_vip(user_id))
except Exception:
is_vip = False
else:
is_vip = bool(is_vip)
with self._cond:
if not self._running:
return False, "调度器未运行"
if len(self._pending_by_account) >= self.max_queue_size:
return False, "任务队列已满,请稍后再试"
if account_id in self._pending_by_account:
return False, "任务已在队列中"
if safe_get_task(account_id) is not None:
return False, "任务已在运行中"
self._seq += 1
task = _TaskRequest(
user_id=user_id,
account_id=account_id,
browse_type=browse_type,
enable_screenshot=bool(enable_screenshot),
source=source,
retry_count=int(retry_count or 0),
submitted_at=submitted_at,
is_vip=is_vip,
seq=self._seq,
done_callback=done_callback,
)
self._pending_by_account[account_id] = task
self._known_account_ids.add(account_id)
priority = 0 if is_vip else 1
heapq.heappush(self._pending, (priority, task.submitted_at, task.seq, task))
self._cond.notify_all()
# 用于可视化/调试:记录队列
with task_queue_lock:
if is_vip:
vip_task_queue.append(account_id)
else:
normal_task_queue.append(account_id)
return True, "已加入队列"
def cancel_pending_task(self, user_id: int, account_id: str) -> bool:
"""取消尚未开始的排队任务(已运行的任务由 should_stop 控制)"""
canceled_task = None
with self._cond:
task = self._pending_by_account.pop(account_id, None)
if not task:
return False
task.canceled = True
canceled_task = task
self._cond.notify_all()
# 从可视化队列移除
with task_queue_lock:
if account_id in vip_task_queue:
vip_task_queue.remove(account_id)
if account_id in normal_task_queue:
normal_task_queue.remove(account_id)
# 批次任务:取消也要推进完成计数,避免批次缓存常驻
try:
batch_id = _get_batch_id_from_source(canceled_task.source)
if batch_id:
acc = safe_get_account(user_id, account_id)
if acc:
account_name = acc.remark if acc.remark else acc.username
else:
account_name = account_id
_batch_task_record_result(
batch_id=batch_id,
account_name=account_name,
screenshot_path=None,
total_items=0,
total_attachments=0,
)
except Exception:
pass
return True
def _dispatch_loop(self):
while True:
task = None
with self._cond:
if not self._running:
return
if not self._pending or self._running_global >= self.max_global:
self._cond.wait(timeout=0.5)
continue
task = self._pop_next_runnable_locked()
if task is None:
self._cond.wait(timeout=0.5)
continue
self._running_global += 1
self._running_by_user[task.user_id] = self._running_by_user.get(task.user_id, 0) + 1
# 从队列移除(可视化)
with task_queue_lock:
if task.account_id in vip_task_queue:
vip_task_queue.remove(task.account_id)
if task.account_id in normal_task_queue:
normal_task_queue.remove(task.account_id)
try:
future = self._executor.submit(self._run_task_wrapper, task)
self._track_future(future)
safe_set_task(task.account_id, future)
except Exception:
with self._cond:
self._running_global = max(0, self._running_global - 1)
# 使用默认值 0 与增加时保持一致
self._running_by_user[task.user_id] = max(0, self._running_by_user.get(task.user_id, 0) - 1)
if self._running_by_user.get(task.user_id) == 0:
self._running_by_user.pop(task.user_id, None)
self._cond.notify_all()
def _pop_next_runnable_locked(self):
"""在锁内从优先队列取出“可运行”的任务避免VIP任务占位阻塞普通任务。"""
if not self._pending:
return None
skipped = []
selected = None
while self._pending:
_, _, _, task = heapq.heappop(self._pending)
if task.canceled:
continue
if self._pending_by_account.get(task.account_id) is not task:
continue
running_for_user = self._running_by_user.get(task.user_id, 0)
if running_for_user >= self.max_per_user:
skipped.append(task)
continue
selected = task
break
for t in skipped:
priority = 0 if t.is_vip else 1
heapq.heappush(self._pending, (priority, t.submitted_at, t.seq, t))
if selected is None:
return None
self._pending_by_account.pop(selected.account_id, None)
return selected
def _run_task_wrapper(self, task: _TaskRequest):
try:
if callable(self._run_task_fn):
self._run_task_fn(
user_id=task.user_id,
account_id=task.account_id,
browse_type=task.browse_type,
enable_screenshot=task.enable_screenshot,
source=task.source,
retry_count=task.retry_count,
)
finally:
try:
if callable(task.done_callback):
task.done_callback()
except Exception:
pass
safe_remove_task(task.account_id)
with self._cond:
self._running_global = max(0, self._running_global - 1)
# 使用默认值 0 与增加时保持一致
self._running_by_user[task.user_id] = max(0, self._running_by_user.get(task.user_id, 0) - 1)
if self._running_by_user.get(task.user_id) == 0:
self._running_by_user.pop(task.user_id, None)
self._cond.notify_all()

File diff suppressed because it is too large Load Diff