refactor: optimize structure, stability and runtime performance

This commit is contained in:
2026-02-07 00:35:11 +08:00
parent fae21329d7
commit bf29ac1924
44 changed files with 6894 additions and 4792 deletions

191
app.py
View File

@@ -173,10 +173,28 @@ def serve_static(filename):
if not is_safe_path("static", filename):
return jsonify({"error": "非法路径"}), 403
response = send_from_directory("static", filename)
response.headers["Cache-Control"] = "no-store, no-cache, must-revalidate, max-age=0"
response.headers["Pragma"] = "no-cache"
response.headers["Expires"] = "0"
cache_ttl = 3600
lowered = filename.lower()
if "/assets/" in lowered or lowered.endswith((".js", ".css", ".woff", ".woff2", ".ttf", ".svg")):
cache_ttl = 604800 # 7天
if request.args.get("v"):
cache_ttl = max(cache_ttl, 604800)
response = send_from_directory("static", filename, max_age=cache_ttl, conditional=True)
# 协商缓存:确保存在 ETag并基于 If-None-Match/If-Modified-Since 返回 304
try:
response.add_etag(overwrite=False)
except Exception:
pass
try:
response.make_conditional(request)
except Exception:
pass
response.headers.setdefault("Vary", "Accept-Encoding")
response.headers["Cache-Control"] = f"public, max-age={cache_ttl}"
return response
@@ -232,6 +250,93 @@ def _signal_handler(sig, frame):
sys.exit(0)
def _cleanup_stale_task_state() -> None:
logger.info("清理遗留任务状态...")
try:
from services.state import safe_get_active_task_ids, safe_remove_task, safe_remove_task_status
for _, accounts in safe_iter_user_accounts_items():
for acc in accounts.values():
if not getattr(acc, "is_running", False):
continue
acc.is_running = False
acc.should_stop = False
acc.status = "未开始"
for account_id in list(safe_get_active_task_ids()):
safe_remove_task(account_id)
safe_remove_task_status(account_id)
logger.info("[OK] 遗留任务状态已清理")
except Exception as e:
logger.warning(f"清理遗留任务状态失败: {e}")
def _init_optional_email_service() -> None:
try:
email_service.init_email_service()
logger.info("[OK] 邮件服务已初始化")
except Exception as e:
logger.warning(f"警告: 邮件服务初始化失败: {e}")
def _load_and_apply_scheduler_limits() -> None:
try:
system_config = database.get_system_config() or {}
max_concurrent_global = int(system_config.get("max_concurrent_global", config.MAX_CONCURRENT_GLOBAL))
max_concurrent_per_account = int(system_config.get("max_concurrent_per_account", config.MAX_CONCURRENT_PER_ACCOUNT))
get_task_scheduler().update_limits(max_global=max_concurrent_global, max_per_user=max_concurrent_per_account)
logger.info(f"[OK] 已加载并发配置: 全局={max_concurrent_global}, 单账号={max_concurrent_per_account}")
except Exception as e:
logger.warning(f"警告: 加载并发配置失败,使用默认值: {e}")
def _start_background_workers() -> None:
logger.info("启动定时任务调度器...")
threading.Thread(target=scheduled_task_worker, daemon=True, name="scheduled-task-worker").start()
logger.info("[OK] 定时任务调度器已启动")
logger.info("[OK] 状态推送线程已启动默认2秒/次)")
threading.Thread(target=status_push_worker, daemon=True, name="status-push-worker").start()
def _init_screenshot_worker_pool() -> None:
try:
pool_size = int((database.get_system_config() or {}).get("max_screenshot_concurrent", 3))
except Exception:
pool_size = 3
try:
logger.info(f"初始化截图线程池({pool_size}个worker按需启动执行环境空闲5分钟后自动释放...")
init_browser_worker_pool(pool_size=pool_size)
logger.info("[OK] 截图线程池初始化完成")
except Exception as e:
logger.warning(f"警告: 截图线程池初始化失败: {e}")
def _warmup_api_connection() -> None:
logger.info("预热 API 连接...")
try:
from api_browser import warmup_api_connection
threading.Thread(
target=warmup_api_connection,
kwargs={"log_callback": lambda msg: logger.info(msg)},
daemon=True,
name="api-warmup",
).start()
except Exception as e:
logger.warning(f"API 预热失败: {e}")
def _log_startup_urls() -> None:
logger.info("服务器启动中...")
logger.info(f"用户访问地址: http://{config.SERVER_HOST}:{config.SERVER_PORT}")
logger.info(f"后台管理地址: http://{config.SERVER_HOST}:{config.SERVER_PORT}/yuyx")
logger.info("默认管理员: admin (首次运行随机密码见日志)")
logger.info("=" * 60)
if __name__ == "__main__":
atexit.register(cleanup_on_exit)
signal.signal(signal.SIGINT, _signal_handler)
@@ -245,81 +350,17 @@ if __name__ == "__main__":
init_checkpoint_manager()
logger.info("[OK] 任务断点管理器已初始化")
# 【新增】容器重启时清理遗留的任务状态
logger.info("清理遗留任务状态...")
try:
from services.state import safe_remove_task, safe_get_active_task_ids, safe_remove_task_status
# 重置所有账号的运行状态
for _, accounts in safe_iter_user_accounts_items():
for acc in accounts.values():
if getattr(acc, "is_running", False):
acc.is_running = False
acc.should_stop = False
acc.status = "未开始"
# 清理活跃任务句柄
for account_id in list(safe_get_active_task_ids()):
safe_remove_task(account_id)
safe_remove_task_status(account_id)
logger.info("[OK] 遗留任务状态已清理")
except Exception as e:
logger.warning(f"清理遗留任务状态失败: {e}")
try:
email_service.init_email_service()
logger.info("[OK] 邮件服务已初始化")
except Exception as e:
logger.warning(f"警告: 邮件服务初始化失败: {e}")
_cleanup_stale_task_state()
_init_optional_email_service()
start_cleanup_scheduler()
start_kdocs_monitor()
try:
system_config = database.get_system_config() or {}
max_concurrent_global = int(system_config.get("max_concurrent_global", config.MAX_CONCURRENT_GLOBAL))
max_concurrent_per_account = int(system_config.get("max_concurrent_per_account", config.MAX_CONCURRENT_PER_ACCOUNT))
get_task_scheduler().update_limits(max_global=max_concurrent_global, max_per_user=max_concurrent_per_account)
logger.info(f"[OK] 已加载并发配置: 全局={max_concurrent_global}, 单账号={max_concurrent_per_account}")
except Exception as e:
logger.warning(f"警告: 加载并发配置失败,使用默认值: {e}")
logger.info("启动定时任务调度器...")
threading.Thread(target=scheduled_task_worker, daemon=True, name="scheduled-task-worker").start()
logger.info("[OK] 定时任务调度器已启动")
logger.info("[OK] 状态推送线程已启动默认2秒/次)")
threading.Thread(target=status_push_worker, daemon=True, name="status-push-worker").start()
logger.info("服务器启动中...")
logger.info(f"用户访问地址: http://{config.SERVER_HOST}:{config.SERVER_PORT}")
logger.info(f"后台管理地址: http://{config.SERVER_HOST}:{config.SERVER_PORT}/yuyx")
logger.info("默认管理员: admin (首次运行随机密码见日志)")
logger.info("=" * 60)
try:
pool_size = int((database.get_system_config() or {}).get("max_screenshot_concurrent", 3))
except Exception:
pool_size = 3
try:
logger.info(f"初始化截图线程池({pool_size}个worker按需启动执行环境空闲5分钟后自动释放...")
init_browser_worker_pool(pool_size=pool_size)
logger.info("[OK] 截图线程池初始化完成")
except Exception as e:
logger.warning(f"警告: 截图线程池初始化失败: {e}")
# 预热 API 连接(后台进行,不阻塞启动)
logger.info("预热 API 连接...")
try:
from api_browser import warmup_api_connection
import threading
threading.Thread(
target=warmup_api_connection,
kwargs={"log_callback": lambda msg: logger.info(msg)},
daemon=True,
name="api-warmup",
).start()
except Exception as e:
logger.warning(f"API 预热失败: {e}")
_load_and_apply_scheduler_limits()
_start_background_workers()
_log_startup_urls()
_init_screenshot_worker_pool()
_warmup_api_connection()
socketio.run(
app,

View File

@@ -120,7 +120,7 @@ config = get_config()
DB_FILE = config.DB_FILE
# 数据库版本 (用于迁移管理)
DB_VERSION = 17
DB_VERSION = 18
# ==================== 系统配置缓存P1 / O-03 ====================
@@ -142,6 +142,37 @@ def invalidate_system_config_cache() -> None:
_system_config_cache_loaded_at = 0.0
def _normalize_system_config_value(value) -> dict:
try:
return dict(value or {})
except Exception:
return {}
def _is_system_config_cache_valid(now_ts: float) -> bool:
if _system_config_cache_value is None:
return False
if _SYSTEM_CONFIG_CACHE_TTL_SECONDS <= 0:
return True
return (now_ts - _system_config_cache_loaded_at) < _SYSTEM_CONFIG_CACHE_TTL_SECONDS
def _read_system_config_cache(now_ts: float, *, ignore_ttl: bool = False) -> Optional[dict]:
with _system_config_cache_lock:
if _system_config_cache_value is None:
return None
if (not ignore_ttl) and (not _is_system_config_cache_valid(now_ts)):
return None
return dict(_system_config_cache_value)
def _write_system_config_cache(value: dict, now_ts: float) -> None:
global _system_config_cache_value, _system_config_cache_loaded_at
with _system_config_cache_lock:
_system_config_cache_value = dict(value)
_system_config_cache_loaded_at = now_ts
def init_database():
"""初始化数据库表结构 + 迁移(入口统一)。"""
db_pool.init_pool(DB_FILE, pool_size=config.DB_POOL_SIZE)
@@ -165,19 +196,21 @@ def migrate_database():
def get_system_config():
"""获取系统配置(带进程内缓存)。"""
global _system_config_cache_value, _system_config_cache_loaded_at
now_ts = time.time()
with _system_config_cache_lock:
if _system_config_cache_value is not None:
if _SYSTEM_CONFIG_CACHE_TTL_SECONDS <= 0 or (now_ts - _system_config_cache_loaded_at) < _SYSTEM_CONFIG_CACHE_TTL_SECONDS:
return dict(_system_config_cache_value)
value = _get_system_config_raw()
cached_value = _read_system_config_cache(now_ts)
if cached_value is not None:
return cached_value
with _system_config_cache_lock:
_system_config_cache_value = dict(value)
_system_config_cache_loaded_at = now_ts
try:
value = _normalize_system_config_value(_get_system_config_raw())
except Exception:
fallback_value = _read_system_config_cache(now_ts, ignore_ttl=True)
if fallback_value is not None:
return fallback_value
raise
_write_system_config_cache(value, now_ts)
return dict(value)

View File

@@ -6,19 +6,51 @@ import db_pool
from crypto_utils import decrypt_password, encrypt_password
from db.utils import get_cst_now_str
_ACCOUNT_STATUS_QUERY_SQL = """
SELECT status, login_fail_count, last_login_error
FROM accounts
WHERE id = ?
"""
def _decode_account_password(account_dict: dict) -> dict:
account_dict["password"] = decrypt_password(account_dict.get("password", ""))
return account_dict
def _normalize_account_ids(account_ids) -> list[str]:
normalized = []
seen = set()
for account_id in account_ids or []:
if not account_id:
continue
account_key = str(account_id)
if account_key in seen:
continue
seen.add(account_key)
normalized.append(account_key)
return normalized
def create_account(user_id, account_id, username, password, remember=True, remark=""):
"""创建账号(密码加密存储)"""
with db_pool.get_db() as conn:
cursor = conn.cursor()
cst_time = get_cst_now_str()
encrypted_password = encrypt_password(password)
cursor.execute(
"""
INSERT INTO accounts (id, user_id, username, password, remember, remark, created_at)
VALUES (?, ?, ?, ?, ?, ?, ?)
""",
(account_id, user_id, username, encrypted_password, 1 if remember else 0, remark, cst_time),
(
account_id,
user_id,
username,
encrypted_password,
1 if remember else 0,
remark,
get_cst_now_str(),
),
)
conn.commit()
return cursor.lastrowid
@@ -29,12 +61,7 @@ def get_user_accounts(user_id):
with db_pool.get_db() as conn:
cursor = conn.cursor()
cursor.execute("SELECT * FROM accounts WHERE user_id = ? ORDER BY created_at DESC", (user_id,))
accounts = []
for row in cursor.fetchall():
account = dict(row)
account["password"] = decrypt_password(account.get("password", ""))
accounts.append(account)
return accounts
return [_decode_account_password(dict(row)) for row in cursor.fetchall()]
def get_account(account_id):
@@ -43,11 +70,9 @@ def get_account(account_id):
cursor = conn.cursor()
cursor.execute("SELECT * FROM accounts WHERE id = ?", (account_id,))
row = cursor.fetchone()
if row:
account = dict(row)
account["password"] = decrypt_password(account.get("password", ""))
return account
return None
if not row:
return None
return _decode_account_password(dict(row))
def update_account_remark(account_id, remark):
@@ -78,33 +103,21 @@ def increment_account_login_fail(account_id, error_message):
if not row:
return False
fail_count = (row["login_fail_count"] or 0) + 1
if fail_count >= 3:
cursor.execute(
"""
UPDATE accounts
SET login_fail_count = ?,
last_login_error = ?,
status = 'suspended'
WHERE id = ?
""",
(fail_count, error_message, account_id),
)
conn.commit()
return True
fail_count = int(row["login_fail_count"] or 0) + 1
is_suspended = fail_count >= 3
cursor.execute(
"""
UPDATE accounts
SET login_fail_count = ?,
last_login_error = ?
last_login_error = ?,
status = CASE WHEN ? = 1 THEN 'suspended' ELSE status END
WHERE id = ?
""",
(fail_count, error_message, account_id),
(fail_count, error_message, 1 if is_suspended else 0, account_id),
)
conn.commit()
return False
return is_suspended
def reset_account_login_status(account_id):
@@ -129,29 +142,22 @@ def get_account_status(account_id):
"""获取账号状态信息"""
with db_pool.get_db() as conn:
cursor = conn.cursor()
cursor.execute(
"""
SELECT status, login_fail_count, last_login_error
FROM accounts
WHERE id = ?
""",
(account_id,),
)
cursor.execute(_ACCOUNT_STATUS_QUERY_SQL, (account_id,))
return cursor.fetchone()
def get_account_status_batch(account_ids):
"""批量获取账号状态信息"""
account_ids = [str(account_id) for account_id in (account_ids or []) if account_id]
if not account_ids:
normalized_ids = _normalize_account_ids(account_ids)
if not normalized_ids:
return {}
results = {}
chunk_size = 900 # 避免触发 SQLite 绑定参数上限
with db_pool.get_db() as conn:
cursor = conn.cursor()
for idx in range(0, len(account_ids), chunk_size):
chunk = account_ids[idx : idx + chunk_size]
for idx in range(0, len(normalized_ids), chunk_size):
chunk = normalized_ids[idx : idx + chunk_size]
placeholders = ",".join("?" for _ in chunk)
cursor.execute(
f"""

View File

@@ -3,9 +3,6 @@
from __future__ import annotations
import sqlite3
from datetime import datetime, timedelta
import pytz
import db_pool
from db.utils import get_cst_now_str
@@ -16,6 +13,99 @@ from password_utils import (
verify_password_sha256,
)
_DEFAULT_SYSTEM_CONFIG = {
"max_concurrent_global": 2,
"max_concurrent_per_account": 1,
"max_screenshot_concurrent": 3,
"schedule_enabled": 0,
"schedule_time": "02:00",
"schedule_browse_type": "应读",
"schedule_weekdays": "1,2,3,4,5,6,7",
"proxy_enabled": 0,
"proxy_api_url": "",
"proxy_expire_minutes": 3,
"enable_screenshot": 1,
"auto_approve_enabled": 0,
"auto_approve_hourly_limit": 10,
"auto_approve_vip_days": 7,
"kdocs_enabled": 0,
"kdocs_doc_url": "",
"kdocs_default_unit": "",
"kdocs_sheet_name": "",
"kdocs_sheet_index": 0,
"kdocs_unit_column": "A",
"kdocs_image_column": "D",
"kdocs_admin_notify_enabled": 0,
"kdocs_admin_notify_email": "",
"kdocs_row_start": 0,
"kdocs_row_end": 0,
}
_SYSTEM_CONFIG_UPDATERS = (
("max_concurrent_global", "max_concurrent"),
("schedule_enabled", "schedule_enabled"),
("schedule_time", "schedule_time"),
("schedule_browse_type", "schedule_browse_type"),
("schedule_weekdays", "schedule_weekdays"),
("max_concurrent_per_account", "max_concurrent_per_account"),
("max_screenshot_concurrent", "max_screenshot_concurrent"),
("enable_screenshot", "enable_screenshot"),
("proxy_enabled", "proxy_enabled"),
("proxy_api_url", "proxy_api_url"),
("proxy_expire_minutes", "proxy_expire_minutes"),
("auto_approve_enabled", "auto_approve_enabled"),
("auto_approve_hourly_limit", "auto_approve_hourly_limit"),
("auto_approve_vip_days", "auto_approve_vip_days"),
("kdocs_enabled", "kdocs_enabled"),
("kdocs_doc_url", "kdocs_doc_url"),
("kdocs_default_unit", "kdocs_default_unit"),
("kdocs_sheet_name", "kdocs_sheet_name"),
("kdocs_sheet_index", "kdocs_sheet_index"),
("kdocs_unit_column", "kdocs_unit_column"),
("kdocs_image_column", "kdocs_image_column"),
("kdocs_admin_notify_enabled", "kdocs_admin_notify_enabled"),
("kdocs_admin_notify_email", "kdocs_admin_notify_email"),
("kdocs_row_start", "kdocs_row_start"),
("kdocs_row_end", "kdocs_row_end"),
)
def _count_scalar(cursor, sql: str, params=()) -> int:
cursor.execute(sql, params)
row = cursor.fetchone()
if not row:
return 0
try:
if "count" in row.keys():
return int(row["count"] or 0)
except Exception:
pass
try:
return int(row[0] or 0)
except Exception:
return 0
def _table_exists(cursor, table_name: str) -> bool:
cursor.execute(
"""
SELECT name FROM sqlite_master
WHERE type='table' AND name=?
""",
(table_name,),
)
return bool(cursor.fetchone())
def _normalize_days(days, default: int = 30) -> int:
try:
value = int(days)
except Exception:
value = default
if value < 0:
return 0
return value
def ensure_default_admin() -> bool:
"""确保存在默认管理员账号(行为保持不变)。"""
@@ -24,10 +114,9 @@ def ensure_default_admin() -> bool:
with db_pool.get_db() as conn:
cursor = conn.cursor()
cursor.execute("SELECT COUNT(*) as count FROM admins")
result = cursor.fetchone()
count = _count_scalar(cursor, "SELECT COUNT(*) as count FROM admins")
if result["count"] == 0:
if count == 0:
alphabet = string.ascii_letters + string.digits
random_password = "".join(secrets.choice(alphabet) for _ in range(12))
@@ -101,41 +190,33 @@ def get_system_stats() -> dict:
with db_pool.get_db() as conn:
cursor = conn.cursor()
cursor.execute("SELECT COUNT(*) as count FROM users")
total_users = cursor.fetchone()["count"]
cursor.execute("SELECT COUNT(*) as count FROM users WHERE status = 'approved'")
approved_users = cursor.fetchone()["count"]
cursor.execute(
total_users = _count_scalar(cursor, "SELECT COUNT(*) as count FROM users")
approved_users = _count_scalar(cursor, "SELECT COUNT(*) as count FROM users WHERE status = 'approved'")
new_users_today = _count_scalar(
cursor,
"""
SELECT COUNT(*) as count
FROM users
WHERE date(created_at) = date('now', 'localtime')
"""
""",
)
new_users_today = cursor.fetchone()["count"]
cursor.execute(
new_users_7d = _count_scalar(
cursor,
"""
SELECT COUNT(*) as count
FROM users
WHERE datetime(created_at) >= datetime('now', 'localtime', '-7 days')
"""
""",
)
new_users_7d = cursor.fetchone()["count"]
cursor.execute("SELECT COUNT(*) as count FROM accounts")
total_accounts = cursor.fetchone()["count"]
cursor.execute(
total_accounts = _count_scalar(cursor, "SELECT COUNT(*) as count FROM accounts")
vip_users = _count_scalar(
cursor,
"""
SELECT COUNT(*) as count FROM users
WHERE vip_expire_time IS NOT NULL
AND datetime(vip_expire_time) > datetime('now', 'localtime')
"""
""",
)
vip_users = cursor.fetchone()["count"]
return {
"total_users": total_users,
@@ -153,37 +234,9 @@ def get_system_config_raw() -> dict:
cursor = conn.cursor()
cursor.execute("SELECT * FROM system_config WHERE id = 1")
row = cursor.fetchone()
if row:
return dict(row)
return {
"max_concurrent_global": 2,
"max_concurrent_per_account": 1,
"max_screenshot_concurrent": 3,
"schedule_enabled": 0,
"schedule_time": "02:00",
"schedule_browse_type": "应读",
"schedule_weekdays": "1,2,3,4,5,6,7",
"proxy_enabled": 0,
"proxy_api_url": "",
"proxy_expire_minutes": 3,
"enable_screenshot": 1,
"auto_approve_enabled": 0,
"auto_approve_hourly_limit": 10,
"auto_approve_vip_days": 7,
"kdocs_enabled": 0,
"kdocs_doc_url": "",
"kdocs_default_unit": "",
"kdocs_sheet_name": "",
"kdocs_sheet_index": 0,
"kdocs_unit_column": "A",
"kdocs_image_column": "D",
"kdocs_admin_notify_enabled": 0,
"kdocs_admin_notify_email": "",
"kdocs_row_start": 0,
"kdocs_row_end": 0,
}
return dict(_DEFAULT_SYSTEM_CONFIG)
def update_system_config(
@@ -215,127 +268,51 @@ def update_system_config(
kdocs_row_end=None,
) -> bool:
"""更新系统配置仅更新DB不做缓存处理"""
allowed_fields = {
"max_concurrent_global",
"schedule_enabled",
"schedule_time",
"schedule_browse_type",
"schedule_weekdays",
"max_concurrent_per_account",
"max_screenshot_concurrent",
"enable_screenshot",
"proxy_enabled",
"proxy_api_url",
"proxy_expire_minutes",
"auto_approve_enabled",
"auto_approve_hourly_limit",
"auto_approve_vip_days",
"kdocs_enabled",
"kdocs_doc_url",
"kdocs_default_unit",
"kdocs_sheet_name",
"kdocs_sheet_index",
"kdocs_unit_column",
"kdocs_image_column",
"kdocs_admin_notify_enabled",
"kdocs_admin_notify_email",
"kdocs_row_start",
"kdocs_row_end",
"updated_at",
arg_values = {
"max_concurrent": max_concurrent,
"schedule_enabled": schedule_enabled,
"schedule_time": schedule_time,
"schedule_browse_type": schedule_browse_type,
"schedule_weekdays": schedule_weekdays,
"max_concurrent_per_account": max_concurrent_per_account,
"max_screenshot_concurrent": max_screenshot_concurrent,
"enable_screenshot": enable_screenshot,
"proxy_enabled": proxy_enabled,
"proxy_api_url": proxy_api_url,
"proxy_expire_minutes": proxy_expire_minutes,
"auto_approve_enabled": auto_approve_enabled,
"auto_approve_hourly_limit": auto_approve_hourly_limit,
"auto_approve_vip_days": auto_approve_vip_days,
"kdocs_enabled": kdocs_enabled,
"kdocs_doc_url": kdocs_doc_url,
"kdocs_default_unit": kdocs_default_unit,
"kdocs_sheet_name": kdocs_sheet_name,
"kdocs_sheet_index": kdocs_sheet_index,
"kdocs_unit_column": kdocs_unit_column,
"kdocs_image_column": kdocs_image_column,
"kdocs_admin_notify_enabled": kdocs_admin_notify_enabled,
"kdocs_admin_notify_email": kdocs_admin_notify_email,
"kdocs_row_start": kdocs_row_start,
"kdocs_row_end": kdocs_row_end,
}
updates = []
params = []
for db_field, arg_name in _SYSTEM_CONFIG_UPDATERS:
value = arg_values.get(arg_name)
if value is None:
continue
updates.append(f"{db_field} = ?")
params.append(value)
if not updates:
return False
updates.append("updated_at = ?")
params.append(get_cst_now_str())
with db_pool.get_db() as conn:
cursor = conn.cursor()
updates = []
params = []
if max_concurrent is not None:
updates.append("max_concurrent_global = ?")
params.append(max_concurrent)
if schedule_enabled is not None:
updates.append("schedule_enabled = ?")
params.append(schedule_enabled)
if schedule_time is not None:
updates.append("schedule_time = ?")
params.append(schedule_time)
if schedule_browse_type is not None:
updates.append("schedule_browse_type = ?")
params.append(schedule_browse_type)
if max_concurrent_per_account is not None:
updates.append("max_concurrent_per_account = ?")
params.append(max_concurrent_per_account)
if max_screenshot_concurrent is not None:
updates.append("max_screenshot_concurrent = ?")
params.append(max_screenshot_concurrent)
if enable_screenshot is not None:
updates.append("enable_screenshot = ?")
params.append(enable_screenshot)
if schedule_weekdays is not None:
updates.append("schedule_weekdays = ?")
params.append(schedule_weekdays)
if proxy_enabled is not None:
updates.append("proxy_enabled = ?")
params.append(proxy_enabled)
if proxy_api_url is not None:
updates.append("proxy_api_url = ?")
params.append(proxy_api_url)
if proxy_expire_minutes is not None:
updates.append("proxy_expire_minutes = ?")
params.append(proxy_expire_minutes)
if auto_approve_enabled is not None:
updates.append("auto_approve_enabled = ?")
params.append(auto_approve_enabled)
if auto_approve_hourly_limit is not None:
updates.append("auto_approve_hourly_limit = ?")
params.append(auto_approve_hourly_limit)
if auto_approve_vip_days is not None:
updates.append("auto_approve_vip_days = ?")
params.append(auto_approve_vip_days)
if kdocs_enabled is not None:
updates.append("kdocs_enabled = ?")
params.append(kdocs_enabled)
if kdocs_doc_url is not None:
updates.append("kdocs_doc_url = ?")
params.append(kdocs_doc_url)
if kdocs_default_unit is not None:
updates.append("kdocs_default_unit = ?")
params.append(kdocs_default_unit)
if kdocs_sheet_name is not None:
updates.append("kdocs_sheet_name = ?")
params.append(kdocs_sheet_name)
if kdocs_sheet_index is not None:
updates.append("kdocs_sheet_index = ?")
params.append(kdocs_sheet_index)
if kdocs_unit_column is not None:
updates.append("kdocs_unit_column = ?")
params.append(kdocs_unit_column)
if kdocs_image_column is not None:
updates.append("kdocs_image_column = ?")
params.append(kdocs_image_column)
if kdocs_admin_notify_enabled is not None:
updates.append("kdocs_admin_notify_enabled = ?")
params.append(kdocs_admin_notify_enabled)
if kdocs_admin_notify_email is not None:
updates.append("kdocs_admin_notify_email = ?")
params.append(kdocs_admin_notify_email)
if kdocs_row_start is not None:
updates.append("kdocs_row_start = ?")
params.append(kdocs_row_start)
if kdocs_row_end is not None:
updates.append("kdocs_row_end = ?")
params.append(kdocs_row_end)
if not updates:
return False
updates.append("updated_at = ?")
params.append(get_cst_now_str())
for update_clause in updates:
field_name = update_clause.split("=")[0].strip()
if field_name not in allowed_fields:
raise ValueError(f"非法字段名: {field_name}")
sql = f"UPDATE system_config SET {', '.join(updates)} WHERE id = 1"
cursor.execute(sql, params)
conn.commit()
@@ -346,13 +323,13 @@ def get_hourly_registration_count() -> int:
"""获取最近一小时内的注册用户数"""
with db_pool.get_db() as conn:
cursor = conn.cursor()
cursor.execute(
return _count_scalar(
cursor,
"""
SELECT COUNT(*) FROM users
SELECT COUNT(*) as count FROM users
WHERE created_at >= datetime('now', 'localtime', '-1 hour')
"""
""",
)
return cursor.fetchone()[0]
# ==================== 密码重置(管理员) ====================
@@ -374,17 +351,12 @@ def admin_reset_user_password(user_id: int, new_password: str) -> bool:
def clean_old_operation_logs(days: int = 30) -> int:
"""清理指定天数前的操作日志如果存在operation_logs表"""
safe_days = _normalize_days(days, default=30)
with db_pool.get_db() as conn:
cursor = conn.cursor()
cursor.execute(
"""
SELECT name FROM sqlite_master
WHERE type='table' AND name='operation_logs'
"""
)
if not cursor.fetchone():
if not _table_exists(cursor, "operation_logs"):
return 0
try:
@@ -393,11 +365,11 @@ def clean_old_operation_logs(days: int = 30) -> int:
DELETE FROM operation_logs
WHERE created_at < datetime('now', 'localtime', '-' || ? || ' days')
""",
(days,),
(safe_days,),
)
deleted_count = cursor.rowcount
conn.commit()
print(f"已清理 {deleted_count} 条旧操作日志 (>{days}天)")
print(f"已清理 {deleted_count} 条旧操作日志 (>{safe_days}天)")
return deleted_count
except Exception as e:
print(f"清理旧操作日志失败: {e}")

View File

@@ -6,12 +6,38 @@ import db_pool
from db.utils import get_cst_now_str
def _normalize_limit(value, default: int, *, minimum: int = 1, maximum: int = 500) -> int:
try:
parsed = int(value)
except Exception:
parsed = default
parsed = max(minimum, parsed)
parsed = min(maximum, parsed)
return parsed
def _normalize_offset(value, default: int = 0) -> int:
try:
parsed = int(value)
except Exception:
parsed = default
return max(0, parsed)
def _normalize_announcement_payload(title, content, image_url):
normalized_title = str(title or "").strip()
normalized_content = str(content or "").strip()
normalized_image = str(image_url or "").strip() or None
return normalized_title, normalized_content, normalized_image
def _deactivate_all_active_announcements(cursor, cst_time: str) -> None:
cursor.execute("UPDATE announcements SET is_active = 0, updated_at = ? WHERE is_active = 1", (cst_time,))
def create_announcement(title, content, image_url=None, is_active=True):
"""创建公告(默认启用;启用时会自动停用其他公告)"""
title = (title or "").strip()
content = (content or "").strip()
image_url = (image_url or "").strip()
image_url = image_url or None
title, content, image_url = _normalize_announcement_payload(title, content, image_url)
if not title or not content:
return None
@@ -20,7 +46,7 @@ def create_announcement(title, content, image_url=None, is_active=True):
cst_time = get_cst_now_str()
if is_active:
cursor.execute("UPDATE announcements SET is_active = 0, updated_at = ? WHERE is_active = 1", (cst_time,))
_deactivate_all_active_announcements(cursor, cst_time)
cursor.execute(
"""
@@ -44,6 +70,9 @@ def get_announcement_by_id(announcement_id):
def get_announcements(limit=50, offset=0):
"""获取公告列表(管理员用)"""
safe_limit = _normalize_limit(limit, 50)
safe_offset = _normalize_offset(offset, 0)
with db_pool.get_db() as conn:
cursor = conn.cursor()
cursor.execute(
@@ -52,7 +81,7 @@ def get_announcements(limit=50, offset=0):
ORDER BY created_at DESC, id DESC
LIMIT ? OFFSET ?
""",
(limit, offset),
(safe_limit, safe_offset),
)
return [dict(row) for row in cursor.fetchall()]
@@ -64,7 +93,7 @@ def set_announcement_active(announcement_id, is_active):
cst_time = get_cst_now_str()
if is_active:
cursor.execute("UPDATE announcements SET is_active = 0, updated_at = ? WHERE is_active = 1", (cst_time,))
_deactivate_all_active_announcements(cursor, cst_time)
cursor.execute(
"""
UPDATE announcements
@@ -121,13 +150,12 @@ def dismiss_announcement_for_user(user_id, announcement_id):
"""用户永久关闭某条公告(幂等)"""
with db_pool.get_db() as conn:
cursor = conn.cursor()
cst_time = get_cst_now_str()
cursor.execute(
"""
INSERT OR IGNORE INTO announcement_dismissals (user_id, announcement_id, dismissed_at)
VALUES (?, ?, ?)
""",
(user_id, announcement_id, cst_time),
(user_id, announcement_id, get_cst_now_str()),
)
conn.commit()
return cursor.rowcount >= 0

View File

@@ -5,6 +5,27 @@ from __future__ import annotations
import db_pool
def _to_bool_with_default(value, default: bool = True) -> bool:
if value is None:
return default
try:
return bool(int(value))
except Exception:
try:
return bool(value)
except Exception:
return default
def _normalize_notify_enabled(enabled) -> int:
if isinstance(enabled, bool):
return 1 if enabled else 0
try:
return 1 if int(enabled) else 0
except Exception:
return 1
def get_user_by_email(email):
"""根据邮箱获取用户"""
with db_pool.get_db() as conn:
@@ -25,7 +46,7 @@ def update_user_email(user_id, email, verified=False):
SET email = ?, email_verified = ?
WHERE id = ?
""",
(email, int(verified), user_id),
(email, 1 if verified else 0, user_id),
)
conn.commit()
return cursor.rowcount > 0
@@ -42,7 +63,7 @@ def update_user_email_notify(user_id, enabled):
SET email_notify_enabled = ?
WHERE id = ?
""",
(int(enabled), user_id),
(_normalize_notify_enabled(enabled), user_id),
)
conn.commit()
return cursor.rowcount > 0
@@ -57,6 +78,6 @@ def get_user_email_notify(user_id):
row = cursor.fetchone()
if row is None:
return True
return bool(row[0]) if row[0] is not None else True
return _to_bool_with_default(row[0], default=True)
except Exception:
return True

View File

@@ -2,32 +2,73 @@
# -*- coding: utf-8 -*-
from __future__ import annotations
from datetime import datetime
import pytz
import db_pool
from db.utils import escape_html
from db.utils import escape_html, get_cst_now_str
def _normalize_limit(value, default: int, *, minimum: int = 1, maximum: int = 500) -> int:
try:
parsed = int(value)
except Exception:
parsed = default
parsed = max(minimum, parsed)
parsed = min(maximum, parsed)
return parsed
def _normalize_offset(value, default: int = 0) -> int:
try:
parsed = int(value)
except Exception:
parsed = default
return max(0, parsed)
def _safe_text(value) -> str:
if value is None:
return ""
text = str(value)
return escape_html(text) if text else ""
def _build_feedback_filter_sql(status_filter=None) -> tuple[str, list]:
where_clauses = ["1=1"]
params = []
if status_filter:
where_clauses.append("status = ?")
params.append(status_filter)
return " AND ".join(where_clauses), params
def _normalize_feedback_stats_row(row) -> dict:
row_dict = dict(row) if row else {}
return {
"total": int(row_dict.get("total") or 0),
"pending": int(row_dict.get("pending") or 0),
"replied": int(row_dict.get("replied") or 0),
"closed": int(row_dict.get("closed") or 0),
}
def create_bug_feedback(user_id, username, title, description, contact=""):
"""创建Bug反馈带XSS防护"""
with db_pool.get_db() as conn:
cursor = conn.cursor()
cst_tz = pytz.timezone("Asia/Shanghai")
cst_time = datetime.now(cst_tz).strftime("%Y-%m-%d %H:%M:%S")
safe_title = escape_html(title) if title else ""
safe_description = escape_html(description) if description else ""
safe_contact = escape_html(contact) if contact else ""
safe_username = escape_html(username) if username else ""
cursor.execute(
"""
INSERT INTO bug_feedbacks (user_id, username, title, description, contact, created_at)
VALUES (?, ?, ?, ?, ?, ?)
""",
(user_id, safe_username, safe_title, safe_description, safe_contact, cst_time),
(
user_id,
_safe_text(username),
_safe_text(title),
_safe_text(description),
_safe_text(contact),
get_cst_now_str(),
),
)
conn.commit()
@@ -36,25 +77,25 @@ def create_bug_feedback(user_id, username, title, description, contact=""):
def get_bug_feedbacks(limit=100, offset=0, status_filter=None):
"""获取Bug反馈列表管理员用"""
safe_limit = _normalize_limit(limit, 100, minimum=1, maximum=1000)
safe_offset = _normalize_offset(offset, 0)
with db_pool.get_db() as conn:
cursor = conn.cursor()
sql = "SELECT * FROM bug_feedbacks WHERE 1=1"
params = []
if status_filter:
sql += " AND status = ?"
params.append(status_filter)
sql += " ORDER BY created_at DESC LIMIT ? OFFSET ?"
params.extend([limit, offset])
cursor.execute(sql, params)
where_sql, params = _build_feedback_filter_sql(status_filter=status_filter)
sql = f"""
SELECT * FROM bug_feedbacks
WHERE {where_sql}
ORDER BY created_at DESC
LIMIT ? OFFSET ?
"""
cursor.execute(sql, params + [safe_limit, safe_offset])
return [dict(row) for row in cursor.fetchall()]
def get_user_feedbacks(user_id, limit=50):
"""获取用户自己的反馈列表"""
safe_limit = _normalize_limit(limit, 50, minimum=1, maximum=1000)
with db_pool.get_db() as conn:
cursor = conn.cursor()
cursor.execute(
@@ -64,7 +105,7 @@ def get_user_feedbacks(user_id, limit=50):
ORDER BY created_at DESC
LIMIT ?
""",
(user_id, limit),
(user_id, safe_limit),
)
return [dict(row) for row in cursor.fetchall()]
@@ -82,18 +123,13 @@ def reply_feedback(feedback_id, admin_reply):
"""管理员回复反馈带XSS防护"""
with db_pool.get_db() as conn:
cursor = conn.cursor()
cst_tz = pytz.timezone("Asia/Shanghai")
cst_time = datetime.now(cst_tz).strftime("%Y-%m-%d %H:%M:%S")
safe_reply = escape_html(admin_reply) if admin_reply else ""
cursor.execute(
"""
UPDATE bug_feedbacks
SET admin_reply = ?, status = 'replied', replied_at = ?
WHERE id = ?
""",
(safe_reply, cst_time, feedback_id),
(_safe_text(admin_reply), get_cst_now_str(), feedback_id),
)
conn.commit()
@@ -139,6 +175,4 @@ def get_feedback_stats():
FROM bug_feedbacks
"""
)
row = cursor.fetchone()
return dict(row) if row else {"total": 0, "pending": 0, "replied": 0, "closed": 0}
return _normalize_feedback_stats_row(cursor.fetchone())

View File

@@ -28,105 +28,136 @@ def set_current_version(conn, version: int) -> None:
conn.commit()
def _table_exists(cursor, table_name: str) -> bool:
cursor.execute("SELECT name FROM sqlite_master WHERE type='table' AND name=?", (str(table_name),))
return cursor.fetchone() is not None
def _get_table_columns(cursor, table_name: str) -> set[str]:
cursor.execute(f"PRAGMA table_info({table_name})")
return {col[1] for col in cursor.fetchall()}
def _add_column_if_missing(cursor, table_name: str, columns: set[str], column_name: str, column_ddl: str, *, ok_message: str) -> bool:
if column_name in columns:
return False
cursor.execute(f"ALTER TABLE {table_name} ADD COLUMN {column_name} {column_ddl}")
columns.add(column_name)
print(ok_message)
return True
def _read_row_value(row, key: str, index: int):
if isinstance(row, sqlite3.Row):
return row[key]
return row[index]
def _get_migration_steps():
return [
(1, _migrate_to_v1),
(2, _migrate_to_v2),
(3, _migrate_to_v3),
(4, _migrate_to_v4),
(5, _migrate_to_v5),
(6, _migrate_to_v6),
(7, _migrate_to_v7),
(8, _migrate_to_v8),
(9, _migrate_to_v9),
(10, _migrate_to_v10),
(11, _migrate_to_v11),
(12, _migrate_to_v12),
(13, _migrate_to_v13),
(14, _migrate_to_v14),
(15, _migrate_to_v15),
(16, _migrate_to_v16),
(17, _migrate_to_v17),
(18, _migrate_to_v18),
]
def migrate_database(conn, target_version: int) -> None:
"""数据库迁移:按版本增量升级(向前兼容)。"""
cursor = conn.cursor()
cursor.execute("INSERT OR IGNORE INTO db_version (id, version, updated_at) VALUES (1, 0, ?)", (get_cst_now_str(),))
conn.commit()
target_version = int(target_version)
current_version = get_current_version(conn)
if current_version < 1:
_migrate_to_v1(conn)
current_version = 1
if current_version < 2:
_migrate_to_v2(conn)
current_version = 2
if current_version < 3:
_migrate_to_v3(conn)
current_version = 3
if current_version < 4:
_migrate_to_v4(conn)
current_version = 4
if current_version < 5:
_migrate_to_v5(conn)
current_version = 5
if current_version < 6:
_migrate_to_v6(conn)
current_version = 6
if current_version < 7:
_migrate_to_v7(conn)
current_version = 7
if current_version < 8:
_migrate_to_v8(conn)
current_version = 8
if current_version < 9:
_migrate_to_v9(conn)
current_version = 9
if current_version < 10:
_migrate_to_v10(conn)
current_version = 10
if current_version < 11:
_migrate_to_v11(conn)
current_version = 11
if current_version < 12:
_migrate_to_v12(conn)
current_version = 12
if current_version < 13:
_migrate_to_v13(conn)
current_version = 13
if current_version < 14:
_migrate_to_v14(conn)
current_version = 14
if current_version < 15:
_migrate_to_v15(conn)
current_version = 15
if current_version < 16:
_migrate_to_v16(conn)
current_version = 16
if current_version < 17:
_migrate_to_v17(conn)
current_version = 17
if current_version < 18:
_migrate_to_v18(conn)
current_version = 18
for version, migrate_fn in _get_migration_steps():
if version > target_version or current_version >= version:
continue
migrate_fn(conn)
current_version = version
if current_version != int(target_version):
set_current_version(conn, int(target_version))
if current_version != target_version:
set_current_version(conn, target_version)
def _migrate_to_v1(conn):
"""迁移到版本1 - 添加缺失字段"""
cursor = conn.cursor()
cursor.execute("PRAGMA table_info(system_config)")
columns = [col[1] for col in cursor.fetchall()]
system_columns = _get_table_columns(cursor, "system_config")
_add_column_if_missing(
cursor,
"system_config",
system_columns,
"schedule_weekdays",
'TEXT DEFAULT "1,2,3,4,5,6,7"',
ok_message=" [OK] 添加 schedule_weekdays 字段",
)
_add_column_if_missing(
cursor,
"system_config",
system_columns,
"max_screenshot_concurrent",
"INTEGER DEFAULT 3",
ok_message=" [OK] 添加 max_screenshot_concurrent 字段",
)
_add_column_if_missing(
cursor,
"system_config",
system_columns,
"max_concurrent_per_account",
"INTEGER DEFAULT 1",
ok_message=" [OK] 添加 max_concurrent_per_account 字段",
)
_add_column_if_missing(
cursor,
"system_config",
system_columns,
"auto_approve_enabled",
"INTEGER DEFAULT 0",
ok_message=" [OK] 添加 auto_approve_enabled 字段",
)
_add_column_if_missing(
cursor,
"system_config",
system_columns,
"auto_approve_hourly_limit",
"INTEGER DEFAULT 10",
ok_message=" [OK] 添加 auto_approve_hourly_limit 字段",
)
_add_column_if_missing(
cursor,
"system_config",
system_columns,
"auto_approve_vip_days",
"INTEGER DEFAULT 7",
ok_message=" [OK] 添加 auto_approve_vip_days 字段",
)
if "schedule_weekdays" not in columns:
cursor.execute('ALTER TABLE system_config ADD COLUMN schedule_weekdays TEXT DEFAULT "1,2,3,4,5,6,7"')
print(" [OK] 添加 schedule_weekdays 字段")
if "max_screenshot_concurrent" not in columns:
cursor.execute("ALTER TABLE system_config ADD COLUMN max_screenshot_concurrent INTEGER DEFAULT 3")
print(" [OK] 添加 max_screenshot_concurrent 字段")
if "max_concurrent_per_account" not in columns:
cursor.execute("ALTER TABLE system_config ADD COLUMN max_concurrent_per_account INTEGER DEFAULT 1")
print(" [OK] 添加 max_concurrent_per_account 字段")
if "auto_approve_enabled" not in columns:
cursor.execute("ALTER TABLE system_config ADD COLUMN auto_approve_enabled INTEGER DEFAULT 0")
print(" [OK] 添加 auto_approve_enabled 字段")
if "auto_approve_hourly_limit" not in columns:
cursor.execute("ALTER TABLE system_config ADD COLUMN auto_approve_hourly_limit INTEGER DEFAULT 10")
print(" [OK] 添加 auto_approve_hourly_limit 字段")
if "auto_approve_vip_days" not in columns:
cursor.execute("ALTER TABLE system_config ADD COLUMN auto_approve_vip_days INTEGER DEFAULT 7")
print(" [OK] 添加 auto_approve_vip_days 字段")
cursor.execute("PRAGMA table_info(task_logs)")
columns = [col[1] for col in cursor.fetchall()]
if "duration" not in columns:
cursor.execute("ALTER TABLE task_logs ADD COLUMN duration INTEGER")
print(" [OK] 添加 duration 字段到 task_logs")
task_log_columns = _get_table_columns(cursor, "task_logs")
_add_column_if_missing(
cursor,
"task_logs",
task_log_columns,
"duration",
"INTEGER",
ok_message=" [OK] 添加 duration 字段到 task_logs",
)
conn.commit()
@@ -135,24 +166,39 @@ def _migrate_to_v2(conn):
"""迁移到版本2 - 添加代理配置字段"""
cursor = conn.cursor()
cursor.execute("PRAGMA table_info(system_config)")
columns = [col[1] for col in cursor.fetchall()]
if "proxy_enabled" not in columns:
cursor.execute("ALTER TABLE system_config ADD COLUMN proxy_enabled INTEGER DEFAULT 0")
print(" [OK] 添加 proxy_enabled 字段")
if "proxy_api_url" not in columns:
cursor.execute('ALTER TABLE system_config ADD COLUMN proxy_api_url TEXT DEFAULT ""')
print(" [OK] 添加 proxy_api_url 字段")
if "proxy_expire_minutes" not in columns:
cursor.execute("ALTER TABLE system_config ADD COLUMN proxy_expire_minutes INTEGER DEFAULT 3")
print(" [OK] 添加 proxy_expire_minutes 字段")
if "enable_screenshot" not in columns:
cursor.execute("ALTER TABLE system_config ADD COLUMN enable_screenshot INTEGER DEFAULT 1")
print(" [OK] 添加 enable_screenshot 字段")
columns = _get_table_columns(cursor, "system_config")
_add_column_if_missing(
cursor,
"system_config",
columns,
"proxy_enabled",
"INTEGER DEFAULT 0",
ok_message=" [OK] 添加 proxy_enabled 字段",
)
_add_column_if_missing(
cursor,
"system_config",
columns,
"proxy_api_url",
'TEXT DEFAULT ""',
ok_message=" [OK] 添加 proxy_api_url 字段",
)
_add_column_if_missing(
cursor,
"system_config",
columns,
"proxy_expire_minutes",
"INTEGER DEFAULT 3",
ok_message=" [OK] 添加 proxy_expire_minutes 字段",
)
_add_column_if_missing(
cursor,
"system_config",
columns,
"enable_screenshot",
"INTEGER DEFAULT 1",
ok_message=" [OK] 添加 enable_screenshot 字段",
)
conn.commit()
@@ -161,20 +207,31 @@ def _migrate_to_v3(conn):
"""迁移到版本3 - 添加账号状态和登录失败计数字段"""
cursor = conn.cursor()
cursor.execute("PRAGMA table_info(accounts)")
columns = [col[1] for col in cursor.fetchall()]
if "status" not in columns:
cursor.execute('ALTER TABLE accounts ADD COLUMN status TEXT DEFAULT "active"')
print(" [OK] 添加 accounts.status 字段 (账号状态)")
if "login_fail_count" not in columns:
cursor.execute("ALTER TABLE accounts ADD COLUMN login_fail_count INTEGER DEFAULT 0")
print(" [OK] 添加 accounts.login_fail_count 字段 (登录失败计数)")
if "last_login_error" not in columns:
cursor.execute("ALTER TABLE accounts ADD COLUMN last_login_error TEXT")
print(" [OK] 添加 accounts.last_login_error 字段 (最后登录错误)")
columns = _get_table_columns(cursor, "accounts")
_add_column_if_missing(
cursor,
"accounts",
columns,
"status",
'TEXT DEFAULT "active"',
ok_message=" [OK] 添加 accounts.status 字段 (账号状态)",
)
_add_column_if_missing(
cursor,
"accounts",
columns,
"login_fail_count",
"INTEGER DEFAULT 0",
ok_message=" [OK] 添加 accounts.login_fail_count 字段 (登录失败计数)",
)
_add_column_if_missing(
cursor,
"accounts",
columns,
"last_login_error",
"TEXT",
ok_message=" [OK] 添加 accounts.last_login_error 字段 (最后登录错误)",
)
conn.commit()
@@ -183,12 +240,15 @@ def _migrate_to_v4(conn):
"""迁移到版本4 - 添加任务来源字段"""
cursor = conn.cursor()
cursor.execute("PRAGMA table_info(task_logs)")
columns = [col[1] for col in cursor.fetchall()]
if "source" not in columns:
cursor.execute('ALTER TABLE task_logs ADD COLUMN source TEXT DEFAULT "manual"')
print(" [OK] 添加 task_logs.source 字段 (任务来源: manual/scheduled/immediate)")
columns = _get_table_columns(cursor, "task_logs")
_add_column_if_missing(
cursor,
"task_logs",
columns,
"source",
'TEXT DEFAULT "manual"',
ok_message=" [OK] 添加 task_logs.source 字段 (任务来源: manual/scheduled/immediate)",
)
conn.commit()
@@ -300,20 +360,17 @@ def _migrate_to_v6(conn):
def _migrate_to_v7(conn):
"""迁移到版本7 - 统一存储北京时间将历史UTC时间字段整体+8小时"""
cursor = conn.cursor()
def table_exists(table_name: str) -> bool:
cursor.execute("SELECT name FROM sqlite_master WHERE type='table' AND name=?", (table_name,))
return cursor.fetchone() is not None
def column_exists(table_name: str, column_name: str) -> bool:
cursor.execute(f"PRAGMA table_info({table_name})")
return any(row[1] == column_name for row in cursor.fetchall())
columns_cache: dict[str, set[str]] = {}
def shift_utc_to_cst(table_name: str, column_name: str) -> None:
if not table_exists(table_name):
if not _table_exists(cursor, table_name):
return
if not column_exists(table_name, column_name):
if table_name not in columns_cache:
columns_cache[table_name] = _get_table_columns(cursor, table_name)
if column_name not in columns_cache[table_name]:
return
cursor.execute(
f"""
UPDATE {table_name}
@@ -329,10 +386,6 @@ def _migrate_to_v7(conn):
("accounts", "created_at"),
("password_reset_requests", "created_at"),
("password_reset_requests", "processed_at"),
]:
shift_utc_to_cst(table, col)
for table, col in [
("smtp_configs", "created_at"),
("smtp_configs", "updated_at"),
("smtp_configs", "last_success_at"),
@@ -340,10 +393,6 @@ def _migrate_to_v7(conn):
("email_tokens", "created_at"),
("email_logs", "created_at"),
("email_stats", "last_updated"),
]:
shift_utc_to_cst(table, col)
for table, col in [
("task_checkpoints", "created_at"),
("task_checkpoints", "updated_at"),
("task_checkpoints", "completed_at"),
@@ -359,15 +408,23 @@ def _migrate_to_v8(conn):
cursor = conn.cursor()
# 1) 增量字段random_delay旧库可能不存在
cursor.execute("PRAGMA table_info(user_schedules)")
columns = [col[1] for col in cursor.fetchall()]
if "random_delay" not in columns:
cursor.execute("ALTER TABLE user_schedules ADD COLUMN random_delay INTEGER DEFAULT 0")
print(" [OK] 添加 user_schedules.random_delay 字段")
if "next_run_at" not in columns:
cursor.execute("ALTER TABLE user_schedules ADD COLUMN next_run_at TIMESTAMP")
print(" [OK] 添加 user_schedules.next_run_at 字段")
columns = _get_table_columns(cursor, "user_schedules")
_add_column_if_missing(
cursor,
"user_schedules",
columns,
"random_delay",
"INTEGER DEFAULT 0",
ok_message=" [OK] 添加 user_schedules.random_delay 字段",
)
_add_column_if_missing(
cursor,
"user_schedules",
columns,
"next_run_at",
"TIMESTAMP",
ok_message=" [OK] 添加 user_schedules.next_run_at 字段",
)
cursor.execute("CREATE INDEX IF NOT EXISTS idx_user_schedules_next_run ON user_schedules(next_run_at)")
conn.commit()
@@ -392,12 +449,12 @@ def _migrate_to_v8(conn):
fixed = 0
for row in rows:
try:
schedule_id = row["id"] if isinstance(row, sqlite3.Row) else row[0]
schedule_time = row["schedule_time"] if isinstance(row, sqlite3.Row) else row[1]
weekdays = row["weekdays"] if isinstance(row, sqlite3.Row) else row[2]
random_delay = row["random_delay"] if isinstance(row, sqlite3.Row) else row[3]
last_run_at = row["last_run_at"] if isinstance(row, sqlite3.Row) else row[4]
next_run_at = row["next_run_at"] if isinstance(row, sqlite3.Row) else row[5]
schedule_id = _read_row_value(row, "id", 0)
schedule_time = _read_row_value(row, "schedule_time", 1)
weekdays = _read_row_value(row, "weekdays", 2)
random_delay = _read_row_value(row, "random_delay", 3)
last_run_at = _read_row_value(row, "last_run_at", 4)
next_run_at = _read_row_value(row, "next_run_at", 5)
except Exception:
continue
@@ -430,27 +487,46 @@ def _migrate_to_v9(conn):
"""迁移到版本9 - 邮件设置字段迁移(清理 email_service scattered ALTER TABLE"""
cursor = conn.cursor()
cursor.execute("SELECT name FROM sqlite_master WHERE type='table' AND name='email_settings'")
if not cursor.fetchone():
if not _table_exists(cursor, "email_settings"):
# 邮件表由 email_service.init_email_tables 创建;此处仅做增量字段迁移
return
cursor.execute("PRAGMA table_info(email_settings)")
columns = [col[1] for col in cursor.fetchall()]
columns = _get_table_columns(cursor, "email_settings")
changed = False
if "register_verify_enabled" not in columns:
cursor.execute("ALTER TABLE email_settings ADD COLUMN register_verify_enabled INTEGER DEFAULT 0")
print(" [OK] 添加 email_settings.register_verify_enabled 字段")
changed = True
if "base_url" not in columns:
cursor.execute("ALTER TABLE email_settings ADD COLUMN base_url TEXT DEFAULT ''")
print(" [OK] 添加 email_settings.base_url 字段")
changed = True
if "task_notify_enabled" not in columns:
cursor.execute("ALTER TABLE email_settings ADD COLUMN task_notify_enabled INTEGER DEFAULT 0")
print(" [OK] 添加 email_settings.task_notify_enabled 字段")
changed = True
changed = (
_add_column_if_missing(
cursor,
"email_settings",
columns,
"register_verify_enabled",
"INTEGER DEFAULT 0",
ok_message=" [OK] 添加 email_settings.register_verify_enabled 字段",
)
or changed
)
changed = (
_add_column_if_missing(
cursor,
"email_settings",
columns,
"base_url",
"TEXT DEFAULT ''",
ok_message=" [OK] 添加 email_settings.base_url 字段",
)
or changed
)
changed = (
_add_column_if_missing(
cursor,
"email_settings",
columns,
"task_notify_enabled",
"INTEGER DEFAULT 0",
ok_message=" [OK] 添加 email_settings.task_notify_enabled 字段",
)
or changed
)
if changed:
conn.commit()
@@ -459,18 +535,31 @@ def _migrate_to_v9(conn):
def _migrate_to_v10(conn):
"""迁移到版本10 - users 邮箱字段迁移(避免运行时 ALTER TABLE"""
cursor = conn.cursor()
cursor.execute("PRAGMA table_info(users)")
columns = [col[1] for col in cursor.fetchall()]
columns = _get_table_columns(cursor, "users")
changed = False
if "email_verified" not in columns:
cursor.execute("ALTER TABLE users ADD COLUMN email_verified INTEGER DEFAULT 0")
print(" [OK] 添加 users.email_verified 字段")
changed = True
if "email_notify_enabled" not in columns:
cursor.execute("ALTER TABLE users ADD COLUMN email_notify_enabled INTEGER DEFAULT 1")
print(" [OK] 添加 users.email_notify_enabled 字段")
changed = True
changed = (
_add_column_if_missing(
cursor,
"users",
columns,
"email_verified",
"INTEGER DEFAULT 0",
ok_message=" [OK] 添加 users.email_verified 字段",
)
or changed
)
changed = (
_add_column_if_missing(
cursor,
"users",
columns,
"email_notify_enabled",
"INTEGER DEFAULT 1",
ok_message=" [OK] 添加 users.email_notify_enabled 字段",
)
or changed
)
if changed:
conn.commit()
@@ -657,19 +746,24 @@ def _migrate_to_v15(conn):
"""迁移到版本15 - 邮件设置:新设备登录提醒全局开关"""
cursor = conn.cursor()
cursor.execute("SELECT name FROM sqlite_master WHERE type='table' AND name='email_settings'")
if not cursor.fetchone():
if not _table_exists(cursor, "email_settings"):
# 邮件表由 email_service.init_email_tables 创建;此处仅做增量字段迁移
return
cursor.execute("PRAGMA table_info(email_settings)")
columns = [col[1] for col in cursor.fetchall()]
columns = _get_table_columns(cursor, "email_settings")
changed = False
if "login_alert_enabled" not in columns:
cursor.execute("ALTER TABLE email_settings ADD COLUMN login_alert_enabled INTEGER DEFAULT 1")
print(" [OK] 添加 email_settings.login_alert_enabled 字段")
changed = True
changed = (
_add_column_if_missing(
cursor,
"email_settings",
columns,
"login_alert_enabled",
"INTEGER DEFAULT 1",
ok_message=" [OK] 添加 email_settings.login_alert_enabled 字段",
)
or changed
)
try:
cursor.execute("UPDATE email_settings SET login_alert_enabled = 1 WHERE login_alert_enabled IS NULL")
@@ -686,22 +780,24 @@ def _migrate_to_v15(conn):
def _migrate_to_v16(conn):
"""迁移到版本16 - 公告支持图片字段"""
cursor = conn.cursor()
cursor.execute("PRAGMA table_info(announcements)")
columns = [col[1] for col in cursor.fetchall()]
columns = _get_table_columns(cursor, "announcements")
if "image_url" not in columns:
cursor.execute("ALTER TABLE announcements ADD COLUMN image_url TEXT")
if _add_column_if_missing(
cursor,
"announcements",
columns,
"image_url",
"TEXT",
ok_message=" [OK] 添加 announcements.image_url 字段",
):
conn.commit()
print(" [OK] 添加 announcements.image_url 字段")
def _migrate_to_v17(conn):
"""迁移到版本17 - 金山文档上传配置与用户开关"""
cursor = conn.cursor()
cursor.execute("PRAGMA table_info(system_config)")
columns = [col[1] for col in cursor.fetchall()]
system_columns = _get_table_columns(cursor, "system_config")
system_fields = [
("kdocs_enabled", "INTEGER DEFAULT 0"),
("kdocs_doc_url", "TEXT DEFAULT ''"),
@@ -714,21 +810,29 @@ def _migrate_to_v17(conn):
("kdocs_admin_notify_email", "TEXT DEFAULT ''"),
]
for field, ddl in system_fields:
if field not in columns:
cursor.execute(f"ALTER TABLE system_config ADD COLUMN {field} {ddl}")
print(f" [OK] 添加 system_config.{field} 字段")
cursor.execute("PRAGMA table_info(users)")
columns = [col[1] for col in cursor.fetchall()]
_add_column_if_missing(
cursor,
"system_config",
system_columns,
field,
ddl,
ok_message=f" [OK] 添加 system_config.{field} 字段",
)
user_columns = _get_table_columns(cursor, "users")
user_fields = [
("kdocs_unit", "TEXT DEFAULT ''"),
("kdocs_auto_upload", "INTEGER DEFAULT 0"),
]
for field, ddl in user_fields:
if field not in columns:
cursor.execute(f"ALTER TABLE users ADD COLUMN {field} {ddl}")
print(f" [OK] 添加 users.{field} 字段")
_add_column_if_missing(
cursor,
"users",
user_columns,
field,
ddl,
ok_message=f" [OK] 添加 users.{field} 字段",
)
conn.commit()
@@ -737,15 +841,22 @@ def _migrate_to_v18(conn):
"""迁移到版本18 - 金山文档上传:有效行范围配置"""
cursor = conn.cursor()
cursor.execute("PRAGMA table_info(system_config)")
columns = [col[1] for col in cursor.fetchall()]
if "kdocs_row_start" not in columns:
cursor.execute("ALTER TABLE system_config ADD COLUMN kdocs_row_start INTEGER DEFAULT 0")
print(" [OK] 添加 system_config.kdocs_row_start 字段")
if "kdocs_row_end" not in columns:
cursor.execute("ALTER TABLE system_config ADD COLUMN kdocs_row_end INTEGER DEFAULT 0")
print(" [OK] 添加 system_config.kdocs_row_end 字段")
columns = _get_table_columns(cursor, "system_config")
_add_column_if_missing(
cursor,
"system_config",
columns,
"kdocs_row_start",
"INTEGER DEFAULT 0",
ok_message=" [OK] 添加 system_config.kdocs_row_start 字段",
)
_add_column_if_missing(
cursor,
"system_config",
columns,
"kdocs_row_end",
"INTEGER DEFAULT 0",
ok_message=" [OK] 添加 system_config.kdocs_row_end 字段",
)
conn.commit()

View File

@@ -2,12 +2,93 @@
# -*- coding: utf-8 -*-
from __future__ import annotations
from datetime import datetime
import json
from datetime import datetime, timedelta
import db_pool
from services.schedule_utils import compute_next_run_at, format_cst
from services.time_utils import get_beijing_now
_SCHEDULE_DEFAULT_TIME = "08:00"
_SCHEDULE_DEFAULT_WEEKDAYS = "1,2,3,4,5"
_ALLOWED_SCHEDULE_UPDATE_FIELDS = (
"name",
"enabled",
"schedule_time",
"weekdays",
"browse_type",
"enable_screenshot",
"random_delay",
"account_ids",
)
_ALLOWED_EXEC_LOG_UPDATE_FIELDS = (
"total_accounts",
"success_accounts",
"failed_accounts",
"total_items",
"total_attachments",
"total_screenshots",
"duration_seconds",
"status",
"error_message",
)
def _normalize_limit(limit, default: int, *, minimum: int = 1) -> int:
try:
parsed = int(limit)
except Exception:
parsed = default
if parsed < minimum:
return minimum
return parsed
def _to_int(value, default: int = 0) -> int:
try:
return int(value)
except Exception:
return default
def _format_optional_datetime(dt: datetime | None) -> str | None:
if dt is None:
return None
return format_cst(dt)
def _serialize_account_ids(account_ids) -> str:
return json.dumps(account_ids) if account_ids else "[]"
def _compute_schedule_next_run_str(
*,
now_dt,
schedule_time,
weekdays,
random_delay,
last_run_at,
) -> str:
next_dt = compute_next_run_at(
now=now_dt,
schedule_time=str(schedule_time or _SCHEDULE_DEFAULT_TIME),
weekdays=str(weekdays or _SCHEDULE_DEFAULT_WEEKDAYS),
random_delay=_to_int(random_delay, 0),
last_run_at=str(last_run_at or "") if last_run_at else None,
)
return format_cst(next_dt)
def _map_schedule_log_row(row) -> dict:
log = dict(row)
log["created_at"] = log.get("execute_time")
log["success_count"] = log.get("success_accounts", 0)
log["failed_count"] = log.get("failed_accounts", 0)
log["duration"] = log.get("duration_seconds", 0)
return log
def get_user_schedules(user_id):
"""获取用户的所有定时任务"""
@@ -44,14 +125,10 @@ def create_user_schedule(
account_ids=None,
):
"""创建用户定时任务"""
import json
with db_pool.get_db() as conn:
cursor = conn.cursor()
cst_time = format_cst(get_beijing_now())
account_ids_str = json.dumps(account_ids) if account_ids else "[]"
cursor.execute(
"""
INSERT INTO user_schedules (
@@ -66,8 +143,8 @@ def create_user_schedule(
weekdays,
browse_type,
enable_screenshot,
int(random_delay or 0),
account_ids_str,
_to_int(random_delay, 0),
_serialize_account_ids(account_ids),
cst_time,
cst_time,
),
@@ -79,28 +156,11 @@ def create_user_schedule(
def update_user_schedule(schedule_id, **kwargs):
"""更新用户定时任务"""
import json
with db_pool.get_db() as conn:
cursor = conn.cursor()
now_dt = get_beijing_now()
now_str = format_cst(now_dt)
updates = []
params = []
allowed_fields = [
"name",
"enabled",
"schedule_time",
"weekdays",
"browse_type",
"enable_screenshot",
"random_delay",
"account_ids",
]
# 读取旧值,用于决定是否需要重算 next_run_at
cursor.execute(
"""
SELECT enabled, schedule_time, weekdays, random_delay, last_run_at
@@ -112,10 +172,11 @@ def update_user_schedule(schedule_id, **kwargs):
current = cursor.fetchone()
if not current:
return False
current_enabled = int(current[0] or 0)
current_enabled = _to_int(current[0], 0)
current_time = current[1]
current_weekdays = current[2]
current_random_delay = int(current[3] or 0)
current_random_delay = _to_int(current[3], 0)
current_last_run_at = current[4]
will_enabled = current_enabled
@@ -123,21 +184,28 @@ def update_user_schedule(schedule_id, **kwargs):
next_weekdays = current_weekdays
next_random_delay = current_random_delay
for field in allowed_fields:
if field in kwargs:
value = kwargs[field]
if field == "account_ids" and isinstance(value, list):
value = json.dumps(value)
if field == "enabled":
will_enabled = 1 if value else 0
if field == "schedule_time":
next_time = value
if field == "weekdays":
next_weekdays = value
if field == "random_delay":
next_random_delay = int(value or 0)
updates.append(f"{field} = ?")
params.append(value)
updates = []
params = []
for field in _ALLOWED_SCHEDULE_UPDATE_FIELDS:
if field not in kwargs:
continue
value = kwargs[field]
if field == "account_ids" and isinstance(value, list):
value = json.dumps(value)
if field == "enabled":
will_enabled = 1 if value else 0
if field == "schedule_time":
next_time = value
if field == "weekdays":
next_weekdays = value
if field == "random_delay":
next_random_delay = int(value or 0)
updates.append(f"{field} = ?")
params.append(value)
if not updates:
return False
@@ -145,30 +213,26 @@ def update_user_schedule(schedule_id, **kwargs):
updates.append("updated_at = ?")
params.append(now_str)
# 关键字段变更后重算 next_run_at确保索引驱动不会跑偏
#
# 需求:当用户修改“执行时间/执行日期/随机±15分钟”后即使今天已经执行过也允许按新配置在今天再次触发。
# 做法:这些关键字段发生变更时,重算 next_run_at 时忽略 last_run_at 的“同日仅一次”限制。
config_changed = any(key in kwargs for key in ["schedule_time", "weekdays", "random_delay"])
config_changed = any(key in kwargs for key in ("schedule_time", "weekdays", "random_delay"))
enabled_toggled = "enabled" in kwargs
should_recompute_next = config_changed or (enabled_toggled and will_enabled == 1)
if should_recompute_next:
next_dt = compute_next_run_at(
now=now_dt,
schedule_time=str(next_time or "08:00"),
weekdays=str(next_weekdays or "1,2,3,4,5"),
random_delay=int(next_random_delay or 0),
last_run_at=None if config_changed else (str(current_last_run_at or "") if current_last_run_at else None),
next_run_at = _compute_schedule_next_run_str(
now_dt=now_dt,
schedule_time=next_time,
weekdays=next_weekdays,
random_delay=next_random_delay,
last_run_at=None if config_changed else current_last_run_at,
)
updates.append("next_run_at = ?")
params.append(format_cst(next_dt))
params.append(next_run_at)
# 若本次显式禁用任务,则 next_run_at 清空(与 toggle 行为保持一致)
if enabled_toggled and will_enabled == 0:
updates.append("next_run_at = ?")
params.append(None)
params.append(schedule_id)
params.append(schedule_id)
sql = f"UPDATE user_schedules SET {', '.join(updates)} WHERE id = ?"
cursor.execute(sql, params)
conn.commit()
@@ -203,28 +267,19 @@ def toggle_user_schedule(schedule_id, enabled):
)
row = cursor.fetchone()
if row:
schedule_time, weekdays, random_delay, last_run_at, existing_next_run_at = (
row[0],
row[1],
row[2],
row[3],
row[4],
)
schedule_time, weekdays, random_delay, last_run_at, existing_next_run_at = row
existing_next_run_at = str(existing_next_run_at or "").strip() or None
# 若 next_run_at 已经被“修改配置”逻辑预先计算好且仍在未来,则优先沿用,
# 避免 last_run_at 的“同日仅一次”限制阻塞用户把任务调整到今天再次触发。
if existing_next_run_at and existing_next_run_at > now_str:
next_run_at = existing_next_run_at
else:
next_dt = compute_next_run_at(
now=now_dt,
schedule_time=str(schedule_time or "08:00"),
weekdays=str(weekdays or "1,2,3,4,5"),
random_delay=int(random_delay or 0),
last_run_at=str(last_run_at or "") if last_run_at else None,
next_run_at = _compute_schedule_next_run_str(
now_dt=now_dt,
schedule_time=schedule_time,
weekdays=weekdays,
random_delay=random_delay,
last_run_at=last_run_at,
)
next_run_at = format_cst(next_dt)
cursor.execute(
"""
@@ -272,16 +327,15 @@ def update_schedule_last_run(schedule_id):
row = cursor.fetchone()
if not row:
return False
schedule_time, weekdays, random_delay = row[0], row[1], row[2]
next_dt = compute_next_run_at(
now=now_dt,
schedule_time=str(schedule_time or "08:00"),
weekdays=str(weekdays or "1,2,3,4,5"),
random_delay=int(random_delay or 0),
schedule_time, weekdays, random_delay = row
next_run_at = _compute_schedule_next_run_str(
now_dt=now_dt,
schedule_time=schedule_time,
weekdays=weekdays,
random_delay=random_delay,
last_run_at=now_str,
)
next_run_at = format_cst(next_dt)
cursor.execute(
"""
@@ -305,7 +359,11 @@ def update_schedule_next_run(schedule_id: int, next_run_at: str) -> bool:
SET next_run_at = ?, updated_at = ?
WHERE id = ?
""",
(str(next_run_at or "").strip() or None, format_cst(get_beijing_now()), int(schedule_id)),
(
str(next_run_at or "").strip() or None,
format_cst(get_beijing_now()),
int(schedule_id),
),
)
conn.commit()
return cursor.rowcount > 0
@@ -328,15 +386,15 @@ def recompute_schedule_next_run(schedule_id: int, *, now_dt=None) -> bool:
if not row:
return False
schedule_time, weekdays, random_delay, last_run_at = row[0], row[1], row[2], row[3]
next_dt = compute_next_run_at(
now=now_dt,
schedule_time=str(schedule_time or "08:00"),
weekdays=str(weekdays or "1,2,3,4,5"),
random_delay=int(random_delay or 0),
last_run_at=str(last_run_at or "") if last_run_at else None,
schedule_time, weekdays, random_delay, last_run_at = row
next_run_at = _compute_schedule_next_run_str(
now_dt=now_dt,
schedule_time=schedule_time,
weekdays=weekdays,
random_delay=random_delay,
last_run_at=last_run_at,
)
return update_schedule_next_run(int(schedule_id), format_cst(next_dt))
return update_schedule_next_run(int(schedule_id), next_run_at)
def get_due_user_schedules(now_cst: str, limit: int = 50):
@@ -345,6 +403,8 @@ def get_due_user_schedules(now_cst: str, limit: int = 50):
if not now_cst:
now_cst = format_cst(get_beijing_now())
safe_limit = _normalize_limit(limit, 50, minimum=1)
with db_pool.get_db() as conn:
cursor = conn.cursor()
cursor.execute(
@@ -358,7 +418,7 @@ def get_due_user_schedules(now_cst: str, limit: int = 50):
ORDER BY us.next_run_at ASC
LIMIT ?
""",
(now_cst, int(limit)),
(now_cst, safe_limit),
)
return [dict(row) for row in cursor.fetchall()]
@@ -370,15 +430,13 @@ def create_schedule_execution_log(schedule_id, user_id, schedule_name):
"""创建定时任务执行日志"""
with db_pool.get_db() as conn:
cursor = conn.cursor()
execute_time = format_cst(get_beijing_now())
cursor.execute(
"""
INSERT INTO schedule_execution_logs (
schedule_id, user_id, schedule_name, execute_time, status
) VALUES (?, ?, ?, ?, 'running')
""",
(schedule_id, user_id, schedule_name, execute_time),
(schedule_id, user_id, schedule_name, format_cst(get_beijing_now())),
)
conn.commit()
@@ -393,22 +451,11 @@ def update_schedule_execution_log(log_id, **kwargs):
updates = []
params = []
allowed_fields = [
"total_accounts",
"success_accounts",
"failed_accounts",
"total_items",
"total_attachments",
"total_screenshots",
"duration_seconds",
"status",
"error_message",
]
for field in allowed_fields:
if field in kwargs:
updates.append(f"{field} = ?")
params.append(kwargs[field])
for field in _ALLOWED_EXEC_LOG_UPDATE_FIELDS:
if field not in kwargs:
continue
updates.append(f"{field} = ?")
params.append(kwargs[field])
if not updates:
return False
@@ -424,6 +471,7 @@ def update_schedule_execution_log(log_id, **kwargs):
def get_schedule_execution_logs(schedule_id, limit=10):
"""获取定时任务执行日志"""
try:
safe_limit = _normalize_limit(limit, 10, minimum=1)
with db_pool.get_db() as conn:
cursor = conn.cursor()
cursor.execute(
@@ -433,24 +481,16 @@ def get_schedule_execution_logs(schedule_id, limit=10):
ORDER BY execute_time DESC
LIMIT ?
""",
(schedule_id, limit),
(schedule_id, safe_limit),
)
logs = []
rows = cursor.fetchall()
for row in rows:
for row in cursor.fetchall():
try:
log = dict(row)
log["created_at"] = log.get("execute_time")
log["success_count"] = log.get("success_accounts", 0)
log["failed_count"] = log.get("failed_accounts", 0)
log["duration"] = log.get("duration_seconds", 0)
logs.append(log)
logs.append(_map_schedule_log_row(row))
except Exception as e:
print(f"[数据库] 处理日志行时出错: {e}")
continue
return logs
except Exception as e:
print(f"[数据库] 查询定时任务日志时出错: {e}")
@@ -462,6 +502,7 @@ def get_schedule_execution_logs(schedule_id, limit=10):
def get_user_all_schedule_logs(user_id, limit=50):
"""获取用户所有定时任务的执行日志"""
safe_limit = _normalize_limit(limit, 50, minimum=1)
with db_pool.get_db() as conn:
cursor = conn.cursor()
cursor.execute(
@@ -471,7 +512,7 @@ def get_user_all_schedule_logs(user_id, limit=50):
ORDER BY execute_time DESC
LIMIT ?
""",
(user_id, limit),
(user_id, safe_limit),
)
return [dict(row) for row in cursor.fetchall()]
@@ -493,14 +534,21 @@ def delete_schedule_logs(schedule_id, user_id):
def clean_old_schedule_logs(days=30):
"""清理指定天数前的定时任务执行日志"""
safe_days = _to_int(days, 30)
if safe_days < 0:
safe_days = 0
cutoff_dt = get_beijing_now() - timedelta(days=safe_days)
cutoff_str = format_cst(cutoff_dt)
with db_pool.get_db() as conn:
cursor = conn.cursor()
cursor.execute(
"""
DELETE FROM schedule_execution_logs
WHERE execute_time < datetime('now', 'localtime', '-' || ? || ' days')
WHERE execute_time < ?
""",
(days,),
(cutoff_str,),
)
conn.commit()
return cursor.rowcount

View File

@@ -362,6 +362,8 @@ def ensure_schema(conn) -> None:
cursor.execute("CREATE INDEX IF NOT EXISTS idx_users_username ON users(username)")
cursor.execute("CREATE INDEX IF NOT EXISTS idx_users_status ON users(status)")
cursor.execute("CREATE INDEX IF NOT EXISTS idx_users_vip_expire ON users(vip_expire_time)")
cursor.execute("CREATE INDEX IF NOT EXISTS idx_users_email ON users(email)")
cursor.execute("CREATE INDEX IF NOT EXISTS idx_users_created_at ON users(created_at)")
cursor.execute("CREATE INDEX IF NOT EXISTS idx_login_fingerprints_user ON login_fingerprints(user_id)")
cursor.execute("CREATE INDEX IF NOT EXISTS idx_login_ips_user ON login_ips(user_id)")
@@ -391,6 +393,8 @@ def ensure_schema(conn) -> None:
cursor.execute("CREATE INDEX IF NOT EXISTS idx_task_logs_user_id ON task_logs(user_id)")
cursor.execute("CREATE INDEX IF NOT EXISTS idx_task_logs_status ON task_logs(status)")
cursor.execute("CREATE INDEX IF NOT EXISTS idx_task_logs_created_at ON task_logs(created_at)")
cursor.execute("CREATE INDEX IF NOT EXISTS idx_task_logs_source ON task_logs(source)")
cursor.execute("CREATE INDEX IF NOT EXISTS idx_task_logs_source_created_at ON task_logs(source, created_at)")
cursor.execute("CREATE INDEX IF NOT EXISTS idx_task_logs_user_date ON task_logs(user_id, created_at)")
cursor.execute("CREATE INDEX IF NOT EXISTS idx_bug_feedbacks_user_id ON bug_feedbacks(user_id)")
@@ -409,6 +413,9 @@ def ensure_schema(conn) -> None:
cursor.execute("CREATE INDEX IF NOT EXISTS idx_schedule_execution_logs_schedule_id ON schedule_execution_logs(schedule_id)")
cursor.execute("CREATE INDEX IF NOT EXISTS idx_schedule_execution_logs_user_id ON schedule_execution_logs(user_id)")
cursor.execute("CREATE INDEX IF NOT EXISTS idx_schedule_execution_logs_status ON schedule_execution_logs(status)")
cursor.execute("CREATE INDEX IF NOT EXISTS idx_schedule_execution_logs_execute_time ON schedule_execution_logs(execute_time)")
cursor.execute("CREATE INDEX IF NOT EXISTS idx_schedule_execution_logs_schedule_time ON schedule_execution_logs(schedule_id, execute_time)")
cursor.execute("CREATE INDEX IF NOT EXISTS idx_schedule_execution_logs_user_time ON schedule_execution_logs(user_id, execute_time)")
# 初始化VIP配置幂等
try:

View File

@@ -3,13 +3,82 @@
from __future__ import annotations
from datetime import timedelta
from typing import Any, Optional
from typing import Dict
from typing import Any, Dict, Optional
import db_pool
from db.utils import get_cst_now, get_cst_now_str
_THREAT_EVENT_SELECT_COLUMNS = """
id,
threat_type,
score,
rule,
field_name,
matched,
value_preview,
ip,
user_id,
request_method,
request_path,
user_agent,
created_at
"""
def _normalize_page(page: int) -> int:
try:
page_i = int(page)
except Exception:
page_i = 1
return max(1, page_i)
def _normalize_per_page(per_page: int, default: int = 20) -> int:
try:
value = int(per_page)
except Exception:
value = default
return max(1, min(200, value))
def _normalize_limit(limit: int, default: int = 50) -> int:
try:
value = int(limit)
except Exception:
value = default
return max(1, min(200, value))
def _row_value(row, key: str, index: int = 0, default=None):
if row is None:
return default
try:
return row[key]
except Exception:
try:
return row[index]
except Exception:
return default
def _fetch_threat_events_history(where_clause: str, params: tuple[Any, ...], limit_i: int) -> list[dict]:
with db_pool.get_db() as conn:
cursor = conn.cursor()
cursor.execute(
f"""
SELECT
{_THREAT_EVENT_SELECT_COLUMNS}
FROM threat_events
WHERE {where_clause}
ORDER BY created_at DESC, id DESC
LIMIT ?
""",
tuple(params) + (limit_i,),
)
return [dict(r) for r in cursor.fetchall()]
def record_login_context(user_id: int, ip_address: str, user_agent: str) -> Dict[str, bool]:
"""记录登录环境信息,返回是否新设备/新IP。"""
user_id = int(user_id)
@@ -36,7 +105,7 @@ def record_login_context(user_id: int, ip_address: str, user_agent: str) -> Dict
SET last_seen = ?, last_ip = ?
WHERE id = ?
""",
(now_str, ip_text, row["id"] if isinstance(row, dict) else row[0]),
(now_str, ip_text, _row_value(row, "id", 0)),
)
else:
cursor.execute(
@@ -61,7 +130,7 @@ def record_login_context(user_id: int, ip_address: str, user_agent: str) -> Dict
SET last_seen = ?
WHERE id = ?
""",
(now_str, row["id"] if isinstance(row, dict) else row[0]),
(now_str, _row_value(row, "id", 0)),
)
else:
cursor.execute(
@@ -166,15 +235,8 @@ def _build_threat_events_where_clause(filters: Optional[dict]) -> tuple[str, lis
def get_threat_events_list(page: int, per_page: int, filters: Optional[dict] = None) -> dict:
"""分页获取威胁事件。"""
try:
page_i = max(1, int(page))
except Exception:
page_i = 1
try:
per_page_i = int(per_page)
except Exception:
per_page_i = 20
per_page_i = max(1, min(200, per_page_i))
page_i = _normalize_page(page)
per_page_i = _normalize_per_page(per_page, default=20)
where_sql, params = _build_threat_events_where_clause(filters)
offset = (page_i - 1) * per_page_i
@@ -188,19 +250,7 @@ def get_threat_events_list(page: int, per_page: int, filters: Optional[dict] = N
cursor.execute(
f"""
SELECT
id,
threat_type,
score,
rule,
field_name,
matched,
value_preview,
ip,
user_id,
request_method,
request_path,
user_agent,
created_at
{_THREAT_EVENT_SELECT_COLUMNS}
FROM threat_events
{where_sql}
ORDER BY created_at DESC, id DESC
@@ -218,75 +268,20 @@ def get_ip_threat_history(ip: str, limit: int = 50) -> list[dict]:
ip_text = str(ip or "").strip()[:64]
if not ip_text:
return []
try:
limit_i = max(1, min(200, int(limit)))
except Exception:
limit_i = 50
with db_pool.get_db() as conn:
cursor = conn.cursor()
cursor.execute(
"""
SELECT
id,
threat_type,
score,
rule,
field_name,
matched,
value_preview,
ip,
user_id,
request_method,
request_path,
user_agent,
created_at
FROM threat_events
WHERE ip = ?
ORDER BY created_at DESC, id DESC
LIMIT ?
""",
(ip_text, limit_i),
)
return [dict(r) for r in cursor.fetchall()]
limit_i = _normalize_limit(limit, default=50)
return _fetch_threat_events_history("ip = ?", (ip_text,), limit_i)
def get_user_threat_history(user_id: int, limit: int = 50) -> list[dict]:
"""获取用户的威胁历史最近limit条"""
if user_id is None:
return []
try:
user_id_int = int(user_id)
except Exception:
return []
try:
limit_i = max(1, min(200, int(limit)))
except Exception:
limit_i = 50
with db_pool.get_db() as conn:
cursor = conn.cursor()
cursor.execute(
"""
SELECT
id,
threat_type,
score,
rule,
field_name,
matched,
value_preview,
ip,
user_id,
request_method,
request_path,
user_agent,
created_at
FROM threat_events
WHERE user_id = ?
ORDER BY created_at DESC, id DESC
LIMIT ?
""",
(user_id_int, limit_i),
)
return [dict(r) for r in cursor.fetchall()]
limit_i = _normalize_limit(limit, default=50)
return _fetch_threat_events_history("user_id = ?", (user_id_int,), limit_i)

View File

@@ -2,12 +2,135 @@
# -*- coding: utf-8 -*-
from __future__ import annotations
from datetime import datetime
import pytz
from datetime import datetime, timedelta
import db_pool
from db.utils import sanitize_sql_like_pattern
from db.utils import get_cst_now, get_cst_now_str, sanitize_sql_like_pattern
_TASK_STATS_SELECT_SQL = """
SELECT
COUNT(*) as total_tasks,
SUM(CASE WHEN status = 'success' THEN 1 ELSE 0 END) as success_tasks,
SUM(CASE WHEN status = 'failed' THEN 1 ELSE 0 END) as failed_tasks,
SUM(total_items) as total_items,
SUM(total_attachments) as total_attachments
FROM task_logs
"""
_USER_RUN_STATS_SELECT_SQL = """
SELECT
SUM(CASE WHEN status = 'success' THEN 1 ELSE 0 END) as completed,
SUM(CASE WHEN status = 'failed' THEN 1 ELSE 0 END) as failed,
SUM(total_items) as total_items,
SUM(total_attachments) as total_attachments
FROM task_logs
"""
def _build_day_bounds(date_filter: str) -> tuple[str | None, str | None]:
"""将 YYYY-MM-DD 转换为 [day_start, day_end) 区间。"""
try:
day_start = datetime.strptime(str(date_filter), "%Y-%m-%d")
except Exception:
return None, None
day_end = day_start + timedelta(days=1)
return day_start.strftime("%Y-%m-%d %H:%M:%S"), day_end.strftime("%Y-%m-%d %H:%M:%S")
def _normalize_int(value, default: int, *, minimum: int | None = None) -> int:
try:
parsed = int(value)
except Exception:
parsed = default
if minimum is not None and parsed < minimum:
return minimum
return parsed
def _stat_value(row, key: str) -> int:
try:
value = row[key] if row else 0
except Exception:
value = 0
return int(value or 0)
def _build_task_logs_where_sql(
*,
date_filter=None,
status_filter=None,
source_filter=None,
user_id_filter=None,
account_filter=None,
) -> tuple[str, list]:
where_clauses = ["1=1"]
params = []
if date_filter:
day_start, day_end = _build_day_bounds(date_filter)
if day_start and day_end:
where_clauses.append("tl.created_at >= ? AND tl.created_at < ?")
params.extend([day_start, day_end])
else:
where_clauses.append("date(tl.created_at) = ?")
params.append(date_filter)
if status_filter:
where_clauses.append("tl.status = ?")
params.append(status_filter)
if source_filter:
source_filter = str(source_filter or "").strip()
if source_filter == "user_scheduled":
where_clauses.append("tl.source LIKE ? ESCAPE '\\\\'")
params.append("user_scheduled:%")
elif source_filter.endswith("*"):
prefix = source_filter[:-1]
safe_prefix = sanitize_sql_like_pattern(prefix)
where_clauses.append("tl.source LIKE ? ESCAPE '\\\\'")
params.append(f"{safe_prefix}%")
else:
where_clauses.append("tl.source = ?")
params.append(source_filter)
if user_id_filter:
where_clauses.append("tl.user_id = ?")
params.append(user_id_filter)
if account_filter:
safe_filter = sanitize_sql_like_pattern(account_filter)
where_clauses.append("tl.username LIKE ? ESCAPE '\\\\'")
params.append(f"%{safe_filter}%")
return " AND ".join(where_clauses), params
def _fetch_task_stats_row(cursor, *, where_clause: str = "", params: tuple | list = ()) -> dict:
sql = _TASK_STATS_SELECT_SQL
if where_clause:
sql = f"{sql}\nWHERE {where_clause}"
cursor.execute(sql, params)
row = cursor.fetchone()
return {
"total_tasks": _stat_value(row, "total_tasks"),
"success_tasks": _stat_value(row, "success_tasks"),
"failed_tasks": _stat_value(row, "failed_tasks"),
"total_items": _stat_value(row, "total_items"),
"total_attachments": _stat_value(row, "total_attachments"),
}
def _fetch_user_run_stats_row(cursor, *, where_clause: str, params: tuple | list) -> dict:
sql = f"{_USER_RUN_STATS_SELECT_SQL}\nWHERE {where_clause}"
cursor.execute(sql, params)
row = cursor.fetchone()
return {
"completed": _stat_value(row, "completed"),
"failed": _stat_value(row, "failed"),
"total_items": _stat_value(row, "total_items"),
"total_attachments": _stat_value(row, "total_attachments"),
}
def create_task_log(
@@ -25,8 +148,6 @@ def create_task_log(
"""创建任务日志记录"""
with db_pool.get_db() as conn:
cursor = conn.cursor()
cst_tz = pytz.timezone("Asia/Shanghai")
cst_time = datetime.now(cst_tz).strftime("%Y-%m-%d %H:%M:%S")
cursor.execute(
"""
@@ -45,7 +166,7 @@ def create_task_log(
total_attachments,
error_message,
duration,
cst_time,
get_cst_now_str(),
source,
),
)
@@ -64,54 +185,27 @@ def get_task_logs(
account_filter=None,
):
"""获取任务日志列表(支持分页和多种筛选)"""
limit = _normalize_int(limit, 100, minimum=1)
offset = _normalize_int(offset, 0, minimum=0)
with db_pool.get_db() as conn:
cursor = conn.cursor()
where_clauses = ["1=1"]
params = []
if date_filter:
where_clauses.append("date(tl.created_at) = ?")
params.append(date_filter)
if status_filter:
where_clauses.append("tl.status = ?")
params.append(status_filter)
if source_filter:
source_filter = str(source_filter or "").strip()
# 兼容“虚拟来源”:用于筛选 user_scheduled:batch_xxx 这类动态值
if source_filter == "user_scheduled":
where_clauses.append("tl.source LIKE ? ESCAPE '\\\\'")
params.append("user_scheduled:%")
elif source_filter.endswith("*"):
prefix = source_filter[:-1]
safe_prefix = sanitize_sql_like_pattern(prefix)
where_clauses.append("tl.source LIKE ? ESCAPE '\\\\'")
params.append(f"{safe_prefix}%")
else:
where_clauses.append("tl.source = ?")
params.append(source_filter)
if user_id_filter:
where_clauses.append("tl.user_id = ?")
params.append(user_id_filter)
if account_filter:
safe_filter = sanitize_sql_like_pattern(account_filter)
where_clauses.append("tl.username LIKE ? ESCAPE '\\\\'")
params.append(f"%{safe_filter}%")
where_sql = " AND ".join(where_clauses)
where_sql, params = _build_task_logs_where_sql(
date_filter=date_filter,
status_filter=status_filter,
source_filter=source_filter,
user_id_filter=user_id_filter,
account_filter=account_filter,
)
count_sql = f"""
SELECT COUNT(*) as total
FROM task_logs tl
LEFT JOIN users u ON tl.user_id = u.id
WHERE {where_sql}
"""
cursor.execute(count_sql, params)
total = cursor.fetchone()["total"]
total = _stat_value(cursor.fetchone(), "total")
data_sql = f"""
SELECT
@@ -123,9 +217,10 @@ def get_task_logs(
ORDER BY tl.created_at DESC
LIMIT ? OFFSET ?
"""
params.extend([limit, offset])
data_params = list(params)
data_params.extend([limit, offset])
cursor.execute(data_sql, params)
cursor.execute(data_sql, data_params)
logs = [dict(row) for row in cursor.fetchall()]
return {"logs": logs, "total": total}
@@ -133,61 +228,39 @@ def get_task_logs(
def get_task_stats(date_filter=None):
"""获取任务统计信息"""
if date_filter is None:
date_filter = get_cst_now().strftime("%Y-%m-%d")
day_start, day_end = _build_day_bounds(date_filter)
with db_pool.get_db() as conn:
cursor = conn.cursor()
cst_tz = pytz.timezone("Asia/Shanghai")
if date_filter is None:
date_filter = datetime.now(cst_tz).strftime("%Y-%m-%d")
if day_start and day_end:
today_stats = _fetch_task_stats_row(
cursor,
where_clause="created_at >= ? AND created_at < ?",
params=(day_start, day_end),
)
else:
today_stats = _fetch_task_stats_row(
cursor,
where_clause="date(created_at) = ?",
params=(date_filter,),
)
cursor.execute(
"""
SELECT
COUNT(*) as total_tasks,
SUM(CASE WHEN status = 'success' THEN 1 ELSE 0 END) as success_tasks,
SUM(CASE WHEN status = 'failed' THEN 1 ELSE 0 END) as failed_tasks,
SUM(total_items) as total_items,
SUM(total_attachments) as total_attachments
FROM task_logs
WHERE date(created_at) = ?
""",
(date_filter,),
)
today_stats = cursor.fetchone()
total_stats = _fetch_task_stats_row(cursor)
cursor.execute(
"""
SELECT
COUNT(*) as total_tasks,
SUM(CASE WHEN status = 'success' THEN 1 ELSE 0 END) as success_tasks,
SUM(CASE WHEN status = 'failed' THEN 1 ELSE 0 END) as failed_tasks,
SUM(total_items) as total_items,
SUM(total_attachments) as total_attachments
FROM task_logs
"""
)
total_stats = cursor.fetchone()
return {
"today": {
"total_tasks": today_stats["total_tasks"] or 0,
"success_tasks": today_stats["success_tasks"] or 0,
"failed_tasks": today_stats["failed_tasks"] or 0,
"total_items": today_stats["total_items"] or 0,
"total_attachments": today_stats["total_attachments"] or 0,
},
"total": {
"total_tasks": total_stats["total_tasks"] or 0,
"success_tasks": total_stats["success_tasks"] or 0,
"failed_tasks": total_stats["failed_tasks"] or 0,
"total_items": total_stats["total_items"] or 0,
"total_attachments": total_stats["total_attachments"] or 0,
},
}
return {"today": today_stats, "total": total_stats}
def delete_old_task_logs(days=30, batch_size=1000):
"""删除N天前的任务日志分批删除避免长时间锁表"""
days = _normalize_int(days, 30, minimum=0)
batch_size = _normalize_int(batch_size, 1000, minimum=1)
cutoff = (get_cst_now() - timedelta(days=days)).strftime("%Y-%m-%d %H:%M:%S")
total_deleted = 0
while True:
with db_pool.get_db() as conn:
@@ -197,16 +270,16 @@ def delete_old_task_logs(days=30, batch_size=1000):
DELETE FROM task_logs
WHERE rowid IN (
SELECT rowid FROM task_logs
WHERE created_at < datetime('now', 'localtime', '-' || ? || ' days')
WHERE created_at < ?
LIMIT ?
)
""",
(days, batch_size),
(cutoff, batch_size),
)
deleted = cursor.rowcount
conn.commit()
if deleted == 0:
if deleted <= 0:
break
total_deleted += deleted
@@ -215,31 +288,23 @@ def delete_old_task_logs(days=30, batch_size=1000):
def get_user_run_stats(user_id, date_filter=None):
"""获取用户的运行统计信息"""
if date_filter is None:
date_filter = get_cst_now().strftime("%Y-%m-%d")
day_start, day_end = _build_day_bounds(date_filter)
with db_pool.get_db() as conn:
cst_tz = pytz.timezone("Asia/Shanghai")
cursor = conn.cursor()
if date_filter is None:
date_filter = datetime.now(cst_tz).strftime("%Y-%m-%d")
if day_start and day_end:
return _fetch_user_run_stats_row(
cursor,
where_clause="user_id = ? AND created_at >= ? AND created_at < ?",
params=(user_id, day_start, day_end),
)
cursor.execute(
"""
SELECT
SUM(CASE WHEN status = 'success' THEN 1 ELSE 0 END) as completed,
SUM(CASE WHEN status = 'failed' THEN 1 ELSE 0 END) as failed,
SUM(total_items) as total_items,
SUM(total_attachments) as total_attachments
FROM task_logs
WHERE user_id = ? AND date(created_at) = ?
""",
(user_id, date_filter),
return _fetch_user_run_stats_row(
cursor,
where_clause="user_id = ? AND date(created_at) = ?",
params=(user_id, date_filter),
)
stats = cursor.fetchone()
return {
"completed": stats["completed"] or 0,
"failed": stats["failed"] or 0,
"total_items": stats["total_items"] or 0,
"total_attachments": stats["total_attachments"] or 0,
}

View File

@@ -16,8 +16,41 @@ from password_utils import (
verify_password_bcrypt,
verify_password_sha256,
)
logger = get_logger(__name__)
_CST_TZ = pytz.timezone("Asia/Shanghai")
_PERMANENT_VIP_EXPIRE = "2099-12-31 23:59:59"
def _row_to_dict(row):
return dict(row) if row else None
def _get_user_by_field(field_name: str, field_value):
with db_pool.get_db() as conn:
cursor = conn.cursor()
cursor.execute(f"SELECT * FROM users WHERE {field_name} = ?", (field_value,))
return _row_to_dict(cursor.fetchone())
def _parse_cst_datetime(datetime_str: str | None):
if not datetime_str:
return None
try:
naive_dt = datetime.strptime(str(datetime_str), "%Y-%m-%d %H:%M:%S")
return _CST_TZ.localize(naive_dt)
except Exception:
return None
def _format_vip_expire(days: int, *, base_dt: datetime | None = None) -> str:
if int(days) == 999999:
return _PERMANENT_VIP_EXPIRE
if base_dt is None:
base_dt = datetime.now(_CST_TZ)
return (base_dt + timedelta(days=int(days))).strftime("%Y-%m-%d %H:%M:%S")
def get_vip_config():
"""获取VIP配置"""
@@ -32,13 +65,12 @@ def set_default_vip_days(days):
"""设置默认VIP天数"""
with db_pool.get_db() as conn:
cursor = conn.cursor()
cst_time = get_cst_now_str()
cursor.execute(
"""
INSERT OR REPLACE INTO vip_config (id, default_vip_days, updated_at)
VALUES (1, ?, ?)
""",
(days, cst_time),
(days, get_cst_now_str()),
)
conn.commit()
return True
@@ -47,14 +79,8 @@ def set_default_vip_days(days):
def set_user_vip(user_id, days):
"""设置用户VIP - days: 7=一周, 30=一个月, 365=一年, 999999=永久"""
with db_pool.get_db() as conn:
cst_tz = pytz.timezone("Asia/Shanghai")
cursor = conn.cursor()
if days == 999999:
expire_time = "2099-12-31 23:59:59"
else:
expire_time = (datetime.now(cst_tz) + timedelta(days=days)).strftime("%Y-%m-%d %H:%M:%S")
expire_time = _format_vip_expire(days)
cursor.execute("UPDATE users SET vip_expire_time = ? WHERE id = ?", (expire_time, user_id))
conn.commit()
return cursor.rowcount > 0
@@ -63,29 +89,26 @@ def set_user_vip(user_id, days):
def extend_user_vip(user_id, days):
"""延长用户VIP时间"""
user = get_user_by_id(user_id)
cst_tz = pytz.timezone("Asia/Shanghai")
if not user:
return False
current_expire = user.get("vip_expire_time")
now_dt = datetime.now(_CST_TZ)
if current_expire and current_expire != _PERMANENT_VIP_EXPIRE:
expire_time = _parse_cst_datetime(current_expire)
if expire_time is not None:
if expire_time < now_dt:
expire_time = now_dt
new_expire = _format_vip_expire(days, base_dt=expire_time)
else:
logger.warning("解析VIP过期时间失败使用当前时间")
new_expire = _format_vip_expire(days, base_dt=now_dt)
else:
new_expire = _format_vip_expire(days, base_dt=now_dt)
with db_pool.get_db() as conn:
cursor = conn.cursor()
current_expire = user.get("vip_expire_time")
if current_expire and current_expire != "2099-12-31 23:59:59":
try:
expire_time_naive = datetime.strptime(current_expire, "%Y-%m-%d %H:%M:%S")
expire_time = cst_tz.localize(expire_time_naive)
now = datetime.now(cst_tz)
if expire_time < now:
expire_time = now
new_expire = (expire_time + timedelta(days=days)).strftime("%Y-%m-%d %H:%M:%S")
except (ValueError, AttributeError) as e:
logger.warning(f"解析VIP过期时间失败: {e}, 使用当前时间")
new_expire = (datetime.now(cst_tz) + timedelta(days=days)).strftime("%Y-%m-%d %H:%M:%S")
else:
new_expire = (datetime.now(cst_tz) + timedelta(days=days)).strftime("%Y-%m-%d %H:%M:%S")
cursor.execute("UPDATE users SET vip_expire_time = ? WHERE id = ?", (new_expire, user_id))
conn.commit()
return cursor.rowcount > 0
@@ -105,45 +128,49 @@ def is_user_vip(user_id):
注意数据库中存储的时间统一使用CSTAsia/Shanghai时区
"""
cst_tz = pytz.timezone("Asia/Shanghai")
user = get_user_by_id(user_id)
if not user or not user.get("vip_expire_time"):
if not user:
return False
try:
expire_time_naive = datetime.strptime(user["vip_expire_time"], "%Y-%m-%d %H:%M:%S")
expire_time = cst_tz.localize(expire_time_naive)
now = datetime.now(cst_tz)
return now < expire_time
except (ValueError, AttributeError) as e:
logger.warning(f"检查VIP状态失败 (user_id={user_id}): {e}")
vip_expire_time = user.get("vip_expire_time")
if not vip_expire_time:
return False
expire_time = _parse_cst_datetime(vip_expire_time)
if expire_time is None:
logger.warning(f"检查VIP状态失败 (user_id={user_id}): 无法解析时间")
return False
return datetime.now(_CST_TZ) < expire_time
def get_user_vip_info(user_id):
"""获取用户VIP信息"""
cst_tz = pytz.timezone("Asia/Shanghai")
user = get_user_by_id(user_id)
if not user:
return {"is_vip": False, "expire_time": None, "days_left": 0, "username": ""}
vip_expire_time = user.get("vip_expire_time")
username = user.get("username", "")
if not vip_expire_time:
return {"is_vip": False, "expire_time": None, "days_left": 0, "username": user.get("username", "")}
return {"is_vip": False, "expire_time": None, "days_left": 0, "username": username}
try:
expire_time_naive = datetime.strptime(vip_expire_time, "%Y-%m-%d %H:%M:%S")
expire_time = cst_tz.localize(expire_time_naive)
now = datetime.now(cst_tz)
is_vip = now < expire_time
days_left = (expire_time - now).days if is_vip else 0
expire_time = _parse_cst_datetime(vip_expire_time)
if expire_time is None:
logger.warning("VIP信息获取错误: 无法解析过期时间")
return {"is_vip": False, "expire_time": None, "days_left": 0, "username": username}
return {"username": user.get("username", ""), "is_vip": is_vip, "expire_time": vip_expire_time, "days_left": max(0, days_left)}
except Exception as e:
logger.warning(f"VIP信息获取错误: {e}")
return {"is_vip": False, "expire_time": None, "days_left": 0, "username": user.get("username", "")}
now_dt = datetime.now(_CST_TZ)
is_vip = now_dt < expire_time
days_left = (expire_time - now_dt).days if is_vip else 0
return {
"username": username,
"is_vip": is_vip,
"expire_time": vip_expire_time,
"days_left": max(0, days_left),
}
# ==================== 用户相关 ====================
@@ -151,8 +178,6 @@ def get_user_vip_info(user_id):
def create_user(username, password, email=""):
"""创建新用户(默认直接通过,赠送默认VIP)"""
cst_tz = pytz.timezone("Asia/Shanghai")
with db_pool.get_db() as conn:
cursor = conn.cursor()
password_hash = hash_password_bcrypt(password)
@@ -160,12 +185,8 @@ def create_user(username, password, email=""):
default_vip_days = get_vip_config()["default_vip_days"]
vip_expire_time = None
if default_vip_days > 0:
if default_vip_days == 999999:
vip_expire_time = "2099-12-31 23:59:59"
else:
vip_expire_time = (datetime.now(cst_tz) + timedelta(days=default_vip_days)).strftime("%Y-%m-%d %H:%M:%S")
if int(default_vip_days or 0) > 0:
vip_expire_time = _format_vip_expire(int(default_vip_days))
try:
cursor.execute(
@@ -210,28 +231,28 @@ def verify_user(username, password):
def get_user_by_id(user_id):
"""根据ID获取用户"""
with db_pool.get_db() as conn:
cursor = conn.cursor()
cursor.execute("SELECT * FROM users WHERE id = ?", (user_id,))
user = cursor.fetchone()
return dict(user) if user else None
return _get_user_by_field("id", user_id)
def get_user_kdocs_settings(user_id):
"""获取用户的金山文档配置"""
user = get_user_by_id(user_id)
if not user:
return None
return {
"kdocs_unit": user.get("kdocs_unit") or "",
"kdocs_auto_upload": 1 if user.get("kdocs_auto_upload") else 0,
}
with db_pool.get_db() as conn:
cursor = conn.cursor()
cursor.execute("SELECT kdocs_unit, kdocs_auto_upload FROM users WHERE id = ?", (user_id,))
row = cursor.fetchone()
if not row:
return None
return {
"kdocs_unit": (row["kdocs_unit"] or "") if isinstance(row, dict) else (row[0] or ""),
"kdocs_auto_upload": 1 if ((row["kdocs_auto_upload"] if isinstance(row, dict) else row[1]) or 0) else 0,
}
def update_user_kdocs_settings(user_id, *, kdocs_unit=None, kdocs_auto_upload=None) -> bool:
"""更新用户的金山文档配置"""
updates = []
params = []
if kdocs_unit is not None:
updates.append("kdocs_unit = ?")
params.append(kdocs_unit)
@@ -252,11 +273,7 @@ def update_user_kdocs_settings(user_id, *, kdocs_unit=None, kdocs_auto_upload=No
def get_user_by_username(username):
"""根据用户名获取用户"""
with db_pool.get_db() as conn:
cursor = conn.cursor()
cursor.execute("SELECT * FROM users WHERE username = ?", (username,))
user = cursor.fetchone()
return dict(user) if user else None
return _get_user_by_field("username", username)
def get_all_users():
@@ -279,14 +296,13 @@ def approve_user(user_id):
"""审核通过用户"""
with db_pool.get_db() as conn:
cursor = conn.cursor()
cst_time = get_cst_now_str()
cursor.execute(
"""
UPDATE users
SET status = 'approved', approved_at = ?
WHERE id = ?
""",
(cst_time, user_id),
(get_cst_now_str(), user_id),
)
conn.commit()
return cursor.rowcount > 0
@@ -315,5 +331,5 @@ def get_user_stats(user_id):
with db_pool.get_db() as conn:
cursor = conn.cursor()
cursor.execute("SELECT COUNT(*) as count FROM accounts WHERE user_id = ?", (user_id,))
account_count = cursor.fetchone()["count"]
return {"account_count": account_count}
row = cursor.fetchone()
return {"account_count": int((row["count"] if row else 0) or 0)}

View File

@@ -7,8 +7,12 @@
import sqlite3
import threading
from queue import Queue, Empty
import time
from queue import Empty, Full, Queue
from app_logger import get_logger
logger = get_logger("database")
class ConnectionPool:
@@ -44,12 +48,55 @@ class ConnectionPool:
"""创建新的数据库连接"""
conn = sqlite3.connect(self.database, check_same_thread=False)
conn.row_factory = sqlite3.Row
# 启用外键约束,确保 ON DELETE CASCADE 等约束生效
conn.execute("PRAGMA foreign_keys=ON")
# 设置WAL模式提高并发性能
conn.execute("PRAGMA journal_mode=WAL")
# 在WAL模式下使用NORMAL同步兼顾性能与可靠性
conn.execute("PRAGMA synchronous=NORMAL")
# 设置合理的超时时间
conn.execute("PRAGMA busy_timeout=5000")
return conn
def _close_connection(self, conn) -> None:
if conn is None:
return
try:
conn.close()
except Exception as e:
logger.warning(f"关闭连接失败: {e}")
def _is_connection_healthy(self, conn) -> bool:
if conn is None:
return False
try:
conn.rollback()
conn.execute("SELECT 1")
return True
except sqlite3.Error as e:
logger.warning(f"连接健康检查失败(数据库错误): {e}")
except Exception as e:
logger.warning(f"连接健康检查失败(未知错误): {e}")
return False
def _replenish_pool_if_needed(self) -> None:
with self._lock:
if self._pool.qsize() >= self.pool_size:
return
new_conn = None
try:
new_conn = self._create_connection()
self._pool.put(new_conn, block=False)
self._created_connections += 1
except Full:
if new_conn:
self._close_connection(new_conn)
except Exception as e:
if new_conn:
self._close_connection(new_conn)
logger.warning(f"重建连接失败: {e}")
def get_connection(self):
"""
从连接池获取连接
@@ -70,66 +117,20 @@ class ConnectionPool:
Args:
conn: 要归还的连接
"""
import sqlite3
from queue import Full
if conn is None:
return
connection_healthy = False
try:
# 回滚任何未提交的事务
conn.rollback()
# 安全修复:验证连接是否健康,防止损坏的连接污染连接池
conn.execute("SELECT 1")
connection_healthy = True
except sqlite3.Error as e:
# 数据库相关错误,连接可能损坏
print(f"连接健康检查失败(数据库错误): {e}")
except Exception as e:
print(f"连接健康检查失败(未知错误): {e}")
if connection_healthy:
if self._is_connection_healthy(conn):
try:
self._pool.put(conn, block=False)
return # 成功归还
return
except Full:
# 队列已满(不应该发生,但处理它)
print(f"警告: 连接池已满,关闭多余连接")
connection_healthy = False # 标记为需要关闭
logger.warning("连接池已满,关闭多余连接")
self._close_connection(conn)
return
# 连接不健康或队列已满,关闭它
try:
conn.close()
except Exception as close_error:
print(f"关闭连接失败: {close_error}")
# 如果连接不健康,尝试创建新连接补充池
if not connection_healthy:
with self._lock:
# 双重检查:确保池确实需要补充
if self._pool.qsize() < self.pool_size:
new_conn = None
try:
new_conn = self._create_connection()
self._pool.put(new_conn, block=False)
# 只有成功放入池后才增加计数
self._created_connections += 1
except Full:
# 在获取锁期间池被填满了,关闭新建的连接
if new_conn:
try:
new_conn.close()
except Exception:
pass
except Exception as create_error:
# 创建连接失败,确保关闭已创建的连接
if new_conn:
try:
new_conn.close()
except Exception:
pass
print(f"重建连接失败: {create_error}")
self._close_connection(conn)
self._replenish_pool_if_needed()
def close_all(self):
"""关闭所有连接"""
@@ -138,7 +139,7 @@ class ConnectionPool:
conn = self._pool.get(block=False)
conn.close()
except Exception as e:
print(f"关闭连接失败: {e}")
logger.warning(f"关闭连接失败: {e}")
def get_stats(self):
"""获取连接池统计信息"""
@@ -175,14 +176,14 @@ class PooledConnection:
if exc_type is not None:
# 发生异常,回滚事务
self._conn.rollback()
print(f"数据库事务已回滚: {exc_type.__name__}")
logger.warning(f"数据库事务已回滚: {exc_type.__name__}")
# 注意: 不自动commit要求用户显式调用conn.commit()
if self._cursor:
self._cursor.close()
self._cursor = None
except Exception as e:
print(f"关闭游标失败: {e}")
logger.warning(f"关闭游标失败: {e}")
finally:
# 归还连接
self._pool.return_connection(self._conn)
@@ -254,7 +255,7 @@ def init_pool(database, pool_size=5):
with _pool_lock:
if _pool is None:
_pool = ConnectionPool(database, pool_size)
print(f"[OK] 数据库连接池已初始化 (大小: {pool_size})")
logger.info(f"[OK] 数据库连接池已初始化 (大小: {pool_size})")
def get_db():

File diff suppressed because it is too large Load Diff

View File

@@ -2,6 +2,7 @@
# -*- coding: utf-8 -*-
from __future__ import annotations
import json
import os
import time
@@ -9,8 +10,40 @@ from services.runtime import get_logger, get_socketio
from services.state import safe_get_account, safe_iter_task_status_items
def _to_int(value, default: int = 0) -> int:
try:
return int(value)
except Exception:
return int(default)
def _payload_signature(payload: dict) -> str:
try:
return json.dumps(payload, ensure_ascii=False, sort_keys=True, separators=(",", ":"), default=str)
except Exception:
return repr(payload)
def _should_emit(
*,
last_sig: str | None,
last_ts: float,
new_sig: str,
now_ts: float,
min_interval: float,
force_interval: float,
) -> bool:
if last_sig is None:
return True
if (now_ts - last_ts) >= force_interval:
return True
if new_sig != last_sig and (now_ts - last_ts) >= min_interval:
return True
return False
def status_push_worker() -> None:
"""后台线程:按间隔推送排队/运行中任务状态更新(可节流)。"""
"""后台线程:按间隔推送排队/运行中任务状态(变更驱动+心跳兜底)。"""
logger = get_logger()
try:
push_interval = float(os.environ.get("STATUS_PUSH_INTERVAL_SECONDS", "1"))
@@ -18,18 +51,41 @@ def status_push_worker() -> None:
push_interval = 1.0
push_interval = max(0.5, push_interval)
try:
queue_min_interval = float(os.environ.get("STATUS_PUSH_MIN_QUEUE_INTERVAL_SECONDS", str(push_interval)))
except Exception:
queue_min_interval = push_interval
queue_min_interval = max(push_interval, queue_min_interval)
try:
progress_min_interval = float(
os.environ.get("STATUS_PUSH_MIN_PROGRESS_INTERVAL_SECONDS", str(push_interval))
)
except Exception:
progress_min_interval = push_interval
progress_min_interval = max(push_interval, progress_min_interval)
try:
force_interval = float(os.environ.get("STATUS_PUSH_FORCE_INTERVAL_SECONDS", "10"))
except Exception:
force_interval = 10.0
force_interval = max(push_interval, force_interval)
socketio = get_socketio()
from services.tasks import get_task_scheduler
scheduler = get_task_scheduler()
emitted_state: dict[str, dict] = {}
while True:
try:
now_ts = time.time()
queue_snapshot = scheduler.get_queue_state_snapshot()
pending_total = int(queue_snapshot.get("pending_total", 0) or 0)
running_total = int(queue_snapshot.get("running_total", 0) or 0)
running_by_user = queue_snapshot.get("running_by_user") or {}
positions = queue_snapshot.get("positions") or {}
active_account_ids = set()
status_items = safe_iter_task_status_items()
for account_id, status_info in status_items:
@@ -39,11 +95,15 @@ def status_push_worker() -> None:
user_id = status_info.get("user_id")
if not user_id:
continue
active_account_ids.add(str(account_id))
account = safe_get_account(user_id, account_id)
if not account:
continue
user_id_int = _to_int(user_id)
account_data = account.to_dict()
pos = positions.get(account_id) or {}
pos = positions.get(account_id) or positions.get(str(account_id)) or {}
account_data.update(
{
"queue_pending_total": pending_total,
@@ -51,10 +111,23 @@ def status_push_worker() -> None:
"queue_ahead": pos.get("queue_ahead"),
"queue_position": pos.get("queue_position"),
"queue_is_vip": pos.get("is_vip"),
"queue_running_user": int(running_by_user.get(int(user_id), 0) or 0),
"queue_running_user": _to_int(running_by_user.get(user_id_int, running_by_user.get(str(user_id_int), 0))),
}
)
socketio.emit("account_update", account_data, room=f"user_{user_id}")
cache_entry = emitted_state.setdefault(str(account_id), {})
account_sig = _payload_signature(account_data)
if _should_emit(
last_sig=cache_entry.get("account_sig"),
last_ts=float(cache_entry.get("account_ts", 0) or 0),
new_sig=account_sig,
now_ts=now_ts,
min_interval=queue_min_interval,
force_interval=force_interval,
):
socketio.emit("account_update", account_data, room=f"user_{user_id}")
cache_entry["account_sig"] = account_sig
cache_entry["account_ts"] = now_ts
if status != "运行中":
continue
@@ -74,9 +147,26 @@ def status_push_worker() -> None:
"queue_running_total": running_total,
"queue_ahead": pos.get("queue_ahead"),
"queue_position": pos.get("queue_position"),
"queue_running_user": int(running_by_user.get(int(user_id), 0) or 0),
"queue_running_user": _to_int(running_by_user.get(user_id_int, running_by_user.get(str(user_id_int), 0))),
}
socketio.emit("task_progress", progress_data, room=f"user_{user_id}")
progress_sig = _payload_signature(progress_data)
if _should_emit(
last_sig=cache_entry.get("progress_sig"),
last_ts=float(cache_entry.get("progress_ts", 0) or 0),
new_sig=progress_sig,
now_ts=now_ts,
min_interval=progress_min_interval,
force_interval=force_interval,
):
socketio.emit("task_progress", progress_data, room=f"user_{user_id}")
cache_entry["progress_sig"] = progress_sig
cache_entry["progress_ts"] = now_ts
if emitted_state:
stale_ids = [account_id for account_id in emitted_state.keys() if account_id not in active_account_ids]
for account_id in stale_ids:
emitted_state.pop(account_id, None)
time.sleep(push_interval)
except Exception as e:

View File

@@ -8,6 +8,15 @@ admin_api_bp = Blueprint("admin_api", __name__, url_prefix="/yuyx/api")
# Import side effects: register routes on blueprint
from routes.admin_api import core as _core # noqa: F401
from routes.admin_api import system_config_api as _system_config_api # noqa: F401
from routes.admin_api import operations_api as _operations_api # noqa: F401
from routes.admin_api import announcements_api as _announcements_api # noqa: F401
from routes.admin_api import users_api as _users_api # noqa: F401
from routes.admin_api import account_api as _account_api # noqa: F401
from routes.admin_api import feedback_api as _feedback_api # noqa: F401
from routes.admin_api import infra_api as _infra_api # noqa: F401
from routes.admin_api import tasks_api as _tasks_api # noqa: F401
from routes.admin_api import email_api as _email_api # noqa: F401
# Export security blueprint for app registration
from routes.admin_api.security import security_bp # noqa: F401

View File

@@ -0,0 +1,63 @@
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
from __future__ import annotations
import database
from app_security import validate_password
from flask import jsonify, request, session
from routes.admin_api import admin_api_bp
from routes.decorators import admin_required
# ==================== 密码重置 / 反馈(管理员) ====================
@admin_api_bp.route("/admin/password", methods=["PUT"])
@admin_required
def update_admin_password():
"""修改管理员密码"""
data = request.json or {}
new_password = (data.get("new_password") or "").strip()
if not new_password:
return jsonify({"error": "密码不能为空"}), 400
username = session.get("admin_username")
if database.update_admin_password(username, new_password):
return jsonify({"success": True})
return jsonify({"error": "修改失败"}), 400
@admin_api_bp.route("/admin/username", methods=["PUT"])
@admin_required
def update_admin_username():
"""修改管理员用户名"""
data = request.json or {}
new_username = (data.get("new_username") or "").strip()
if not new_username:
return jsonify({"error": "用户名不能为空"}), 400
old_username = session.get("admin_username")
if database.update_admin_username(old_username, new_username):
session["admin_username"] = new_username
return jsonify({"success": True})
return jsonify({"error": "修改失败,用户名可能已存在"}), 400
@admin_api_bp.route("/users/<int:user_id>/reset_password", methods=["POST"])
@admin_required
def admin_reset_password_route(user_id):
"""管理员直接重置用户密码(无需审核)"""
data = request.json or {}
new_password = (data.get("new_password") or "").strip()
if not new_password:
return jsonify({"error": "新密码不能为空"}), 400
is_valid, error_msg = validate_password(new_password)
if not is_valid:
return jsonify({"error": error_msg}), 400
if database.admin_reset_user_password(user_id, new_password):
return jsonify({"message": "密码重置成功"})
return jsonify({"error": "重置失败,用户不存在"}), 400

View File

@@ -0,0 +1,144 @@
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
from __future__ import annotations
import os
import posixpath
import secrets
import time
import database
from app_config import get_config
from app_logger import get_logger
from app_security import is_safe_path, sanitize_filename
from flask import current_app, jsonify, request, url_for
from routes.admin_api import admin_api_bp
from routes.decorators import admin_required
logger = get_logger("app")
config = get_config()
def _get_upload_dir():
rel_dir = getattr(config, "ANNOUNCEMENT_IMAGE_DIR", "static/announcements")
if not is_safe_path(current_app.root_path, rel_dir):
rel_dir = "static/announcements"
abs_dir = os.path.join(current_app.root_path, rel_dir)
os.makedirs(abs_dir, exist_ok=True)
return abs_dir, rel_dir
def _get_file_size(file_storage):
try:
file_storage.stream.seek(0, os.SEEK_END)
size = file_storage.stream.tell()
file_storage.stream.seek(0)
return size
except Exception:
return None
# ==================== 公告管理API管理员 ====================
@admin_api_bp.route("/announcements/upload_image", methods=["POST"])
@admin_required
def admin_upload_announcement_image():
"""上传公告图片返回可访问URL"""
file = request.files.get("file")
if not file or not file.filename:
return jsonify({"error": "请选择图片"}), 400
filename = sanitize_filename(file.filename)
ext = os.path.splitext(filename)[1].lower()
allowed_exts = getattr(config, "ALLOWED_ANNOUNCEMENT_IMAGE_EXTENSIONS", {".png", ".jpg", ".jpeg"})
if not ext or ext not in allowed_exts:
return jsonify({"error": "不支持的图片格式"}), 400
if file.mimetype and not str(file.mimetype).startswith("image/"):
return jsonify({"error": "文件类型无效"}), 400
size = _get_file_size(file)
max_size = int(getattr(config, "MAX_ANNOUNCEMENT_IMAGE_SIZE", 5 * 1024 * 1024))
if size is not None and size > max_size:
max_mb = max_size // 1024 // 1024
return jsonify({"error": f"图片大小不能超过{max_mb}MB"}), 400
abs_dir, rel_dir = _get_upload_dir()
token = secrets.token_hex(6)
name = f"announcement_{int(time.time())}_{token}{ext}"
save_path = os.path.join(abs_dir, name)
file.save(save_path)
static_root = os.path.join(current_app.root_path, "static")
rel_to_static = os.path.relpath(abs_dir, static_root)
if rel_to_static.startswith(".."):
rel_to_static = "announcements"
url_path = posixpath.join(rel_to_static.replace(os.sep, "/"), name)
return jsonify({"success": True, "url": url_for("serve_static", filename=url_path)})
@admin_api_bp.route("/announcements", methods=["GET"])
@admin_required
def admin_get_announcements():
"""获取公告列表"""
try:
limit = int(request.args.get("limit", 50))
offset = int(request.args.get("offset", 0))
except (TypeError, ValueError):
limit, offset = 50, 0
limit = max(1, min(200, limit))
offset = max(0, offset)
return jsonify(database.get_announcements(limit=limit, offset=offset))
@admin_api_bp.route("/announcements", methods=["POST"])
@admin_required
def admin_create_announcement():
"""创建公告(默认启用并替换旧公告)"""
data = request.json or {}
title = (data.get("title") or "").strip()
content = (data.get("content") or "").strip()
image_url = (data.get("image_url") or "").strip()
is_active = bool(data.get("is_active", True))
if image_url and len(image_url) > 1000:
return jsonify({"error": "图片地址过长"}), 400
announcement_id = database.create_announcement(title, content, image_url=image_url, is_active=is_active)
if not announcement_id:
return jsonify({"error": "标题和内容不能为空"}), 400
return jsonify({"success": True, "id": announcement_id})
@admin_api_bp.route("/announcements/<int:announcement_id>/activate", methods=["POST"])
@admin_required
def admin_activate_announcement(announcement_id):
"""启用公告(会自动停用其他公告)"""
if not database.get_announcement_by_id(announcement_id):
return jsonify({"error": "公告不存在"}), 404
ok = database.set_announcement_active(announcement_id, True)
return jsonify({"success": ok})
@admin_api_bp.route("/announcements/<int:announcement_id>/deactivate", methods=["POST"])
@admin_required
def admin_deactivate_announcement(announcement_id):
"""停用公告"""
if not database.get_announcement_by_id(announcement_id):
return jsonify({"error": "公告不存在"}), 404
ok = database.set_announcement_active(announcement_id, False)
return jsonify({"success": ok})
@admin_api_bp.route("/announcements/<int:announcement_id>", methods=["DELETE"])
@admin_required
def admin_delete_announcement(announcement_id):
"""删除公告"""
if not database.get_announcement_by_id(announcement_id):
return jsonify({"error": "公告不存在"}), 404
ok = database.delete_announcement(announcement_id)
return jsonify({"success": ok})

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,214 @@
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
from __future__ import annotations
import email_service
from app_logger import get_logger
from app_security import validate_email
from flask import jsonify, request
from routes.admin_api import admin_api_bp
from routes.decorators import admin_required
logger = get_logger("app")
@admin_api_bp.route("/email/settings", methods=["GET"])
@admin_required
def get_email_settings_api():
"""获取全局邮件设置"""
try:
settings = email_service.get_email_settings()
return jsonify(settings)
except Exception as e:
logger.error(f"获取邮件设置失败: {e}")
return jsonify({"error": str(e)}), 500
@admin_api_bp.route("/email/settings", methods=["POST"])
@admin_required
def update_email_settings_api():
"""更新全局邮件设置"""
try:
data = request.json or {}
enabled = data.get("enabled", False)
failover_enabled = data.get("failover_enabled", True)
register_verify_enabled = data.get("register_verify_enabled")
login_alert_enabled = data.get("login_alert_enabled")
base_url = data.get("base_url")
task_notify_enabled = data.get("task_notify_enabled")
email_service.update_email_settings(
enabled=enabled,
failover_enabled=failover_enabled,
register_verify_enabled=register_verify_enabled,
login_alert_enabled=login_alert_enabled,
base_url=base_url,
task_notify_enabled=task_notify_enabled,
)
return jsonify({"success": True})
except Exception as e:
logger.error(f"更新邮件设置失败: {e}")
return jsonify({"error": str(e)}), 500
@admin_api_bp.route("/smtp/configs", methods=["GET"])
@admin_required
def get_smtp_configs_api():
"""获取所有SMTP配置列表"""
try:
configs = email_service.get_smtp_configs(include_password=False)
return jsonify(configs)
except Exception as e:
logger.error(f"获取SMTP配置失败: {e}")
return jsonify({"error": str(e)}), 500
@admin_api_bp.route("/smtp/configs", methods=["POST"])
@admin_required
def create_smtp_config_api():
"""创建SMTP配置"""
try:
data = request.json or {}
if not data.get("host"):
return jsonify({"error": "SMTP服务器地址不能为空"}), 400
if not data.get("username"):
return jsonify({"error": "SMTP用户名不能为空"}), 400
config_id = email_service.create_smtp_config(data)
return jsonify({"success": True, "id": config_id})
except Exception as e:
logger.error(f"创建SMTP配置失败: {e}")
return jsonify({"error": str(e)}), 500
@admin_api_bp.route("/smtp/configs/<int:config_id>", methods=["GET"])
@admin_required
def get_smtp_config_api(config_id):
"""获取单个SMTP配置详情"""
try:
config_data = email_service.get_smtp_config(config_id, include_password=False)
if not config_data:
return jsonify({"error": "配置不存在"}), 404
return jsonify(config_data)
except Exception as e:
logger.error(f"获取SMTP配置失败: {e}")
return jsonify({"error": str(e)}), 500
@admin_api_bp.route("/smtp/configs/<int:config_id>", methods=["PUT"])
@admin_required
def update_smtp_config_api(config_id):
"""更新SMTP配置"""
try:
data = request.json or {}
if email_service.update_smtp_config(config_id, data):
return jsonify({"success": True})
return jsonify({"error": "更新失败"}), 400
except Exception as e:
logger.error(f"更新SMTP配置失败: {e}")
return jsonify({"error": str(e)}), 500
@admin_api_bp.route("/smtp/configs/<int:config_id>", methods=["DELETE"])
@admin_required
def delete_smtp_config_api(config_id):
"""删除SMTP配置"""
try:
if email_service.delete_smtp_config(config_id):
return jsonify({"success": True})
return jsonify({"error": "删除失败"}), 400
except Exception as e:
logger.error(f"删除SMTP配置失败: {e}")
return jsonify({"error": str(e)}), 500
@admin_api_bp.route("/smtp/configs/<int:config_id>/test", methods=["POST"])
@admin_required
def test_smtp_config_api(config_id):
"""测试SMTP配置"""
try:
data = request.json or {}
test_email = str(data.get("email", "") or "").strip()
if not test_email:
return jsonify({"error": "请提供测试邮箱"}), 400
is_valid, error_msg = validate_email(test_email)
if not is_valid:
return jsonify({"error": error_msg}), 400
result = email_service.test_smtp_config(config_id, test_email)
return jsonify(result)
except Exception as e:
logger.error(f"测试SMTP配置失败: {e}")
return jsonify({"success": False, "error": str(e)}), 500
@admin_api_bp.route("/smtp/configs/<int:config_id>/primary", methods=["POST"])
@admin_required
def set_primary_smtp_config_api(config_id):
"""设置主SMTP配置"""
try:
if email_service.set_primary_smtp_config(config_id):
return jsonify({"success": True})
return jsonify({"error": "设置失败"}), 400
except Exception as e:
logger.error(f"设置主SMTP配置失败: {e}")
return jsonify({"error": str(e)}), 500
@admin_api_bp.route("/smtp/configs/primary/clear", methods=["POST"])
@admin_required
def clear_primary_smtp_config_api():
"""取消主SMTP配置"""
try:
email_service.clear_primary_smtp_config()
return jsonify({"success": True})
except Exception as e:
logger.error(f"取消主SMTP配置失败: {e}")
return jsonify({"error": str(e)}), 500
@admin_api_bp.route("/email/stats", methods=["GET"])
@admin_required
def get_email_stats_api():
"""获取邮件发送统计"""
try:
stats = email_service.get_email_stats()
return jsonify(stats)
except Exception as e:
logger.error(f"获取邮件统计失败: {e}")
return jsonify({"error": str(e)}), 500
@admin_api_bp.route("/email/logs", methods=["GET"])
@admin_required
def get_email_logs_api():
"""获取邮件发送日志"""
try:
page = request.args.get("page", 1, type=int)
page_size = request.args.get("page_size", 20, type=int)
email_type = request.args.get("type", None)
status = request.args.get("status", None)
page_size = min(max(page_size, 10), 100)
result = email_service.get_email_logs(page, page_size, email_type, status)
return jsonify(result)
except Exception as e:
logger.error(f"获取邮件日志失败: {e}")
return jsonify({"error": str(e)}), 500
@admin_api_bp.route("/email/logs/cleanup", methods=["POST"])
@admin_required
def cleanup_email_logs_api():
"""清理过期邮件日志"""
try:
data = request.json or {}
days = data.get("days", 30)
days = min(max(days, 7), 365)
deleted = email_service.cleanup_email_logs(days)
return jsonify({"success": True, "deleted": deleted})
except Exception as e:
logger.error(f"清理邮件日志失败: {e}")
return jsonify({"error": str(e)}), 500

View File

@@ -0,0 +1,58 @@
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
from __future__ import annotations
import database
from flask import jsonify, request
from routes.admin_api import admin_api_bp
from routes.decorators import admin_required
@admin_api_bp.route("/feedbacks", methods=["GET"])
@admin_required
def get_all_feedbacks():
"""管理员获取所有反馈"""
status = request.args.get("status")
try:
limit = int(request.args.get("limit", 100))
offset = int(request.args.get("offset", 0))
limit = min(max(1, limit), 1000)
offset = max(0, offset)
except (ValueError, TypeError):
return jsonify({"error": "无效的分页参数"}), 400
feedbacks = database.get_bug_feedbacks(limit=limit, offset=offset, status_filter=status)
stats = database.get_feedback_stats()
return jsonify({"feedbacks": feedbacks, "stats": stats})
@admin_api_bp.route("/feedbacks/<int:feedback_id>/reply", methods=["POST"])
@admin_required
def reply_to_feedback(feedback_id):
"""管理员回复反馈"""
data = request.get_json() or {}
reply = (data.get("reply") or "").strip()
if not reply:
return jsonify({"error": "回复内容不能为空"}), 400
if database.reply_feedback(feedback_id, reply):
return jsonify({"message": "回复成功"})
return jsonify({"error": "反馈不存在"}), 404
@admin_api_bp.route("/feedbacks/<int:feedback_id>/close", methods=["POST"])
@admin_required
def close_feedback_api(feedback_id):
"""管理员关闭反馈"""
if database.close_feedback(feedback_id):
return jsonify({"message": "已关闭"})
return jsonify({"error": "反馈不存在"}), 404
@admin_api_bp.route("/feedbacks/<int:feedback_id>", methods=["DELETE"])
@admin_required
def delete_feedback_api(feedback_id):
"""管理员删除反馈"""
if database.delete_feedback(feedback_id):
return jsonify({"message": "已删除"})
return jsonify({"error": "反馈不存在"}), 404

View File

@@ -0,0 +1,226 @@
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
from __future__ import annotations
import os
import time
from datetime import datetime
import database
from app_logger import get_logger
from flask import jsonify, session
from routes.admin_api import admin_api_bp
from routes.decorators import admin_required
from services.time_utils import BEIJING_TZ, get_beijing_now
logger = get_logger("app")
@admin_api_bp.route("/stats", methods=["GET"])
@admin_required
def get_system_stats():
"""获取系统统计"""
stats = database.get_system_stats()
stats["admin_username"] = session.get("admin_username", "admin")
return jsonify(stats)
@admin_api_bp.route("/browser_pool/stats", methods=["GET"])
@admin_required
def get_browser_pool_stats():
"""获取截图线程池状态"""
try:
from browser_pool_worker import get_browser_worker_pool
pool = get_browser_worker_pool()
stats = pool.get_stats() or {}
worker_details = []
for w in stats.get("workers") or []:
last_ts = float(w.get("last_active_ts") or 0)
last_active_at = None
if last_ts > 0:
try:
last_active_at = datetime.fromtimestamp(last_ts, tz=BEIJING_TZ).strftime("%Y-%m-%d %H:%M:%S")
except Exception:
last_active_at = None
created_ts = w.get("browser_created_at")
created_at = None
if created_ts:
try:
created_at = datetime.fromtimestamp(float(created_ts), tz=BEIJING_TZ).strftime("%Y-%m-%d %H:%M:%S")
except Exception:
created_at = None
worker_details.append(
{
"worker_id": w.get("worker_id"),
"idle": bool(w.get("idle")),
"has_browser": bool(w.get("has_browser")),
"total_tasks": int(w.get("total_tasks") or 0),
"failed_tasks": int(w.get("failed_tasks") or 0),
"browser_use_count": int(w.get("browser_use_count") or 0),
"browser_created_at": created_at,
"browser_created_ts": created_ts,
"last_active_at": last_active_at,
"last_active_ts": last_ts,
"thread_alive": bool(w.get("thread_alive")),
}
)
total_workers = len(worker_details) if worker_details else int(stats.get("pool_size") or 0)
return jsonify(
{
"total_workers": total_workers,
"active_workers": int(stats.get("busy_workers") or 0),
"idle_workers": int(stats.get("idle_workers") or 0),
"queue_size": int(stats.get("queue_size") or 0),
"workers": worker_details,
"summary": {
"total_tasks": int(stats.get("total_tasks") or 0),
"failed_tasks": int(stats.get("failed_tasks") or 0),
"success_rate": stats.get("success_rate"),
},
"server_time_cst": get_beijing_now().strftime("%Y-%m-%d %H:%M:%S"),
}
)
except Exception as e:
logger.exception(f"[AdminAPI] 获取截图线程池状态失败: {e}")
return jsonify({"error": "获取截图线程池状态失败"}), 500
@admin_api_bp.route("/docker_stats", methods=["GET"])
@admin_required
def get_docker_stats():
"""获取Docker容器运行状态"""
import subprocess
docker_status = {
"running": False,
"container_name": "N/A",
"uptime": "N/A",
"memory_usage": "N/A",
"memory_limit": "N/A",
"memory_percent": "N/A",
"cpu_percent": "N/A",
"status": "Unknown",
}
try:
if os.path.exists("/.dockerenv"):
docker_status["running"] = True
try:
with open("/etc/hostname", "r") as f:
docker_status["container_name"] = f.read().strip()
except Exception as e:
logger.debug(f"读取容器名称失败: {e}")
try:
if os.path.exists("/sys/fs/cgroup/memory.current"):
with open("/sys/fs/cgroup/memory.current", "r") as f:
mem_total = int(f.read().strip())
cache = 0
if os.path.exists("/sys/fs/cgroup/memory.stat"):
with open("/sys/fs/cgroup/memory.stat", "r") as f:
for line in f:
if line.startswith("inactive_file "):
cache = int(line.split()[1])
break
mem_bytes = mem_total - cache
docker_status["memory_usage"] = "{:.2f} MB".format(mem_bytes / 1024 / 1024)
if os.path.exists("/sys/fs/cgroup/memory.max"):
with open("/sys/fs/cgroup/memory.max", "r") as f:
limit_str = f.read().strip()
if limit_str != "max":
limit_bytes = int(limit_str)
docker_status["memory_limit"] = "{:.2f} GB".format(limit_bytes / 1024 / 1024 / 1024)
docker_status["memory_percent"] = "{:.2f}%".format(mem_bytes / limit_bytes * 100)
elif os.path.exists("/sys/fs/cgroup/memory/memory.usage_in_bytes"):
with open("/sys/fs/cgroup/memory/memory.usage_in_bytes", "r") as f:
mem_bytes = int(f.read().strip())
docker_status["memory_usage"] = "{:.2f} MB".format(mem_bytes / 1024 / 1024)
with open("/sys/fs/cgroup/memory/memory.limit_in_bytes", "r") as f:
limit_bytes = int(f.read().strip())
if limit_bytes < 1e18:
docker_status["memory_limit"] = "{:.2f} GB".format(limit_bytes / 1024 / 1024 / 1024)
docker_status["memory_percent"] = "{:.2f}%".format(mem_bytes / limit_bytes * 100)
except Exception as e:
logger.debug(f"读取内存信息失败: {e}")
try:
if os.path.exists("/sys/fs/cgroup/cpu.stat"):
cpu_usage = 0
with open("/sys/fs/cgroup/cpu.stat", "r") as f:
for line in f:
if line.startswith("usage_usec"):
cpu_usage = int(line.split()[1])
break
time.sleep(0.1)
cpu_usage2 = 0
with open("/sys/fs/cgroup/cpu.stat", "r") as f:
for line in f:
if line.startswith("usage_usec"):
cpu_usage2 = int(line.split()[1])
break
cpu_percent = (cpu_usage2 - cpu_usage) / 0.1 / 1e6 * 100
docker_status["cpu_percent"] = "{:.2f}%".format(cpu_percent)
elif os.path.exists("/sys/fs/cgroup/cpu/cpuacct.usage"):
with open("/sys/fs/cgroup/cpu/cpuacct.usage", "r") as f:
cpu_usage = int(f.read().strip())
time.sleep(0.1)
with open("/sys/fs/cgroup/cpu/cpuacct.usage", "r") as f:
cpu_usage2 = int(f.read().strip())
cpu_percent = (cpu_usage2 - cpu_usage) / 0.1 / 1e9 * 100
docker_status["cpu_percent"] = "{:.2f}%".format(cpu_percent)
except Exception as e:
logger.debug(f"读取CPU信息失败: {e}")
try:
# 读取系统运行时间
with open('/proc/uptime', 'r') as f:
system_uptime = float(f.read().split()[0])
# 读取 PID 1 的启动时间 (jiffies)
with open('/proc/1/stat', 'r') as f:
stat = f.read().split()
starttime_jiffies = int(stat[21])
# 获取 CLK_TCK (通常是 100)
clk_tck = os.sysconf(os.sysconf_names['SC_CLK_TCK'])
# 计算容器运行时长(秒)
container_uptime_seconds = system_uptime - (starttime_jiffies / clk_tck)
# 格式化为可读字符串
days = int(container_uptime_seconds // 86400)
hours = int((container_uptime_seconds % 86400) // 3600)
minutes = int((container_uptime_seconds % 3600) // 60)
if days > 0:
docker_status["uptime"] = f"{days}{hours}小时{minutes}分钟"
elif hours > 0:
docker_status["uptime"] = f"{hours}小时{minutes}分钟"
else:
docker_status["uptime"] = f"{minutes}分钟"
except Exception as e:
logger.debug(f"获取容器运行时间失败: {e}")
docker_status["status"] = "Running"
else:
docker_status["status"] = "Not in Docker"
except Exception as e:
docker_status["status"] = f"Error: {str(e)}"
return jsonify(docker_status)

View File

@@ -0,0 +1,228 @@
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
from __future__ import annotations
import threading
import time
from datetime import datetime
import database
import requests
from app_logger import get_logger
from app_security import is_safe_outbound_url
from flask import jsonify, request
from routes.admin_api import admin_api_bp
from routes.decorators import admin_required
from services.scheduler import run_scheduled_task
from services.time_utils import BEIJING_TZ, get_beijing_now
logger = get_logger("app")
_server_cpu_percent_lock = threading.Lock()
_server_cpu_percent_last: float | None = None
_server_cpu_percent_last_ts = 0.0
def _get_server_cpu_percent() -> float:
import psutil
global _server_cpu_percent_last, _server_cpu_percent_last_ts
now = time.time()
with _server_cpu_percent_lock:
if _server_cpu_percent_last is not None and (now - _server_cpu_percent_last_ts) < 0.5:
return _server_cpu_percent_last
try:
if _server_cpu_percent_last is None:
cpu_percent = float(psutil.cpu_percent(interval=0.1))
else:
cpu_percent = float(psutil.cpu_percent(interval=None))
except Exception:
cpu_percent = float(_server_cpu_percent_last or 0.0)
if cpu_percent < 0:
cpu_percent = 0.0
_server_cpu_percent_last = cpu_percent
_server_cpu_percent_last_ts = now
return cpu_percent
@admin_api_bp.route("/kdocs/status", methods=["GET"])
@admin_required
def get_kdocs_status_api():
"""获取金山文档上传状态"""
try:
from services.kdocs_uploader import get_kdocs_uploader
uploader = get_kdocs_uploader()
status = uploader.get_status()
live = str(request.args.get("live", "")).lower() in ("1", "true", "yes")
if live:
live_status = uploader.refresh_login_status()
if live_status.get("success"):
logged_in = bool(live_status.get("logged_in"))
status["logged_in"] = logged_in
status["last_login_ok"] = logged_in
status["login_required"] = not logged_in
if live_status.get("error"):
status["last_error"] = live_status.get("error")
else:
status["logged_in"] = True if status.get("last_login_ok") else False if status.get("last_login_ok") is False else None
if status.get("last_login_ok") is True and status.get("last_error") == "操作超时":
status["last_error"] = None
return jsonify(status)
except Exception as e:
return jsonify({"error": f"获取状态失败: {e}"}), 500
@admin_api_bp.route("/kdocs/qr", methods=["POST"])
@admin_required
def get_kdocs_qr_api():
"""获取金山文档登录二维码"""
try:
from services.kdocs_uploader import get_kdocs_uploader
uploader = get_kdocs_uploader()
data = request.get_json(silent=True) or {}
force = bool(data.get("force"))
if not force:
force = str(request.args.get("force", "")).lower() in ("1", "true", "yes")
result = uploader.request_qr(force=force)
if not result.get("success"):
return jsonify({"error": result.get("error", "获取二维码失败")}), 400
return jsonify(result)
except Exception as e:
return jsonify({"error": f"获取二维码失败: {e}"}), 500
@admin_api_bp.route("/kdocs/clear-login", methods=["POST"])
@admin_required
def clear_kdocs_login_api():
"""清除金山文档登录态"""
try:
from services.kdocs_uploader import get_kdocs_uploader
uploader = get_kdocs_uploader()
result = uploader.clear_login()
if not result.get("success"):
return jsonify({"error": result.get("error", "清除失败")}), 400
return jsonify({"success": True})
except Exception as e:
return jsonify({"error": f"清除失败: {e}"}), 500
@admin_api_bp.route("/schedule/execute", methods=["POST"])
@admin_required
def execute_schedule_now():
"""立即执行定时任务(无视定时时间和星期限制)"""
try:
threading.Thread(target=run_scheduled_task, args=(True,), daemon=True).start()
logger.info("[立即执行定时任务] 管理员手动触发定时任务执行(跳过星期检查)")
return jsonify({"message": "定时任务已开始执行,请查看任务列表获取进度"})
except Exception as e:
logger.error(f"[立即执行定时任务] 启动失败: {str(e)}")
return jsonify({"error": f"启动失败: {str(e)}"}), 500
@admin_api_bp.route("/proxy/config", methods=["GET"])
@admin_required
def get_proxy_config_api():
"""获取代理配置"""
config_data = database.get_system_config()
return jsonify(
{
"proxy_enabled": config_data.get("proxy_enabled", 0),
"proxy_api_url": config_data.get("proxy_api_url", ""),
"proxy_expire_minutes": config_data.get("proxy_expire_minutes", 3),
}
)
@admin_api_bp.route("/proxy/config", methods=["POST"])
@admin_required
def update_proxy_config_api():
"""更新代理配置"""
data = request.json or {}
proxy_enabled = data.get("proxy_enabled")
proxy_api_url = (data.get("proxy_api_url", "") or "").strip()
proxy_expire_minutes = data.get("proxy_expire_minutes")
if proxy_enabled is not None and proxy_enabled not in [0, 1]:
return jsonify({"error": "proxy_enabled必须是0或1"}), 400
if proxy_expire_minutes is not None:
if not isinstance(proxy_expire_minutes, int) or proxy_expire_minutes < 1:
return jsonify({"error": "代理有效期必须是大于0的整数"}), 400
if database.update_system_config(
proxy_enabled=proxy_enabled,
proxy_api_url=proxy_api_url,
proxy_expire_minutes=proxy_expire_minutes,
):
return jsonify({"message": "代理配置已更新"})
return jsonify({"error": "更新失败"}), 400
@admin_api_bp.route("/proxy/test", methods=["POST"])
@admin_required
def test_proxy_api():
"""测试代理连接"""
data = request.json or {}
api_url = (data.get("api_url") or "").strip()
if not api_url:
return jsonify({"error": "请提供API地址"}), 400
if not is_safe_outbound_url(api_url):
return jsonify({"error": "API地址不可用或不安全"}), 400
try:
response = requests.get(api_url, timeout=10)
if response.status_code == 200:
ip_port = response.text.strip()
if ip_port and ":" in ip_port:
return jsonify({"success": True, "proxy": ip_port, "message": f"代理获取成功: {ip_port}"})
return jsonify({"success": False, "message": f"代理格式错误: {ip_port}"}), 400
return jsonify({"success": False, "message": f"HTTP错误: {response.status_code}"}), 400
except Exception as e:
return jsonify({"success": False, "message": f"连接失败: {str(e)}"}), 500
@admin_api_bp.route("/server/info", methods=["GET"])
@admin_required
def get_server_info_api():
"""获取服务器信息"""
import psutil
cpu_percent = _get_server_cpu_percent()
memory = psutil.virtual_memory()
memory_total = f"{memory.total / (1024**3):.1f}GB"
memory_used = f"{memory.used / (1024**3):.1f}GB"
memory_percent = memory.percent
disk = psutil.disk_usage("/")
disk_total = f"{disk.total / (1024**3):.1f}GB"
disk_used = f"{disk.used / (1024**3):.1f}GB"
disk_percent = disk.percent
boot_time = datetime.fromtimestamp(psutil.boot_time(), tz=BEIJING_TZ)
uptime_delta = get_beijing_now() - boot_time
days = uptime_delta.days
hours = uptime_delta.seconds // 3600
uptime = f"{days}{hours}小时"
return jsonify(
{
"cpu_percent": cpu_percent,
"memory_total": memory_total,
"memory_used": memory_used,
"memory_percent": memory_percent,
"disk_total": disk_total,
"disk_used": disk_used,
"disk_percent": disk_percent,
"uptime": uptime,
}
)

View File

@@ -62,6 +62,19 @@ def _parse_bool(value: Any) -> bool:
return text in {"1", "true", "yes", "y", "on"}
def _parse_int(value: Any, *, default: int | None = None, min_value: int | None = None) -> int | None:
try:
parsed = int(value)
except Exception:
parsed = default
if parsed is None:
return None
if min_value is not None:
parsed = max(int(min_value), parsed)
return parsed
def _sanitize_threat_event(event: dict) -> dict:
return {
"id": event.get("id"),
@@ -199,10 +212,7 @@ def ban_ip():
if not reason:
return jsonify({"error": "reason不能为空"}), 400
try:
duration_hours = max(1, int(duration_hours_raw))
except Exception:
duration_hours = 24
duration_hours = _parse_int(duration_hours_raw, default=24, min_value=1) or 24
ok = blacklist.ban_ip(ip, reason, duration_hours=duration_hours, permanent=permanent)
if not ok:
@@ -235,20 +245,14 @@ def ban_user():
duration_hours_raw = data.get("duration_hours", 24)
permanent = _parse_bool(data.get("permanent", False))
try:
user_id = int(user_id_raw)
except Exception:
user_id = None
user_id = _parse_int(user_id_raw)
if user_id is None:
return jsonify({"error": "user_id不能为空"}), 400
if not reason:
return jsonify({"error": "reason不能为空"}), 400
try:
duration_hours = max(1, int(duration_hours_raw))
except Exception:
duration_hours = 24
duration_hours = _parse_int(duration_hours_raw, default=24, min_value=1) or 24
ok = blacklist._ban_user_internal(user_id, reason=reason, duration_hours=duration_hours, permanent=permanent)
if not ok:
@@ -262,10 +266,7 @@ def unban_user():
"""解除用户封禁"""
data = _parse_json()
user_id_raw = data.get("user_id")
try:
user_id = int(user_id_raw)
except Exception:
user_id = None
user_id = _parse_int(user_id_raw)
if user_id is None:
return jsonify({"error": "user_id不能为空"}), 400

View File

@@ -0,0 +1,228 @@
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
from __future__ import annotations
import database
from app_logger import get_logger
from app_security import is_safe_outbound_url, validate_email
from flask import jsonify, request
from routes.admin_api import admin_api_bp
from routes.decorators import admin_required
from services.browse_types import BROWSE_TYPE_SHOULD_READ, validate_browse_type
from services.tasks import get_task_scheduler
logger = get_logger("app")
@admin_api_bp.route("/system/config", methods=["GET"])
@admin_required
def get_system_config_api():
"""获取系统配置"""
return jsonify(database.get_system_config())
@admin_api_bp.route("/system/config", methods=["POST"])
@admin_required
def update_system_config_api():
"""更新系统配置"""
data = request.json or {}
max_concurrent = data.get("max_concurrent_global")
schedule_enabled = data.get("schedule_enabled")
schedule_time = data.get("schedule_time")
schedule_browse_type = data.get("schedule_browse_type")
schedule_weekdays = data.get("schedule_weekdays")
new_max_concurrent_per_account = data.get("max_concurrent_per_account")
new_max_screenshot_concurrent = data.get("max_screenshot_concurrent")
enable_screenshot = data.get("enable_screenshot")
auto_approve_enabled = data.get("auto_approve_enabled")
auto_approve_hourly_limit = data.get("auto_approve_hourly_limit")
auto_approve_vip_days = data.get("auto_approve_vip_days")
kdocs_enabled = data.get("kdocs_enabled")
kdocs_doc_url = data.get("kdocs_doc_url")
kdocs_default_unit = data.get("kdocs_default_unit")
kdocs_sheet_name = data.get("kdocs_sheet_name")
kdocs_sheet_index = data.get("kdocs_sheet_index")
kdocs_unit_column = data.get("kdocs_unit_column")
kdocs_image_column = data.get("kdocs_image_column")
kdocs_admin_notify_enabled = data.get("kdocs_admin_notify_enabled")
kdocs_admin_notify_email = data.get("kdocs_admin_notify_email")
kdocs_row_start = data.get("kdocs_row_start")
kdocs_row_end = data.get("kdocs_row_end")
if max_concurrent is not None:
if not isinstance(max_concurrent, int) or max_concurrent < 1:
return jsonify({"error": "全局并发数必须大于0建议小型服务器2-5中型5-10大型10-20"}), 400
if new_max_concurrent_per_account is not None:
if not isinstance(new_max_concurrent_per_account, int) or new_max_concurrent_per_account < 1:
return jsonify({"error": "单账号并发数必须大于0建议设为1避免同一用户任务相互影响"}), 400
if new_max_screenshot_concurrent is not None:
if not isinstance(new_max_screenshot_concurrent, int) or new_max_screenshot_concurrent < 1:
return jsonify({"error": "截图并发数必须大于0建议根据服务器配置设置wkhtmltoimage 资源占用较低)"}), 400
if enable_screenshot is not None:
if isinstance(enable_screenshot, bool):
enable_screenshot = 1 if enable_screenshot else 0
if enable_screenshot not in (0, 1):
return jsonify({"error": "截图开关必须是0或1"}), 400
if schedule_time is not None:
import re
if not re.match(r"^([01]\\d|2[0-3]):([0-5]\\d)$", schedule_time):
return jsonify({"error": "时间格式错误,应为 HH:MM"}), 400
if schedule_browse_type is not None:
normalized = validate_browse_type(schedule_browse_type, default=BROWSE_TYPE_SHOULD_READ)
if not normalized:
return jsonify({"error": "浏览类型无效"}), 400
schedule_browse_type = normalized
if schedule_weekdays is not None:
try:
days = [int(d.strip()) for d in schedule_weekdays.split(",") if d.strip()]
if not all(1 <= d <= 7 for d in days):
return jsonify({"error": "星期数字必须在1-7之间"}), 400
except (ValueError, AttributeError):
return jsonify({"error": "星期格式错误"}), 400
if auto_approve_hourly_limit is not None:
if not isinstance(auto_approve_hourly_limit, int) or auto_approve_hourly_limit < 1:
return jsonify({"error": "每小时注册限制必须大于0"}), 400
if auto_approve_vip_days is not None:
if not isinstance(auto_approve_vip_days, int) or auto_approve_vip_days < 0:
return jsonify({"error": "注册赠送VIP天数不能为负数"}), 400
if kdocs_enabled is not None:
if isinstance(kdocs_enabled, bool):
kdocs_enabled = 1 if kdocs_enabled else 0
if kdocs_enabled not in (0, 1):
return jsonify({"error": "表格上传开关必须是0或1"}), 400
if kdocs_doc_url is not None:
kdocs_doc_url = str(kdocs_doc_url or "").strip()
if kdocs_doc_url and not is_safe_outbound_url(kdocs_doc_url):
return jsonify({"error": "文档链接格式不正确"}), 400
if kdocs_default_unit is not None:
kdocs_default_unit = str(kdocs_default_unit or "").strip()
if len(kdocs_default_unit) > 50:
return jsonify({"error": "默认县区长度不能超过50"}), 400
if kdocs_sheet_name is not None:
kdocs_sheet_name = str(kdocs_sheet_name or "").strip()
if len(kdocs_sheet_name) > 50:
return jsonify({"error": "Sheet名称长度不能超过50"}), 400
if kdocs_sheet_index is not None:
try:
kdocs_sheet_index = int(kdocs_sheet_index)
except Exception:
return jsonify({"error": "Sheet序号必须是数字"}), 400
if kdocs_sheet_index < 0:
return jsonify({"error": "Sheet序号不能为负数"}), 400
if kdocs_unit_column is not None:
kdocs_unit_column = str(kdocs_unit_column or "").strip().upper()
if not kdocs_unit_column:
return jsonify({"error": "县区列不能为空"}), 400
import re
if not re.match(r"^[A-Z]{1,3}$", kdocs_unit_column):
return jsonify({"error": "县区列格式错误"}), 400
if kdocs_image_column is not None:
kdocs_image_column = str(kdocs_image_column or "").strip().upper()
if not kdocs_image_column:
return jsonify({"error": "图片列不能为空"}), 400
import re
if not re.match(r"^[A-Z]{1,3}$", kdocs_image_column):
return jsonify({"error": "图片列格式错误"}), 400
if kdocs_admin_notify_enabled is not None:
if isinstance(kdocs_admin_notify_enabled, bool):
kdocs_admin_notify_enabled = 1 if kdocs_admin_notify_enabled else 0
if kdocs_admin_notify_enabled not in (0, 1):
return jsonify({"error": "管理员通知开关必须是0或1"}), 400
if kdocs_admin_notify_email is not None:
kdocs_admin_notify_email = str(kdocs_admin_notify_email or "").strip()
if kdocs_admin_notify_email:
is_valid, error_msg = validate_email(kdocs_admin_notify_email)
if not is_valid:
return jsonify({"error": error_msg}), 400
if kdocs_row_start is not None:
try:
kdocs_row_start = int(kdocs_row_start)
except (ValueError, TypeError):
return jsonify({"error": "起始行必须是数字"}), 400
if kdocs_row_start < 0:
return jsonify({"error": "起始行不能为负数"}), 400
if kdocs_row_end is not None:
try:
kdocs_row_end = int(kdocs_row_end)
except (ValueError, TypeError):
return jsonify({"error": "结束行必须是数字"}), 400
if kdocs_row_end < 0:
return jsonify({"error": "结束行不能为负数"}), 400
old_config = database.get_system_config() or {}
if not database.update_system_config(
max_concurrent=max_concurrent,
schedule_enabled=schedule_enabled,
schedule_time=schedule_time,
schedule_browse_type=schedule_browse_type,
schedule_weekdays=schedule_weekdays,
max_concurrent_per_account=new_max_concurrent_per_account,
max_screenshot_concurrent=new_max_screenshot_concurrent,
enable_screenshot=enable_screenshot,
auto_approve_enabled=auto_approve_enabled,
auto_approve_hourly_limit=auto_approve_hourly_limit,
auto_approve_vip_days=auto_approve_vip_days,
kdocs_enabled=kdocs_enabled,
kdocs_doc_url=kdocs_doc_url,
kdocs_default_unit=kdocs_default_unit,
kdocs_sheet_name=kdocs_sheet_name,
kdocs_sheet_index=kdocs_sheet_index,
kdocs_unit_column=kdocs_unit_column,
kdocs_image_column=kdocs_image_column,
kdocs_admin_notify_enabled=kdocs_admin_notify_enabled,
kdocs_admin_notify_email=kdocs_admin_notify_email,
kdocs_row_start=kdocs_row_start,
kdocs_row_end=kdocs_row_end,
):
return jsonify({"error": "更新失败"}), 400
try:
new_config = database.get_system_config() or {}
scheduler = get_task_scheduler()
scheduler.update_limits(
max_global=int(new_config.get("max_concurrent_global", old_config.get("max_concurrent_global", 2))),
max_per_user=int(new_config.get("max_concurrent_per_account", old_config.get("max_concurrent_per_account", 1))),
)
if new_max_screenshot_concurrent is not None:
try:
from browser_pool_worker import resize_browser_worker_pool
if resize_browser_worker_pool(int(new_config.get("max_screenshot_concurrent", new_max_screenshot_concurrent))):
logger.info(f"截图线程池并发已更新为: {new_config.get('max_screenshot_concurrent')}")
except Exception as pool_error:
logger.warning(f"截图线程池并发更新失败: {pool_error}")
except Exception:
pass
if max_concurrent is not None and max_concurrent != old_config.get("max_concurrent_global"):
logger.info(f"全局并发数已更新为: {max_concurrent}")
if new_max_concurrent_per_account is not None and new_max_concurrent_per_account != old_config.get("max_concurrent_per_account"):
logger.info(f"单用户并发数已更新为: {new_max_concurrent_per_account}")
if new_max_screenshot_concurrent is not None:
logger.info(f"截图并发数已更新为: {new_max_screenshot_concurrent}")
return jsonify({"message": "系统配置已更新"})

View File

@@ -0,0 +1,138 @@
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
from __future__ import annotations
import database
from app_logger import get_logger
from flask import jsonify, request
from routes.admin_api import admin_api_bp
from routes.decorators import admin_required
from services.state import safe_iter_task_status_items
from services.tasks import get_task_scheduler
logger = get_logger("app")
def _parse_page_int(name: str, default: int, *, minimum: int, maximum: int) -> int:
try:
value = int(request.args.get(name, default))
return max(minimum, min(value, maximum))
except (ValueError, TypeError):
return default
@admin_api_bp.route("/task/stats", methods=["GET"])
@admin_required
def get_task_stats_api():
"""获取任务统计数据"""
date_filter = request.args.get("date")
stats = database.get_task_stats(date_filter)
return jsonify(stats)
@admin_api_bp.route("/task/running", methods=["GET"])
@admin_required
def get_running_tasks_api():
"""获取当前运行中和排队中的任务"""
import time as time_mod
current_time = time_mod.time()
running = []
queuing = []
user_cache = {}
for account_id, info in safe_iter_task_status_items():
elapsed = int(current_time - info.get("start_time", current_time))
info_user_id = info.get("user_id")
if info_user_id not in user_cache:
user_cache[info_user_id] = database.get_user_by_id(info_user_id)
user = user_cache.get(info_user_id)
user_username = user["username"] if user else "N/A"
progress = info.get("progress", {"items": 0, "attachments": 0})
task_info = {
"account_id": account_id,
"user_id": info.get("user_id"),
"user_username": user_username,
"username": info.get("username"),
"browse_type": info.get("browse_type"),
"source": info.get("source", "manual"),
"detail_status": info.get("detail_status", "未知"),
"progress_items": progress.get("items", 0),
"progress_attachments": progress.get("attachments", 0),
"elapsed_seconds": elapsed,
"elapsed_display": f"{elapsed // 60}{elapsed % 60}" if elapsed >= 60 else f"{elapsed}",
}
if info.get("status") == "运行中":
running.append(task_info)
else:
queuing.append(task_info)
running.sort(key=lambda x: x["elapsed_seconds"], reverse=True)
queuing.sort(key=lambda x: x["elapsed_seconds"], reverse=True)
try:
max_concurrent = int(get_task_scheduler().max_global)
except Exception:
max_concurrent = int((database.get_system_config() or {}).get("max_concurrent_global", 2))
return jsonify(
{
"running": running,
"queuing": queuing,
"running_count": len(running),
"queuing_count": len(queuing),
"max_concurrent": max_concurrent,
}
)
@admin_api_bp.route("/task/logs", methods=["GET"])
@admin_required
def get_task_logs_api():
"""获取任务日志列表(支持分页和多种筛选)"""
limit = _parse_page_int("limit", 20, minimum=1, maximum=200)
offset = _parse_page_int("offset", 0, minimum=0, maximum=10**9)
date_filter = request.args.get("date")
status_filter = request.args.get("status")
source_filter = request.args.get("source")
user_id_filter = request.args.get("user_id")
account_filter = (request.args.get("account") or "").strip()
if user_id_filter:
try:
user_id_filter = int(user_id_filter)
except (ValueError, TypeError):
user_id_filter = None
try:
result = database.get_task_logs(
limit=limit,
offset=offset,
date_filter=date_filter,
status_filter=status_filter,
source_filter=source_filter,
user_id_filter=user_id_filter,
account_filter=account_filter if account_filter else None,
)
return jsonify(result)
except Exception as e:
logger.error(f"获取任务日志失败: {e}")
return jsonify({"logs": [], "total": 0, "error": "查询失败"})
@admin_api_bp.route("/task/logs/clear", methods=["POST"])
@admin_required
def clear_old_task_logs_api():
"""清理旧的任务日志"""
data = request.json or {}
days = data.get("days", 30)
if not isinstance(days, int) or days < 1:
return jsonify({"error": "天数必须是大于0的整数"}), 400
deleted_count = database.delete_old_task_logs(days)
return jsonify({"message": f"已删除{days}天前的{deleted_count}条日志"})

View File

@@ -0,0 +1,117 @@
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
from __future__ import annotations
import database
from flask import jsonify, request
from routes.admin_api import admin_api_bp
from routes.decorators import admin_required
from services.state import safe_clear_user_logs, safe_remove_user_accounts
# ==================== 用户管理/统计(管理员) ====================
@admin_api_bp.route("/users", methods=["GET"])
@admin_required
def get_all_users():
"""获取所有用户"""
users = database.get_all_users()
return jsonify(users)
@admin_api_bp.route("/users/pending", methods=["GET"])
@admin_required
def get_pending_users():
"""获取待审核用户"""
users = database.get_pending_users()
return jsonify(users)
@admin_api_bp.route("/users/<int:user_id>/approve", methods=["POST"])
@admin_required
def approve_user_route(user_id):
"""审核通过用户"""
if database.approve_user(user_id):
return jsonify({"success": True})
return jsonify({"error": "审核失败"}), 400
@admin_api_bp.route("/users/<int:user_id>/reject", methods=["POST"])
@admin_required
def reject_user_route(user_id):
"""拒绝用户"""
if database.reject_user(user_id):
return jsonify({"success": True})
return jsonify({"error": "拒绝失败"}), 400
@admin_api_bp.route("/users/<int:user_id>", methods=["DELETE"])
@admin_required
def delete_user_route(user_id):
"""删除用户"""
if database.delete_user(user_id):
safe_remove_user_accounts(user_id)
safe_clear_user_logs(user_id)
return jsonify({"success": True})
return jsonify({"error": "删除失败"}), 400
# ==================== VIP 管理(管理员) ====================
@admin_api_bp.route("/vip/config", methods=["GET"])
@admin_required
def get_vip_config_api():
"""获取VIP配置"""
config = database.get_vip_config()
return jsonify(config)
@admin_api_bp.route("/vip/config", methods=["POST"])
@admin_required
def set_vip_config_api():
"""设置默认VIP天数"""
data = request.json or {}
days = data.get("default_vip_days", 0)
if not isinstance(days, int) or days < 0:
return jsonify({"error": "VIP天数必须是非负整数"}), 400
database.set_default_vip_days(days)
return jsonify({"message": "VIP配置已更新", "default_vip_days": days})
@admin_api_bp.route("/users/<int:user_id>/vip", methods=["POST"])
@admin_required
def set_user_vip_api(user_id):
"""设置用户VIP"""
data = request.json or {}
days = data.get("days", 30)
valid_days = [7, 30, 365, 999999]
if days not in valid_days:
return jsonify({"error": "VIP天数必须是 7/30/365/999999 之一"}), 400
if database.set_user_vip(user_id, days):
vip_type = {7: "一周", 30: "一个月", 365: "一年", 999999: "永久"}[days]
return jsonify({"message": f"VIP设置成功: {vip_type}"})
return jsonify({"error": "设置失败,用户不存在"}), 400
@admin_api_bp.route("/users/<int:user_id>/vip", methods=["DELETE"])
@admin_required
def remove_user_vip_api(user_id):
"""移除用户VIP"""
if database.remove_user_vip(user_id):
return jsonify({"message": "VIP已移除"})
return jsonify({"error": "移除失败"}), 400
@admin_api_bp.route("/users/<int:user_id>/vip", methods=["GET"])
@admin_required
def get_user_vip_info_api(user_id):
"""获取用户VIP信息(管理员)"""
vip_info = database.get_user_vip_info(user_id)
return jsonify(vip_info)

View File

@@ -40,6 +40,48 @@ def _emit(event: str, data: object, *, room: str | None = None) -> None:
pass
def _emit_account_update(user_id: int, account) -> None:
_emit("account_update", account.to_dict(), room=f"user_{user_id}")
def _request_json(default=None):
if default is None:
default = {}
data = request.get_json(silent=True)
return data if isinstance(data, dict) else default
def _ensure_accounts_loaded(user_id: int) -> dict:
accounts = safe_get_user_accounts_snapshot(user_id)
if accounts:
return accounts
load_user_accounts(user_id)
return safe_get_user_accounts_snapshot(user_id)
def _get_user_account(user_id: int, account_id: str, *, refresh_if_missing: bool = False):
account = safe_get_account(user_id, account_id)
if account or (not refresh_if_missing):
return account
load_user_accounts(user_id)
return safe_get_account(user_id, account_id)
def _validate_browse_type_input(raw_browse_type, *, default=BROWSE_TYPE_SHOULD_READ):
browse_type = validate_browse_type(raw_browse_type, default=default)
if not browse_type:
return None, (jsonify({"error": "浏览类型无效"}), 400)
return browse_type, None
def _cancel_pending_account_task(user_id: int, account_id: str) -> bool:
try:
scheduler = get_task_scheduler()
return bool(scheduler.cancel_pending_task(user_id=user_id, account_id=account_id))
except Exception:
return False
@api_accounts_bp.route("/api/accounts", methods=["GET"])
@login_required
def get_accounts():
@@ -49,8 +91,7 @@ def get_accounts():
accounts = safe_get_user_accounts_snapshot(user_id)
if refresh or not accounts:
load_user_accounts(user_id)
accounts = safe_get_user_accounts_snapshot(user_id)
accounts = _ensure_accounts_loaded(user_id)
return jsonify([acc.to_dict() for acc in accounts.values()])
@@ -63,20 +104,18 @@ def add_account():
current_count = len(database.get_user_accounts(user_id))
is_vip = database.is_user_vip(user_id)
if not is_vip and current_count >= 3:
if (not is_vip) and current_count >= 3:
return jsonify({"error": "普通用户最多添加3个账号升级VIP可无限添加"}), 403
data = request.json
username = data.get("username", "").strip()
password = data.get("password", "").strip()
remark = data.get("remark", "").strip()[:200]
data = _request_json()
username = str(data.get("username", "")).strip()
password = str(data.get("password", "")).strip()
remark = str(data.get("remark", "")).strip()[:200]
if not username or not password:
return jsonify({"error": "用户名和密码不能为空"}), 400
accounts = safe_get_user_accounts_snapshot(user_id)
if not accounts:
load_user_accounts(user_id)
accounts = safe_get_user_accounts_snapshot(user_id)
accounts = _ensure_accounts_loaded(user_id)
for acc in accounts.values():
if acc.username == username:
return jsonify({"error": f"账号 '{username}' 已存在"}), 400
@@ -92,7 +131,7 @@ def add_account():
safe_set_account(user_id, account_id, account)
log_to_client(f"添加账号: {username}", user_id)
_emit("account_update", account.to_dict(), room=f"user_{user_id}")
_emit_account_update(user_id, account)
return jsonify(account.to_dict())
@@ -103,15 +142,15 @@ def update_account(account_id):
"""更新账号信息(密码等)"""
user_id = current_user.id
account = safe_get_account(user_id, account_id)
account = _get_user_account(user_id, account_id)
if not account:
return jsonify({"error": "账号不存在"}), 404
if account.is_running:
return jsonify({"error": "账号正在运行中,请先停止"}), 400
data = request.json
new_password = data.get("password", "").strip()
data = _request_json()
new_password = str(data.get("password", "")).strip()
new_remember = data.get("remember", account.remember)
if not new_password:
@@ -147,7 +186,7 @@ def delete_account(account_id):
"""删除账号"""
user_id = current_user.id
account = safe_get_account(user_id, account_id)
account = _get_user_account(user_id, account_id)
if not account:
return jsonify({"error": "账号不存在"}), 404
@@ -159,7 +198,6 @@ def delete_account(account_id):
username = account.username
database.delete_account(account_id)
safe_remove_account(user_id, account_id)
log_to_client(f"删除账号: {username}", user_id)
@@ -196,12 +234,12 @@ def update_remark(account_id):
"""更新备注"""
user_id = current_user.id
account = safe_get_account(user_id, account_id)
account = _get_user_account(user_id, account_id)
if not account:
return jsonify({"error": "账号不存在"}), 404
data = request.json
remark = data.get("remark", "").strip()[:200]
data = _request_json()
remark = str(data.get("remark", "")).strip()[:200]
database.update_account_remark(account_id, remark)
@@ -217,17 +255,18 @@ def start_account(account_id):
"""启动账号任务"""
user_id = current_user.id
account = safe_get_account(user_id, account_id)
account = _get_user_account(user_id, account_id)
if not account:
return jsonify({"error": "账号不存在"}), 404
if account.is_running:
return jsonify({"error": "任务已在运行中"}), 400
data = request.json or {}
browse_type = validate_browse_type(data.get("browse_type"), default=BROWSE_TYPE_SHOULD_READ)
if not browse_type:
return jsonify({"error": "浏览类型无效"}), 400
data = _request_json()
browse_type, browse_error = _validate_browse_type_input(data.get("browse_type"), default=BROWSE_TYPE_SHOULD_READ)
if browse_error:
return browse_error
enable_screenshot = data.get("enable_screenshot", True)
ok, message = submit_account_task(
user_id=user_id,
@@ -249,7 +288,7 @@ def stop_account(account_id):
"""停止账号任务"""
user_id = current_user.id
account = safe_get_account(user_id, account_id)
account = _get_user_account(user_id, account_id)
if not account:
return jsonify({"error": "账号不存在"}), 404
@@ -259,20 +298,16 @@ def stop_account(account_id):
account.should_stop = True
account.status = "正在停止"
try:
scheduler = get_task_scheduler()
if scheduler.cancel_pending_task(user_id=user_id, account_id=account_id):
account.status = "已停止"
account.is_running = False
safe_remove_task_status(account_id)
_emit("account_update", account.to_dict(), room=f"user_{user_id}")
log_to_client(f"任务已取消: {account.username}", user_id)
return jsonify({"success": True, "canceled": True})
except Exception:
pass
if _cancel_pending_account_task(user_id, account_id):
account.status = "已停止"
account.is_running = False
safe_remove_task_status(account_id)
_emit_account_update(user_id, account)
log_to_client(f"任务已取消: {account.username}", user_id)
return jsonify({"success": True, "canceled": True})
log_to_client(f"停止任务: {account.username}", user_id)
_emit("account_update", account.to_dict(), room=f"user_{user_id}")
_emit_account_update(user_id, account)
return jsonify({"success": True})
@@ -283,23 +318,20 @@ def manual_screenshot(account_id):
"""手动为指定账号截图"""
user_id = current_user.id
account = safe_get_account(user_id, account_id)
if not account:
load_user_accounts(user_id)
account = safe_get_account(user_id, account_id)
account = _get_user_account(user_id, account_id, refresh_if_missing=True)
if not account:
return jsonify({"error": "账号不存在"}), 404
if account.is_running:
return jsonify({"error": "任务运行中,无法截图"}), 400
data = request.json or {}
data = _request_json()
requested_browse_type = data.get("browse_type", None)
if requested_browse_type is None:
browse_type = normalize_browse_type(account.last_browse_type)
else:
browse_type = validate_browse_type(requested_browse_type, default=BROWSE_TYPE_SHOULD_READ)
if not browse_type:
return jsonify({"error": "浏览类型无效"}), 400
browse_type, browse_error = _validate_browse_type_input(requested_browse_type, default=BROWSE_TYPE_SHOULD_READ)
if browse_error:
return browse_error
account.last_browse_type = browse_type
@@ -317,12 +349,16 @@ def manual_screenshot(account_id):
def batch_start_accounts():
"""批量启动账号"""
user_id = current_user.id
data = request.json or {}
data = _request_json()
account_ids = data.get("account_ids", [])
browse_type = validate_browse_type(data.get("browse_type", BROWSE_TYPE_SHOULD_READ), default=BROWSE_TYPE_SHOULD_READ)
if not browse_type:
return jsonify({"error": "浏览类型无效"}), 400
browse_type, browse_error = _validate_browse_type_input(
data.get("browse_type", BROWSE_TYPE_SHOULD_READ),
default=BROWSE_TYPE_SHOULD_READ,
)
if browse_error:
return browse_error
enable_screenshot = data.get("enable_screenshot", True)
if not account_ids:
@@ -331,11 +367,10 @@ def batch_start_accounts():
started = []
failed = []
if not safe_get_user_accounts_snapshot(user_id):
load_user_accounts(user_id)
_ensure_accounts_loaded(user_id)
for account_id in account_ids:
account = safe_get_account(user_id, account_id)
account = _get_user_account(user_id, account_id)
if not account:
failed.append({"id": account_id, "reason": "账号不存在"})
continue
@@ -357,7 +392,13 @@ def batch_start_accounts():
failed.append({"id": account_id, "reason": msg})
return jsonify(
{"success": True, "started_count": len(started), "failed_count": len(failed), "started": started, "failed": failed}
{
"success": True,
"started_count": len(started),
"failed_count": len(failed),
"started": started,
"failed": failed,
}
)
@@ -366,39 +407,29 @@ def batch_start_accounts():
def batch_stop_accounts():
"""批量停止账号"""
user_id = current_user.id
data = request.json
data = _request_json()
account_ids = data.get("account_ids", [])
if not account_ids:
return jsonify({"error": "请选择要停止的账号"}), 400
stopped = []
if not safe_get_user_accounts_snapshot(user_id):
load_user_accounts(user_id)
_ensure_accounts_loaded(user_id)
for account_id in account_ids:
account = safe_get_account(user_id, account_id)
if not account:
continue
if not account.is_running:
account = _get_user_account(user_id, account_id)
if (not account) or (not account.is_running):
continue
account.should_stop = True
account.status = "正在停止"
stopped.append(account_id)
try:
scheduler = get_task_scheduler()
if scheduler.cancel_pending_task(user_id=user_id, account_id=account_id):
account.status = "已停止"
account.is_running = False
safe_remove_task_status(account_id)
except Exception:
pass
if _cancel_pending_account_task(user_id, account_id):
account.status = "已停止"
account.is_running = False
safe_remove_task_status(account_id)
_emit("account_update", account.to_dict(), room=f"user_{user_id}")
_emit_account_update(user_id, account)
return jsonify({"success": True, "stopped_count": len(stopped), "stopped": stopped})

View File

@@ -2,16 +2,20 @@
# -*- coding: utf-8 -*-
from __future__ import annotations
import base64
import random
import secrets
import threading
import time
import uuid
from io import BytesIO
import database
import email_service
from app_config import get_config
from app_logger import get_logger
from app_security import get_rate_limit_ip, require_ip_not_locked, validate_email, validate_password, validate_username
from flask import Blueprint, jsonify, redirect, render_template, request, url_for
from flask import Blueprint, jsonify, request
from flask_login import login_required, login_user, logout_user
from routes.pages import render_app_spa_or_legacy
from services.accounts_service import load_user_accounts
@@ -39,12 +43,162 @@ config = get_config()
api_auth_bp = Blueprint("api_auth", __name__)
_CAPTCHA_FONT_PATHS = [
"/usr/share/fonts/truetype/liberation/LiberationSans-Bold.ttf",
"/usr/share/fonts/truetype/dejavu/DejaVuSans-Bold.ttf",
"/usr/share/fonts/truetype/freefont/FreeSansBold.ttf",
]
_CAPTCHA_FONT = None
_CAPTCHA_FONT_LOCK = threading.Lock()
def _get_json_payload() -> dict:
data = request.get_json(silent=True)
return data if isinstance(data, dict) else {}
def _load_captcha_font(image_font_module):
global _CAPTCHA_FONT
if _CAPTCHA_FONT is not None:
return _CAPTCHA_FONT
with _CAPTCHA_FONT_LOCK:
if _CAPTCHA_FONT is not None:
return _CAPTCHA_FONT
for font_path in _CAPTCHA_FONT_PATHS:
try:
_CAPTCHA_FONT = image_font_module.truetype(font_path, 42)
break
except Exception:
continue
if _CAPTCHA_FONT is None:
_CAPTCHA_FONT = image_font_module.load_default()
return _CAPTCHA_FONT
def _generate_captcha_image_data_uri(code: str) -> str:
from PIL import Image, ImageDraw, ImageFont
width, height = 160, 60
image = Image.new("RGB", (width, height), color=(255, 255, 255))
draw = ImageDraw.Draw(image)
for _ in range(6):
x1 = random.randint(0, width)
y1 = random.randint(0, height)
x2 = random.randint(0, width)
y2 = random.randint(0, height)
draw.line(
[(x1, y1), (x2, y2)],
fill=(random.randint(0, 200), random.randint(0, 200), random.randint(0, 200)),
width=1,
)
for _ in range(80):
x = random.randint(0, width)
y = random.randint(0, height)
draw.point((x, y), fill=(random.randint(0, 200), random.randint(0, 200), random.randint(0, 200)))
font = _load_captcha_font(ImageFont)
for i, char in enumerate(code):
x = 12 + i * 35 + random.randint(-3, 3)
y = random.randint(5, 12)
color = (random.randint(0, 150), random.randint(0, 150), random.randint(0, 150))
draw.text((x, y), char, font=font, fill=color)
buffer = BytesIO()
image.save(buffer, format="PNG")
img_base64 = base64.b64encode(buffer.getvalue()).decode("utf-8")
return f"data:image/png;base64,{img_base64}"
def _with_vip_suffix(message: str, auto_approve_enabled: bool, auto_approve_vip_days: int) -> str:
if auto_approve_enabled and auto_approve_vip_days > 0:
return f"{message},赠送{auto_approve_vip_days}天VIP"
return message
def _verify_common_captcha(client_ip: str, captcha_session: str, captcha_code: str):
success, message = safe_verify_and_consume_captcha(captcha_session, captcha_code)
if success:
return True, None
is_locked = record_failed_captcha(client_ip)
if is_locked:
return False, (jsonify({"error": "验证码错误次数过多,IP已被锁定1小时"}), 429)
return False, (jsonify({"error": message}), 400)
def _verify_login_captcha_if_needed(
*,
captcha_required: bool,
captcha_session: str,
captcha_code: str,
client_ip: str,
username_key: str,
):
if not captcha_required:
return True, None
if not captcha_session or not captcha_code:
return False, (jsonify({"error": "请填写验证码", "need_captcha": True}), 400)
success, message = safe_verify_and_consume_captcha(captcha_session, captcha_code)
if success:
return True, None
record_login_failure(client_ip, username_key)
return False, (jsonify({"error": message, "need_captcha": True}), 400)
def _send_password_reset_email_if_possible(email: str, username: str, user_id: int) -> None:
result = email_service.send_password_reset_email(email=email, username=username, user_id=user_id)
if not result["success"]:
logger.error(f"密码重置邮件发送失败: {result['error']}")
def _send_login_security_alert_if_needed(user: dict, username: str, client_ip: str) -> None:
try:
user_agent = request.headers.get("User-Agent", "")
context = database.record_login_context(user["id"], client_ip, user_agent)
if not context or (not context.get("new_ip") and not context.get("new_device")):
return
if not config.LOGIN_ALERT_ENABLED:
return
if not should_send_login_alert(user["id"], client_ip):
return
if not email_service.get_email_settings().get("login_alert_enabled", True):
return
user_info = database.get_user_by_id(user["id"]) or {}
if (not user_info.get("email")) or (not user_info.get("email_verified")):
return
if not database.get_user_email_notify(user["id"]):
return
email_service.send_security_alert_email(
email=user_info.get("email"),
username=user_info.get("username") or username,
ip_address=client_ip,
user_agent=user_agent,
new_ip=context.get("new_ip", False),
new_device=context.get("new_device", False),
user_id=user["id"],
)
except Exception:
pass
@api_auth_bp.route("/api/register", methods=["POST"])
@require_ip_not_locked
def register():
"""用户注册"""
data = request.json or {}
data = _get_json_payload()
username = data.get("username", "").strip()
password = data.get("password", "").strip()
email = data.get("email", "").strip().lower()
@@ -67,12 +221,9 @@ def register():
if not allowed:
return jsonify({"error": error_msg}), 429
success, message = safe_verify_and_consume_captcha(captcha_session, captcha_code)
if not success:
is_locked = record_failed_captcha(client_ip)
if is_locked:
return jsonify({"error": "验证码错误次数过多,IP已被锁定1小时"}), 429
return jsonify({"error": message}), 400
captcha_ok, captcha_error_response = _verify_common_captcha(client_ip, captcha_session, captcha_code)
if not captcha_ok:
return captcha_error_response
email_settings = email_service.get_email_settings()
email_verify_enabled = email_settings.get("register_verify_enabled", False) and email_settings.get("enabled", False)
@@ -105,20 +256,22 @@ def register():
if email_verify_enabled and email:
result = email_service.send_register_verification_email(email=email, username=username, user_id=user_id)
if result["success"]:
message = "注册成功!验证邮件已发送(可直接登录,建议完成邮箱验证)"
if auto_approve_enabled and auto_approve_vip_days > 0:
message += f",赠送{auto_approve_vip_days}天VIP"
message = _with_vip_suffix(
"注册成功!验证邮件已发送(可直接登录,建议完成邮箱验证)",
auto_approve_enabled,
auto_approve_vip_days,
)
return jsonify({"success": True, "message": message, "need_verify": True})
logger.error(f"注册验证邮件发送失败: {result['error']}")
message = f"注册成功,但验证邮件发送失败({result['error']})。你仍可直接登录"
if auto_approve_enabled and auto_approve_vip_days > 0:
message += f",赠送{auto_approve_vip_days}天VIP"
message = _with_vip_suffix(
f"注册成功,但验证邮件发送失败({result['error']})。你仍可直接登录",
auto_approve_enabled,
auto_approve_vip_days,
)
return jsonify({"success": True, "message": message, "need_verify": True})
message = "注册成功!可直接登录"
if auto_approve_enabled and auto_approve_vip_days > 0:
message += f",赠送{auto_approve_vip_days}天VIP"
message = _with_vip_suffix("注册成功!可直接登录", auto_approve_enabled, auto_approve_vip_days)
return jsonify({"success": True, "message": message})
return jsonify({"error": "用户名已存在"}), 400
@@ -175,7 +328,7 @@ def verify_email(token):
@require_ip_not_locked
def resend_verify_email():
"""重发验证邮件"""
data = request.json or {}
data = _get_json_payload()
email = data.get("email", "").strip().lower()
captcha_session = data.get("captcha_session", "")
captcha_code = data.get("captcha", "").strip()
@@ -195,12 +348,9 @@ def resend_verify_email():
if not allowed:
return jsonify({"error": error_msg}), 429
success, message = safe_verify_and_consume_captcha(captcha_session, captcha_code)
if not success:
is_locked = record_failed_captcha(client_ip)
if is_locked:
return jsonify({"error": "验证码错误次数过多,IP已被锁定1小时"}), 429
return jsonify({"error": message}), 400
captcha_ok, captcha_error_response = _verify_common_captcha(client_ip, captcha_session, captcha_code)
if not captcha_ok:
return captcha_error_response
user = database.get_user_by_email(email)
if not user:
@@ -235,7 +385,7 @@ def get_email_verify_status():
@require_ip_not_locked
def forgot_password():
"""发送密码重置邮件"""
data = request.json or {}
data = _get_json_payload()
email = data.get("email", "").strip().lower()
username = data.get("username", "").strip()
captcha_session = data.get("captcha_session", "")
@@ -263,12 +413,9 @@ def forgot_password():
if not allowed:
return jsonify({"error": error_msg}), 429
success, message = safe_verify_and_consume_captcha(captcha_session, captcha_code)
if not success:
is_locked = record_failed_captcha(client_ip)
if is_locked:
return jsonify({"error": "验证码错误次数过多,IP已被锁定1小时"}), 429
return jsonify({"error": message}), 400
captcha_ok, captcha_error_response = _verify_common_captcha(client_ip, captcha_session, captcha_code)
if not captcha_ok:
return captcha_error_response
email_settings = email_service.get_email_settings()
if not email_settings.get("enabled", False):
@@ -293,20 +440,16 @@ def forgot_password():
if not allowed:
return jsonify({"error": error_msg}), 429
result = email_service.send_password_reset_email(
_send_password_reset_email_if_possible(
email=bound_email,
username=user["username"],
user_id=user["id"],
)
if not result["success"]:
logger.error(f"密码重置邮件发送失败: {result['error']}")
return jsonify({"success": True, "message": "如果该账号已绑定邮箱,您将收到密码重置邮件"})
user = database.get_user_by_email(email)
if user and user.get("status") == "approved":
result = email_service.send_password_reset_email(email=email, username=user["username"], user_id=user["id"])
if not result["success"]:
logger.error(f"密码重置邮件发送失败: {result['error']}")
_send_password_reset_email_if_possible(email=email, username=user["username"], user_id=user["id"])
return jsonify({"success": True, "message": "如果该邮箱已注册,您将收到密码重置邮件"})
@@ -331,7 +474,7 @@ def reset_password_page(token):
@api_auth_bp.route("/api/reset-password-confirm", methods=["POST"])
def reset_password_confirm():
"""确认密码重置"""
data = request.json or {}
data = _get_json_payload()
token = data.get("token", "").strip()
new_password = data.get("new_password", "").strip()
@@ -356,67 +499,15 @@ def reset_password_confirm():
@api_auth_bp.route("/api/generate_captcha", methods=["POST"])
def generate_captcha():
"""生成4位数字验证码图片"""
import base64
import uuid
from io import BytesIO
session_id = str(uuid.uuid4())
code = "".join([str(secrets.randbelow(10)) for _ in range(4)])
code = "".join(str(secrets.randbelow(10)) for _ in range(4))
safe_set_captcha(session_id, {"code": code, "expire_time": time.time() + 300, "failed_attempts": 0})
safe_cleanup_expired_captcha()
try:
from PIL import Image, ImageDraw, ImageFont
import io
width, height = 160, 60
image = Image.new("RGB", (width, height), color=(255, 255, 255))
draw = ImageDraw.Draw(image)
for _ in range(6):
x1 = random.randint(0, width)
y1 = random.randint(0, height)
x2 = random.randint(0, width)
y2 = random.randint(0, height)
draw.line(
[(x1, y1), (x2, y2)],
fill=(random.randint(0, 200), random.randint(0, 200), random.randint(0, 200)),
width=1,
)
for _ in range(80):
x = random.randint(0, width)
y = random.randint(0, height)
draw.point((x, y), fill=(random.randint(0, 200), random.randint(0, 200), random.randint(0, 200)))
font = None
font_paths = [
"/usr/share/fonts/truetype/liberation/LiberationSans-Bold.ttf",
"/usr/share/fonts/truetype/dejavu/DejaVuSans-Bold.ttf",
"/usr/share/fonts/truetype/freefont/FreeSansBold.ttf",
]
for font_path in font_paths:
try:
font = ImageFont.truetype(font_path, 42)
break
except Exception:
continue
if font is None:
font = ImageFont.load_default()
for i, char in enumerate(code):
x = 12 + i * 35 + random.randint(-3, 3)
y = random.randint(5, 12)
color = (random.randint(0, 150), random.randint(0, 150), random.randint(0, 150))
draw.text((x, y), char, font=font, fill=color)
buffer = io.BytesIO()
image.save(buffer, format="PNG")
img_base64 = base64.b64encode(buffer.getvalue()).decode("utf-8")
return jsonify({"session_id": session_id, "captcha_image": f"data:image/png;base64,{img_base64}"})
captcha_image = _generate_captcha_image_data_uri(code)
return jsonify({"session_id": session_id, "captcha_image": captcha_image})
except ImportError as e:
logger.error(f"PIL库未安装验证码功能不可用: {e}")
safe_delete_captcha(session_id)
@@ -427,7 +518,7 @@ def generate_captcha():
@require_ip_not_locked
def login():
"""用户登录"""
data = request.json or {}
data = _get_json_payload()
username = data.get("username", "").strip()
password = data.get("password", "").strip()
captcha_session = data.get("captcha_session", "")
@@ -452,13 +543,15 @@ def login():
return jsonify({"error": error_msg, "need_captcha": True}), 429
captcha_required = check_login_captcha_required(client_ip, username_key) or scan_locked or bool(need_captcha)
if captcha_required:
if not captcha_session or not captcha_code:
return jsonify({"error": "请填写验证码", "need_captcha": True}), 400
success, message = safe_verify_and_consume_captcha(captcha_session, captcha_code)
if not success:
record_login_failure(client_ip, username_key)
return jsonify({"error": message, "need_captcha": True}), 400
captcha_ok, captcha_error_response = _verify_login_captcha_if_needed(
captcha_required=captcha_required,
captcha_session=captcha_session,
captcha_code=captcha_code,
client_ip=client_ip,
username_key=username_key,
)
if not captcha_ok:
return captcha_error_response
user = database.verify_user(username, password)
if not user:
@@ -476,29 +569,7 @@ def login():
login_user(user_obj)
load_user_accounts(user["id"])
try:
user_agent = request.headers.get("User-Agent", "")
context = database.record_login_context(user["id"], client_ip, user_agent)
if context and (context.get("new_ip") or context.get("new_device")):
if (
config.LOGIN_ALERT_ENABLED
and should_send_login_alert(user["id"], client_ip)
and email_service.get_email_settings().get("login_alert_enabled", True)
):
user_info = database.get_user_by_id(user["id"]) or {}
if user_info.get("email") and user_info.get("email_verified"):
if database.get_user_email_notify(user["id"]):
email_service.send_security_alert_email(
email=user_info.get("email"),
username=user_info.get("username") or username,
ip_address=client_ip,
user_agent=user_agent,
new_ip=context.get("new_ip", False),
new_device=context.get("new_device", False),
user_id=user["id"],
)
except Exception:
pass
_send_login_security_alert_if_needed(user=user, username=username, client_ip=client_ip)
return jsonify({"success": True})

View File

@@ -2,7 +2,11 @@
# -*- coding: utf-8 -*-
from __future__ import annotations
import json
import re
import threading
import time as time_mod
import uuid
import database
from flask import Blueprint, jsonify, request
@@ -17,6 +21,13 @@ api_schedules_bp = Blueprint("api_schedules", __name__)
_HHMM_RE = re.compile(r"^(\d{1,2}):(\d{2})$")
def _request_json(default=None):
if default is None:
default = {}
data = request.get_json(silent=True)
return data if isinstance(data, dict) else default
def _normalize_hhmm(value: object) -> str | None:
match = _HHMM_RE.match(str(value or "").strip())
if not match:
@@ -28,18 +39,53 @@ def _normalize_hhmm(value: object) -> str | None:
return f"{hour:02d}:{minute:02d}"
def _normalize_random_delay(value) -> tuple[int | None, str | None]:
try:
normalized = int(value or 0)
except Exception:
return None, "random_delay必须是0或1"
if normalized not in (0, 1):
return None, "random_delay必须是0或1"
return normalized, None
def _parse_schedule_account_ids(raw_value) -> list:
try:
parsed = json.loads(raw_value or "[]")
except (json.JSONDecodeError, TypeError):
return []
return parsed if isinstance(parsed, list) else []
def _get_owned_schedule_or_error(schedule_id: int):
schedule = database.get_schedule_by_id(schedule_id)
if not schedule:
return None, (jsonify({"error": "定时任务不存在"}), 404)
if schedule.get("user_id") != current_user.id:
return None, (jsonify({"error": "无权访问"}), 403)
return schedule, None
def _ensure_user_accounts_loaded(user_id: int) -> None:
if safe_get_user_accounts_snapshot(user_id):
return
load_user_accounts(user_id)
def _parse_browse_type_or_error(raw_value, *, default=BROWSE_TYPE_SHOULD_READ):
browse_type = validate_browse_type(raw_value, default=default)
if not browse_type:
return None, (jsonify({"error": "浏览类型无效"}), 400)
return browse_type, None
@api_schedules_bp.route("/api/schedules", methods=["GET"])
@login_required
def get_user_schedules_api():
"""获取当前用户的所有定时任务"""
schedules = database.get_user_schedules(current_user.id)
import json
for s in schedules:
try:
s["account_ids"] = json.loads(s.get("account_ids", "[]") or "[]")
except (json.JSONDecodeError, TypeError):
s["account_ids"] = []
for schedule in schedules:
schedule["account_ids"] = _parse_schedule_account_ids(schedule.get("account_ids"))
return jsonify(schedules)
@@ -47,23 +93,26 @@ def get_user_schedules_api():
@login_required
def create_user_schedule_api():
"""创建用户定时任务"""
data = request.json or {}
data = _request_json()
name = data.get("name", "我的定时任务")
schedule_time = data.get("schedule_time", "08:00")
weekdays = data.get("weekdays", "1,2,3,4,5")
browse_type = validate_browse_type(data.get("browse_type", BROWSE_TYPE_SHOULD_READ), default=BROWSE_TYPE_SHOULD_READ)
if not browse_type:
return jsonify({"error": "浏览类型无效"}), 400
browse_type, browse_error = _parse_browse_type_or_error(data.get("browse_type", BROWSE_TYPE_SHOULD_READ))
if browse_error:
return browse_error
enable_screenshot = data.get("enable_screenshot", 1)
random_delay = int(data.get("random_delay", 0) or 0)
random_delay, delay_error = _normalize_random_delay(data.get("random_delay", 0))
if delay_error:
return jsonify({"error": delay_error}), 400
account_ids = data.get("account_ids", [])
normalized_time = _normalize_hhmm(schedule_time)
if not normalized_time:
return jsonify({"error": "时间格式不正确,应为 HH:MM"}), 400
if random_delay not in (0, 1):
return jsonify({"error": "random_delay必须是0或1"}), 400
schedule_id = database.create_user_schedule(
user_id=current_user.id,
@@ -85,18 +134,11 @@ def create_user_schedule_api():
@login_required
def get_schedule_detail_api(schedule_id):
"""获取定时任务详情"""
schedule = database.get_schedule_by_id(schedule_id)
if not schedule:
return jsonify({"error": "定时任务不存在"}), 404
if schedule["user_id"] != current_user.id:
return jsonify({"error": "无权访问"}), 403
schedule, error_response = _get_owned_schedule_or_error(schedule_id)
if error_response:
return error_response
import json
try:
schedule["account_ids"] = json.loads(schedule.get("account_ids", "[]") or "[]")
except (json.JSONDecodeError, TypeError):
schedule["account_ids"] = []
schedule["account_ids"] = _parse_schedule_account_ids(schedule.get("account_ids"))
return jsonify(schedule)
@@ -104,14 +146,12 @@ def get_schedule_detail_api(schedule_id):
@login_required
def update_schedule_api(schedule_id):
"""更新定时任务"""
schedule = database.get_schedule_by_id(schedule_id)
if not schedule:
return jsonify({"error": "定时任务不存在"}), 404
if schedule["user_id"] != current_user.id:
return jsonify({"error": "无权访问"}), 403
_, error_response = _get_owned_schedule_or_error(schedule_id)
if error_response:
return error_response
data = request.json or {}
allowed_fields = [
data = _request_json()
allowed_fields = {
"name",
"schedule_time",
"weekdays",
@@ -120,27 +160,26 @@ def update_schedule_api(schedule_id):
"random_delay",
"account_ids",
"enabled",
]
update_data = {k: v for k, v in data.items() if k in allowed_fields}
}
update_data = {key: value for key, value in data.items() if key in allowed_fields}
if "schedule_time" in update_data:
normalized_time = _normalize_hhmm(update_data["schedule_time"])
if not normalized_time:
return jsonify({"error": "时间格式不正确,应为 HH:MM"}), 400
update_data["schedule_time"] = normalized_time
if "random_delay" in update_data:
try:
update_data["random_delay"] = int(update_data.get("random_delay") or 0)
except Exception:
return jsonify({"error": "random_delay必须是0或1"}), 400
if update_data["random_delay"] not in (0, 1):
return jsonify({"error": "random_delay必须是0或1"}), 400
random_delay, delay_error = _normalize_random_delay(update_data.get("random_delay"))
if delay_error:
return jsonify({"error": delay_error}), 400
update_data["random_delay"] = random_delay
if "browse_type" in update_data:
normalized = validate_browse_type(update_data.get("browse_type"), default=BROWSE_TYPE_SHOULD_READ)
if not normalized:
return jsonify({"error": "浏览类型无效"}), 400
update_data["browse_type"] = normalized
normalized_browse_type, browse_error = _parse_browse_type_or_error(update_data.get("browse_type"))
if browse_error:
return browse_error
update_data["browse_type"] = normalized_browse_type
success = database.update_user_schedule(schedule_id, **update_data)
if success:
@@ -152,11 +191,9 @@ def update_schedule_api(schedule_id):
@login_required
def delete_schedule_api(schedule_id):
"""删除定时任务"""
schedule = database.get_schedule_by_id(schedule_id)
if not schedule:
return jsonify({"error": "定时任务不存在"}), 404
if schedule["user_id"] != current_user.id:
return jsonify({"error": "无权访问"}), 403
_, error_response = _get_owned_schedule_or_error(schedule_id)
if error_response:
return error_response
success = database.delete_user_schedule(schedule_id)
if success:
@@ -168,13 +205,11 @@ def delete_schedule_api(schedule_id):
@login_required
def toggle_schedule_api(schedule_id):
"""启用/禁用定时任务"""
schedule = database.get_schedule_by_id(schedule_id)
if not schedule:
return jsonify({"error": "定时任务不存在"}), 404
if schedule["user_id"] != current_user.id:
return jsonify({"error": "无权访问"}), 403
schedule, error_response = _get_owned_schedule_or_error(schedule_id)
if error_response:
return error_response
data = request.json
data = _request_json()
enabled = data.get("enabled", not schedule["enabled"])
success = database.toggle_user_schedule(schedule_id, enabled)
@@ -187,22 +222,11 @@ def toggle_schedule_api(schedule_id):
@login_required
def run_schedule_now_api(schedule_id):
"""立即执行定时任务"""
import json
import threading
import time as time_mod
import uuid
schedule = database.get_schedule_by_id(schedule_id)
if not schedule:
return jsonify({"error": "定时任务不存在"}), 404
if schedule["user_id"] != current_user.id:
return jsonify({"error": "无权访问"}), 403
try:
account_ids = json.loads(schedule.get("account_ids", "[]") or "[]")
except (json.JSONDecodeError, TypeError):
account_ids = []
schedule, error_response = _get_owned_schedule_or_error(schedule_id)
if error_response:
return error_response
account_ids = _parse_schedule_account_ids(schedule.get("account_ids"))
if not account_ids:
return jsonify({"error": "没有配置账号"}), 400
@@ -210,8 +234,7 @@ def run_schedule_now_api(schedule_id):
browse_type = normalize_browse_type(schedule.get("browse_type", BROWSE_TYPE_SHOULD_READ))
enable_screenshot = schedule["enable_screenshot"]
if not safe_get_user_accounts_snapshot(user_id):
load_user_accounts(user_id)
_ensure_user_accounts_loaded(user_id)
from services.state import safe_create_batch, safe_finalize_batch_after_dispatch
from services.task_batches import _send_batch_task_email_if_configured
@@ -250,6 +273,7 @@ def run_schedule_now_api(schedule_id):
if remaining["done"] or remaining["count"] > 0:
return
remaining["done"] = True
execution_duration = int(time_mod.time() - execution_start_time)
database.update_schedule_execution_log(
log_id,
@@ -260,19 +284,17 @@ def run_schedule_now_api(schedule_id):
status="completed",
)
task_source = f"user_scheduled:{batch_id}"
for account_id in account_ids:
account = safe_get_account(user_id, account_id)
if not account:
skipped_count += 1
continue
if account.is_running:
if (not account) or account.is_running:
skipped_count += 1
continue
task_source = f"user_scheduled:{batch_id}"
with completion_lock:
remaining["count"] += 1
ok, msg = submit_account_task(
ok, _ = submit_account_task(
user_id=user_id,
account_id=account_id,
browse_type=browse_type,

View File

@@ -4,6 +4,7 @@ from __future__ import annotations
import os
from datetime import datetime
from typing import Iterator
import database
from app_config import get_config
@@ -15,41 +16,67 @@ from services.time_utils import BEIJING_TZ
config = get_config()
SCREENSHOTS_DIR = config.SCREENSHOTS_DIR
_IMAGE_EXTENSIONS = (".png", ".jpg", ".jpeg")
api_screenshots_bp = Blueprint("api_screenshots", __name__)
def _get_user_prefix(user_id: int) -> str:
user_info = database.get_user_by_id(user_id)
return user_info["username"] if user_info else f"user{user_id}"
def _is_user_screenshot(filename: str, username_prefix: str) -> bool:
return filename.startswith(username_prefix + "_") and filename.lower().endswith(_IMAGE_EXTENSIONS)
def _iter_user_screenshot_entries(username_prefix: str) -> Iterator[os.DirEntry]:
if not os.path.exists(SCREENSHOTS_DIR):
return
with os.scandir(SCREENSHOTS_DIR) as entries:
for entry in entries:
if (not entry.is_file()) or (not _is_user_screenshot(entry.name, username_prefix)):
continue
yield entry
def _build_display_name(filename: str) -> str:
base_name, ext = filename.rsplit(".", 1)
parts = base_name.split("_", 1)
if len(parts) > 1:
return f"{parts[1]}.{ext}"
return filename
@api_screenshots_bp.route("/api/screenshots", methods=["GET"])
@login_required
def get_screenshots():
"""获取当前用户的截图列表"""
user_id = current_user.id
user_info = database.get_user_by_id(user_id)
username_prefix = user_info["username"] if user_info else f"user{user_id}"
username_prefix = _get_user_prefix(user_id)
try:
screenshots = []
if os.path.exists(SCREENSHOTS_DIR):
for filename in os.listdir(SCREENSHOTS_DIR):
if filename.lower().endswith((".png", ".jpg", ".jpeg")) and filename.startswith(username_prefix + "_"):
filepath = os.path.join(SCREENSHOTS_DIR, filename)
stat = os.stat(filepath)
created_time = datetime.fromtimestamp(stat.st_mtime, tz=BEIJING_TZ)
parts = filename.rsplit(".", 1)[0].split("_", 1)
if len(parts) > 1:
display_name = parts[1] + "." + filename.rsplit(".", 1)[1]
else:
display_name = filename
for entry in _iter_user_screenshot_entries(username_prefix):
filename = entry.name
stat = entry.stat()
created_time = datetime.fromtimestamp(stat.st_mtime, tz=BEIJING_TZ)
screenshots.append(
{
"filename": filename,
"display_name": _build_display_name(filename),
"size": stat.st_size,
"created": created_time.strftime("%Y-%m-%d %H:%M:%S"),
"_created_ts": stat.st_mtime,
}
)
screenshots.sort(key=lambda item: item.get("_created_ts", 0), reverse=True)
for item in screenshots:
item.pop("_created_ts", None)
screenshots.append(
{
"filename": filename,
"display_name": display_name,
"size": stat.st_size,
"created": created_time.strftime("%Y-%m-%d %H:%M:%S"),
}
)
screenshots.sort(key=lambda x: x["created"], reverse=True)
return jsonify(screenshots)
except Exception as e:
return jsonify({"error": str(e)}), 500
@@ -60,10 +87,9 @@ def get_screenshots():
def serve_screenshot(filename):
"""提供截图文件访问"""
user_id = current_user.id
user_info = database.get_user_by_id(user_id)
username_prefix = user_info["username"] if user_info else f"user{user_id}"
username_prefix = _get_user_prefix(user_id)
if not filename.startswith(username_prefix + "_"):
if not _is_user_screenshot(filename, username_prefix):
return jsonify({"error": "无权访问"}), 403
if not is_safe_path(SCREENSHOTS_DIR, filename):
@@ -77,12 +103,14 @@ def serve_screenshot(filename):
def delete_screenshot(filename):
"""删除指定截图"""
user_id = current_user.id
user_info = database.get_user_by_id(user_id)
username_prefix = user_info["username"] if user_info else f"user{user_id}"
username_prefix = _get_user_prefix(user_id)
if not filename.startswith(username_prefix + "_"):
if not _is_user_screenshot(filename, username_prefix):
return jsonify({"error": "无权删除"}), 403
if not is_safe_path(SCREENSHOTS_DIR, filename):
return jsonify({"error": "非法路径"}), 403
try:
filepath = os.path.join(SCREENSHOTS_DIR, filename)
if os.path.exists(filepath):
@@ -99,19 +127,15 @@ def delete_screenshot(filename):
def clear_all_screenshots():
"""清空当前用户的所有截图"""
user_id = current_user.id
user_info = database.get_user_by_id(user_id)
username_prefix = user_info["username"] if user_info else f"user{user_id}"
username_prefix = _get_user_prefix(user_id)
try:
deleted_count = 0
if os.path.exists(SCREENSHOTS_DIR):
for filename in os.listdir(SCREENSHOTS_DIR):
if filename.lower().endswith((".png", ".jpg", ".jpeg")) and filename.startswith(username_prefix + "_"):
filepath = os.path.join(SCREENSHOTS_DIR, filename)
os.remove(filepath)
deleted_count += 1
for entry in _iter_user_screenshot_entries(username_prefix):
os.remove(entry.path)
deleted_count += 1
log_to_client(f"清理了 {deleted_count} 个截图文件", user_id)
return jsonify({"success": True, "deleted": deleted_count})
except Exception as e:
return jsonify({"error": str(e)}), 500

View File

@@ -10,12 +10,96 @@ from flask import Blueprint, jsonify, request
from flask_login import current_user, login_required
from routes.pages import render_app_spa_or_legacy
from services.state import check_email_rate_limit, check_ip_request_rate, safe_iter_task_status_items
from services.tasks import get_task_scheduler
logger = get_logger("app")
api_user_bp = Blueprint("api_user", __name__)
def _get_current_user_record():
return database.get_user_by_id(current_user.id)
def _get_current_user_or_404():
user = _get_current_user_record()
if user:
return user, None
return None, (jsonify({"error": "用户不存在"}), 404)
def _get_current_username(*, fallback: str) -> str:
user = _get_current_user_record()
username = (user or {}).get("username", "")
return username if username else fallback
def _coerce_binary_flag(value, *, field_label: str):
if isinstance(value, bool):
value = 1 if value else 0
try:
value = int(value)
except Exception:
return None, f"{field_label}必须是0或1"
if value not in (0, 1):
return None, f"{field_label}必须是0或1"
return value, None
def _check_bind_email_rate_limits(email: str):
client_ip = get_rate_limit_ip()
allowed, error_msg = check_ip_request_rate(client_ip, "email")
if not allowed:
return False, error_msg, 429
allowed, error_msg = check_email_rate_limit(email, "bind_email")
if not allowed:
return False, error_msg, 429
return True, "", 200
def _render_verify_bind_failed(*, title: str, error_message: str):
spa_initial_state = {
"page": "verify_result",
"success": False,
"title": title,
"error_message": error_message,
"primary_label": "返回登录",
"primary_url": "/login",
}
return render_app_spa_or_legacy(
"verify_failed.html",
legacy_context={"error_message": error_message},
spa_initial_state=spa_initial_state,
)
def _render_verify_bind_success(email: str):
spa_initial_state = {
"page": "verify_result",
"success": True,
"title": "邮箱绑定成功",
"message": f"邮箱 {email} 已成功绑定到您的账号!",
"primary_label": "返回登录",
"primary_url": "/login",
"redirect_url": "/login",
"redirect_seconds": 5,
}
return render_app_spa_or_legacy("verify_success.html", spa_initial_state=spa_initial_state)
def _get_current_running_count(user_id: int) -> int:
try:
queue_snapshot = get_task_scheduler().get_queue_state_snapshot() or {}
running_by_user = queue_snapshot.get("running_by_user") or {}
return int(running_by_user.get(int(user_id), running_by_user.get(str(user_id), 0)) or 0)
except Exception:
current_running = 0
for _, info in safe_iter_task_status_items():
if info.get("user_id") == user_id and info.get("status") == "运行中":
current_running += 1
return current_running
@api_user_bp.route("/api/announcements/active", methods=["GET"])
@login_required
def get_active_announcement():
@@ -77,8 +161,7 @@ def submit_feedback():
if len(description) > 2000:
return jsonify({"error": "描述不能超过2000个字符"}), 400
user_info = database.get_user_by_id(current_user.id)
username = user_info["username"] if user_info else f"用户{current_user.id}"
username = _get_current_username(fallback=f"用户{current_user.id}")
feedback_id = database.create_bug_feedback(
user_id=current_user.id,
@@ -104,8 +187,7 @@ def get_my_feedbacks():
def get_current_user_vip():
"""获取当前用户VIP信息"""
vip_info = database.get_user_vip_info(current_user.id)
user_info = database.get_user_by_id(current_user.id)
vip_info["username"] = user_info["username"] if user_info else "Unknown"
vip_info["username"] = _get_current_username(fallback="Unknown")
return jsonify(vip_info)
@@ -124,9 +206,9 @@ def change_user_password():
if not is_valid:
return jsonify({"error": error_msg}), 400
user = database.get_user_by_id(current_user.id)
if not user:
return jsonify({"error": "用户不存在"}), 404
user, error_response = _get_current_user_or_404()
if error_response:
return error_response
username = user.get("username", "")
if not username or not database.verify_user(username, current_password):
@@ -141,9 +223,9 @@ def change_user_password():
@login_required
def get_user_email():
"""获取当前用户的邮箱信息"""
user = database.get_user_by_id(current_user.id)
if not user:
return jsonify({"error": "用户不存在"}), 404
user, error_response = _get_current_user_or_404()
if error_response:
return error_response
return jsonify({"email": user.get("email", ""), "email_verified": user.get("email_verified", False)})
@@ -172,14 +254,9 @@ def update_user_kdocs_settings():
return jsonify({"error": "县区长度不能超过50"}), 400
if kdocs_auto_upload is not None:
if isinstance(kdocs_auto_upload, bool):
kdocs_auto_upload = 1 if kdocs_auto_upload else 0
try:
kdocs_auto_upload = int(kdocs_auto_upload)
except Exception:
return jsonify({"error": "自动上传开关必须是0或1"}), 400
if kdocs_auto_upload not in (0, 1):
return jsonify({"error": "自动上传开关必须是0或1"}), 400
kdocs_auto_upload, parse_error = _coerce_binary_flag(kdocs_auto_upload, field_label="自动上传开关")
if parse_error:
return jsonify({"error": parse_error}), 400
if not database.update_user_kdocs_settings(
current_user.id,
@@ -207,13 +284,9 @@ def bind_user_email():
if not is_valid:
return jsonify({"error": error_msg}), 400
client_ip = get_rate_limit_ip()
allowed, error_msg = check_ip_request_rate(client_ip, "email")
allowed, error_msg, status_code = _check_bind_email_rate_limits(email)
if not allowed:
return jsonify({"error": error_msg}), 429
allowed, error_msg = check_email_rate_limit(email, "bind_email")
if not allowed:
return jsonify({"error": error_msg}), 429
return jsonify({"error": error_msg}), status_code
settings = email_service.get_email_settings()
if not settings.get("enabled", False):
@@ -223,9 +296,9 @@ def bind_user_email():
if existing_user and existing_user["id"] != current_user.id:
return jsonify({"error": "该邮箱已被其他用户绑定"}), 400
user = database.get_user_by_id(current_user.id)
if not user:
return jsonify({"error": "用户不存在"}), 404
user, error_response = _get_current_user_or_404()
if error_response:
return error_response
if user.get("email") == email and user.get("email_verified"):
return jsonify({"error": "该邮箱已绑定并验证"}), 400
@@ -247,56 +320,20 @@ def verify_bind_email(token):
email = result["email"]
if database.update_user_email(user_id, email, verified=True):
spa_initial_state = {
"page": "verify_result",
"success": True,
"title": "邮箱绑定成功",
"message": f"邮箱 {email} 已成功绑定到您的账号!",
"primary_label": "返回登录",
"primary_url": "/login",
"redirect_url": "/login",
"redirect_seconds": 5,
}
return render_app_spa_or_legacy("verify_success.html", spa_initial_state=spa_initial_state)
return _render_verify_bind_success(email)
error_message = "邮箱绑定失败,请重试"
spa_initial_state = {
"page": "verify_result",
"success": False,
"title": "绑定失败",
"error_message": error_message,
"primary_label": "返回登录",
"primary_url": "/login",
}
return render_app_spa_or_legacy(
"verify_failed.html",
legacy_context={"error_message": error_message},
spa_initial_state=spa_initial_state,
)
return _render_verify_bind_failed(title="绑定失败", error_message="邮箱绑定失败,请重试")
error_message = "验证链接已过期或无效,请重新发送验证邮件"
spa_initial_state = {
"page": "verify_result",
"success": False,
"title": "链接无效",
"error_message": error_message,
"primary_label": "返回登录",
"primary_url": "/login",
}
return render_app_spa_or_legacy(
"verify_failed.html",
legacy_context={"error_message": error_message},
spa_initial_state=spa_initial_state,
)
return _render_verify_bind_failed(title="链接无效", error_message="验证链接已过期或无效,请重新发送验证邮件")
@api_user_bp.route("/api/user/unbind-email", methods=["POST"])
@login_required
def unbind_user_email():
"""解绑用户邮箱"""
user = database.get_user_by_id(current_user.id)
if not user:
return jsonify({"error": "用户不存在"}), 404
user, error_response = _get_current_user_or_404()
if error_response:
return error_response
if not user.get("email"):
return jsonify({"error": "当前未绑定邮箱"}), 400
@@ -334,10 +371,7 @@ def get_run_stats():
stats = database.get_user_run_stats(user_id)
current_running = 0
for _, info in safe_iter_task_status_items():
if info.get("user_id") == user_id and info.get("status") == "运行中":
current_running += 1
current_running = _get_current_running_count(user_id)
return jsonify(
{

View File

@@ -31,7 +31,7 @@ def admin_required(f):
if is_api:
return jsonify({"error": "需要管理员权限"}), 403
return redirect(url_for("pages.admin_login_page"))
logger.info(f"[admin_required] 管理员 {session.get('admin_username')} 访问 {request.path}")
logger.debug(f"[admin_required] 管理员 {session.get('admin_username')} 访问 {request.path}")
return f(*args, **kwargs)
return decorated_function

View File

@@ -2,12 +2,62 @@
# -*- coding: utf-8 -*-
from __future__ import annotations
import os
import time
from flask import Blueprint, jsonify
import database
import db_pool
from services.time_utils import get_beijing_now
health_bp = Blueprint("health", __name__)
_PROCESS_START_TS = time.time()
def _build_runtime_metrics() -> dict:
metrics = {
"uptime_seconds": max(0, int(time.time() - _PROCESS_START_TS)),
}
try:
pool_stats = db_pool.get_pool_stats() or {}
metrics["db_pool"] = {
"pool_size": int(pool_stats.get("pool_size", 0) or 0),
"available": int(pool_stats.get("available", 0) or 0),
"in_use": int(pool_stats.get("in_use", 0) or 0),
}
except Exception:
pass
try:
import psutil
proc = psutil.Process(os.getpid())
with proc.oneshot():
mem_info = proc.memory_info()
metrics["process"] = {
"rss_mb": round(float(mem_info.rss) / 1024 / 1024, 2),
"cpu_percent": round(float(proc.cpu_percent(interval=None)), 2),
"threads": int(proc.num_threads()),
}
except Exception:
pass
try:
from services import tasks as tasks_module
scheduler = getattr(tasks_module, "_task_scheduler", None)
if scheduler is not None:
queue_snapshot = scheduler.get_queue_state_snapshot() or {}
metrics["task_queue"] = {
"pending_total": int(queue_snapshot.get("pending_total", 0) or 0),
"running_total": int(queue_snapshot.get("running_total", 0) or 0),
}
except Exception:
pass
return metrics
@health_bp.route("/health", methods=["GET"])
@@ -26,6 +76,6 @@ def health_check():
"time": get_beijing_now().strftime("%Y-%m-%d %H:%M:%S"),
"db_ok": db_ok,
"db_error": db_error,
"metrics": _build_runtime_metrics(),
}
return jsonify(payload), (200 if db_ok else 500)

View File

@@ -0,0 +1,60 @@
# 健康监控(邮件版)
本目录提供 `health_email_monitor.py`,通过调用 `/health` 接口并使用**容器内已有邮件配置**发告警邮件。
## 1) 快速试跑
```bash
cd /root/zsglpt
python3 scripts/health_email_monitor.py \
--to 你的告警邮箱@example.com \
--container knowledge-automation-multiuser \
--url http://127.0.0.1:51232/health \
--dry-run
```
去掉 `--dry-run` 即会实际发邮件。
## 2) 建议 cron每分钟
```bash
* * * * * cd /root/zsglpt && /usr/bin/python3 scripts/health_email_monitor.py \
--to 你的告警邮箱@example.com \
--container knowledge-automation-multiuser \
--url http://127.0.0.1:51232/health \
>> /root/zsglpt/logs/health_monitor.log 2>&1
```
## 3) 支持的规则
- `service_down`:健康接口请求失败(立即告警)
- `health_fail`:返回 `ok/db_ok` 异常或 HTTP 5xx立即告警
- `db_pool_exhausted`:连接池耗尽(默认连续 3 次才告警)
- `queue_backlog_high`:任务堆积过高(默认 `pending_total >= 50` 且连续 5 次)
脚本支持恢复通知(规则恢复正常会发“恢复”邮件)。
## 4) 常用参数
- `--to`:收件人(必填)
- `--container`Docker 容器名(默认 `knowledge-automation-multiuser`
- `--url`:健康地址(默认 `http://127.0.0.1:51232/health`
- `--state-file`:状态文件路径(默认 `/tmp/zsglpt_health_monitor_state.json`
- `--remind-seconds`:重复告警间隔(默认 3600 秒)
- `--queue-threshold`:队列告警阈值(默认 50
- `--queue-streak`:队列连续次数阈值(默认 5
- `--db-pool-streak`:连接池连续次数阈值(默认 3
## 5) 环境变量方式(可选)
也可不用命令行参数,改用环境变量:
- `MONITOR_EMAIL_TO`
- `MONITOR_DOCKER_CONTAINER`
- `HEALTH_URL`
- `MONITOR_STATE_FILE`
- `MONITOR_REMIND_SECONDS`
- `MONITOR_QUEUE_THRESHOLD`
- `MONITOR_QUEUE_STREAK`
- `MONITOR_DB_POOL_STREAK`

View File

@@ -0,0 +1,348 @@
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
from __future__ import annotations
import argparse
import json
import os
import subprocess
import time
from datetime import datetime
from typing import Any, Dict, Tuple
from urllib.error import HTTPError, URLError
from urllib.request import Request, urlopen
DEFAULT_STATE_FILE = "/tmp/zsglpt_health_monitor_state.json"
def _now_text() -> str:
return datetime.now().strftime("%Y-%m-%d %H:%M:%S")
def _safe_int(value: Any, default: int = 0) -> int:
try:
return int(value)
except Exception:
return int(default)
def _safe_float(value: Any, default: float = 0.0) -> float:
try:
return float(value)
except Exception:
return float(default)
def _load_state(path: str) -> Dict[str, Any]:
if not path or not os.path.exists(path):
return {
"version": 1,
"rules": {},
"counters": {},
}
try:
with open(path, "r", encoding="utf-8") as f:
raw = json.load(f)
if not isinstance(raw, dict):
raise ValueError("state is not dict")
raw.setdefault("version", 1)
raw.setdefault("rules", {})
raw.setdefault("counters", {})
return raw
except Exception:
return {
"version": 1,
"rules": {},
"counters": {},
}
def _save_state(path: str, state: Dict[str, Any]) -> None:
if not path:
return
state_dir = os.path.dirname(path)
if state_dir:
os.makedirs(state_dir, exist_ok=True)
tmp_path = f"{path}.tmp"
with open(tmp_path, "w", encoding="utf-8") as f:
json.dump(state, f, ensure_ascii=False, indent=2)
os.replace(tmp_path, path)
def _fetch_health(url: str, timeout: int) -> Tuple[int | None, Dict[str, Any], str | None]:
req = Request(
url,
headers={
"User-Agent": "zsglpt-health-email-monitor/1.0",
"Accept": "application/json",
},
method="GET",
)
try:
with urlopen(req, timeout=max(1, int(timeout))) as resp:
status = int(resp.getcode())
body = resp.read().decode("utf-8", errors="ignore")
except HTTPError as e:
status = int(getattr(e, "code", 0) or 0)
body = ""
try:
body = e.read().decode("utf-8", errors="ignore")
except Exception:
pass
data = {}
if body:
try:
data = json.loads(body)
if not isinstance(data, dict):
data = {}
except Exception:
data = {}
return status, data, f"HTTPError: {e}"
except URLError as e:
return None, {}, f"URLError: {e}"
except Exception as e:
return None, {}, f"RequestError: {e}"
data: Dict[str, Any] = {}
if body:
try:
loaded = json.loads(body)
if isinstance(loaded, dict):
data = loaded
except Exception:
data = {}
return status, data, None
def _inc_streak(state: Dict[str, Any], key: str, bad: bool) -> int:
counters = state.setdefault("counters", {})
current = _safe_int(counters.get(key), 0)
current = (current + 1) if bad else 0
counters[key] = current
return current
def _rule_transition(
state: Dict[str, Any],
*,
rule_name: str,
bad: bool,
streak: int,
threshold: int,
remind_seconds: int,
now_ts: float,
) -> str | None:
rules = state.setdefault("rules", {})
rule_state = rules.setdefault(rule_name, {"active": False, "last_sent": 0})
is_active = bool(rule_state.get("active", False))
last_sent = _safe_float(rule_state.get("last_sent", 0), 0.0)
threshold = max(1, int(threshold))
remind_seconds = max(60, int(remind_seconds))
if bad and streak >= threshold:
if not is_active:
rule_state["active"] = True
rule_state["last_sent"] = now_ts
return "alert"
if (now_ts - last_sent) >= remind_seconds:
rule_state["last_sent"] = now_ts
return "alert"
return None
if is_active and (not bad):
rule_state["active"] = False
rule_state["last_sent"] = now_ts
return "recover"
return None
def _send_email_via_container(
*,
container_name: str,
to_email: str,
subject: str,
body: str,
timeout_seconds: int = 45,
) -> Tuple[bool, str]:
code = (
"import sys,email_service;"
"res=email_service.send_email(to_email=sys.argv[1],subject=sys.argv[2],body=sys.argv[3],email_type='health_monitor');"
"ok=bool(res.get('success'));"
"print('ok' if ok else ('error:'+str(res.get('error'))));"
"raise SystemExit(0 if ok else 2)"
)
cmd = [
"docker",
"exec",
container_name,
"python",
"-c",
code,
to_email,
subject,
body,
]
try:
proc = subprocess.run(
cmd,
capture_output=True,
text=True,
timeout=max(5, int(timeout_seconds)),
check=False,
)
except Exception as e:
return False, str(e)
output = (proc.stdout or "") + (proc.stderr or "")
output = output.strip()
return proc.returncode == 0, output
def _build_common_lines(status: int | None, data: Dict[str, Any], fetch_error: str | None) -> list[str]:
metrics = data.get("metrics") if isinstance(data.get("metrics"), dict) else {}
db_pool = metrics.get("db_pool") if isinstance(metrics.get("db_pool"), dict) else {}
process = metrics.get("process") if isinstance(metrics.get("process"), dict) else {}
task_queue = metrics.get("task_queue") if isinstance(metrics.get("task_queue"), dict) else {}
lines = [
f"时间: {_now_text()}",
f"健康地址: {data.get('_monitor_url', '')}",
f"HTTP状态: {status if status is not None else '请求失败'}",
f"ok/db_ok: {data.get('ok')} / {data.get('db_ok')}",
]
if fetch_error:
lines.append(f"请求错误: {fetch_error}")
lines.extend(
[
f"队列: pending={task_queue.get('pending_total', 'N/A')}, running={task_queue.get('running_total', 'N/A')}",
f"连接池: size={db_pool.get('pool_size', 'N/A')}, available={db_pool.get('available', 'N/A')}, in_use={db_pool.get('in_use', 'N/A')}",
f"进程: rss_mb={process.get('rss_mb', 'N/A')}, cpu%={process.get('cpu_percent', 'N/A')}, threads={process.get('threads', 'N/A')}",
f"运行时长: {metrics.get('uptime_seconds', 'N/A')}",
]
)
return lines
def main() -> int:
parser = argparse.ArgumentParser(description="zsglpt 邮件健康监控(基于 /health")
parser.add_argument("--url", default=os.environ.get("HEALTH_URL", "http://127.0.0.1:51232/health"))
parser.add_argument("--to", default=os.environ.get("MONITOR_EMAIL_TO", ""))
parser.add_argument(
"--container",
default=os.environ.get("MONITOR_DOCKER_CONTAINER", "knowledge-automation-multiuser"),
)
parser.add_argument("--state-file", default=os.environ.get("MONITOR_STATE_FILE", DEFAULT_STATE_FILE))
parser.add_argument("--timeout", type=int, default=_safe_int(os.environ.get("MONITOR_TIMEOUT", 8), 8))
parser.add_argument(
"--remind-seconds",
type=int,
default=_safe_int(os.environ.get("MONITOR_REMIND_SECONDS", 3600), 3600),
)
parser.add_argument(
"--queue-threshold",
type=int,
default=_safe_int(os.environ.get("MONITOR_QUEUE_THRESHOLD", 50), 50),
)
parser.add_argument(
"--queue-streak",
type=int,
default=_safe_int(os.environ.get("MONITOR_QUEUE_STREAK", 5), 5),
)
parser.add_argument(
"--db-pool-streak",
type=int,
default=_safe_int(os.environ.get("MONITOR_DB_POOL_STREAK", 3), 3),
)
parser.add_argument("--dry-run", action="store_true", help="仅打印,不实际发邮件")
args = parser.parse_args()
if not args.to:
print("[monitor] 缺少收件人,请设置 --to 或 MONITOR_EMAIL_TO", flush=True)
return 2
state = _load_state(args.state_file)
now_ts = time.time()
status, data, fetch_error = _fetch_health(args.url, args.timeout)
if not isinstance(data, dict):
data = {}
data["_monitor_url"] = args.url
metrics = data.get("metrics") if isinstance(data.get("metrics"), dict) else {}
db_pool = metrics.get("db_pool") if isinstance(metrics.get("db_pool"), dict) else {}
task_queue = metrics.get("task_queue") if isinstance(metrics.get("task_queue"), dict) else {}
service_down = status is None
health_fail = bool(status is not None and (status >= 500 or (not data.get("ok", False)) or (not data.get("db_ok", False))))
db_pool_exhausted = (
_safe_int(db_pool.get("pool_size"), 0) > 0
and _safe_int(db_pool.get("available"), 0) <= 0
and _safe_int(db_pool.get("in_use"), 0) >= _safe_int(db_pool.get("pool_size"), 0)
)
queue_backlog_high = _safe_int(task_queue.get("pending_total"), 0) >= max(1, int(args.queue_threshold))
rule_defs = [
("service_down", service_down, 1),
("health_fail", health_fail, 1),
("db_pool_exhausted", db_pool_exhausted, max(1, int(args.db_pool_streak))),
("queue_backlog_high", queue_backlog_high, max(1, int(args.queue_streak))),
]
pending_notifications: list[tuple[str, str]] = []
for rule_name, bad, threshold in rule_defs:
streak = _inc_streak(state, rule_name, bad)
action = _rule_transition(
state,
rule_name=rule_name,
bad=bad,
streak=streak,
threshold=threshold,
remind_seconds=args.remind_seconds,
now_ts=now_ts,
)
if action:
pending_notifications.append((rule_name, action))
_save_state(args.state_file, state)
if not pending_notifications:
print(f"[monitor] {_now_text()} 正常,无需发送邮件")
return 0
common_lines = _build_common_lines(status, data, fetch_error)
for rule_name, action in pending_notifications:
level = "告警" if action == "alert" else "恢复"
subject = f"[zsglpt健康监控][{level}] {rule_name}"
body_lines = [
f"规则: {rule_name}",
f"状态: {level}",
"",
*common_lines,
]
body = "\n".join(body_lines)
if args.dry_run:
print(f"[monitor][dry-run] subject={subject}\n{body}\n")
continue
ok, msg = _send_email_via_container(
container_name=args.container,
to_email=args.to,
subject=subject,
body=body,
)
if ok:
print(f"[monitor] 邮件已发送: {subject}")
else:
print(f"[monitor] 邮件发送失败: {subject} | {msg}")
return 0
if __name__ == "__main__":
raise SystemExit(main())

View File

@@ -243,6 +243,35 @@ class KDocsUploader:
except queue.Empty:
return {"success": False, "error": "操作超时"}
def _put_task_response(self, task: Dict[str, Any], result: Dict[str, Any]) -> None:
response_queue = task.get("response")
if not response_queue:
return
try:
response_queue.put(result)
except Exception:
return
def _process_task(self, task: Dict[str, Any]) -> bool:
action = task.get("action")
payload = task.get("payload") or {}
if action == "shutdown":
return False
if action == "upload":
self._handle_upload(payload)
return True
if action == "qr":
self._put_task_response(task, self._handle_qr(payload))
return True
if action == "clear_login":
self._put_task_response(task, self._handle_clear_login())
return True
if action == "status":
self._put_task_response(task, self._handle_status_check())
return True
return True
def _run(self) -> None:
thread_id = self._thread_id
logger.info(f"[KDocs] 上传线程启动 (ID={thread_id})")
@@ -261,34 +290,17 @@ class KDocsUploader:
# 更新最后活动时间
self._last_activity = time.time()
action = task.get("action")
if action == "shutdown":
break
try:
if action == "upload":
self._handle_upload(task.get("payload") or {})
elif action == "qr":
result = self._handle_qr(task.get("payload") or {})
task.get("response").put(result)
elif action == "clear_login":
result = self._handle_clear_login()
task.get("response").put(result)
elif action == "status":
result = self._handle_status_check()
task.get("response").put(result)
should_continue = self._process_task(task)
if not should_continue:
break
# 任务处理完成后更新活动时间
self._last_activity = time.time()
except Exception as e:
logger.warning(f"[KDocs] 处理任务失败: {e}")
# 如果有响应队列,返回错误
if "response" in task and task.get("response"):
try:
task["response"].put({"success": False, "error": str(e)})
except Exception:
pass
self._put_task_response(task, {"success": False, "error": str(e)})
except Exception as e:
logger.warning(f"[KDocs] 线程主循环异常: {e}")
@@ -830,18 +842,180 @@ class KDocsUploader:
except Exception as e:
logger.warning(f"[KDocs] 保存登录态失败: {e}")
def _resolve_doc_url(self, cfg: Dict[str, Any]) -> str:
return (cfg.get("kdocs_doc_url") or "").strip()
def _ensure_doc_access(
self,
doc_url: str,
*,
fast: bool = False,
use_storage_state: bool = True,
) -> Optional[str]:
if not self._ensure_playwright(use_storage_state=use_storage_state):
return self._last_error or "浏览器不可用"
if not self._open_document(doc_url, fast=fast):
return self._last_error or "打开文档失败"
return None
def _trigger_fast_login_dialog(self, timeout_ms: int) -> None:
self._ensure_login_dialog(
timeout_ms=timeout_ms,
frame_timeout_ms=timeout_ms,
quick=True,
)
def _capture_qr_with_retry(self, fast_login_timeout: int) -> Tuple[Optional[bytes], Optional[bytes]]:
qr_image = None
invalid_qr = None
for attempt in range(10):
if attempt in (3, 7):
self._trigger_fast_login_dialog(fast_login_timeout)
candidate = self._capture_qr_image()
if candidate and self._is_valid_qr_image(candidate):
qr_image = candidate
break
if candidate:
invalid_qr = candidate
time.sleep(0.8) # 优化: 1 -> 0.8
return qr_image, invalid_qr
def _save_qr_debug_artifacts(self, invalid_qr: Optional[bytes]) -> None:
try:
pages = self._iter_pages()
page_urls = [getattr(p, "url", "") for p in pages]
logger.warning(f"[KDocs] 二维码未捕获,页面: {page_urls}")
ts = int(time.time())
saved = []
for idx, page in enumerate(pages[:3]):
try:
path = f"data/kdocs_debug_{ts}_{idx}.png"
page.screenshot(path=path, full_page=True)
saved.append(path)
except Exception:
continue
if saved:
logger.warning(f"[KDocs] 已保存调试截图: {saved}")
if invalid_qr:
try:
path = f"data/kdocs_invalid_qr_{ts}.png"
with open(path, "wb") as handle:
handle.write(invalid_qr)
logger.warning(f"[KDocs] 已保存无效二维码截图: {path}")
except Exception:
pass
except Exception:
pass
def _log_upload_failure(self, message: str, user_id: Any, account_id: Any) -> None:
try:
log_to_client(f"表格上传失败: {message}", user_id, account_id)
except Exception:
pass
def _mark_upload_tracking(self, user_id: Any, account_id: Any) -> Tuple[Any, Optional[str], bool]:
account = None
prev_status = None
status_tracked = False
try:
account = safe_get_account(user_id, account_id)
if account and self._should_mark_upload(account):
prev_status = getattr(account, "status", None)
account.status = "上传截图"
self._emit_account_update(user_id, account)
status_tracked = True
except Exception:
prev_status = None
return account, prev_status, status_tracked
def _parse_upload_payload(self, payload: Dict[str, Any]) -> Optional[Dict[str, Any]]:
unit = (payload.get("unit") or "").strip()
name = (payload.get("name") or "").strip()
image_path = payload.get("image_path")
if not unit or not name:
return None
if not image_path or not os.path.exists(image_path):
return None
return {
"unit": unit,
"name": name,
"image_path": image_path,
"user_id": payload.get("user_id"),
"account_id": payload.get("account_id"),
}
def _resolve_upload_sheet_config(self, cfg: Dict[str, Any]) -> Dict[str, Any]:
return {
"sheet_name": (cfg.get("kdocs_sheet_name") or "").strip(),
"sheet_index": int(cfg.get("kdocs_sheet_index") or 0),
"unit_col": (cfg.get("kdocs_unit_column") or "A").strip().upper(),
"image_col": (cfg.get("kdocs_image_column") or "D").strip().upper(),
"row_start": int(cfg.get("kdocs_row_start") or 0),
"row_end": int(cfg.get("kdocs_row_end") or 0),
}
def _try_upload_to_sheet(self, cfg: Dict[str, Any], unit: str, name: str, image_path: str) -> Tuple[bool, str]:
sheet_cfg = self._resolve_upload_sheet_config(cfg)
success = False
error_msg = ""
for _ in range(2):
try:
if sheet_cfg["sheet_name"] or sheet_cfg["sheet_index"]:
self._select_sheet(sheet_cfg["sheet_name"], sheet_cfg["sheet_index"])
row_num = self._find_person_with_unit(
unit,
name,
sheet_cfg["unit_col"],
row_start=sheet_cfg["row_start"],
row_end=sheet_cfg["row_end"],
)
if row_num < 0:
error_msg = f"未找到人员: {unit}-{name}"
break
success = self._upload_image_to_cell(row_num, image_path, sheet_cfg["image_col"])
if success:
break
except Exception as e:
error_msg = str(e)
return success, error_msg
def _handle_upload_login_invalid(
self,
*,
unit: str,
name: str,
image_path: str,
user_id: Any,
account_id: Any,
) -> None:
error_msg = "登录已失效,请管理员重新扫码登录"
self._login_required = True
self._last_login_ok = False
self._notify_admin(unit, name, image_path, error_msg)
self._log_upload_failure(error_msg, user_id, account_id)
def _handle_qr(self, payload: Dict[str, Any]) -> Dict[str, Any]:
cfg = self._load_system_config()
doc_url = (cfg.get("kdocs_doc_url") or "").strip()
doc_url = self._resolve_doc_url(cfg)
if not doc_url:
return {"success": False, "error": "未配置金山文档链接"}
force = bool(payload.get("force"))
if force:
self._handle_clear_login()
if not self._ensure_playwright(use_storage_state=not force):
return {"success": False, "error": self._last_error or "浏览器不可用"}
if not self._open_document(doc_url, fast=True):
return {"success": False, "error": self._last_error or "打开文档失败"}
doc_error = self._ensure_doc_access(doc_url, fast=True, use_storage_state=not force)
if doc_error:
return {"success": False, "error": doc_error}
if not force and self._has_saved_login_state() and self._is_logged_in():
self._login_required = False
@@ -850,54 +1024,12 @@ class KDocsUploader:
return {"success": True, "logged_in": True, "qr_image": ""}
fast_login_timeout = int(os.environ.get("KDOCS_FAST_LOGIN_TIMEOUT_MS", "300"))
self._ensure_login_dialog(
timeout_ms=fast_login_timeout,
frame_timeout_ms=fast_login_timeout,
quick=True,
)
qr_image = None
invalid_qr = None
for attempt in range(10):
if attempt in (3, 7):
self._ensure_login_dialog(
timeout_ms=fast_login_timeout,
frame_timeout_ms=fast_login_timeout,
quick=True,
)
candidate = self._capture_qr_image()
if candidate and self._is_valid_qr_image(candidate):
qr_image = candidate
break
if candidate:
invalid_qr = candidate
time.sleep(0.8) # 优化: 1 -> 0.8
self._trigger_fast_login_dialog(fast_login_timeout)
qr_image, invalid_qr = self._capture_qr_with_retry(fast_login_timeout)
if not qr_image:
self._last_error = "二维码识别异常" if invalid_qr else "二维码获取失败"
try:
pages = self._iter_pages()
page_urls = [getattr(p, "url", "") for p in pages]
logger.warning(f"[KDocs] 二维码未捕获,页面: {page_urls}")
ts = int(time.time())
saved = []
for idx, page in enumerate(pages[:3]):
try:
path = f"data/kdocs_debug_{ts}_{idx}.png"
page.screenshot(path=path, full_page=True)
saved.append(path)
except Exception:
continue
if saved:
logger.warning(f"[KDocs] 已保存调试截图: {saved}")
if invalid_qr:
try:
path = f"data/kdocs_invalid_qr_{ts}.png"
with open(path, "wb") as handle:
handle.write(invalid_qr)
logger.warning(f"[KDocs] 已保存无效二维码截图: {path}")
except Exception:
pass
except Exception:
pass
self._save_qr_debug_artifacts(invalid_qr)
return {"success": False, "error": self._last_error}
try:
@@ -933,24 +1065,22 @@ class KDocsUploader:
def _handle_status_check(self) -> Dict[str, Any]:
cfg = self._load_system_config()
doc_url = (cfg.get("kdocs_doc_url") or "").strip()
doc_url = self._resolve_doc_url(cfg)
if not doc_url:
return {"success": True, "logged_in": False, "error": "未配置文档链接"}
if not self._ensure_playwright():
return {"success": False, "logged_in": False, "error": self._last_error or "浏览器不可用"}
if not self._open_document(doc_url, fast=True):
return {"success": False, "logged_in": False, "error": self._last_error or "打开文档失败"}
doc_error = self._ensure_doc_access(doc_url, fast=True)
if doc_error:
return {"success": False, "logged_in": False, "error": doc_error}
fast_login_timeout = int(os.environ.get("KDOCS_FAST_LOGIN_TIMEOUT_MS", "300"))
self._ensure_login_dialog(
timeout_ms=fast_login_timeout,
frame_timeout_ms=fast_login_timeout,
quick=True,
)
self._trigger_fast_login_dialog(fast_login_timeout)
self._try_confirm_login(
timeout_ms=fast_login_timeout,
frame_timeout_ms=fast_login_timeout,
quick=True,
)
logged_in = self._is_logged_in()
self._last_login_ok = logged_in
self._login_required = not logged_in
@@ -962,79 +1092,43 @@ class KDocsUploader:
cfg = self._load_system_config()
if int(cfg.get("kdocs_enabled", 0) or 0) != 1:
return
doc_url = (cfg.get("kdocs_doc_url") or "").strip()
doc_url = self._resolve_doc_url(cfg)
if not doc_url:
return
unit = (payload.get("unit") or "").strip()
name = (payload.get("name") or "").strip()
image_path = payload.get("image_path")
user_id = payload.get("user_id")
account_id = payload.get("account_id")
if not unit or not name:
return
if not image_path or not os.path.exists(image_path):
upload_data = self._parse_upload_payload(payload)
if not upload_data:
return
account = None
prev_status = None
status_tracked = False
unit = upload_data["unit"]
name = upload_data["name"]
image_path = upload_data["image_path"]
user_id = upload_data["user_id"]
account_id = upload_data["account_id"]
account, prev_status, status_tracked = self._mark_upload_tracking(user_id, account_id)
try:
try:
account = safe_get_account(user_id, account_id)
if account and self._should_mark_upload(account):
prev_status = getattr(account, "status", None)
account.status = "上传截图"
self._emit_account_update(user_id, account)
status_tracked = True
except Exception:
prev_status = None
if not self._ensure_playwright():
self._notify_admin(unit, name, image_path, self._last_error or "浏览器不可用")
return
if not self._open_document(doc_url):
self._notify_admin(unit, name, image_path, self._last_error or "打开文档失败")
doc_error = self._ensure_doc_access(doc_url)
if doc_error:
self._notify_admin(unit, name, image_path, doc_error)
return
if not self._is_logged_in():
self._login_required = True
self._last_login_ok = False
self._notify_admin(unit, name, image_path, "登录已失效,请管理员重新扫码登录")
try:
log_to_client("表格上传失败: 登录已失效,请管理员重新扫码登录", user_id, account_id)
except Exception:
pass
self._handle_upload_login_invalid(
unit=unit,
name=name,
image_path=image_path,
user_id=user_id,
account_id=account_id,
)
return
self._login_required = False
self._last_login_ok = True
sheet_name = (cfg.get("kdocs_sheet_name") or "").strip()
sheet_index = int(cfg.get("kdocs_sheet_index") or 0)
unit_col = (cfg.get("kdocs_unit_column") or "A").strip().upper()
image_col = (cfg.get("kdocs_image_column") or "D").strip().upper()
row_start = int(cfg.get("kdocs_row_start") or 0)
row_end = int(cfg.get("kdocs_row_end") or 0)
success = False
error_msg = ""
for attempt in range(2):
try:
if sheet_name or sheet_index:
self._select_sheet(sheet_name, sheet_index)
row_num = self._find_person_with_unit(unit, name, unit_col, row_start=row_start, row_end=row_end)
if row_num < 0:
error_msg = f"未找到人员: {unit}-{name}"
break
success = self._upload_image_to_cell(row_num, image_path, image_col)
if success:
break
except Exception as e:
error_msg = str(e)
success, error_msg = self._try_upload_to_sheet(cfg, unit, name, image_path)
if success:
self._last_success_at = time.time()
self._last_error = None
@@ -1048,10 +1142,7 @@ class KDocsUploader:
error_msg = "上传失败"
self._last_error = error_msg
self._notify_admin(unit, name, image_path, error_msg)
try:
log_to_client(f"表格上传失败: {error_msg}", user_id, account_id)
except Exception:
pass
self._log_upload_failure(error_msg, user_id, account_id)
finally:
if status_tracked:
self._restore_account_status(user_id, account, prev_status)

View File

@@ -2,6 +2,7 @@
# -*- coding: utf-8 -*-
from __future__ import annotations
import os
import threading
import time
from datetime import datetime
@@ -10,6 +11,8 @@ from app_config import get_config
from app_logger import get_logger
from services.state import (
cleanup_expired_ip_rate_limits,
cleanup_expired_ip_request_rates,
cleanup_expired_login_security_state,
safe_cleanup_expired_batches,
safe_cleanup_expired_captcha,
safe_cleanup_expired_pending_random,
@@ -31,6 +34,69 @@ PENDING_RANDOM_EXPIRE_SECONDS = int(getattr(config, "PENDING_RANDOM_EXPIRE_SECON
_kdocs_offline_notified: bool = False
def _to_int(value, default: int = 0) -> int:
try:
return int(value)
except Exception:
return int(default)
def _collect_active_user_ids() -> set[int]:
active_user_ids: set[int] = set()
for _, info in safe_iter_task_status_items():
user_id = info.get("user_id") if isinstance(info, dict) else None
if user_id is None:
continue
try:
active_user_ids.add(int(user_id))
except Exception:
continue
return active_user_ids
def _find_expired_user_cache_ids(current_time: float, active_user_ids: set[int]) -> list[int]:
expired_users = []
for user_id, last_access in (safe_get_user_accounts_last_access_items() or []):
try:
user_id_int = int(user_id)
last_access_ts = float(last_access)
except Exception:
continue
if (current_time - last_access_ts) <= USER_ACCOUNTS_EXPIRE_SECONDS:
continue
if user_id_int in active_user_ids:
continue
if safe_has_user(user_id_int):
expired_users.append(user_id_int)
return expired_users
def _find_completed_task_status_ids(current_time: float) -> list[str]:
completed_task_ids = []
for account_id, status_data in safe_iter_task_status_items():
status = status_data.get("status") if isinstance(status_data, dict) else None
if status not in ["已完成", "失败", "已停止"]:
continue
start_time = float(status_data.get("start_time", 0) or 0)
if (current_time - start_time) > 600: # 10分钟
completed_task_ids.append(account_id)
return completed_task_ids
def _reap_zombie_processes() -> None:
while True:
try:
pid, _ = os.waitpid(-1, os.WNOHANG)
if pid == 0:
break
logger.debug(f"已回收僵尸进程: PID={pid}")
except ChildProcessError:
break
except Exception:
break
def cleanup_expired_data() -> None:
"""定期清理过期数据,防止内存泄漏(逻辑保持不变)。"""
current_time = time.time()
@@ -43,48 +109,36 @@ def cleanup_expired_data() -> None:
if deleted_ips:
logger.debug(f"已清理 {deleted_ips} 个过期IP限流记录")
expired_users = []
last_access_items = safe_get_user_accounts_last_access_items()
if last_access_items:
task_items = safe_iter_task_status_items()
active_user_ids = {int(info.get("user_id")) for _, info in task_items if info.get("user_id")}
for user_id, last_access in last_access_items:
if (current_time - float(last_access)) <= USER_ACCOUNTS_EXPIRE_SECONDS:
continue
if int(user_id) in active_user_ids:
continue
if safe_has_user(user_id):
expired_users.append(int(user_id))
deleted_ip_requests = cleanup_expired_ip_request_rates(current_time)
if deleted_ip_requests:
logger.debug(f"已清理 {deleted_ip_requests} 个过期IP请求频率记录")
login_cleanup_stats = cleanup_expired_login_security_state(current_time)
login_cleanup_total = sum(int(v or 0) for v in login_cleanup_stats.values())
if login_cleanup_total:
logger.debug(
"已清理登录风控缓存: "
f"失败计数={login_cleanup_stats.get('failures', 0)}, "
f"限流桶={login_cleanup_stats.get('rate_limits', 0)}, "
f"扫描状态={login_cleanup_stats.get('scan_states', 0)}, "
f"短时锁={login_cleanup_stats.get('ip_user_locks', 0)}, "
f"告警状态={login_cleanup_stats.get('alerts', 0)}"
)
active_user_ids = _collect_active_user_ids()
expired_users = _find_expired_user_cache_ids(current_time, active_user_ids)
for user_id in expired_users:
safe_remove_user_accounts(user_id)
if expired_users:
logger.debug(f"已清理 {len(expired_users)} 个过期用户账号缓存")
completed_tasks = []
for account_id, status_data in safe_iter_task_status_items():
if status_data.get("status") in ["已完成", "失败", "已停止"]:
start_time = float(status_data.get("start_time", 0) or 0)
if (current_time - start_time) > 600: # 10分钟
completed_tasks.append(account_id)
for account_id in completed_tasks:
completed_task_ids = _find_completed_task_status_ids(current_time)
for account_id in completed_task_ids:
safe_remove_task_status(account_id)
if completed_tasks:
logger.debug(f"已清理 {len(completed_tasks)} 个已完成任务状态")
if completed_task_ids:
logger.debug(f"已清理 {len(completed_task_ids)} 个已完成任务状态")
try:
import os
while True:
try:
pid, status = os.waitpid(-1, os.WNOHANG)
if pid == 0:
break
logger.debug(f"已回收僵尸进程: PID={pid}")
except ChildProcessError:
break
except Exception:
pass
_reap_zombie_processes()
deleted_batches = safe_cleanup_expired_batches(BATCH_TASK_EXPIRE_SECONDS, current_time)
if deleted_batches:
@@ -95,52 +149,39 @@ def cleanup_expired_data() -> None:
logger.debug(f"已清理 {deleted_random} 个过期随机延迟任务")
def check_kdocs_online_status() -> None:
"""检测金山文档登录状态,如果离线则发送邮件通知管理员(每次掉线只通知一次)"""
global _kdocs_offline_notified
def _load_kdocs_monitor_config():
import database
cfg = database.get_system_config()
if not cfg:
return None
kdocs_enabled = _to_int(cfg.get("kdocs_enabled"), 0)
if not kdocs_enabled:
return None
admin_notify_enabled = _to_int(cfg.get("kdocs_admin_notify_enabled"), 0)
admin_notify_email = str(cfg.get("kdocs_admin_notify_email") or "").strip()
if (not admin_notify_enabled) or (not admin_notify_email):
return None
return admin_notify_email
def _is_kdocs_offline(status: dict) -> tuple[bool, bool, bool | None]:
login_required = bool(status.get("login_required", False))
last_login_ok = status.get("last_login_ok")
is_offline = login_required or (last_login_ok is False)
return is_offline, login_required, last_login_ok
def _send_kdocs_offline_alert(admin_notify_email: str, *, login_required: bool, last_login_ok) -> bool:
try:
import database
from services.kdocs_uploader import get_kdocs_uploader
import email_service
# 获取系统配置
cfg = database.get_system_config()
if not cfg:
return
# 检查是否启用了金山文档功能
kdocs_enabled = int(cfg.get("kdocs_enabled") or 0)
if not kdocs_enabled:
return
# 检查是否启用了管理员通知
admin_notify_enabled = int(cfg.get("kdocs_admin_notify_enabled") or 0)
admin_notify_email = (cfg.get("kdocs_admin_notify_email") or "").strip()
if not admin_notify_enabled or not admin_notify_email:
return
# 获取金山文档状态
kdocs = get_kdocs_uploader()
status = kdocs.get_status()
login_required = status.get("login_required", False)
last_login_ok = status.get("last_login_ok")
# 如果需要登录或最后登录状态不是成功
is_offline = login_required or (last_login_ok is False)
if is_offline:
# 已经通知过了,不再重复通知
if _kdocs_offline_notified:
logger.debug("[KDocs监控] 金山文档离线,已通知过,跳过重复通知")
return
# 发送邮件通知
try:
import email_service
now_str = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
subject = "【金山文档离线告警】需要重新登录"
body = f"""
now_str = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
subject = "【金山文档离线告警】需要重新登录"
body = f"""
您好,
系统检测到金山文档上传功能已离线,需要重新扫码登录。
@@ -155,58 +196,92 @@ def check_kdocs_online_status() -> None:
---
此邮件由系统自动发送,请勿直接回复。
"""
email_service.send_email_async(
to_email=admin_notify_email,
subject=subject,
body=body,
email_type="kdocs_offline_alert",
)
_kdocs_offline_notified = True # 标记为已通知
logger.warning(f"[KDocs监控] 金山文档离线,已发送通知邮件到 {admin_notify_email}")
except Exception as e:
logger.error(f"[KDocs监控] 发送离线通知邮件失败: {e}")
else:
# 恢复在线,重置通知状态
email_service.send_email_async(
to_email=admin_notify_email,
subject=subject,
body=body,
email_type="kdocs_offline_alert",
)
logger.warning(f"[KDocs监控] 金山文档离线,已发送通知邮件到 {admin_notify_email}")
return True
except Exception as e:
logger.error(f"[KDocs监控] 发送离线通知邮件失败: {e}")
return False
def check_kdocs_online_status() -> None:
"""检测金山文档登录状态,如果离线则发送邮件通知管理员(每次掉线只通知一次)"""
global _kdocs_offline_notified
try:
admin_notify_email = _load_kdocs_monitor_config()
if not admin_notify_email:
return
from services.kdocs_uploader import get_kdocs_uploader
kdocs = get_kdocs_uploader()
status = kdocs.get_status() or {}
is_offline, login_required, last_login_ok = _is_kdocs_offline(status)
if is_offline:
if _kdocs_offline_notified:
logger.info("[KDocs监控] 金山文档已恢复在线,重置通知状态")
_kdocs_offline_notified = False
logger.debug("[KDocs监控] 金山文档状态正常")
logger.debug("[KDocs监控] 金山文档离线,已通知过,跳过重复通知")
return
if _send_kdocs_offline_alert(
admin_notify_email,
login_required=login_required,
last_login_ok=last_login_ok,
):
_kdocs_offline_notified = True
return
if _kdocs_offline_notified:
logger.info("[KDocs监控] 金山文档已恢复在线,重置通知状态")
_kdocs_offline_notified = False
logger.debug("[KDocs监控] 金山文档状态正常")
except Exception as e:
logger.error(f"[KDocs监控] 检测失败: {e}")
def start_cleanup_scheduler() -> None:
"""启动定期清理调度器"""
def cleanup_loop():
def _start_daemon_loop(name: str, *, startup_delay: float, interval_seconds: float, job, error_tag: str):
def loop():
if startup_delay > 0:
time.sleep(startup_delay)
while True:
try:
time.sleep(300) # 每5分钟执行一次清理
cleanup_expired_data()
job()
time.sleep(interval_seconds)
except Exception as e:
logger.error(f"清理任务执行失败: {e}")
logger.error(f"{error_tag}: {e}")
time.sleep(min(60.0, max(1.0, interval_seconds / 5.0)))
cleanup_thread = threading.Thread(target=cleanup_loop, daemon=True, name="cleanup-scheduler")
cleanup_thread.start()
thread = threading.Thread(target=loop, daemon=True, name=name)
thread.start()
return thread
def start_cleanup_scheduler() -> None:
"""启动定期清理调度器"""
_start_daemon_loop(
"cleanup-scheduler",
startup_delay=300,
interval_seconds=300,
job=cleanup_expired_data,
error_tag="清理任务执行失败",
)
logger.info("内存清理调度器已启动")
def start_kdocs_monitor() -> None:
"""启动金山文档状态监控"""
def monitor_loop():
# 启动后等待 60 秒再开始检测(给系统初始化的时间)
time.sleep(60)
while True:
try:
check_kdocs_online_status()
time.sleep(300) # 每5分钟检测一次
except Exception as e:
logger.error(f"[KDocs监控] 监控任务执行失败: {e}")
time.sleep(60)
monitor_thread = threading.Thread(target=monitor_loop, daemon=True, name="kdocs-monitor")
monitor_thread.start()
_start_daemon_loop(
"kdocs-monitor",
startup_delay=60,
interval_seconds=300,
job=check_kdocs_online_status,
error_tag="[KDocs监控] 监控任务执行失败",
)
logger.info("[KDocs监控] 金山文档状态监控已启动每5分钟检测一次")

View File

@@ -27,6 +27,12 @@ from services.time_utils import get_beijing_now
logger = get_logger("app")
config = get_config()
try:
_SCHEDULE_SUBMIT_DELAY_SECONDS = float(os.environ.get("SCHEDULE_SUBMIT_DELAY_SECONDS", "0.2"))
except Exception:
_SCHEDULE_SUBMIT_DELAY_SECONDS = 0.2
_SCHEDULE_SUBMIT_DELAY_SECONDS = max(0.0, _SCHEDULE_SUBMIT_DELAY_SECONDS)
SCREENSHOTS_DIR = config.SCREENSHOTS_DIR
os.makedirs(SCREENSHOTS_DIR, exist_ok=True)
@@ -55,6 +61,150 @@ def _normalize_hhmm(value: object, *, default: str) -> str:
return f"{hour:02d}:{minute:02d}"
def _safe_recompute_schedule_next_run(schedule_id: int) -> None:
try:
database.recompute_schedule_next_run(schedule_id)
except Exception:
pass
def _load_accounts_for_users(approved_users: list[dict]) -> tuple[dict[int, dict], list[str]]:
"""批量加载用户账号快照。"""
user_accounts: dict[int, dict] = {}
account_ids: list[str] = []
for user in approved_users:
user_id = user["id"]
accounts = safe_get_user_accounts_snapshot(user_id)
if not accounts:
load_user_accounts(user_id)
accounts = safe_get_user_accounts_snapshot(user_id)
if accounts:
user_accounts[user_id] = accounts
account_ids.extend(list(accounts.keys()))
return user_accounts, account_ids
def _should_skip_suspended_account(account_status_info, account, username: str) -> bool:
"""判断是否应跳过暂停账号,并输出日志。"""
if not account_status_info:
return False
status = account_status_info["status"] if "status" in account_status_info.keys() else "active"
if status != "suspended":
return False
fail_count = account_status_info["login_fail_count"] if "login_fail_count" in account_status_info.keys() else 0
logger.info(
f"[定时任务] 跳过暂停账号: {account.username} (用户:{username}) - 连续{fail_count}次密码错误,需修改密码"
)
return True
def _parse_schedule_account_ids(schedule_config: dict, schedule_id: int):
import json
try:
account_ids_raw = schedule_config.get("account_ids", "[]") or "[]"
account_ids = json.loads(account_ids_raw)
except Exception as e:
logger.warning(f"[定时任务] 任务#{schedule_id} 解析account_ids失败: {e}")
return []
if isinstance(account_ids, list):
return account_ids
return []
def _create_user_schedule_batch(*, batch_id: str, user_id: int, browse_type: str, schedule_name: str, now_ts: float) -> None:
safe_create_batch(
batch_id,
{
"user_id": user_id,
"browse_type": browse_type,
"schedule_name": schedule_name,
"screenshots": [],
"total_accounts": 0,
"completed": 0,
"created_at": now_ts,
"updated_at": now_ts,
},
)
def _build_user_schedule_done_callback(
*,
completion_lock: threading.Lock,
remaining: dict,
counters: dict,
execution_start_time: float,
log_id: int,
schedule_id: int,
total_accounts: int,
):
def on_browse_done():
with completion_lock:
remaining["count"] -= 1
if remaining["done"] or remaining["count"] > 0:
return
remaining["done"] = True
execution_duration = int(time.time() - execution_start_time)
started_count = int(counters.get("started", 0) or 0)
database.update_schedule_execution_log(
log_id,
total_accounts=total_accounts,
success_accounts=started_count,
failed_accounts=total_accounts - started_count,
duration_seconds=execution_duration,
status="completed",
)
logger.info(f"[用户定时任务] 任务#{schedule_id}浏览阶段完成,耗时{execution_duration}秒,等待截图完成后发送邮件")
return on_browse_done
def _submit_user_schedule_accounts(
*,
user_id: int,
account_ids: list,
browse_type: str,
enable_screenshot,
task_source: str,
done_callback,
completion_lock: threading.Lock,
remaining: dict,
counters: dict,
) -> tuple[int, int]:
started_count = 0
skipped_count = 0
for account_id in account_ids:
account = safe_get_account(user_id, account_id)
if (not account) or account.is_running:
skipped_count += 1
continue
with completion_lock:
remaining["count"] += 1
ok, msg = submit_account_task(
user_id=user_id,
account_id=account_id,
browse_type=browse_type,
enable_screenshot=enable_screenshot,
source=task_source,
done_callback=done_callback,
)
if ok:
started_count += 1
counters["started"] = started_count
else:
with completion_lock:
remaining["count"] -= 1
skipped_count += 1
logger.warning(f"[用户定时任务] 账号 {account.username} 启动失败: {msg}")
return started_count, skipped_count
def run_scheduled_task(skip_weekday_check: bool = False) -> None:
"""执行所有账号的浏览任务(可被手动调用,过滤重复账号)"""
try:
@@ -87,17 +237,7 @@ def run_scheduled_task(skip_weekday_check: bool = False) -> None:
cfg = database.get_system_config()
enable_screenshot_scheduled = cfg.get("enable_screenshot", 0) == 1
user_accounts = {}
account_ids = []
for user in approved_users:
user_id = user["id"]
accounts = safe_get_user_accounts_snapshot(user_id)
if not accounts:
load_user_accounts(user_id)
accounts = safe_get_user_accounts_snapshot(user_id)
if accounts:
user_accounts[user_id] = accounts
account_ids.extend(list(accounts.keys()))
user_accounts, account_ids = _load_accounts_for_users(approved_users)
account_statuses = database.get_account_status_batch(account_ids)
@@ -113,18 +253,8 @@ def run_scheduled_task(skip_weekday_check: bool = False) -> None:
continue
account_status_info = account_statuses.get(str(account_id))
if account_status_info:
status = account_status_info["status"] if "status" in account_status_info.keys() else "active"
if status == "suspended":
fail_count = (
account_status_info["login_fail_count"]
if "login_fail_count" in account_status_info.keys()
else 0
)
logger.info(
f"[定时任务] 跳过暂停账号: {account.username} (用户:{user['username']}) - 连续{fail_count}次密码错误,需修改密码"
)
continue
if _should_skip_suspended_account(account_status_info, account, user["username"]):
continue
if account.username in executed_usernames:
skipped_duplicates += 1
@@ -149,7 +279,8 @@ def run_scheduled_task(skip_weekday_check: bool = False) -> None:
else:
logger.warning(f"[定时任务] 启动失败({account.username}): {msg}")
time.sleep(2)
if _SCHEDULE_SUBMIT_DELAY_SECONDS > 0:
time.sleep(_SCHEDULE_SUBMIT_DELAY_SECONDS)
logger.info(
f"[定时任务] 执行完成 - 总账号数:{total_accounts}, 已执行:{executed_accounts}, 跳过重复:{skipped_duplicates}"
@@ -198,15 +329,16 @@ def scheduled_task_worker() -> None:
deleted_screenshots = 0
if os.path.exists(SCREENSHOTS_DIR):
cutoff_time = time.time() - (7 * 24 * 60 * 60)
for filename in os.listdir(SCREENSHOTS_DIR):
if filename.lower().endswith((".png", ".jpg", ".jpeg")):
filepath = os.path.join(SCREENSHOTS_DIR, filename)
with os.scandir(SCREENSHOTS_DIR) as entries:
for entry in entries:
if (not entry.is_file()) or (not entry.name.lower().endswith((".png", ".jpg", ".jpeg"))):
continue
try:
if os.path.getmtime(filepath) < cutoff_time:
os.remove(filepath)
if entry.stat().st_mtime < cutoff_time:
os.remove(entry.path)
deleted_screenshots += 1
except Exception as e:
logger.warning(f"[定时清理] 删除截图失败 {filename}: {str(e)}")
logger.warning(f"[定时清理] 删除截图失败 {entry.name}: {str(e)}")
logger.info(f"[定时清理] 已删除 {deleted_screenshots} 个截图文件")
logger.info("[定时清理] 清理完成!")
@@ -214,10 +346,97 @@ def scheduled_task_worker() -> None:
except Exception as e:
logger.exception(f"[定时清理] 清理任务出错: {str(e)}")
def _parse_due_schedule_weekdays(schedule_config: dict, schedule_id: int):
weekdays_str = schedule_config.get("weekdays", "1,2,3,4,5")
try:
return [int(d) for d in weekdays_str.split(",") if d.strip()]
except Exception as e:
logger.warning(f"[定时任务] 任务#{schedule_id} 解析weekdays失败: {e}")
_safe_recompute_schedule_next_run(schedule_id)
return None
def _execute_due_user_schedule(schedule_config: dict) -> None:
schedule_name = schedule_config.get("name", "未命名任务")
schedule_id = schedule_config["id"]
user_id = schedule_config["user_id"]
browse_type = normalize_browse_type(schedule_config.get("browse_type", BROWSE_TYPE_SHOULD_READ))
enable_screenshot = schedule_config.get("enable_screenshot", 1)
account_ids = _parse_schedule_account_ids(schedule_config, schedule_id)
if not account_ids:
_safe_recompute_schedule_next_run(schedule_id)
return
if not safe_get_user_accounts_snapshot(user_id):
load_user_accounts(user_id)
import uuid
execution_start_time = time.time()
log_id = database.create_schedule_execution_log(
schedule_id=schedule_id,
user_id=user_id,
schedule_name=schedule_name,
)
batch_id = f"batch_{uuid.uuid4().hex[:12]}"
now_ts = time.time()
_create_user_schedule_batch(
batch_id=batch_id,
user_id=user_id,
browse_type=browse_type,
schedule_name=schedule_name,
now_ts=now_ts,
)
completion_lock = threading.Lock()
remaining = {"count": 0, "done": False}
counters = {"started": 0}
on_browse_done = _build_user_schedule_done_callback(
completion_lock=completion_lock,
remaining=remaining,
counters=counters,
execution_start_time=execution_start_time,
log_id=log_id,
schedule_id=schedule_id,
total_accounts=len(account_ids),
)
task_source = f"user_scheduled:{batch_id}"
started_count, skipped_count = _submit_user_schedule_accounts(
user_id=user_id,
account_ids=account_ids,
browse_type=browse_type,
enable_screenshot=enable_screenshot,
task_source=task_source,
done_callback=on_browse_done,
completion_lock=completion_lock,
remaining=remaining,
counters=counters,
)
batch_info = safe_finalize_batch_after_dispatch(batch_id, started_count, now_ts=time.time())
if batch_info:
_send_batch_task_email_if_configured(batch_info)
database.update_schedule_last_run(schedule_id)
logger.info(f"[用户定时任务] 已启动 {started_count} 个账号,跳过 {skipped_count} 个账号批次ID: {batch_id}")
if started_count <= 0:
database.update_schedule_execution_log(
log_id,
total_accounts=len(account_ids),
success_accounts=0,
failed_accounts=len(account_ids),
duration_seconds=0,
status="completed",
)
if started_count == 0 and len(account_ids) > 0:
logger.warning("[用户定时任务] ⚠️ 警告所有账号都被跳过了请检查user_accounts状态")
def check_user_schedules():
"""检查并执行用户定时任务O-08next_run_at 索引驱动)。"""
import json
try:
now = get_beijing_now()
now_str = now.strftime("%Y-%m-%d %H:%M:%S")
@@ -226,145 +445,22 @@ def scheduled_task_worker() -> None:
due_schedules = database.get_due_user_schedules(now_str, limit=50) or []
for schedule_config in due_schedules:
schedule_name = schedule_config.get("name", "未命名任务")
schedule_id = schedule_config["id"]
schedule_name = schedule_config.get("name", "未命名任务")
weekdays_str = schedule_config.get("weekdays", "1,2,3,4,5")
try:
allowed_weekdays = [int(d) for d in weekdays_str.split(",") if d.strip()]
except Exception as e:
logger.warning(f"[定时任务] 任务#{schedule_id} 解析weekdays失败: {e}")
try:
database.recompute_schedule_next_run(schedule_id)
except Exception:
pass
allowed_weekdays = _parse_due_schedule_weekdays(schedule_config, schedule_id)
if allowed_weekdays is None:
continue
if current_weekday not in allowed_weekdays:
try:
database.recompute_schedule_next_run(schedule_id)
except Exception:
pass
_safe_recompute_schedule_next_run(schedule_id)
continue
logger.info(f"[用户定时任务] 任务#{schedule_id} '{schedule_name}' 到期,开始执行 (next_run_at={schedule_config.get('next_run_at')})")
user_id = schedule_config["user_id"]
schedule_id = schedule_config["id"]
browse_type = normalize_browse_type(schedule_config.get("browse_type", BROWSE_TYPE_SHOULD_READ))
enable_screenshot = schedule_config.get("enable_screenshot", 1)
try:
account_ids_raw = schedule_config.get("account_ids", "[]") or "[]"
account_ids = json.loads(account_ids_raw)
except Exception as e:
logger.warning(f"[定时任务] 任务#{schedule_id} 解析account_ids失败: {e}")
account_ids = []
if not account_ids:
try:
database.recompute_schedule_next_run(schedule_id)
except Exception:
pass
continue
if not safe_get_user_accounts_snapshot(user_id):
load_user_accounts(user_id)
import time as time_mod
import uuid
execution_start_time = time_mod.time()
log_id = database.create_schedule_execution_log(
schedule_id=schedule_id, user_id=user_id, schedule_name=schedule_config.get("name", "未命名任务")
logger.info(
f"[用户定时任务] 任务#{schedule_id} '{schedule_name}' 到期,开始执行 "
f"(next_run_at={schedule_config.get('next_run_at')})"
)
batch_id = f"batch_{uuid.uuid4().hex[:12]}"
now_ts = time_mod.time()
safe_create_batch(
batch_id,
{
"user_id": user_id,
"browse_type": browse_type,
"schedule_name": schedule_config.get("name", "未命名任务"),
"screenshots": [],
"total_accounts": 0,
"completed": 0,
"created_at": now_ts,
"updated_at": now_ts,
},
)
started_count = 0
skipped_count = 0
completion_lock = threading.Lock()
remaining = {"count": 0, "done": False}
def on_browse_done():
with completion_lock:
remaining["count"] -= 1
if remaining["done"] or remaining["count"] > 0:
return
remaining["done"] = True
execution_duration = int(time_mod.time() - execution_start_time)
database.update_schedule_execution_log(
log_id,
total_accounts=len(account_ids),
success_accounts=started_count,
failed_accounts=len(account_ids) - started_count,
duration_seconds=execution_duration,
status="completed",
)
logger.info(
f"[用户定时任务] 任务#{schedule_id}浏览阶段完成,耗时{execution_duration}秒,等待截图完成后发送邮件"
)
for account_id in account_ids:
account = safe_get_account(user_id, account_id)
if not account:
skipped_count += 1
continue
if account.is_running:
skipped_count += 1
continue
task_source = f"user_scheduled:{batch_id}"
with completion_lock:
remaining["count"] += 1
ok, msg = submit_account_task(
user_id=user_id,
account_id=account_id,
browse_type=browse_type,
enable_screenshot=enable_screenshot,
source=task_source,
done_callback=on_browse_done,
)
if ok:
started_count += 1
else:
with completion_lock:
remaining["count"] -= 1
skipped_count += 1
logger.warning(f"[用户定时任务] 账号 {account.username} 启动失败: {msg}")
batch_info = safe_finalize_batch_after_dispatch(batch_id, started_count, now_ts=time_mod.time())
if batch_info:
_send_batch_task_email_if_configured(batch_info)
database.update_schedule_last_run(schedule_id)
logger.info(f"[用户定时任务] 已启动 {started_count} 个账号,跳过 {skipped_count} 个账号批次ID: {batch_id}")
if started_count <= 0:
database.update_schedule_execution_log(
log_id,
total_accounts=len(account_ids),
success_accounts=0,
failed_accounts=len(account_ids),
duration_seconds=0,
status="completed",
)
if started_count == 0 and len(account_ids) > 0:
logger.warning("[用户定时任务] ⚠️ 警告所有账号都被跳过了请检查user_accounts状态")
_execute_due_user_schedule(schedule_config)
except Exception as e:
logger.exception(f"[用户定时任务] 检查出错: {str(e)}")

View File

@@ -6,12 +6,14 @@ import os
import shutil
import subprocess
import time
from urllib.parse import urlsplit
import database
import email_service
from api_browser import APIBrowser, get_cookie_jar_path, is_cookie_jar_fresh
from app_config import get_config
from app_logger import get_logger
from app_security import sanitize_filename
from browser_pool_worker import get_browser_worker_pool
from services.client_log import log_to_client
from services.runtime import get_socketio
@@ -194,6 +196,293 @@ def _emit(event: str, data: object, *, room: str | None = None) -> None:
pass
def _set_screenshot_running_status(user_id: int, account_id: str) -> None:
"""更新账号状态为截图中。"""
acc = safe_get_account(user_id, account_id)
if not acc:
return
acc.status = "截图中"
safe_update_task_status(account_id, {"status": "运行中", "detail_status": "正在截图"})
_emit("account_update", acc.to_dict(), room=f"user_{user_id}")
def _get_worker_display_info(browser_instance) -> tuple[str, int]:
"""获取截图 worker 的展示信息。"""
if isinstance(browser_instance, dict):
return str(browser_instance.get("worker_id", "?")), int(browser_instance.get("use_count", 0) or 0)
return "?", 0
def _get_proxy_context(account) -> tuple[dict | None, str | None]:
"""提取截图阶段代理配置。"""
proxy_config = account.proxy_config if hasattr(account, "proxy_config") else None
proxy_server = proxy_config.get("server") if proxy_config else None
return proxy_config, proxy_server
def _build_screenshot_targets(browse_type: str) -> tuple[str, str, str]:
"""构建截图目标 URL 与页面脚本。"""
parsed = urlsplit(config.ZSGL_LOGIN_URL)
base = f"{parsed.scheme}://{parsed.netloc}"
if "注册前" in str(browse_type):
bz = 0
else:
bz = 0
target_url = f"{base}/admin/center.aspx?bz={bz}"
index_url = config.ZSGL_INDEX_URL or f"{base}/admin/index.aspx"
run_script = (
"(function(){"
"function done(){window.status='ready';}"
"function ensureNav(){try{if(typeof loadMenuTree==='function'){loadMenuTree(true);}}catch(e){}}"
"function expandMenu(){"
"try{var body=document.body;if(body&&body.classList.contains('lay-mini')){body.classList.remove('lay-mini');}}catch(e){}"
"try{if(typeof mainPageResize==='function'){mainPageResize();}}catch(e){}"
"try{if(typeof toggleMainMenu==='function' && document.body && document.body.classList.contains('lay-mini')){toggleMainMenu();}}catch(e){}"
"try{var navRight=document.querySelector('.nav-right');if(navRight){navRight.style.display='block';}}catch(e){}"
"try{var mainNav=document.getElementById('main-nav');if(mainNav){mainNav.style.display='block';}}catch(e){}"
"}"
"function navReady(){"
"try{var nav=document.getElementById('sidebar-nav');return nav && nav.querySelectorAll('a').length>0;}catch(e){return false;}"
"}"
"function frameReady(){"
"try{var f=document.getElementById('mainframe');return f && f.contentDocument && f.contentDocument.readyState==='complete';}catch(e){return false;}"
"}"
"function check(){"
"if(navReady() && frameReady()){done();return;}"
"setTimeout(check,300);"
"}"
"var f=document.getElementById('mainframe');"
"ensureNav();"
"expandMenu();"
"if(!f){done();return;}"
f"f.src='{target_url}';"
"f.onload=function(){ensureNav();expandMenu();setTimeout(check,300);};"
"setTimeout(check,5000);"
"})();"
)
return index_url, target_url, run_script
def _build_screenshot_output_path(username_prefix: str, account, browse_type: str) -> tuple[str, str]:
"""构建截图输出文件名与路径。"""
timestamp = get_beijing_now().strftime("%Y%m%d_%H%M%S")
login_account = account.remark if account.remark else account.username
raw_filename = f"{username_prefix}_{login_account}_{browse_type}_{timestamp}.jpg"
screenshot_filename = sanitize_filename(raw_filename)
return screenshot_filename, os.path.join(SCREENSHOTS_DIR, screenshot_filename)
def _ensure_screenshot_login_state(
*,
account,
proxy_config,
cookie_path: str,
attempt: int,
max_retries: int,
user_id: int,
account_id: str,
custom_log,
) -> str:
"""确保截图前登录态有效。返回: ok/retry/fail。"""
should_refresh_login = not is_cookie_jar_fresh(cookie_path)
if not should_refresh_login:
return "ok"
log_to_client("正在刷新登录态...", user_id, account_id)
if _ensure_login_cookies(account, proxy_config, custom_log):
return "ok"
if attempt > 1:
log_to_client("截图登录失败", user_id, account_id)
if attempt < max_retries:
log_to_client("将重试...", user_id, account_id)
time.sleep(2)
return "retry"
log_to_client("❌ 截图失败: 登录失败", user_id, account_id)
return "fail"
def _take_screenshot_once(
*,
index_url: str,
target_url: str,
screenshot_path: str,
cookie_path: str,
proxy_server: str | None,
run_script: str,
log_callback,
) -> str:
"""执行一次截图尝试并验证输出文件。返回: success/invalid/failed。"""
cookies_for_shot = cookie_path if is_cookie_jar_fresh(cookie_path) else None
attempts = [
{
"url": index_url,
"run_script": run_script,
"window_status": "ready",
},
{
"url": target_url,
"run_script": None,
"window_status": None,
},
]
ok = False
for shot in attempts:
ok = take_screenshot_wkhtmltoimage(
shot["url"],
screenshot_path,
cookies_path=cookies_for_shot,
proxy_server=proxy_server,
run_script=shot["run_script"],
window_status=shot["window_status"],
log_callback=log_callback,
)
if ok:
break
if not ok:
return "failed"
if os.path.exists(screenshot_path) and os.path.getsize(screenshot_path) > 1000:
return "success"
if os.path.exists(screenshot_path):
os.remove(screenshot_path)
return "invalid"
def _get_result_screenshot_path(result) -> str | None:
"""从截图结果中提取截图文件绝对路径。"""
if result and result.get("success") and result.get("filename"):
return os.path.join(SCREENSHOTS_DIR, result["filename"])
return None
def _enqueue_kdocs_upload_if_needed(user_id: int, account_id: str, account, screenshot_path: str | None) -> None:
"""按配置提交金山文档上传任务。"""
if not screenshot_path:
return
cfg = database.get_system_config() or {}
if int(cfg.get("kdocs_enabled", 0) or 0) != 1:
return
doc_url = (cfg.get("kdocs_doc_url") or "").strip()
if not doc_url:
return
user_cfg = database.get_user_kdocs_settings(user_id) or {}
if int(user_cfg.get("kdocs_auto_upload", 0) or 0) != 1:
return
unit = (user_cfg.get("kdocs_unit") or cfg.get("kdocs_default_unit") or "").strip()
name = (account.remark or "").strip()
if not unit:
log_to_client("表格上传跳过: 未配置县区", user_id, account_id)
return
if not name:
log_to_client("表格上传跳过: 账号备注为空", user_id, account_id)
return
from services.kdocs_uploader import get_kdocs_uploader
ok = get_kdocs_uploader().enqueue_upload(
user_id=user_id,
account_id=account_id,
unit=unit,
name=name,
image_path=screenshot_path,
)
if not ok:
log_to_client("表格上传排队失败: 队列已满", user_id, account_id)
def _dispatch_screenshot_result(
*,
user_id: int,
account_id: str,
source: str,
browse_type: str,
browse_result: dict,
result,
account,
user_info,
) -> None:
"""将截图结果发送到批次统计/邮件通知链路。"""
batch_id = _get_batch_id_from_source(source)
screenshot_path = _get_result_screenshot_path(result)
account_name = account.remark if account.remark else account.username
try:
if result and result.get("success") and screenshot_path:
_enqueue_kdocs_upload_if_needed(user_id, account_id, account, screenshot_path)
except Exception as kdocs_error:
logger.warning(f"表格上传任务提交失败: {kdocs_error}")
if batch_id:
_batch_task_record_result(
batch_id=batch_id,
account_name=account_name,
screenshot_path=screenshot_path,
total_items=browse_result.get("total_items", 0),
total_attachments=browse_result.get("total_attachments", 0),
)
return
if source and source.startswith("user_scheduled"):
if user_info and user_info.get("email") and database.get_user_email_notify(user_id):
email_service.send_task_complete_email_async(
user_id=user_id,
email=user_info["email"],
username=user_info["username"],
account_name=account_name,
browse_type=browse_type,
total_items=browse_result.get("total_items", 0),
total_attachments=browse_result.get("total_attachments", 0),
screenshot_path=screenshot_path,
log_callback=lambda msg: log_to_client(msg, user_id, account_id),
)
def _finalize_screenshot_callback_state(user_id: int, account_id: str, account) -> None:
"""截图回调的通用收尾状态变更。"""
account.is_running = False
account.status = "未开始"
safe_remove_task_status(account_id)
_emit("account_update", account.to_dict(), room=f"user_{user_id}")
def _persist_browse_log_after_screenshot(
*,
user_id: int,
account_id: str,
account,
browse_type: str,
source: str,
task_start_time,
browse_result,
) -> None:
"""截图完成后写入任务日志(浏览完成日志)。"""
import time as time_module
total_elapsed = int(time_module.time() - task_start_time)
database.create_task_log(
user_id=user_id,
account_id=account_id,
username=account.username,
browse_type=browse_type,
status="success",
total_items=browse_result.get("total_items", 0),
total_attachments=browse_result.get("total_attachments", 0),
duration=total_elapsed,
source=source,
)
def take_screenshot_for_account(
user_id,
account_id,
@@ -213,21 +502,21 @@ def take_screenshot_for_account(
# 标记账号正在截图(防止重复提交截图任务)
account.is_running = True
user_info = database.get_user_by_id(user_id)
username_prefix = user_info["username"] if user_info else f"user{user_id}"
def screenshot_task(
browser_instance, user_id, account_id, account, browse_type, source, task_start_time, browse_result
):
"""在worker线程中执行的截图任务"""
# ✅ 获得worker后立即更新状态为"截图中"
acc = safe_get_account(user_id, account_id)
if acc:
acc.status = "截图中"
safe_update_task_status(account_id, {"status": "运行中", "detail_status": "正在截图"})
_emit("account_update", acc.to_dict(), room=f"user_{user_id}")
_set_screenshot_running_status(user_id, account_id)
max_retries = 3
proxy_config = account.proxy_config if hasattr(account, "proxy_config") else None
proxy_server = proxy_config.get("server") if proxy_config else None
proxy_config, proxy_server = _get_proxy_context(account)
cookie_path = get_cookie_jar_path(account.username)
index_url, target_url, run_script = _build_screenshot_targets(browse_type)
for attempt in range(1, max_retries + 1):
try:
@@ -239,8 +528,7 @@ def take_screenshot_for_account(
if attempt > 1:
log_to_client(f"🔄 第 {attempt} 次截图尝试...", user_id, account_id)
worker_id = browser_instance.get("worker_id", "?") if isinstance(browser_instance, dict) else "?"
use_count = browser_instance.get("use_count", 0) if isinstance(browser_instance, dict) else 0
worker_id, use_count = _get_worker_display_info(browser_instance)
log_to_client(
f"使用Worker-{worker_id}执行截图(已执行{use_count}次)",
user_id,
@@ -250,99 +538,39 @@ def take_screenshot_for_account(
def custom_log(message: str):
log_to_client(message, user_id, account_id)
# 智能登录状态检查:只在必要时才刷新登录
should_refresh_login = not is_cookie_jar_fresh(cookie_path)
if should_refresh_login and attempt > 1:
# 重试时刷新登录attempt > 1 表示第2次及以后的尝试
log_to_client("正在刷新登录态...", user_id, account_id)
if not _ensure_login_cookies(account, proxy_config, custom_log):
log_to_client("截图登录失败", user_id, account_id)
if attempt < max_retries:
log_to_client("将重试...", user_id, account_id)
time.sleep(2)
continue
log_to_client("❌ 截图失败: 登录失败", user_id, account_id)
return {"success": False, "error": "登录失败"}
elif should_refresh_login:
# 首次尝试时快速检查登录状态
log_to_client("正在刷新登录态...", user_id, account_id)
if not _ensure_login_cookies(account, proxy_config, custom_log):
log_to_client("❌ 截图失败: 登录失败", user_id, account_id)
return {"success": False, "error": "登录失败"}
login_state = _ensure_screenshot_login_state(
account=account,
proxy_config=proxy_config,
cookie_path=cookie_path,
attempt=attempt,
max_retries=max_retries,
user_id=user_id,
account_id=account_id,
custom_log=custom_log,
)
if login_state == "retry":
continue
if login_state == "fail":
return {"success": False, "error": "登录失败"}
log_to_client(f"导航到 '{browse_type}' 页面...", user_id, account_id)
from urllib.parse import urlsplit
parsed = urlsplit(config.ZSGL_LOGIN_URL)
base = f"{parsed.scheme}://{parsed.netloc}"
if "注册前" in str(browse_type):
bz = 0
else:
bz = 0 # 应读(网站更新后 bz=0 为应读)
target_url = f"{base}/admin/center.aspx?bz={bz}"
index_url = config.ZSGL_INDEX_URL or f"{base}/admin/index.aspx"
run_script = (
"(function(){"
"function done(){window.status='ready';}"
"function ensureNav(){try{if(typeof loadMenuTree==='function'){loadMenuTree(true);}}catch(e){}}"
"function expandMenu(){"
"try{var body=document.body;if(body&&body.classList.contains('lay-mini')){body.classList.remove('lay-mini');}}catch(e){}"
"try{if(typeof mainPageResize==='function'){mainPageResize();}}catch(e){}"
"try{if(typeof toggleMainMenu==='function' && document.body && document.body.classList.contains('lay-mini')){toggleMainMenu();}}catch(e){}"
"try{var navRight=document.querySelector('.nav-right');if(navRight){navRight.style.display='block';}}catch(e){}"
"try{var mainNav=document.getElementById('main-nav');if(mainNav){mainNav.style.display='block';}}catch(e){}"
"}"
"function navReady(){"
"try{var nav=document.getElementById('sidebar-nav');return nav && nav.querySelectorAll('a').length>0;}catch(e){return false;}"
"}"
"function frameReady(){"
"try{var f=document.getElementById('mainframe');return f && f.contentDocument && f.contentDocument.readyState==='complete';}catch(e){return false;}"
"}"
"function check(){"
"if(navReady() && frameReady()){done();return;}"
"setTimeout(check,300);"
"}"
"var f=document.getElementById('mainframe');"
"ensureNav();"
"expandMenu();"
"if(!f){done();return;}"
f"f.src='{target_url}';"
"f.onload=function(){ensureNav();expandMenu();setTimeout(check,300);};"
"setTimeout(check,5000);"
"})();"
)
timestamp = get_beijing_now().strftime("%Y%m%d_%H%M%S")
user_info = database.get_user_by_id(user_id)
username_prefix = user_info["username"] if user_info else f"user{user_id}"
login_account = account.remark if account.remark else account.username
screenshot_filename = f"{username_prefix}_{login_account}_{browse_type}_{timestamp}.jpg"
screenshot_path = os.path.join(SCREENSHOTS_DIR, screenshot_filename)
cookies_for_shot = cookie_path if is_cookie_jar_fresh(cookie_path) else None
if take_screenshot_wkhtmltoimage(
index_url,
screenshot_path,
cookies_path=cookies_for_shot,
screenshot_filename, screenshot_path = _build_screenshot_output_path(username_prefix, account, browse_type)
shot_state = _take_screenshot_once(
index_url=index_url,
target_url=target_url,
screenshot_path=screenshot_path,
cookie_path=cookie_path,
proxy_server=proxy_server,
run_script=run_script,
window_status="ready",
log_callback=custom_log,
) or take_screenshot_wkhtmltoimage(
target_url,
screenshot_path,
cookies_path=cookies_for_shot,
proxy_server=proxy_server,
log_callback=custom_log,
):
if os.path.exists(screenshot_path) and os.path.getsize(screenshot_path) > 1000:
log_to_client(f"[OK] 截图成功: {screenshot_filename}", user_id, account_id)
return {"success": True, "filename": screenshot_filename}
)
if shot_state == "success":
log_to_client(f"[OK] 截图成功: {screenshot_filename}", user_id, account_id)
return {"success": True, "filename": screenshot_filename}
if shot_state == "invalid":
log_to_client("截图文件异常,将重试", user_id, account_id)
if os.path.exists(screenshot_path):
os.remove(screenshot_path)
else:
log_to_client("截图保存失败", user_id, account_id)
@@ -361,12 +589,7 @@ def take_screenshot_for_account(
def screenshot_callback(result, error):
"""截图完成回调"""
try:
account.is_running = False
account.status = "未开始"
safe_remove_task_status(account_id)
_emit("account_update", account.to_dict(), room=f"user_{user_id}")
_finalize_screenshot_callback_state(user_id, account_id, account)
if error:
log_to_client(f"❌ 截图失败: {error}", user_id, account_id)
@@ -375,84 +598,27 @@ def take_screenshot_for_account(
log_to_client(f"❌ 截图失败: {error_msg}", user_id, account_id)
if task_start_time and browse_result:
import time as time_module
total_elapsed = int(time_module.time() - task_start_time)
database.create_task_log(
_persist_browse_log_after_screenshot(
user_id=user_id,
account_id=account_id,
username=account.username,
account=account,
browse_type=browse_type,
status="success",
total_items=browse_result.get("total_items", 0),
total_attachments=browse_result.get("total_attachments", 0),
duration=total_elapsed,
source=source,
task_start_time=task_start_time,
browse_result=browse_result,
)
try:
batch_id = _get_batch_id_from_source(source)
screenshot_path = None
if result and result.get("success") and result.get("filename"):
screenshot_path = os.path.join(SCREENSHOTS_DIR, result["filename"])
account_name = account.remark if account.remark else account.username
try:
if screenshot_path and result and result.get("success"):
cfg = database.get_system_config() or {}
if int(cfg.get("kdocs_enabled", 0) or 0) == 1:
doc_url = (cfg.get("kdocs_doc_url") or "").strip()
if doc_url:
user_cfg = database.get_user_kdocs_settings(user_id) or {}
if int(user_cfg.get("kdocs_auto_upload", 0) or 0) == 1:
unit = (
user_cfg.get("kdocs_unit") or cfg.get("kdocs_default_unit") or ""
).strip()
name = (account.remark or "").strip()
if unit and name:
from services.kdocs_uploader import get_kdocs_uploader
ok = get_kdocs_uploader().enqueue_upload(
user_id=user_id,
account_id=account_id,
unit=unit,
name=name,
image_path=screenshot_path,
)
if not ok:
log_to_client("表格上传排队失败: 队列已满", user_id, account_id)
else:
if not unit:
log_to_client("表格上传跳过: 未配置县区", user_id, account_id)
if not name:
log_to_client("表格上传跳过: 账号备注为空", user_id, account_id)
except Exception as kdocs_error:
logger.warning(f"表格上传任务提交失败: {kdocs_error}")
if batch_id:
_batch_task_record_result(
batch_id=batch_id,
account_name=account_name,
screenshot_path=screenshot_path,
total_items=browse_result.get("total_items", 0),
total_attachments=browse_result.get("total_attachments", 0),
)
elif source and source.startswith("user_scheduled"):
user_info = database.get_user_by_id(user_id)
if user_info and user_info.get("email") and database.get_user_email_notify(user_id):
email_service.send_task_complete_email_async(
user_id=user_id,
email=user_info["email"],
username=user_info["username"],
account_name=account_name,
browse_type=browse_type,
total_items=browse_result.get("total_items", 0),
total_attachments=browse_result.get("total_attachments", 0),
screenshot_path=screenshot_path,
log_callback=lambda msg: log_to_client(msg, user_id, account_id),
)
_dispatch_screenshot_result(
user_id=user_id,
account_id=account_id,
source=source,
browse_type=browse_type,
browse_result=browse_result,
result=result,
account=account,
user_info=user_info,
)
except Exception as email_error:
logger.warning(f"发送任务完成邮件失败: {email_error}")
except Exception as e:

View File

@@ -13,7 +13,7 @@ from __future__ import annotations
import threading
import time
import random
from typing import Any, Dict, List, Optional, Tuple
from typing import Any, Callable, Dict, List, Optional, Tuple
from app_config import get_config
@@ -161,6 +161,36 @@ _log_cache_lock = threading.RLock()
_log_cache_total_count = 0
def _pop_oldest_log_for_user(uid: int) -> bool:
global _log_cache_total_count
logs = _log_cache.get(uid)
if not logs:
_log_cache.pop(uid, None)
return False
logs.pop(0)
_log_cache_total_count = max(0, _log_cache_total_count - 1)
if not logs:
_log_cache.pop(uid, None)
return True
def _pop_oldest_log_from_largest_user() -> bool:
largest_uid = None
largest_size = 0
for uid, logs in _log_cache.items():
size = len(logs)
if size > largest_size:
largest_uid = uid
largest_size = size
if largest_uid is None or largest_size <= 0:
return False
return _pop_oldest_log_for_user(int(largest_uid))
def safe_add_log(
user_id: int,
log_entry: Dict[str, Any],
@@ -175,24 +205,17 @@ def safe_add_log(
max_total_logs = int(max_total_logs or config.MAX_TOTAL_LOGS)
with _log_cache_lock:
if uid not in _log_cache:
_log_cache[uid] = []
logs = _log_cache.setdefault(uid, [])
if len(_log_cache[uid]) >= max_logs_per_user:
_log_cache[uid].pop(0)
_log_cache_total_count = max(0, _log_cache_total_count - 1)
if len(logs) >= max_logs_per_user:
_pop_oldest_log_for_user(uid)
logs = _log_cache.setdefault(uid, [])
_log_cache[uid].append(dict(log_entry or {}))
logs.append(dict(log_entry or {}))
_log_cache_total_count += 1
while _log_cache_total_count > max_total_logs:
if not _log_cache:
break
max_user = max(_log_cache.keys(), key=lambda u: len(_log_cache[u]))
if _log_cache.get(max_user):
_log_cache[max_user].pop(0)
_log_cache_total_count -= 1
else:
if not _pop_oldest_log_from_largest_user():
break
@@ -378,6 +401,34 @@ def _get_action_rate_limit(action: str) -> Tuple[int, int]:
return int(config.IP_RATE_LIMIT_LOGIN_MAX), int(config.IP_RATE_LIMIT_LOGIN_WINDOW_SECONDS)
def _format_wait_hint(remaining_seconds: int) -> str:
remaining = max(1, int(remaining_seconds or 0))
if remaining >= 60:
return f"{remaining // 60 + 1}分钟"
return f"{remaining}"
def _check_and_increment_rate_bucket(
*,
buckets: Dict[str, Dict[str, Any]],
key: str,
now_ts: float,
max_requests: int,
window_seconds: int,
) -> Tuple[bool, Optional[str]]:
if not key or int(max_requests) <= 0:
return True, None
data = _get_or_reset_bucket(buckets.get(key), now_ts, window_seconds)
if int(data.get("count", 0) or 0) >= int(max_requests):
remaining = max(1, int(window_seconds - (now_ts - float(data.get("window_start", 0) or 0))))
return False, f"请求过于频繁,请{_format_wait_hint(remaining)}后再试"
data["count"] = int(data.get("count", 0) or 0) + 1
buckets[key] = data
return True, None
def check_ip_request_rate(
ip_address: str,
action: str,
@@ -392,21 +443,13 @@ def check_ip_request_rate(
key = f"{action}:{ip_address}"
with _ip_request_rate_lock:
data = _ip_request_rate.get(key)
if not data or (now_ts - float(data.get("window_start", 0) or 0)) >= window_seconds:
data = {"window_start": now_ts, "count": 0}
_ip_request_rate[key] = data
if int(data.get("count", 0) or 0) >= max_requests:
remaining = max(1, int(window_seconds - (now_ts - float(data.get("window_start", 0) or 0))))
if remaining >= 60:
wait_hint = f"{remaining // 60 + 1}分钟"
else:
wait_hint = f"{remaining}"
return False, f"请求过于频繁,请{wait_hint}后再试"
data["count"] = int(data.get("count", 0) or 0) + 1
return True, None
return _check_and_increment_rate_bucket(
buckets=_ip_request_rate,
key=key,
now_ts=now_ts,
max_requests=max_requests,
window_seconds=window_seconds,
)
def cleanup_expired_ip_request_rates(now_ts: Optional[float] = None) -> int:
@@ -417,8 +460,7 @@ def cleanup_expired_ip_request_rates(now_ts: Optional[float] = None) -> int:
data = _ip_request_rate.get(key) or {}
action = key.split(":", 1)[0]
_, window_seconds = _get_action_rate_limit(action)
window_start = float(data.get("window_start", 0) or 0)
if now_ts - window_start >= window_seconds:
if _is_bucket_expired(data, now_ts, window_seconds):
_ip_request_rate.pop(key, None)
removed += 1
return removed
@@ -487,6 +529,30 @@ def _get_or_reset_bucket(data: Optional[Dict[str, Any]], now_ts: float, window_s
return data
def _is_bucket_expired(
data: Optional[Dict[str, Any]],
now_ts: float,
window_seconds: int,
*,
time_field: str = "window_start",
) -> bool:
start_ts = float((data or {}).get(time_field, 0) or 0)
return (now_ts - start_ts) >= max(1, int(window_seconds))
def _cleanup_map_entries(
store: Dict[Any, Dict[str, Any]],
should_remove: Callable[[Dict[str, Any]], bool],
) -> int:
removed = 0
for key, value in list(store.items()):
item = value if isinstance(value, dict) else {}
if should_remove(item):
store.pop(key, None)
removed += 1
return removed
def record_login_username_attempt(ip_address: str, username: str) -> bool:
now_ts = time.time()
threshold, window_seconds, cooldown_seconds = _get_login_scan_config()
@@ -527,26 +593,32 @@ def check_login_rate_limits(ip_address: str, username: str) -> Tuple[bool, Optio
user_key = _normalize_login_key("user", "", username)
ip_user_key = _normalize_login_key("ipuser", ip_address, username)
def _check(key: str, max_requests: int) -> Tuple[bool, Optional[str]]:
if not key or max_requests <= 0:
return True, None
data = _get_or_reset_bucket(_login_rate_limits.get(key), now_ts, window_seconds)
if int(data.get("count", 0) or 0) >= max_requests:
remaining = max(1, int(window_seconds - (now_ts - float(data.get("window_start", 0) or 0))))
wait_hint = f"{remaining // 60 + 1}分钟" if remaining >= 60 else f"{remaining}"
return False, f"请求过于频繁,请{wait_hint}后再试"
data["count"] = int(data.get("count", 0) or 0) + 1
_login_rate_limits[key] = data
return True, None
with _login_rate_limits_lock:
allowed, msg = _check(ip_key, ip_max)
allowed, msg = _check_and_increment_rate_bucket(
buckets=_login_rate_limits,
key=ip_key,
now_ts=now_ts,
max_requests=ip_max,
window_seconds=window_seconds,
)
if not allowed:
return False, msg
allowed, msg = _check(ip_user_key, ip_user_max)
allowed, msg = _check_and_increment_rate_bucket(
buckets=_login_rate_limits,
key=ip_user_key,
now_ts=now_ts,
max_requests=ip_user_max,
window_seconds=window_seconds,
)
if not allowed:
return False, msg
allowed, msg = _check(user_key, user_max)
allowed, msg = _check_and_increment_rate_bucket(
buckets=_login_rate_limits,
key=user_key,
now_ts=now_ts,
max_requests=user_max,
window_seconds=window_seconds,
)
if not allowed:
return False, msg
@@ -622,15 +694,18 @@ def check_login_captcha_required(ip_address: str, username: Optional[str] = None
ip_key = _normalize_login_key("ip", ip_address)
ip_user_key = _normalize_login_key("ipuser", ip_address, username or "")
def _is_over_threshold(data: Optional[Dict[str, Any]]) -> bool:
if not data:
return False
if (now_ts - float(data.get("first_failed", 0) or 0)) > window_seconds:
return False
return int(data.get("count", 0) or 0) >= max_failures
with _login_failures_lock:
ip_data = _login_failures.get(ip_key)
if ip_data and (now_ts - float(ip_data.get("first_failed", 0) or 0)) <= window_seconds:
if int(ip_data.get("count", 0) or 0) >= max_failures:
return True
ip_user_data = _login_failures.get(ip_user_key)
if ip_user_data and (now_ts - float(ip_user_data.get("first_failed", 0) or 0)) <= window_seconds:
if int(ip_user_data.get("count", 0) or 0) >= max_failures:
return True
if _is_over_threshold(_login_failures.get(ip_key)):
return True
if _is_over_threshold(_login_failures.get(ip_user_key)):
return True
if is_login_scan_locked(ip_address):
return True
@@ -685,6 +760,56 @@ def should_send_login_alert(user_id: int, ip_address: str) -> bool:
return False
def cleanup_expired_login_security_state(now_ts: Optional[float] = None) -> Dict[str, int]:
now_ts = float(now_ts if now_ts is not None else time.time())
_, captcha_window = _get_login_captcha_config()
_, _, _, rate_window = _get_login_rate_limit_config()
_, lock_window, _ = _get_login_lock_config()
_, scan_window, _ = _get_login_scan_config()
alert_expire_seconds = max(3600, int(config.LOGIN_ALERT_MIN_INTERVAL_SECONDS) * 3)
with _login_failures_lock:
failures_removed = _cleanup_map_entries(
_login_failures,
lambda data: (now_ts - float(data.get("first_failed", 0) or 0)) > max(captcha_window, lock_window),
)
with _login_rate_limits_lock:
rate_removed = _cleanup_map_entries(
_login_rate_limits,
lambda data: _is_bucket_expired(data, now_ts, rate_window),
)
with _login_scan_lock:
scan_removed = _cleanup_map_entries(
_login_scan_state,
lambda data: (
(now_ts - float(data.get("first_seen", 0) or 0)) > scan_window
and now_ts >= float(data.get("scan_until", 0) or 0)
),
)
with _login_ip_user_lock:
ip_user_locks_removed = _cleanup_map_entries(
_login_ip_user_locks,
lambda data: now_ts >= float(data.get("lock_until", 0) or 0),
)
with _login_alert_lock:
alerts_removed = _cleanup_map_entries(
_login_alert_state,
lambda data: (now_ts - float(data.get("last_sent", 0) or 0)) > alert_expire_seconds,
)
return {
"failures": failures_removed,
"rate_limits": rate_removed,
"scan_states": scan_removed,
"ip_user_locks": ip_user_locks_removed,
"alerts": alerts_removed,
}
# ==================== 邮箱维度限流 ====================
_email_rate_limit: Dict[str, Dict[str, Any]] = {}
@@ -701,14 +826,13 @@ def check_email_rate_limit(email: str, action: str) -> Tuple[bool, Optional[str]
key = f"{action}:{email_key}"
with _email_rate_limit_lock:
data = _get_or_reset_bucket(_email_rate_limit.get(key), now_ts, window_seconds)
if int(data.get("count", 0) or 0) >= max_requests:
remaining = max(1, int(window_seconds - (now_ts - float(data.get("window_start", 0) or 0))))
wait_hint = f"{remaining // 60 + 1}分钟" if remaining >= 60 else f"{remaining}"
return False, f"请求过于频繁,请{wait_hint}后再试"
data["count"] = int(data.get("count", 0) or 0) + 1
_email_rate_limit[key] = data
return True, None
return _check_and_increment_rate_bucket(
buckets=_email_rate_limit,
key=key,
now_ts=now_ts,
max_requests=max_requests,
window_seconds=window_seconds,
)
# ==================== Batch screenshots批次任务截图收集 ====================

365
services/task_scheduler.py Normal file
View File

@@ -0,0 +1,365 @@
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
from __future__ import annotations
import heapq
import threading
import time
from concurrent.futures import ThreadPoolExecutor, wait
from dataclasses import dataclass
import database
from app_logger import get_logger
from services.state import safe_get_account, safe_get_task, safe_remove_task, safe_set_task
from services.task_batches import _batch_task_record_result, _get_batch_id_from_source
logger = get_logger("app")
# VIP优先级队列仅用于可视化/调试)
vip_task_queue = [] # VIP用户任务队列
normal_task_queue = [] # 普通用户任务队列
task_queue_lock = threading.Lock()
@dataclass
class _TaskRequest:
user_id: int
account_id: str
browse_type: str
enable_screenshot: bool
source: str
retry_count: int
submitted_at: float
is_vip: bool
seq: int
canceled: bool = False
done_callback: object = None
class TaskScheduler:
"""全局任务调度器:队列排队,不为每个任务单独创建线程。"""
def __init__(self, max_global: int, max_per_user: int, max_queue_size: int = 1000, run_task_fn=None):
self.max_global = max(1, int(max_global))
self.max_per_user = max(1, int(max_per_user))
self.max_queue_size = max(1, int(max_queue_size))
self._cond = threading.Condition()
self._pending = [] # heap: (priority, submitted_at, seq, task)
self._pending_by_account = {} # {account_id: task}
self._seq = 0
self._known_account_ids = set()
self._running_global = 0
self._running_by_user = {} # {user_id: running_count}
self._executor_max_workers = self.max_global
self._executor = ThreadPoolExecutor(max_workers=self._executor_max_workers, thread_name_prefix="TaskWorker")
self._futures_lock = threading.Lock()
self._active_futures = set()
self._running = True
self._run_task_fn = run_task_fn
self._dispatcher_thread = threading.Thread(target=self._dispatch_loop, daemon=True, name="TaskDispatcher")
self._dispatcher_thread.start()
def _track_future(self, future) -> None:
with self._futures_lock:
self._active_futures.add(future)
try:
future.add_done_callback(self._untrack_future)
except Exception:
pass
def _untrack_future(self, future) -> None:
with self._futures_lock:
self._active_futures.discard(future)
def shutdown(self, timeout: float = 5.0):
"""停止调度器(用于进程退出清理)"""
with self._cond:
self._running = False
self._cond.notify_all()
try:
self._dispatcher_thread.join(timeout=timeout)
except Exception:
pass
# 等待已提交的任务收尾(最多等待 timeout 秒),避免遗留 active_task 干扰后续调度/测试
try:
deadline = time.time() + max(0.0, float(timeout or 0))
while True:
with self._futures_lock:
pending = [f for f in self._active_futures if not f.done()]
if not pending:
break
remaining = deadline - time.time()
if remaining <= 0:
break
wait(pending, timeout=remaining)
except Exception:
pass
try:
self._executor.shutdown(wait=False)
except Exception:
pass
# 最后兜底:清理本调度器提交过的 active_task避免测试/重启时被“任务已在运行中”误拦截
try:
with self._cond:
known_ids = set(self._known_account_ids) | set(self._pending_by_account.keys())
self._pending.clear()
self._pending_by_account.clear()
self._cond.notify_all()
for account_id in known_ids:
safe_remove_task(account_id)
except Exception:
pass
def update_limits(self, max_global: int = None, max_per_user: int = None, max_queue_size: int = None):
"""动态更新并发/队列上限(不影响已在运行的任务)"""
with self._cond:
if max_per_user is not None:
self.max_per_user = max(1, int(max_per_user))
if max_queue_size is not None:
self.max_queue_size = max(1, int(max_queue_size))
if max_global is not None:
new_max_global = max(1, int(max_global))
self.max_global = new_max_global
if new_max_global > self._executor_max_workers:
# 立即关闭旧线程池,防止资源泄漏
old_executor = self._executor
self._executor_max_workers = new_max_global
self._executor = ThreadPoolExecutor(
max_workers=self._executor_max_workers, thread_name_prefix="TaskWorker"
)
# 立即关闭旧线程池
try:
old_executor.shutdown(wait=False)
logger.info(f"线程池已扩容:{old_executor._max_workers} -> {self._executor_max_workers}")
except Exception as e:
logger.warning(f"关闭旧线程池失败: {e}")
self._cond.notify_all()
def get_queue_state_snapshot(self) -> dict:
"""获取调度器队列/运行状态快照(用于前端展示/监控)。"""
with self._cond:
pending_tasks = [t for t in self._pending_by_account.values() if t and not t.canceled]
pending_tasks.sort(key=lambda t: (0 if t.is_vip else 1, t.submitted_at, t.seq))
positions = {}
for idx, t in enumerate(pending_tasks):
positions[t.account_id] = {"queue_position": idx + 1, "queue_ahead": idx, "is_vip": bool(t.is_vip)}
return {
"pending_total": len(pending_tasks),
"running_total": int(self._running_global),
"running_by_user": dict(self._running_by_user),
"positions": positions,
}
def submit_task(
self,
user_id: int,
account_id: str,
browse_type: str,
enable_screenshot: bool = True,
source: str = "manual",
retry_count: int = 0,
is_vip: bool = None,
done_callback=None,
):
"""提交任务进入队列(返回: (ok, message)"""
if not user_id or not account_id:
return False, "参数错误"
submitted_at = time.time()
if is_vip is None:
try:
is_vip = bool(database.is_user_vip(user_id))
except Exception:
is_vip = False
else:
is_vip = bool(is_vip)
with self._cond:
if not self._running:
return False, "调度器未运行"
if len(self._pending_by_account) >= self.max_queue_size:
return False, "任务队列已满,请稍后再试"
if account_id in self._pending_by_account:
return False, "任务已在队列中"
if safe_get_task(account_id) is not None:
return False, "任务已在运行中"
self._seq += 1
task = _TaskRequest(
user_id=user_id,
account_id=account_id,
browse_type=browse_type,
enable_screenshot=bool(enable_screenshot),
source=source,
retry_count=int(retry_count or 0),
submitted_at=submitted_at,
is_vip=is_vip,
seq=self._seq,
done_callback=done_callback,
)
self._pending_by_account[account_id] = task
self._known_account_ids.add(account_id)
priority = 0 if is_vip else 1
heapq.heappush(self._pending, (priority, task.submitted_at, task.seq, task))
self._cond.notify_all()
# 用于可视化/调试:记录队列
with task_queue_lock:
if is_vip:
vip_task_queue.append(account_id)
else:
normal_task_queue.append(account_id)
return True, "已加入队列"
def cancel_pending_task(self, user_id: int, account_id: str) -> bool:
"""取消尚未开始的排队任务(已运行的任务由 should_stop 控制)"""
canceled_task = None
with self._cond:
task = self._pending_by_account.pop(account_id, None)
if not task:
return False
task.canceled = True
canceled_task = task
self._cond.notify_all()
# 从可视化队列移除
with task_queue_lock:
if account_id in vip_task_queue:
vip_task_queue.remove(account_id)
if account_id in normal_task_queue:
normal_task_queue.remove(account_id)
# 批次任务:取消也要推进完成计数,避免批次缓存常驻
try:
batch_id = _get_batch_id_from_source(canceled_task.source)
if batch_id:
acc = safe_get_account(user_id, account_id)
if acc:
account_name = acc.remark if acc.remark else acc.username
else:
account_name = account_id
_batch_task_record_result(
batch_id=batch_id,
account_name=account_name,
screenshot_path=None,
total_items=0,
total_attachments=0,
)
except Exception:
pass
return True
def _dispatch_loop(self):
while True:
task = None
with self._cond:
if not self._running:
return
if not self._pending or self._running_global >= self.max_global:
self._cond.wait(timeout=0.5)
continue
task = self._pop_next_runnable_locked()
if task is None:
self._cond.wait(timeout=0.5)
continue
self._running_global += 1
self._running_by_user[task.user_id] = self._running_by_user.get(task.user_id, 0) + 1
# 从队列移除(可视化)
with task_queue_lock:
if task.account_id in vip_task_queue:
vip_task_queue.remove(task.account_id)
if task.account_id in normal_task_queue:
normal_task_queue.remove(task.account_id)
try:
future = self._executor.submit(self._run_task_wrapper, task)
self._track_future(future)
safe_set_task(task.account_id, future)
except Exception:
with self._cond:
self._running_global = max(0, self._running_global - 1)
# 使用默认值 0 与增加时保持一致
self._running_by_user[task.user_id] = max(0, self._running_by_user.get(task.user_id, 0) - 1)
if self._running_by_user.get(task.user_id) == 0:
self._running_by_user.pop(task.user_id, None)
self._cond.notify_all()
def _pop_next_runnable_locked(self):
"""在锁内从优先队列取出“可运行”的任务避免VIP任务占位阻塞普通任务。"""
if not self._pending:
return None
skipped = []
selected = None
while self._pending:
_, _, _, task = heapq.heappop(self._pending)
if task.canceled:
continue
if self._pending_by_account.get(task.account_id) is not task:
continue
running_for_user = self._running_by_user.get(task.user_id, 0)
if running_for_user >= self.max_per_user:
skipped.append(task)
continue
selected = task
break
for t in skipped:
priority = 0 if t.is_vip else 1
heapq.heappush(self._pending, (priority, t.submitted_at, t.seq, t))
if selected is None:
return None
self._pending_by_account.pop(selected.account_id, None)
return selected
def _run_task_wrapper(self, task: _TaskRequest):
try:
if callable(self._run_task_fn):
self._run_task_fn(
user_id=task.user_id,
account_id=task.account_id,
browse_type=task.browse_type,
enable_screenshot=task.enable_screenshot,
source=task.source,
retry_count=task.retry_count,
)
finally:
try:
if callable(task.done_callback):
task.done_callback()
except Exception:
pass
safe_remove_task(task.account_id)
with self._cond:
self._running_global = max(0, self._running_global - 1)
# 使用默认值 0 与增加时保持一致
self._running_by_user[task.user_id] = max(0, self._running_by_user.get(task.user_id, 0) - 1)
if self._running_by_user.get(task.user_id) == 0:
self._running_by_user.pop(task.user_id, None)
self._cond.notify_all()

File diff suppressed because it is too large Load Diff