refactor: optimize structure, stability and runtime performance
This commit is contained in:
191
app.py
191
app.py
@@ -173,10 +173,28 @@ def serve_static(filename):
|
||||
if not is_safe_path("static", filename):
|
||||
return jsonify({"error": "非法路径"}), 403
|
||||
|
||||
response = send_from_directory("static", filename)
|
||||
response.headers["Cache-Control"] = "no-store, no-cache, must-revalidate, max-age=0"
|
||||
response.headers["Pragma"] = "no-cache"
|
||||
response.headers["Expires"] = "0"
|
||||
cache_ttl = 3600
|
||||
lowered = filename.lower()
|
||||
if "/assets/" in lowered or lowered.endswith((".js", ".css", ".woff", ".woff2", ".ttf", ".svg")):
|
||||
cache_ttl = 604800 # 7天
|
||||
|
||||
if request.args.get("v"):
|
||||
cache_ttl = max(cache_ttl, 604800)
|
||||
|
||||
response = send_from_directory("static", filename, max_age=cache_ttl, conditional=True)
|
||||
|
||||
# 协商缓存:确保存在 ETag,并基于 If-None-Match/If-Modified-Since 返回 304
|
||||
try:
|
||||
response.add_etag(overwrite=False)
|
||||
except Exception:
|
||||
pass
|
||||
try:
|
||||
response.make_conditional(request)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
response.headers.setdefault("Vary", "Accept-Encoding")
|
||||
response.headers["Cache-Control"] = f"public, max-age={cache_ttl}"
|
||||
return response
|
||||
|
||||
|
||||
@@ -232,6 +250,93 @@ def _signal_handler(sig, frame):
|
||||
sys.exit(0)
|
||||
|
||||
|
||||
def _cleanup_stale_task_state() -> None:
|
||||
logger.info("清理遗留任务状态...")
|
||||
try:
|
||||
from services.state import safe_get_active_task_ids, safe_remove_task, safe_remove_task_status
|
||||
|
||||
for _, accounts in safe_iter_user_accounts_items():
|
||||
for acc in accounts.values():
|
||||
if not getattr(acc, "is_running", False):
|
||||
continue
|
||||
acc.is_running = False
|
||||
acc.should_stop = False
|
||||
acc.status = "未开始"
|
||||
|
||||
for account_id in list(safe_get_active_task_ids()):
|
||||
safe_remove_task(account_id)
|
||||
safe_remove_task_status(account_id)
|
||||
|
||||
logger.info("[OK] 遗留任务状态已清理")
|
||||
except Exception as e:
|
||||
logger.warning(f"清理遗留任务状态失败: {e}")
|
||||
|
||||
|
||||
def _init_optional_email_service() -> None:
|
||||
try:
|
||||
email_service.init_email_service()
|
||||
logger.info("[OK] 邮件服务已初始化")
|
||||
except Exception as e:
|
||||
logger.warning(f"警告: 邮件服务初始化失败: {e}")
|
||||
|
||||
|
||||
def _load_and_apply_scheduler_limits() -> None:
|
||||
try:
|
||||
system_config = database.get_system_config() or {}
|
||||
max_concurrent_global = int(system_config.get("max_concurrent_global", config.MAX_CONCURRENT_GLOBAL))
|
||||
max_concurrent_per_account = int(system_config.get("max_concurrent_per_account", config.MAX_CONCURRENT_PER_ACCOUNT))
|
||||
get_task_scheduler().update_limits(max_global=max_concurrent_global, max_per_user=max_concurrent_per_account)
|
||||
logger.info(f"[OK] 已加载并发配置: 全局={max_concurrent_global}, 单账号={max_concurrent_per_account}")
|
||||
except Exception as e:
|
||||
logger.warning(f"警告: 加载并发配置失败,使用默认值: {e}")
|
||||
|
||||
|
||||
def _start_background_workers() -> None:
|
||||
logger.info("启动定时任务调度器...")
|
||||
threading.Thread(target=scheduled_task_worker, daemon=True, name="scheduled-task-worker").start()
|
||||
logger.info("[OK] 定时任务调度器已启动")
|
||||
|
||||
logger.info("[OK] 状态推送线程已启动(默认2秒/次)")
|
||||
threading.Thread(target=status_push_worker, daemon=True, name="status-push-worker").start()
|
||||
|
||||
|
||||
def _init_screenshot_worker_pool() -> None:
|
||||
try:
|
||||
pool_size = int((database.get_system_config() or {}).get("max_screenshot_concurrent", 3))
|
||||
except Exception:
|
||||
pool_size = 3
|
||||
|
||||
try:
|
||||
logger.info(f"初始化截图线程池({pool_size}个worker,按需启动执行环境,空闲5分钟后自动释放)...")
|
||||
init_browser_worker_pool(pool_size=pool_size)
|
||||
logger.info("[OK] 截图线程池初始化完成")
|
||||
except Exception as e:
|
||||
logger.warning(f"警告: 截图线程池初始化失败: {e}")
|
||||
|
||||
|
||||
def _warmup_api_connection() -> None:
|
||||
logger.info("预热 API 连接...")
|
||||
try:
|
||||
from api_browser import warmup_api_connection
|
||||
|
||||
threading.Thread(
|
||||
target=warmup_api_connection,
|
||||
kwargs={"log_callback": lambda msg: logger.info(msg)},
|
||||
daemon=True,
|
||||
name="api-warmup",
|
||||
).start()
|
||||
except Exception as e:
|
||||
logger.warning(f"API 预热失败: {e}")
|
||||
|
||||
|
||||
def _log_startup_urls() -> None:
|
||||
logger.info("服务器启动中...")
|
||||
logger.info(f"用户访问地址: http://{config.SERVER_HOST}:{config.SERVER_PORT}")
|
||||
logger.info(f"后台管理地址: http://{config.SERVER_HOST}:{config.SERVER_PORT}/yuyx")
|
||||
logger.info("默认管理员: admin (首次运行随机密码见日志)")
|
||||
logger.info("=" * 60)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
atexit.register(cleanup_on_exit)
|
||||
signal.signal(signal.SIGINT, _signal_handler)
|
||||
@@ -245,81 +350,17 @@ if __name__ == "__main__":
|
||||
init_checkpoint_manager()
|
||||
logger.info("[OK] 任务断点管理器已初始化")
|
||||
|
||||
# 【新增】容器重启时清理遗留的任务状态
|
||||
logger.info("清理遗留任务状态...")
|
||||
try:
|
||||
from services.state import safe_remove_task, safe_get_active_task_ids, safe_remove_task_status
|
||||
# 重置所有账号的运行状态
|
||||
for _, accounts in safe_iter_user_accounts_items():
|
||||
for acc in accounts.values():
|
||||
if getattr(acc, "is_running", False):
|
||||
acc.is_running = False
|
||||
acc.should_stop = False
|
||||
acc.status = "未开始"
|
||||
# 清理活跃任务句柄
|
||||
for account_id in list(safe_get_active_task_ids()):
|
||||
safe_remove_task(account_id)
|
||||
safe_remove_task_status(account_id)
|
||||
logger.info("[OK] 遗留任务状态已清理")
|
||||
except Exception as e:
|
||||
logger.warning(f"清理遗留任务状态失败: {e}")
|
||||
|
||||
try:
|
||||
email_service.init_email_service()
|
||||
logger.info("[OK] 邮件服务已初始化")
|
||||
except Exception as e:
|
||||
logger.warning(f"警告: 邮件服务初始化失败: {e}")
|
||||
_cleanup_stale_task_state()
|
||||
_init_optional_email_service()
|
||||
|
||||
start_cleanup_scheduler()
|
||||
start_kdocs_monitor()
|
||||
|
||||
try:
|
||||
system_config = database.get_system_config() or {}
|
||||
max_concurrent_global = int(system_config.get("max_concurrent_global", config.MAX_CONCURRENT_GLOBAL))
|
||||
max_concurrent_per_account = int(system_config.get("max_concurrent_per_account", config.MAX_CONCURRENT_PER_ACCOUNT))
|
||||
get_task_scheduler().update_limits(max_global=max_concurrent_global, max_per_user=max_concurrent_per_account)
|
||||
logger.info(f"[OK] 已加载并发配置: 全局={max_concurrent_global}, 单账号={max_concurrent_per_account}")
|
||||
except Exception as e:
|
||||
logger.warning(f"警告: 加载并发配置失败,使用默认值: {e}")
|
||||
|
||||
logger.info("启动定时任务调度器...")
|
||||
threading.Thread(target=scheduled_task_worker, daemon=True, name="scheduled-task-worker").start()
|
||||
logger.info("[OK] 定时任务调度器已启动")
|
||||
|
||||
logger.info("[OK] 状态推送线程已启动(默认2秒/次)")
|
||||
threading.Thread(target=status_push_worker, daemon=True, name="status-push-worker").start()
|
||||
|
||||
logger.info("服务器启动中...")
|
||||
logger.info(f"用户访问地址: http://{config.SERVER_HOST}:{config.SERVER_PORT}")
|
||||
logger.info(f"后台管理地址: http://{config.SERVER_HOST}:{config.SERVER_PORT}/yuyx")
|
||||
logger.info("默认管理员: admin (首次运行随机密码见日志)")
|
||||
logger.info("=" * 60)
|
||||
|
||||
try:
|
||||
pool_size = int((database.get_system_config() or {}).get("max_screenshot_concurrent", 3))
|
||||
except Exception:
|
||||
pool_size = 3
|
||||
try:
|
||||
logger.info(f"初始化截图线程池({pool_size}个worker,按需启动执行环境,空闲5分钟后自动释放)...")
|
||||
init_browser_worker_pool(pool_size=pool_size)
|
||||
logger.info("[OK] 截图线程池初始化完成")
|
||||
except Exception as e:
|
||||
logger.warning(f"警告: 截图线程池初始化失败: {e}")
|
||||
|
||||
# 预热 API 连接(后台进行,不阻塞启动)
|
||||
logger.info("预热 API 连接...")
|
||||
try:
|
||||
from api_browser import warmup_api_connection
|
||||
import threading
|
||||
|
||||
threading.Thread(
|
||||
target=warmup_api_connection,
|
||||
kwargs={"log_callback": lambda msg: logger.info(msg)},
|
||||
daemon=True,
|
||||
name="api-warmup",
|
||||
).start()
|
||||
except Exception as e:
|
||||
logger.warning(f"API 预热失败: {e}")
|
||||
_load_and_apply_scheduler_limits()
|
||||
_start_background_workers()
|
||||
_log_startup_urls()
|
||||
_init_screenshot_worker_pool()
|
||||
_warmup_api_connection()
|
||||
|
||||
socketio.run(
|
||||
app,
|
||||
|
||||
55
database.py
55
database.py
@@ -120,7 +120,7 @@ config = get_config()
|
||||
DB_FILE = config.DB_FILE
|
||||
|
||||
# 数据库版本 (用于迁移管理)
|
||||
DB_VERSION = 17
|
||||
DB_VERSION = 18
|
||||
|
||||
|
||||
# ==================== 系统配置缓存(P1 / O-03) ====================
|
||||
@@ -142,6 +142,37 @@ def invalidate_system_config_cache() -> None:
|
||||
_system_config_cache_loaded_at = 0.0
|
||||
|
||||
|
||||
def _normalize_system_config_value(value) -> dict:
|
||||
try:
|
||||
return dict(value or {})
|
||||
except Exception:
|
||||
return {}
|
||||
|
||||
|
||||
def _is_system_config_cache_valid(now_ts: float) -> bool:
|
||||
if _system_config_cache_value is None:
|
||||
return False
|
||||
if _SYSTEM_CONFIG_CACHE_TTL_SECONDS <= 0:
|
||||
return True
|
||||
return (now_ts - _system_config_cache_loaded_at) < _SYSTEM_CONFIG_CACHE_TTL_SECONDS
|
||||
|
||||
|
||||
def _read_system_config_cache(now_ts: float, *, ignore_ttl: bool = False) -> Optional[dict]:
|
||||
with _system_config_cache_lock:
|
||||
if _system_config_cache_value is None:
|
||||
return None
|
||||
if (not ignore_ttl) and (not _is_system_config_cache_valid(now_ts)):
|
||||
return None
|
||||
return dict(_system_config_cache_value)
|
||||
|
||||
|
||||
def _write_system_config_cache(value: dict, now_ts: float) -> None:
|
||||
global _system_config_cache_value, _system_config_cache_loaded_at
|
||||
with _system_config_cache_lock:
|
||||
_system_config_cache_value = dict(value)
|
||||
_system_config_cache_loaded_at = now_ts
|
||||
|
||||
|
||||
def init_database():
|
||||
"""初始化数据库表结构 + 迁移(入口统一)。"""
|
||||
db_pool.init_pool(DB_FILE, pool_size=config.DB_POOL_SIZE)
|
||||
@@ -165,19 +196,21 @@ def migrate_database():
|
||||
|
||||
def get_system_config():
|
||||
"""获取系统配置(带进程内缓存)。"""
|
||||
global _system_config_cache_value, _system_config_cache_loaded_at
|
||||
|
||||
now_ts = time.time()
|
||||
with _system_config_cache_lock:
|
||||
if _system_config_cache_value is not None:
|
||||
if _SYSTEM_CONFIG_CACHE_TTL_SECONDS <= 0 or (now_ts - _system_config_cache_loaded_at) < _SYSTEM_CONFIG_CACHE_TTL_SECONDS:
|
||||
return dict(_system_config_cache_value)
|
||||
|
||||
value = _get_system_config_raw()
|
||||
cached_value = _read_system_config_cache(now_ts)
|
||||
if cached_value is not None:
|
||||
return cached_value
|
||||
|
||||
with _system_config_cache_lock:
|
||||
_system_config_cache_value = dict(value)
|
||||
_system_config_cache_loaded_at = now_ts
|
||||
try:
|
||||
value = _normalize_system_config_value(_get_system_config_raw())
|
||||
except Exception:
|
||||
fallback_value = _read_system_config_cache(now_ts, ignore_ttl=True)
|
||||
if fallback_value is not None:
|
||||
return fallback_value
|
||||
raise
|
||||
|
||||
_write_system_config_cache(value, now_ts)
|
||||
return dict(value)
|
||||
|
||||
|
||||
|
||||
@@ -6,19 +6,51 @@ import db_pool
|
||||
from crypto_utils import decrypt_password, encrypt_password
|
||||
from db.utils import get_cst_now_str
|
||||
|
||||
_ACCOUNT_STATUS_QUERY_SQL = """
|
||||
SELECT status, login_fail_count, last_login_error
|
||||
FROM accounts
|
||||
WHERE id = ?
|
||||
"""
|
||||
|
||||
|
||||
def _decode_account_password(account_dict: dict) -> dict:
|
||||
account_dict["password"] = decrypt_password(account_dict.get("password", ""))
|
||||
return account_dict
|
||||
|
||||
|
||||
def _normalize_account_ids(account_ids) -> list[str]:
|
||||
normalized = []
|
||||
seen = set()
|
||||
for account_id in account_ids or []:
|
||||
if not account_id:
|
||||
continue
|
||||
account_key = str(account_id)
|
||||
if account_key in seen:
|
||||
continue
|
||||
seen.add(account_key)
|
||||
normalized.append(account_key)
|
||||
return normalized
|
||||
|
||||
|
||||
def create_account(user_id, account_id, username, password, remember=True, remark=""):
|
||||
"""创建账号(密码加密存储)"""
|
||||
with db_pool.get_db() as conn:
|
||||
cursor = conn.cursor()
|
||||
cst_time = get_cst_now_str()
|
||||
encrypted_password = encrypt_password(password)
|
||||
cursor.execute(
|
||||
"""
|
||||
INSERT INTO accounts (id, user_id, username, password, remember, remark, created_at)
|
||||
VALUES (?, ?, ?, ?, ?, ?, ?)
|
||||
""",
|
||||
(account_id, user_id, username, encrypted_password, 1 if remember else 0, remark, cst_time),
|
||||
(
|
||||
account_id,
|
||||
user_id,
|
||||
username,
|
||||
encrypted_password,
|
||||
1 if remember else 0,
|
||||
remark,
|
||||
get_cst_now_str(),
|
||||
),
|
||||
)
|
||||
conn.commit()
|
||||
return cursor.lastrowid
|
||||
@@ -29,12 +61,7 @@ def get_user_accounts(user_id):
|
||||
with db_pool.get_db() as conn:
|
||||
cursor = conn.cursor()
|
||||
cursor.execute("SELECT * FROM accounts WHERE user_id = ? ORDER BY created_at DESC", (user_id,))
|
||||
accounts = []
|
||||
for row in cursor.fetchall():
|
||||
account = dict(row)
|
||||
account["password"] = decrypt_password(account.get("password", ""))
|
||||
accounts.append(account)
|
||||
return accounts
|
||||
return [_decode_account_password(dict(row)) for row in cursor.fetchall()]
|
||||
|
||||
|
||||
def get_account(account_id):
|
||||
@@ -43,11 +70,9 @@ def get_account(account_id):
|
||||
cursor = conn.cursor()
|
||||
cursor.execute("SELECT * FROM accounts WHERE id = ?", (account_id,))
|
||||
row = cursor.fetchone()
|
||||
if row:
|
||||
account = dict(row)
|
||||
account["password"] = decrypt_password(account.get("password", ""))
|
||||
return account
|
||||
return None
|
||||
if not row:
|
||||
return None
|
||||
return _decode_account_password(dict(row))
|
||||
|
||||
|
||||
def update_account_remark(account_id, remark):
|
||||
@@ -78,33 +103,21 @@ def increment_account_login_fail(account_id, error_message):
|
||||
if not row:
|
||||
return False
|
||||
|
||||
fail_count = (row["login_fail_count"] or 0) + 1
|
||||
|
||||
if fail_count >= 3:
|
||||
cursor.execute(
|
||||
"""
|
||||
UPDATE accounts
|
||||
SET login_fail_count = ?,
|
||||
last_login_error = ?,
|
||||
status = 'suspended'
|
||||
WHERE id = ?
|
||||
""",
|
||||
(fail_count, error_message, account_id),
|
||||
)
|
||||
conn.commit()
|
||||
return True
|
||||
fail_count = int(row["login_fail_count"] or 0) + 1
|
||||
is_suspended = fail_count >= 3
|
||||
|
||||
cursor.execute(
|
||||
"""
|
||||
UPDATE accounts
|
||||
SET login_fail_count = ?,
|
||||
last_login_error = ?
|
||||
last_login_error = ?,
|
||||
status = CASE WHEN ? = 1 THEN 'suspended' ELSE status END
|
||||
WHERE id = ?
|
||||
""",
|
||||
(fail_count, error_message, account_id),
|
||||
(fail_count, error_message, 1 if is_suspended else 0, account_id),
|
||||
)
|
||||
conn.commit()
|
||||
return False
|
||||
return is_suspended
|
||||
|
||||
|
||||
def reset_account_login_status(account_id):
|
||||
@@ -129,29 +142,22 @@ def get_account_status(account_id):
|
||||
"""获取账号状态信息"""
|
||||
with db_pool.get_db() as conn:
|
||||
cursor = conn.cursor()
|
||||
cursor.execute(
|
||||
"""
|
||||
SELECT status, login_fail_count, last_login_error
|
||||
FROM accounts
|
||||
WHERE id = ?
|
||||
""",
|
||||
(account_id,),
|
||||
)
|
||||
cursor.execute(_ACCOUNT_STATUS_QUERY_SQL, (account_id,))
|
||||
return cursor.fetchone()
|
||||
|
||||
|
||||
def get_account_status_batch(account_ids):
|
||||
"""批量获取账号状态信息"""
|
||||
account_ids = [str(account_id) for account_id in (account_ids or []) if account_id]
|
||||
if not account_ids:
|
||||
normalized_ids = _normalize_account_ids(account_ids)
|
||||
if not normalized_ids:
|
||||
return {}
|
||||
|
||||
results = {}
|
||||
chunk_size = 900 # 避免触发 SQLite 绑定参数上限
|
||||
with db_pool.get_db() as conn:
|
||||
cursor = conn.cursor()
|
||||
for idx in range(0, len(account_ids), chunk_size):
|
||||
chunk = account_ids[idx : idx + chunk_size]
|
||||
for idx in range(0, len(normalized_ids), chunk_size):
|
||||
chunk = normalized_ids[idx : idx + chunk_size]
|
||||
placeholders = ",".join("?" for _ in chunk)
|
||||
cursor.execute(
|
||||
f"""
|
||||
|
||||
344
db/admin.py
344
db/admin.py
@@ -3,9 +3,6 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import sqlite3
|
||||
from datetime import datetime, timedelta
|
||||
|
||||
import pytz
|
||||
|
||||
import db_pool
|
||||
from db.utils import get_cst_now_str
|
||||
@@ -16,6 +13,99 @@ from password_utils import (
|
||||
verify_password_sha256,
|
||||
)
|
||||
|
||||
_DEFAULT_SYSTEM_CONFIG = {
|
||||
"max_concurrent_global": 2,
|
||||
"max_concurrent_per_account": 1,
|
||||
"max_screenshot_concurrent": 3,
|
||||
"schedule_enabled": 0,
|
||||
"schedule_time": "02:00",
|
||||
"schedule_browse_type": "应读",
|
||||
"schedule_weekdays": "1,2,3,4,5,6,7",
|
||||
"proxy_enabled": 0,
|
||||
"proxy_api_url": "",
|
||||
"proxy_expire_minutes": 3,
|
||||
"enable_screenshot": 1,
|
||||
"auto_approve_enabled": 0,
|
||||
"auto_approve_hourly_limit": 10,
|
||||
"auto_approve_vip_days": 7,
|
||||
"kdocs_enabled": 0,
|
||||
"kdocs_doc_url": "",
|
||||
"kdocs_default_unit": "",
|
||||
"kdocs_sheet_name": "",
|
||||
"kdocs_sheet_index": 0,
|
||||
"kdocs_unit_column": "A",
|
||||
"kdocs_image_column": "D",
|
||||
"kdocs_admin_notify_enabled": 0,
|
||||
"kdocs_admin_notify_email": "",
|
||||
"kdocs_row_start": 0,
|
||||
"kdocs_row_end": 0,
|
||||
}
|
||||
|
||||
_SYSTEM_CONFIG_UPDATERS = (
|
||||
("max_concurrent_global", "max_concurrent"),
|
||||
("schedule_enabled", "schedule_enabled"),
|
||||
("schedule_time", "schedule_time"),
|
||||
("schedule_browse_type", "schedule_browse_type"),
|
||||
("schedule_weekdays", "schedule_weekdays"),
|
||||
("max_concurrent_per_account", "max_concurrent_per_account"),
|
||||
("max_screenshot_concurrent", "max_screenshot_concurrent"),
|
||||
("enable_screenshot", "enable_screenshot"),
|
||||
("proxy_enabled", "proxy_enabled"),
|
||||
("proxy_api_url", "proxy_api_url"),
|
||||
("proxy_expire_minutes", "proxy_expire_minutes"),
|
||||
("auto_approve_enabled", "auto_approve_enabled"),
|
||||
("auto_approve_hourly_limit", "auto_approve_hourly_limit"),
|
||||
("auto_approve_vip_days", "auto_approve_vip_days"),
|
||||
("kdocs_enabled", "kdocs_enabled"),
|
||||
("kdocs_doc_url", "kdocs_doc_url"),
|
||||
("kdocs_default_unit", "kdocs_default_unit"),
|
||||
("kdocs_sheet_name", "kdocs_sheet_name"),
|
||||
("kdocs_sheet_index", "kdocs_sheet_index"),
|
||||
("kdocs_unit_column", "kdocs_unit_column"),
|
||||
("kdocs_image_column", "kdocs_image_column"),
|
||||
("kdocs_admin_notify_enabled", "kdocs_admin_notify_enabled"),
|
||||
("kdocs_admin_notify_email", "kdocs_admin_notify_email"),
|
||||
("kdocs_row_start", "kdocs_row_start"),
|
||||
("kdocs_row_end", "kdocs_row_end"),
|
||||
)
|
||||
|
||||
|
||||
def _count_scalar(cursor, sql: str, params=()) -> int:
|
||||
cursor.execute(sql, params)
|
||||
row = cursor.fetchone()
|
||||
if not row:
|
||||
return 0
|
||||
try:
|
||||
if "count" in row.keys():
|
||||
return int(row["count"] or 0)
|
||||
except Exception:
|
||||
pass
|
||||
try:
|
||||
return int(row[0] or 0)
|
||||
except Exception:
|
||||
return 0
|
||||
|
||||
|
||||
def _table_exists(cursor, table_name: str) -> bool:
|
||||
cursor.execute(
|
||||
"""
|
||||
SELECT name FROM sqlite_master
|
||||
WHERE type='table' AND name=?
|
||||
""",
|
||||
(table_name,),
|
||||
)
|
||||
return bool(cursor.fetchone())
|
||||
|
||||
|
||||
def _normalize_days(days, default: int = 30) -> int:
|
||||
try:
|
||||
value = int(days)
|
||||
except Exception:
|
||||
value = default
|
||||
if value < 0:
|
||||
return 0
|
||||
return value
|
||||
|
||||
|
||||
def ensure_default_admin() -> bool:
|
||||
"""确保存在默认管理员账号(行为保持不变)。"""
|
||||
@@ -24,10 +114,9 @@ def ensure_default_admin() -> bool:
|
||||
|
||||
with db_pool.get_db() as conn:
|
||||
cursor = conn.cursor()
|
||||
cursor.execute("SELECT COUNT(*) as count FROM admins")
|
||||
result = cursor.fetchone()
|
||||
count = _count_scalar(cursor, "SELECT COUNT(*) as count FROM admins")
|
||||
|
||||
if result["count"] == 0:
|
||||
if count == 0:
|
||||
alphabet = string.ascii_letters + string.digits
|
||||
random_password = "".join(secrets.choice(alphabet) for _ in range(12))
|
||||
|
||||
@@ -101,41 +190,33 @@ def get_system_stats() -> dict:
|
||||
with db_pool.get_db() as conn:
|
||||
cursor = conn.cursor()
|
||||
|
||||
cursor.execute("SELECT COUNT(*) as count FROM users")
|
||||
total_users = cursor.fetchone()["count"]
|
||||
|
||||
cursor.execute("SELECT COUNT(*) as count FROM users WHERE status = 'approved'")
|
||||
approved_users = cursor.fetchone()["count"]
|
||||
|
||||
cursor.execute(
|
||||
total_users = _count_scalar(cursor, "SELECT COUNT(*) as count FROM users")
|
||||
approved_users = _count_scalar(cursor, "SELECT COUNT(*) as count FROM users WHERE status = 'approved'")
|
||||
new_users_today = _count_scalar(
|
||||
cursor,
|
||||
"""
|
||||
SELECT COUNT(*) as count
|
||||
FROM users
|
||||
WHERE date(created_at) = date('now', 'localtime')
|
||||
"""
|
||||
""",
|
||||
)
|
||||
new_users_today = cursor.fetchone()["count"]
|
||||
|
||||
cursor.execute(
|
||||
new_users_7d = _count_scalar(
|
||||
cursor,
|
||||
"""
|
||||
SELECT COUNT(*) as count
|
||||
FROM users
|
||||
WHERE datetime(created_at) >= datetime('now', 'localtime', '-7 days')
|
||||
"""
|
||||
""",
|
||||
)
|
||||
new_users_7d = cursor.fetchone()["count"]
|
||||
|
||||
cursor.execute("SELECT COUNT(*) as count FROM accounts")
|
||||
total_accounts = cursor.fetchone()["count"]
|
||||
|
||||
cursor.execute(
|
||||
total_accounts = _count_scalar(cursor, "SELECT COUNT(*) as count FROM accounts")
|
||||
vip_users = _count_scalar(
|
||||
cursor,
|
||||
"""
|
||||
SELECT COUNT(*) as count FROM users
|
||||
WHERE vip_expire_time IS NOT NULL
|
||||
AND datetime(vip_expire_time) > datetime('now', 'localtime')
|
||||
"""
|
||||
""",
|
||||
)
|
||||
vip_users = cursor.fetchone()["count"]
|
||||
|
||||
return {
|
||||
"total_users": total_users,
|
||||
@@ -153,37 +234,9 @@ def get_system_config_raw() -> dict:
|
||||
cursor = conn.cursor()
|
||||
cursor.execute("SELECT * FROM system_config WHERE id = 1")
|
||||
row = cursor.fetchone()
|
||||
|
||||
if row:
|
||||
return dict(row)
|
||||
|
||||
return {
|
||||
"max_concurrent_global": 2,
|
||||
"max_concurrent_per_account": 1,
|
||||
"max_screenshot_concurrent": 3,
|
||||
"schedule_enabled": 0,
|
||||
"schedule_time": "02:00",
|
||||
"schedule_browse_type": "应读",
|
||||
"schedule_weekdays": "1,2,3,4,5,6,7",
|
||||
"proxy_enabled": 0,
|
||||
"proxy_api_url": "",
|
||||
"proxy_expire_minutes": 3,
|
||||
"enable_screenshot": 1,
|
||||
"auto_approve_enabled": 0,
|
||||
"auto_approve_hourly_limit": 10,
|
||||
"auto_approve_vip_days": 7,
|
||||
"kdocs_enabled": 0,
|
||||
"kdocs_doc_url": "",
|
||||
"kdocs_default_unit": "",
|
||||
"kdocs_sheet_name": "",
|
||||
"kdocs_sheet_index": 0,
|
||||
"kdocs_unit_column": "A",
|
||||
"kdocs_image_column": "D",
|
||||
"kdocs_admin_notify_enabled": 0,
|
||||
"kdocs_admin_notify_email": "",
|
||||
"kdocs_row_start": 0,
|
||||
"kdocs_row_end": 0,
|
||||
}
|
||||
return dict(_DEFAULT_SYSTEM_CONFIG)
|
||||
|
||||
|
||||
def update_system_config(
|
||||
@@ -215,127 +268,51 @@ def update_system_config(
|
||||
kdocs_row_end=None,
|
||||
) -> bool:
|
||||
"""更新系统配置(仅更新DB,不做缓存处理)。"""
|
||||
allowed_fields = {
|
||||
"max_concurrent_global",
|
||||
"schedule_enabled",
|
||||
"schedule_time",
|
||||
"schedule_browse_type",
|
||||
"schedule_weekdays",
|
||||
"max_concurrent_per_account",
|
||||
"max_screenshot_concurrent",
|
||||
"enable_screenshot",
|
||||
"proxy_enabled",
|
||||
"proxy_api_url",
|
||||
"proxy_expire_minutes",
|
||||
"auto_approve_enabled",
|
||||
"auto_approve_hourly_limit",
|
||||
"auto_approve_vip_days",
|
||||
"kdocs_enabled",
|
||||
"kdocs_doc_url",
|
||||
"kdocs_default_unit",
|
||||
"kdocs_sheet_name",
|
||||
"kdocs_sheet_index",
|
||||
"kdocs_unit_column",
|
||||
"kdocs_image_column",
|
||||
"kdocs_admin_notify_enabled",
|
||||
"kdocs_admin_notify_email",
|
||||
"kdocs_row_start",
|
||||
"kdocs_row_end",
|
||||
"updated_at",
|
||||
arg_values = {
|
||||
"max_concurrent": max_concurrent,
|
||||
"schedule_enabled": schedule_enabled,
|
||||
"schedule_time": schedule_time,
|
||||
"schedule_browse_type": schedule_browse_type,
|
||||
"schedule_weekdays": schedule_weekdays,
|
||||
"max_concurrent_per_account": max_concurrent_per_account,
|
||||
"max_screenshot_concurrent": max_screenshot_concurrent,
|
||||
"enable_screenshot": enable_screenshot,
|
||||
"proxy_enabled": proxy_enabled,
|
||||
"proxy_api_url": proxy_api_url,
|
||||
"proxy_expire_minutes": proxy_expire_minutes,
|
||||
"auto_approve_enabled": auto_approve_enabled,
|
||||
"auto_approve_hourly_limit": auto_approve_hourly_limit,
|
||||
"auto_approve_vip_days": auto_approve_vip_days,
|
||||
"kdocs_enabled": kdocs_enabled,
|
||||
"kdocs_doc_url": kdocs_doc_url,
|
||||
"kdocs_default_unit": kdocs_default_unit,
|
||||
"kdocs_sheet_name": kdocs_sheet_name,
|
||||
"kdocs_sheet_index": kdocs_sheet_index,
|
||||
"kdocs_unit_column": kdocs_unit_column,
|
||||
"kdocs_image_column": kdocs_image_column,
|
||||
"kdocs_admin_notify_enabled": kdocs_admin_notify_enabled,
|
||||
"kdocs_admin_notify_email": kdocs_admin_notify_email,
|
||||
"kdocs_row_start": kdocs_row_start,
|
||||
"kdocs_row_end": kdocs_row_end,
|
||||
}
|
||||
|
||||
updates = []
|
||||
params = []
|
||||
for db_field, arg_name in _SYSTEM_CONFIG_UPDATERS:
|
||||
value = arg_values.get(arg_name)
|
||||
if value is None:
|
||||
continue
|
||||
updates.append(f"{db_field} = ?")
|
||||
params.append(value)
|
||||
|
||||
if not updates:
|
||||
return False
|
||||
|
||||
updates.append("updated_at = ?")
|
||||
params.append(get_cst_now_str())
|
||||
|
||||
with db_pool.get_db() as conn:
|
||||
cursor = conn.cursor()
|
||||
updates = []
|
||||
params = []
|
||||
|
||||
if max_concurrent is not None:
|
||||
updates.append("max_concurrent_global = ?")
|
||||
params.append(max_concurrent)
|
||||
if schedule_enabled is not None:
|
||||
updates.append("schedule_enabled = ?")
|
||||
params.append(schedule_enabled)
|
||||
if schedule_time is not None:
|
||||
updates.append("schedule_time = ?")
|
||||
params.append(schedule_time)
|
||||
if schedule_browse_type is not None:
|
||||
updates.append("schedule_browse_type = ?")
|
||||
params.append(schedule_browse_type)
|
||||
if max_concurrent_per_account is not None:
|
||||
updates.append("max_concurrent_per_account = ?")
|
||||
params.append(max_concurrent_per_account)
|
||||
if max_screenshot_concurrent is not None:
|
||||
updates.append("max_screenshot_concurrent = ?")
|
||||
params.append(max_screenshot_concurrent)
|
||||
if enable_screenshot is not None:
|
||||
updates.append("enable_screenshot = ?")
|
||||
params.append(enable_screenshot)
|
||||
if schedule_weekdays is not None:
|
||||
updates.append("schedule_weekdays = ?")
|
||||
params.append(schedule_weekdays)
|
||||
if proxy_enabled is not None:
|
||||
updates.append("proxy_enabled = ?")
|
||||
params.append(proxy_enabled)
|
||||
if proxy_api_url is not None:
|
||||
updates.append("proxy_api_url = ?")
|
||||
params.append(proxy_api_url)
|
||||
if proxy_expire_minutes is not None:
|
||||
updates.append("proxy_expire_minutes = ?")
|
||||
params.append(proxy_expire_minutes)
|
||||
if auto_approve_enabled is not None:
|
||||
updates.append("auto_approve_enabled = ?")
|
||||
params.append(auto_approve_enabled)
|
||||
if auto_approve_hourly_limit is not None:
|
||||
updates.append("auto_approve_hourly_limit = ?")
|
||||
params.append(auto_approve_hourly_limit)
|
||||
if auto_approve_vip_days is not None:
|
||||
updates.append("auto_approve_vip_days = ?")
|
||||
params.append(auto_approve_vip_days)
|
||||
if kdocs_enabled is not None:
|
||||
updates.append("kdocs_enabled = ?")
|
||||
params.append(kdocs_enabled)
|
||||
if kdocs_doc_url is not None:
|
||||
updates.append("kdocs_doc_url = ?")
|
||||
params.append(kdocs_doc_url)
|
||||
if kdocs_default_unit is not None:
|
||||
updates.append("kdocs_default_unit = ?")
|
||||
params.append(kdocs_default_unit)
|
||||
if kdocs_sheet_name is not None:
|
||||
updates.append("kdocs_sheet_name = ?")
|
||||
params.append(kdocs_sheet_name)
|
||||
if kdocs_sheet_index is not None:
|
||||
updates.append("kdocs_sheet_index = ?")
|
||||
params.append(kdocs_sheet_index)
|
||||
if kdocs_unit_column is not None:
|
||||
updates.append("kdocs_unit_column = ?")
|
||||
params.append(kdocs_unit_column)
|
||||
if kdocs_image_column is not None:
|
||||
updates.append("kdocs_image_column = ?")
|
||||
params.append(kdocs_image_column)
|
||||
if kdocs_admin_notify_enabled is not None:
|
||||
updates.append("kdocs_admin_notify_enabled = ?")
|
||||
params.append(kdocs_admin_notify_enabled)
|
||||
if kdocs_admin_notify_email is not None:
|
||||
updates.append("kdocs_admin_notify_email = ?")
|
||||
params.append(kdocs_admin_notify_email)
|
||||
if kdocs_row_start is not None:
|
||||
updates.append("kdocs_row_start = ?")
|
||||
params.append(kdocs_row_start)
|
||||
if kdocs_row_end is not None:
|
||||
updates.append("kdocs_row_end = ?")
|
||||
params.append(kdocs_row_end)
|
||||
|
||||
if not updates:
|
||||
return False
|
||||
|
||||
updates.append("updated_at = ?")
|
||||
params.append(get_cst_now_str())
|
||||
|
||||
for update_clause in updates:
|
||||
field_name = update_clause.split("=")[0].strip()
|
||||
if field_name not in allowed_fields:
|
||||
raise ValueError(f"非法字段名: {field_name}")
|
||||
|
||||
sql = f"UPDATE system_config SET {', '.join(updates)} WHERE id = 1"
|
||||
cursor.execute(sql, params)
|
||||
conn.commit()
|
||||
@@ -346,13 +323,13 @@ def get_hourly_registration_count() -> int:
|
||||
"""获取最近一小时内的注册用户数"""
|
||||
with db_pool.get_db() as conn:
|
||||
cursor = conn.cursor()
|
||||
cursor.execute(
|
||||
return _count_scalar(
|
||||
cursor,
|
||||
"""
|
||||
SELECT COUNT(*) FROM users
|
||||
SELECT COUNT(*) as count FROM users
|
||||
WHERE created_at >= datetime('now', 'localtime', '-1 hour')
|
||||
"""
|
||||
""",
|
||||
)
|
||||
return cursor.fetchone()[0]
|
||||
|
||||
|
||||
# ==================== 密码重置(管理员) ====================
|
||||
@@ -374,17 +351,12 @@ def admin_reset_user_password(user_id: int, new_password: str) -> bool:
|
||||
|
||||
def clean_old_operation_logs(days: int = 30) -> int:
|
||||
"""清理指定天数前的操作日志(如果存在operation_logs表)"""
|
||||
safe_days = _normalize_days(days, default=30)
|
||||
|
||||
with db_pool.get_db() as conn:
|
||||
cursor = conn.cursor()
|
||||
|
||||
cursor.execute(
|
||||
"""
|
||||
SELECT name FROM sqlite_master
|
||||
WHERE type='table' AND name='operation_logs'
|
||||
"""
|
||||
)
|
||||
|
||||
if not cursor.fetchone():
|
||||
if not _table_exists(cursor, "operation_logs"):
|
||||
return 0
|
||||
|
||||
try:
|
||||
@@ -393,11 +365,11 @@ def clean_old_operation_logs(days: int = 30) -> int:
|
||||
DELETE FROM operation_logs
|
||||
WHERE created_at < datetime('now', 'localtime', '-' || ? || ' days')
|
||||
""",
|
||||
(days,),
|
||||
(safe_days,),
|
||||
)
|
||||
deleted_count = cursor.rowcount
|
||||
conn.commit()
|
||||
print(f"已清理 {deleted_count} 条旧操作日志 (>{days}天)")
|
||||
print(f"已清理 {deleted_count} 条旧操作日志 (>{safe_days}天)")
|
||||
return deleted_count
|
||||
except Exception as e:
|
||||
print(f"清理旧操作日志失败: {e}")
|
||||
|
||||
@@ -6,12 +6,38 @@ import db_pool
|
||||
from db.utils import get_cst_now_str
|
||||
|
||||
|
||||
def _normalize_limit(value, default: int, *, minimum: int = 1, maximum: int = 500) -> int:
|
||||
try:
|
||||
parsed = int(value)
|
||||
except Exception:
|
||||
parsed = default
|
||||
parsed = max(minimum, parsed)
|
||||
parsed = min(maximum, parsed)
|
||||
return parsed
|
||||
|
||||
|
||||
def _normalize_offset(value, default: int = 0) -> int:
|
||||
try:
|
||||
parsed = int(value)
|
||||
except Exception:
|
||||
parsed = default
|
||||
return max(0, parsed)
|
||||
|
||||
|
||||
def _normalize_announcement_payload(title, content, image_url):
|
||||
normalized_title = str(title or "").strip()
|
||||
normalized_content = str(content or "").strip()
|
||||
normalized_image = str(image_url or "").strip() or None
|
||||
return normalized_title, normalized_content, normalized_image
|
||||
|
||||
|
||||
def _deactivate_all_active_announcements(cursor, cst_time: str) -> None:
|
||||
cursor.execute("UPDATE announcements SET is_active = 0, updated_at = ? WHERE is_active = 1", (cst_time,))
|
||||
|
||||
|
||||
def create_announcement(title, content, image_url=None, is_active=True):
|
||||
"""创建公告(默认启用;启用时会自动停用其他公告)"""
|
||||
title = (title or "").strip()
|
||||
content = (content or "").strip()
|
||||
image_url = (image_url or "").strip()
|
||||
image_url = image_url or None
|
||||
title, content, image_url = _normalize_announcement_payload(title, content, image_url)
|
||||
if not title or not content:
|
||||
return None
|
||||
|
||||
@@ -20,7 +46,7 @@ def create_announcement(title, content, image_url=None, is_active=True):
|
||||
cst_time = get_cst_now_str()
|
||||
|
||||
if is_active:
|
||||
cursor.execute("UPDATE announcements SET is_active = 0, updated_at = ? WHERE is_active = 1", (cst_time,))
|
||||
_deactivate_all_active_announcements(cursor, cst_time)
|
||||
|
||||
cursor.execute(
|
||||
"""
|
||||
@@ -44,6 +70,9 @@ def get_announcement_by_id(announcement_id):
|
||||
|
||||
def get_announcements(limit=50, offset=0):
|
||||
"""获取公告列表(管理员用)"""
|
||||
safe_limit = _normalize_limit(limit, 50)
|
||||
safe_offset = _normalize_offset(offset, 0)
|
||||
|
||||
with db_pool.get_db() as conn:
|
||||
cursor = conn.cursor()
|
||||
cursor.execute(
|
||||
@@ -52,7 +81,7 @@ def get_announcements(limit=50, offset=0):
|
||||
ORDER BY created_at DESC, id DESC
|
||||
LIMIT ? OFFSET ?
|
||||
""",
|
||||
(limit, offset),
|
||||
(safe_limit, safe_offset),
|
||||
)
|
||||
return [dict(row) for row in cursor.fetchall()]
|
||||
|
||||
@@ -64,7 +93,7 @@ def set_announcement_active(announcement_id, is_active):
|
||||
cst_time = get_cst_now_str()
|
||||
|
||||
if is_active:
|
||||
cursor.execute("UPDATE announcements SET is_active = 0, updated_at = ? WHERE is_active = 1", (cst_time,))
|
||||
_deactivate_all_active_announcements(cursor, cst_time)
|
||||
cursor.execute(
|
||||
"""
|
||||
UPDATE announcements
|
||||
@@ -121,13 +150,12 @@ def dismiss_announcement_for_user(user_id, announcement_id):
|
||||
"""用户永久关闭某条公告(幂等)"""
|
||||
with db_pool.get_db() as conn:
|
||||
cursor = conn.cursor()
|
||||
cst_time = get_cst_now_str()
|
||||
cursor.execute(
|
||||
"""
|
||||
INSERT OR IGNORE INTO announcement_dismissals (user_id, announcement_id, dismissed_at)
|
||||
VALUES (?, ?, ?)
|
||||
""",
|
||||
(user_id, announcement_id, cst_time),
|
||||
(user_id, announcement_id, get_cst_now_str()),
|
||||
)
|
||||
conn.commit()
|
||||
return cursor.rowcount >= 0
|
||||
|
||||
27
db/email.py
27
db/email.py
@@ -5,6 +5,27 @@ from __future__ import annotations
|
||||
import db_pool
|
||||
|
||||
|
||||
def _to_bool_with_default(value, default: bool = True) -> bool:
|
||||
if value is None:
|
||||
return default
|
||||
try:
|
||||
return bool(int(value))
|
||||
except Exception:
|
||||
try:
|
||||
return bool(value)
|
||||
except Exception:
|
||||
return default
|
||||
|
||||
|
||||
def _normalize_notify_enabled(enabled) -> int:
|
||||
if isinstance(enabled, bool):
|
||||
return 1 if enabled else 0
|
||||
try:
|
||||
return 1 if int(enabled) else 0
|
||||
except Exception:
|
||||
return 1
|
||||
|
||||
|
||||
def get_user_by_email(email):
|
||||
"""根据邮箱获取用户"""
|
||||
with db_pool.get_db() as conn:
|
||||
@@ -25,7 +46,7 @@ def update_user_email(user_id, email, verified=False):
|
||||
SET email = ?, email_verified = ?
|
||||
WHERE id = ?
|
||||
""",
|
||||
(email, int(verified), user_id),
|
||||
(email, 1 if verified else 0, user_id),
|
||||
)
|
||||
conn.commit()
|
||||
return cursor.rowcount > 0
|
||||
@@ -42,7 +63,7 @@ def update_user_email_notify(user_id, enabled):
|
||||
SET email_notify_enabled = ?
|
||||
WHERE id = ?
|
||||
""",
|
||||
(int(enabled), user_id),
|
||||
(_normalize_notify_enabled(enabled), user_id),
|
||||
)
|
||||
conn.commit()
|
||||
return cursor.rowcount > 0
|
||||
@@ -57,6 +78,6 @@ def get_user_email_notify(user_id):
|
||||
row = cursor.fetchone()
|
||||
if row is None:
|
||||
return True
|
||||
return bool(row[0]) if row[0] is not None else True
|
||||
return _to_bool_with_default(row[0], default=True)
|
||||
except Exception:
|
||||
return True
|
||||
|
||||
106
db/feedbacks.py
106
db/feedbacks.py
@@ -2,32 +2,73 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
from __future__ import annotations
|
||||
|
||||
from datetime import datetime
|
||||
|
||||
import pytz
|
||||
|
||||
import db_pool
|
||||
from db.utils import escape_html
|
||||
from db.utils import escape_html, get_cst_now_str
|
||||
|
||||
|
||||
def _normalize_limit(value, default: int, *, minimum: int = 1, maximum: int = 500) -> int:
|
||||
try:
|
||||
parsed = int(value)
|
||||
except Exception:
|
||||
parsed = default
|
||||
parsed = max(minimum, parsed)
|
||||
parsed = min(maximum, parsed)
|
||||
return parsed
|
||||
|
||||
|
||||
def _normalize_offset(value, default: int = 0) -> int:
|
||||
try:
|
||||
parsed = int(value)
|
||||
except Exception:
|
||||
parsed = default
|
||||
return max(0, parsed)
|
||||
|
||||
|
||||
def _safe_text(value) -> str:
|
||||
if value is None:
|
||||
return ""
|
||||
text = str(value)
|
||||
return escape_html(text) if text else ""
|
||||
|
||||
|
||||
def _build_feedback_filter_sql(status_filter=None) -> tuple[str, list]:
|
||||
where_clauses = ["1=1"]
|
||||
params = []
|
||||
|
||||
if status_filter:
|
||||
where_clauses.append("status = ?")
|
||||
params.append(status_filter)
|
||||
|
||||
return " AND ".join(where_clauses), params
|
||||
|
||||
|
||||
def _normalize_feedback_stats_row(row) -> dict:
|
||||
row_dict = dict(row) if row else {}
|
||||
return {
|
||||
"total": int(row_dict.get("total") or 0),
|
||||
"pending": int(row_dict.get("pending") or 0),
|
||||
"replied": int(row_dict.get("replied") or 0),
|
||||
"closed": int(row_dict.get("closed") or 0),
|
||||
}
|
||||
|
||||
|
||||
def create_bug_feedback(user_id, username, title, description, contact=""):
|
||||
"""创建Bug反馈(带XSS防护)"""
|
||||
with db_pool.get_db() as conn:
|
||||
cursor = conn.cursor()
|
||||
cst_tz = pytz.timezone("Asia/Shanghai")
|
||||
cst_time = datetime.now(cst_tz).strftime("%Y-%m-%d %H:%M:%S")
|
||||
|
||||
safe_title = escape_html(title) if title else ""
|
||||
safe_description = escape_html(description) if description else ""
|
||||
safe_contact = escape_html(contact) if contact else ""
|
||||
safe_username = escape_html(username) if username else ""
|
||||
|
||||
cursor.execute(
|
||||
"""
|
||||
INSERT INTO bug_feedbacks (user_id, username, title, description, contact, created_at)
|
||||
VALUES (?, ?, ?, ?, ?, ?)
|
||||
""",
|
||||
(user_id, safe_username, safe_title, safe_description, safe_contact, cst_time),
|
||||
(
|
||||
user_id,
|
||||
_safe_text(username),
|
||||
_safe_text(title),
|
||||
_safe_text(description),
|
||||
_safe_text(contact),
|
||||
get_cst_now_str(),
|
||||
),
|
||||
)
|
||||
|
||||
conn.commit()
|
||||
@@ -36,25 +77,25 @@ def create_bug_feedback(user_id, username, title, description, contact=""):
|
||||
|
||||
def get_bug_feedbacks(limit=100, offset=0, status_filter=None):
|
||||
"""获取Bug反馈列表(管理员用)"""
|
||||
safe_limit = _normalize_limit(limit, 100, minimum=1, maximum=1000)
|
||||
safe_offset = _normalize_offset(offset, 0)
|
||||
|
||||
with db_pool.get_db() as conn:
|
||||
cursor = conn.cursor()
|
||||
|
||||
sql = "SELECT * FROM bug_feedbacks WHERE 1=1"
|
||||
params = []
|
||||
|
||||
if status_filter:
|
||||
sql += " AND status = ?"
|
||||
params.append(status_filter)
|
||||
|
||||
sql += " ORDER BY created_at DESC LIMIT ? OFFSET ?"
|
||||
params.extend([limit, offset])
|
||||
|
||||
cursor.execute(sql, params)
|
||||
where_sql, params = _build_feedback_filter_sql(status_filter=status_filter)
|
||||
sql = f"""
|
||||
SELECT * FROM bug_feedbacks
|
||||
WHERE {where_sql}
|
||||
ORDER BY created_at DESC
|
||||
LIMIT ? OFFSET ?
|
||||
"""
|
||||
cursor.execute(sql, params + [safe_limit, safe_offset])
|
||||
return [dict(row) for row in cursor.fetchall()]
|
||||
|
||||
|
||||
def get_user_feedbacks(user_id, limit=50):
|
||||
"""获取用户自己的反馈列表"""
|
||||
safe_limit = _normalize_limit(limit, 50, minimum=1, maximum=1000)
|
||||
with db_pool.get_db() as conn:
|
||||
cursor = conn.cursor()
|
||||
cursor.execute(
|
||||
@@ -64,7 +105,7 @@ def get_user_feedbacks(user_id, limit=50):
|
||||
ORDER BY created_at DESC
|
||||
LIMIT ?
|
||||
""",
|
||||
(user_id, limit),
|
||||
(user_id, safe_limit),
|
||||
)
|
||||
return [dict(row) for row in cursor.fetchall()]
|
||||
|
||||
@@ -82,18 +123,13 @@ def reply_feedback(feedback_id, admin_reply):
|
||||
"""管理员回复反馈(带XSS防护)"""
|
||||
with db_pool.get_db() as conn:
|
||||
cursor = conn.cursor()
|
||||
cst_tz = pytz.timezone("Asia/Shanghai")
|
||||
cst_time = datetime.now(cst_tz).strftime("%Y-%m-%d %H:%M:%S")
|
||||
|
||||
safe_reply = escape_html(admin_reply) if admin_reply else ""
|
||||
|
||||
cursor.execute(
|
||||
"""
|
||||
UPDATE bug_feedbacks
|
||||
SET admin_reply = ?, status = 'replied', replied_at = ?
|
||||
WHERE id = ?
|
||||
""",
|
||||
(safe_reply, cst_time, feedback_id),
|
||||
(_safe_text(admin_reply), get_cst_now_str(), feedback_id),
|
||||
)
|
||||
|
||||
conn.commit()
|
||||
@@ -139,6 +175,4 @@ def get_feedback_stats():
|
||||
FROM bug_feedbacks
|
||||
"""
|
||||
)
|
||||
row = cursor.fetchone()
|
||||
return dict(row) if row else {"total": 0, "pending": 0, "replied": 0, "closed": 0}
|
||||
|
||||
return _normalize_feedback_stats_row(cursor.fetchone())
|
||||
|
||||
541
db/migrations.py
541
db/migrations.py
@@ -28,105 +28,136 @@ def set_current_version(conn, version: int) -> None:
|
||||
conn.commit()
|
||||
|
||||
|
||||
def _table_exists(cursor, table_name: str) -> bool:
|
||||
cursor.execute("SELECT name FROM sqlite_master WHERE type='table' AND name=?", (str(table_name),))
|
||||
return cursor.fetchone() is not None
|
||||
|
||||
|
||||
def _get_table_columns(cursor, table_name: str) -> set[str]:
|
||||
cursor.execute(f"PRAGMA table_info({table_name})")
|
||||
return {col[1] for col in cursor.fetchall()}
|
||||
|
||||
|
||||
def _add_column_if_missing(cursor, table_name: str, columns: set[str], column_name: str, column_ddl: str, *, ok_message: str) -> bool:
|
||||
if column_name in columns:
|
||||
return False
|
||||
cursor.execute(f"ALTER TABLE {table_name} ADD COLUMN {column_name} {column_ddl}")
|
||||
columns.add(column_name)
|
||||
print(ok_message)
|
||||
return True
|
||||
|
||||
|
||||
def _read_row_value(row, key: str, index: int):
|
||||
if isinstance(row, sqlite3.Row):
|
||||
return row[key]
|
||||
return row[index]
|
||||
|
||||
|
||||
def _get_migration_steps():
|
||||
return [
|
||||
(1, _migrate_to_v1),
|
||||
(2, _migrate_to_v2),
|
||||
(3, _migrate_to_v3),
|
||||
(4, _migrate_to_v4),
|
||||
(5, _migrate_to_v5),
|
||||
(6, _migrate_to_v6),
|
||||
(7, _migrate_to_v7),
|
||||
(8, _migrate_to_v8),
|
||||
(9, _migrate_to_v9),
|
||||
(10, _migrate_to_v10),
|
||||
(11, _migrate_to_v11),
|
||||
(12, _migrate_to_v12),
|
||||
(13, _migrate_to_v13),
|
||||
(14, _migrate_to_v14),
|
||||
(15, _migrate_to_v15),
|
||||
(16, _migrate_to_v16),
|
||||
(17, _migrate_to_v17),
|
||||
(18, _migrate_to_v18),
|
||||
]
|
||||
|
||||
|
||||
def migrate_database(conn, target_version: int) -> None:
|
||||
"""数据库迁移:按版本增量升级(向前兼容)。"""
|
||||
cursor = conn.cursor()
|
||||
cursor.execute("INSERT OR IGNORE INTO db_version (id, version, updated_at) VALUES (1, 0, ?)", (get_cst_now_str(),))
|
||||
conn.commit()
|
||||
|
||||
target_version = int(target_version)
|
||||
current_version = get_current_version(conn)
|
||||
|
||||
if current_version < 1:
|
||||
_migrate_to_v1(conn)
|
||||
current_version = 1
|
||||
if current_version < 2:
|
||||
_migrate_to_v2(conn)
|
||||
current_version = 2
|
||||
if current_version < 3:
|
||||
_migrate_to_v3(conn)
|
||||
current_version = 3
|
||||
if current_version < 4:
|
||||
_migrate_to_v4(conn)
|
||||
current_version = 4
|
||||
if current_version < 5:
|
||||
_migrate_to_v5(conn)
|
||||
current_version = 5
|
||||
if current_version < 6:
|
||||
_migrate_to_v6(conn)
|
||||
current_version = 6
|
||||
if current_version < 7:
|
||||
_migrate_to_v7(conn)
|
||||
current_version = 7
|
||||
if current_version < 8:
|
||||
_migrate_to_v8(conn)
|
||||
current_version = 8
|
||||
if current_version < 9:
|
||||
_migrate_to_v9(conn)
|
||||
current_version = 9
|
||||
if current_version < 10:
|
||||
_migrate_to_v10(conn)
|
||||
current_version = 10
|
||||
if current_version < 11:
|
||||
_migrate_to_v11(conn)
|
||||
current_version = 11
|
||||
if current_version < 12:
|
||||
_migrate_to_v12(conn)
|
||||
current_version = 12
|
||||
if current_version < 13:
|
||||
_migrate_to_v13(conn)
|
||||
current_version = 13
|
||||
if current_version < 14:
|
||||
_migrate_to_v14(conn)
|
||||
current_version = 14
|
||||
if current_version < 15:
|
||||
_migrate_to_v15(conn)
|
||||
current_version = 15
|
||||
if current_version < 16:
|
||||
_migrate_to_v16(conn)
|
||||
current_version = 16
|
||||
if current_version < 17:
|
||||
_migrate_to_v17(conn)
|
||||
current_version = 17
|
||||
if current_version < 18:
|
||||
_migrate_to_v18(conn)
|
||||
current_version = 18
|
||||
for version, migrate_fn in _get_migration_steps():
|
||||
if version > target_version or current_version >= version:
|
||||
continue
|
||||
migrate_fn(conn)
|
||||
current_version = version
|
||||
|
||||
if current_version != int(target_version):
|
||||
set_current_version(conn, int(target_version))
|
||||
if current_version != target_version:
|
||||
set_current_version(conn, target_version)
|
||||
|
||||
|
||||
def _migrate_to_v1(conn):
|
||||
"""迁移到版本1 - 添加缺失字段"""
|
||||
cursor = conn.cursor()
|
||||
|
||||
cursor.execute("PRAGMA table_info(system_config)")
|
||||
columns = [col[1] for col in cursor.fetchall()]
|
||||
system_columns = _get_table_columns(cursor, "system_config")
|
||||
_add_column_if_missing(
|
||||
cursor,
|
||||
"system_config",
|
||||
system_columns,
|
||||
"schedule_weekdays",
|
||||
'TEXT DEFAULT "1,2,3,4,5,6,7"',
|
||||
ok_message=" [OK] 添加 schedule_weekdays 字段",
|
||||
)
|
||||
_add_column_if_missing(
|
||||
cursor,
|
||||
"system_config",
|
||||
system_columns,
|
||||
"max_screenshot_concurrent",
|
||||
"INTEGER DEFAULT 3",
|
||||
ok_message=" [OK] 添加 max_screenshot_concurrent 字段",
|
||||
)
|
||||
_add_column_if_missing(
|
||||
cursor,
|
||||
"system_config",
|
||||
system_columns,
|
||||
"max_concurrent_per_account",
|
||||
"INTEGER DEFAULT 1",
|
||||
ok_message=" [OK] 添加 max_concurrent_per_account 字段",
|
||||
)
|
||||
_add_column_if_missing(
|
||||
cursor,
|
||||
"system_config",
|
||||
system_columns,
|
||||
"auto_approve_enabled",
|
||||
"INTEGER DEFAULT 0",
|
||||
ok_message=" [OK] 添加 auto_approve_enabled 字段",
|
||||
)
|
||||
_add_column_if_missing(
|
||||
cursor,
|
||||
"system_config",
|
||||
system_columns,
|
||||
"auto_approve_hourly_limit",
|
||||
"INTEGER DEFAULT 10",
|
||||
ok_message=" [OK] 添加 auto_approve_hourly_limit 字段",
|
||||
)
|
||||
_add_column_if_missing(
|
||||
cursor,
|
||||
"system_config",
|
||||
system_columns,
|
||||
"auto_approve_vip_days",
|
||||
"INTEGER DEFAULT 7",
|
||||
ok_message=" [OK] 添加 auto_approve_vip_days 字段",
|
||||
)
|
||||
|
||||
if "schedule_weekdays" not in columns:
|
||||
cursor.execute('ALTER TABLE system_config ADD COLUMN schedule_weekdays TEXT DEFAULT "1,2,3,4,5,6,7"')
|
||||
print(" [OK] 添加 schedule_weekdays 字段")
|
||||
|
||||
if "max_screenshot_concurrent" not in columns:
|
||||
cursor.execute("ALTER TABLE system_config ADD COLUMN max_screenshot_concurrent INTEGER DEFAULT 3")
|
||||
print(" [OK] 添加 max_screenshot_concurrent 字段")
|
||||
if "max_concurrent_per_account" not in columns:
|
||||
cursor.execute("ALTER TABLE system_config ADD COLUMN max_concurrent_per_account INTEGER DEFAULT 1")
|
||||
print(" [OK] 添加 max_concurrent_per_account 字段")
|
||||
if "auto_approve_enabled" not in columns:
|
||||
cursor.execute("ALTER TABLE system_config ADD COLUMN auto_approve_enabled INTEGER DEFAULT 0")
|
||||
print(" [OK] 添加 auto_approve_enabled 字段")
|
||||
if "auto_approve_hourly_limit" not in columns:
|
||||
cursor.execute("ALTER TABLE system_config ADD COLUMN auto_approve_hourly_limit INTEGER DEFAULT 10")
|
||||
print(" [OK] 添加 auto_approve_hourly_limit 字段")
|
||||
if "auto_approve_vip_days" not in columns:
|
||||
cursor.execute("ALTER TABLE system_config ADD COLUMN auto_approve_vip_days INTEGER DEFAULT 7")
|
||||
print(" [OK] 添加 auto_approve_vip_days 字段")
|
||||
|
||||
cursor.execute("PRAGMA table_info(task_logs)")
|
||||
columns = [col[1] for col in cursor.fetchall()]
|
||||
if "duration" not in columns:
|
||||
cursor.execute("ALTER TABLE task_logs ADD COLUMN duration INTEGER")
|
||||
print(" [OK] 添加 duration 字段到 task_logs")
|
||||
task_log_columns = _get_table_columns(cursor, "task_logs")
|
||||
_add_column_if_missing(
|
||||
cursor,
|
||||
"task_logs",
|
||||
task_log_columns,
|
||||
"duration",
|
||||
"INTEGER",
|
||||
ok_message=" [OK] 添加 duration 字段到 task_logs",
|
||||
)
|
||||
|
||||
conn.commit()
|
||||
|
||||
@@ -135,24 +166,39 @@ def _migrate_to_v2(conn):
|
||||
"""迁移到版本2 - 添加代理配置字段"""
|
||||
cursor = conn.cursor()
|
||||
|
||||
cursor.execute("PRAGMA table_info(system_config)")
|
||||
columns = [col[1] for col in cursor.fetchall()]
|
||||
|
||||
if "proxy_enabled" not in columns:
|
||||
cursor.execute("ALTER TABLE system_config ADD COLUMN proxy_enabled INTEGER DEFAULT 0")
|
||||
print(" [OK] 添加 proxy_enabled 字段")
|
||||
|
||||
if "proxy_api_url" not in columns:
|
||||
cursor.execute('ALTER TABLE system_config ADD COLUMN proxy_api_url TEXT DEFAULT ""')
|
||||
print(" [OK] 添加 proxy_api_url 字段")
|
||||
|
||||
if "proxy_expire_minutes" not in columns:
|
||||
cursor.execute("ALTER TABLE system_config ADD COLUMN proxy_expire_minutes INTEGER DEFAULT 3")
|
||||
print(" [OK] 添加 proxy_expire_minutes 字段")
|
||||
|
||||
if "enable_screenshot" not in columns:
|
||||
cursor.execute("ALTER TABLE system_config ADD COLUMN enable_screenshot INTEGER DEFAULT 1")
|
||||
print(" [OK] 添加 enable_screenshot 字段")
|
||||
columns = _get_table_columns(cursor, "system_config")
|
||||
_add_column_if_missing(
|
||||
cursor,
|
||||
"system_config",
|
||||
columns,
|
||||
"proxy_enabled",
|
||||
"INTEGER DEFAULT 0",
|
||||
ok_message=" [OK] 添加 proxy_enabled 字段",
|
||||
)
|
||||
_add_column_if_missing(
|
||||
cursor,
|
||||
"system_config",
|
||||
columns,
|
||||
"proxy_api_url",
|
||||
'TEXT DEFAULT ""',
|
||||
ok_message=" [OK] 添加 proxy_api_url 字段",
|
||||
)
|
||||
_add_column_if_missing(
|
||||
cursor,
|
||||
"system_config",
|
||||
columns,
|
||||
"proxy_expire_minutes",
|
||||
"INTEGER DEFAULT 3",
|
||||
ok_message=" [OK] 添加 proxy_expire_minutes 字段",
|
||||
)
|
||||
_add_column_if_missing(
|
||||
cursor,
|
||||
"system_config",
|
||||
columns,
|
||||
"enable_screenshot",
|
||||
"INTEGER DEFAULT 1",
|
||||
ok_message=" [OK] 添加 enable_screenshot 字段",
|
||||
)
|
||||
|
||||
conn.commit()
|
||||
|
||||
@@ -161,20 +207,31 @@ def _migrate_to_v3(conn):
|
||||
"""迁移到版本3 - 添加账号状态和登录失败计数字段"""
|
||||
cursor = conn.cursor()
|
||||
|
||||
cursor.execute("PRAGMA table_info(accounts)")
|
||||
columns = [col[1] for col in cursor.fetchall()]
|
||||
|
||||
if "status" not in columns:
|
||||
cursor.execute('ALTER TABLE accounts ADD COLUMN status TEXT DEFAULT "active"')
|
||||
print(" [OK] 添加 accounts.status 字段 (账号状态)")
|
||||
|
||||
if "login_fail_count" not in columns:
|
||||
cursor.execute("ALTER TABLE accounts ADD COLUMN login_fail_count INTEGER DEFAULT 0")
|
||||
print(" [OK] 添加 accounts.login_fail_count 字段 (登录失败计数)")
|
||||
|
||||
if "last_login_error" not in columns:
|
||||
cursor.execute("ALTER TABLE accounts ADD COLUMN last_login_error TEXT")
|
||||
print(" [OK] 添加 accounts.last_login_error 字段 (最后登录错误)")
|
||||
columns = _get_table_columns(cursor, "accounts")
|
||||
_add_column_if_missing(
|
||||
cursor,
|
||||
"accounts",
|
||||
columns,
|
||||
"status",
|
||||
'TEXT DEFAULT "active"',
|
||||
ok_message=" [OK] 添加 accounts.status 字段 (账号状态)",
|
||||
)
|
||||
_add_column_if_missing(
|
||||
cursor,
|
||||
"accounts",
|
||||
columns,
|
||||
"login_fail_count",
|
||||
"INTEGER DEFAULT 0",
|
||||
ok_message=" [OK] 添加 accounts.login_fail_count 字段 (登录失败计数)",
|
||||
)
|
||||
_add_column_if_missing(
|
||||
cursor,
|
||||
"accounts",
|
||||
columns,
|
||||
"last_login_error",
|
||||
"TEXT",
|
||||
ok_message=" [OK] 添加 accounts.last_login_error 字段 (最后登录错误)",
|
||||
)
|
||||
|
||||
conn.commit()
|
||||
|
||||
@@ -183,12 +240,15 @@ def _migrate_to_v4(conn):
|
||||
"""迁移到版本4 - 添加任务来源字段"""
|
||||
cursor = conn.cursor()
|
||||
|
||||
cursor.execute("PRAGMA table_info(task_logs)")
|
||||
columns = [col[1] for col in cursor.fetchall()]
|
||||
|
||||
if "source" not in columns:
|
||||
cursor.execute('ALTER TABLE task_logs ADD COLUMN source TEXT DEFAULT "manual"')
|
||||
print(" [OK] 添加 task_logs.source 字段 (任务来源: manual/scheduled/immediate)")
|
||||
columns = _get_table_columns(cursor, "task_logs")
|
||||
_add_column_if_missing(
|
||||
cursor,
|
||||
"task_logs",
|
||||
columns,
|
||||
"source",
|
||||
'TEXT DEFAULT "manual"',
|
||||
ok_message=" [OK] 添加 task_logs.source 字段 (任务来源: manual/scheduled/immediate)",
|
||||
)
|
||||
|
||||
conn.commit()
|
||||
|
||||
@@ -300,20 +360,17 @@ def _migrate_to_v6(conn):
|
||||
def _migrate_to_v7(conn):
|
||||
"""迁移到版本7 - 统一存储北京时间(将历史UTC时间字段整体+8小时)"""
|
||||
cursor = conn.cursor()
|
||||
|
||||
def table_exists(table_name: str) -> bool:
|
||||
cursor.execute("SELECT name FROM sqlite_master WHERE type='table' AND name=?", (table_name,))
|
||||
return cursor.fetchone() is not None
|
||||
|
||||
def column_exists(table_name: str, column_name: str) -> bool:
|
||||
cursor.execute(f"PRAGMA table_info({table_name})")
|
||||
return any(row[1] == column_name for row in cursor.fetchall())
|
||||
columns_cache: dict[str, set[str]] = {}
|
||||
|
||||
def shift_utc_to_cst(table_name: str, column_name: str) -> None:
|
||||
if not table_exists(table_name):
|
||||
if not _table_exists(cursor, table_name):
|
||||
return
|
||||
if not column_exists(table_name, column_name):
|
||||
|
||||
if table_name not in columns_cache:
|
||||
columns_cache[table_name] = _get_table_columns(cursor, table_name)
|
||||
if column_name not in columns_cache[table_name]:
|
||||
return
|
||||
|
||||
cursor.execute(
|
||||
f"""
|
||||
UPDATE {table_name}
|
||||
@@ -329,10 +386,6 @@ def _migrate_to_v7(conn):
|
||||
("accounts", "created_at"),
|
||||
("password_reset_requests", "created_at"),
|
||||
("password_reset_requests", "processed_at"),
|
||||
]:
|
||||
shift_utc_to_cst(table, col)
|
||||
|
||||
for table, col in [
|
||||
("smtp_configs", "created_at"),
|
||||
("smtp_configs", "updated_at"),
|
||||
("smtp_configs", "last_success_at"),
|
||||
@@ -340,10 +393,6 @@ def _migrate_to_v7(conn):
|
||||
("email_tokens", "created_at"),
|
||||
("email_logs", "created_at"),
|
||||
("email_stats", "last_updated"),
|
||||
]:
|
||||
shift_utc_to_cst(table, col)
|
||||
|
||||
for table, col in [
|
||||
("task_checkpoints", "created_at"),
|
||||
("task_checkpoints", "updated_at"),
|
||||
("task_checkpoints", "completed_at"),
|
||||
@@ -359,15 +408,23 @@ def _migrate_to_v8(conn):
|
||||
cursor = conn.cursor()
|
||||
|
||||
# 1) 增量字段:random_delay(旧库可能不存在)
|
||||
cursor.execute("PRAGMA table_info(user_schedules)")
|
||||
columns = [col[1] for col in cursor.fetchall()]
|
||||
if "random_delay" not in columns:
|
||||
cursor.execute("ALTER TABLE user_schedules ADD COLUMN random_delay INTEGER DEFAULT 0")
|
||||
print(" [OK] 添加 user_schedules.random_delay 字段")
|
||||
|
||||
if "next_run_at" not in columns:
|
||||
cursor.execute("ALTER TABLE user_schedules ADD COLUMN next_run_at TIMESTAMP")
|
||||
print(" [OK] 添加 user_schedules.next_run_at 字段")
|
||||
columns = _get_table_columns(cursor, "user_schedules")
|
||||
_add_column_if_missing(
|
||||
cursor,
|
||||
"user_schedules",
|
||||
columns,
|
||||
"random_delay",
|
||||
"INTEGER DEFAULT 0",
|
||||
ok_message=" [OK] 添加 user_schedules.random_delay 字段",
|
||||
)
|
||||
_add_column_if_missing(
|
||||
cursor,
|
||||
"user_schedules",
|
||||
columns,
|
||||
"next_run_at",
|
||||
"TIMESTAMP",
|
||||
ok_message=" [OK] 添加 user_schedules.next_run_at 字段",
|
||||
)
|
||||
|
||||
cursor.execute("CREATE INDEX IF NOT EXISTS idx_user_schedules_next_run ON user_schedules(next_run_at)")
|
||||
conn.commit()
|
||||
@@ -392,12 +449,12 @@ def _migrate_to_v8(conn):
|
||||
fixed = 0
|
||||
for row in rows:
|
||||
try:
|
||||
schedule_id = row["id"] if isinstance(row, sqlite3.Row) else row[0]
|
||||
schedule_time = row["schedule_time"] if isinstance(row, sqlite3.Row) else row[1]
|
||||
weekdays = row["weekdays"] if isinstance(row, sqlite3.Row) else row[2]
|
||||
random_delay = row["random_delay"] if isinstance(row, sqlite3.Row) else row[3]
|
||||
last_run_at = row["last_run_at"] if isinstance(row, sqlite3.Row) else row[4]
|
||||
next_run_at = row["next_run_at"] if isinstance(row, sqlite3.Row) else row[5]
|
||||
schedule_id = _read_row_value(row, "id", 0)
|
||||
schedule_time = _read_row_value(row, "schedule_time", 1)
|
||||
weekdays = _read_row_value(row, "weekdays", 2)
|
||||
random_delay = _read_row_value(row, "random_delay", 3)
|
||||
last_run_at = _read_row_value(row, "last_run_at", 4)
|
||||
next_run_at = _read_row_value(row, "next_run_at", 5)
|
||||
except Exception:
|
||||
continue
|
||||
|
||||
@@ -430,27 +487,46 @@ def _migrate_to_v9(conn):
|
||||
"""迁移到版本9 - 邮件设置字段迁移(清理 email_service scattered ALTER TABLE)"""
|
||||
cursor = conn.cursor()
|
||||
|
||||
cursor.execute("SELECT name FROM sqlite_master WHERE type='table' AND name='email_settings'")
|
||||
if not cursor.fetchone():
|
||||
if not _table_exists(cursor, "email_settings"):
|
||||
# 邮件表由 email_service.init_email_tables 创建;此处仅做增量字段迁移
|
||||
return
|
||||
|
||||
cursor.execute("PRAGMA table_info(email_settings)")
|
||||
columns = [col[1] for col in cursor.fetchall()]
|
||||
columns = _get_table_columns(cursor, "email_settings")
|
||||
|
||||
changed = False
|
||||
if "register_verify_enabled" not in columns:
|
||||
cursor.execute("ALTER TABLE email_settings ADD COLUMN register_verify_enabled INTEGER DEFAULT 0")
|
||||
print(" [OK] 添加 email_settings.register_verify_enabled 字段")
|
||||
changed = True
|
||||
if "base_url" not in columns:
|
||||
cursor.execute("ALTER TABLE email_settings ADD COLUMN base_url TEXT DEFAULT ''")
|
||||
print(" [OK] 添加 email_settings.base_url 字段")
|
||||
changed = True
|
||||
if "task_notify_enabled" not in columns:
|
||||
cursor.execute("ALTER TABLE email_settings ADD COLUMN task_notify_enabled INTEGER DEFAULT 0")
|
||||
print(" [OK] 添加 email_settings.task_notify_enabled 字段")
|
||||
changed = True
|
||||
changed = (
|
||||
_add_column_if_missing(
|
||||
cursor,
|
||||
"email_settings",
|
||||
columns,
|
||||
"register_verify_enabled",
|
||||
"INTEGER DEFAULT 0",
|
||||
ok_message=" [OK] 添加 email_settings.register_verify_enabled 字段",
|
||||
)
|
||||
or changed
|
||||
)
|
||||
changed = (
|
||||
_add_column_if_missing(
|
||||
cursor,
|
||||
"email_settings",
|
||||
columns,
|
||||
"base_url",
|
||||
"TEXT DEFAULT ''",
|
||||
ok_message=" [OK] 添加 email_settings.base_url 字段",
|
||||
)
|
||||
or changed
|
||||
)
|
||||
changed = (
|
||||
_add_column_if_missing(
|
||||
cursor,
|
||||
"email_settings",
|
||||
columns,
|
||||
"task_notify_enabled",
|
||||
"INTEGER DEFAULT 0",
|
||||
ok_message=" [OK] 添加 email_settings.task_notify_enabled 字段",
|
||||
)
|
||||
or changed
|
||||
)
|
||||
|
||||
if changed:
|
||||
conn.commit()
|
||||
@@ -459,18 +535,31 @@ def _migrate_to_v9(conn):
|
||||
def _migrate_to_v10(conn):
|
||||
"""迁移到版本10 - users 邮箱字段迁移(避免运行时 ALTER TABLE)"""
|
||||
cursor = conn.cursor()
|
||||
cursor.execute("PRAGMA table_info(users)")
|
||||
columns = [col[1] for col in cursor.fetchall()]
|
||||
columns = _get_table_columns(cursor, "users")
|
||||
|
||||
changed = False
|
||||
if "email_verified" not in columns:
|
||||
cursor.execute("ALTER TABLE users ADD COLUMN email_verified INTEGER DEFAULT 0")
|
||||
print(" [OK] 添加 users.email_verified 字段")
|
||||
changed = True
|
||||
if "email_notify_enabled" not in columns:
|
||||
cursor.execute("ALTER TABLE users ADD COLUMN email_notify_enabled INTEGER DEFAULT 1")
|
||||
print(" [OK] 添加 users.email_notify_enabled 字段")
|
||||
changed = True
|
||||
changed = (
|
||||
_add_column_if_missing(
|
||||
cursor,
|
||||
"users",
|
||||
columns,
|
||||
"email_verified",
|
||||
"INTEGER DEFAULT 0",
|
||||
ok_message=" [OK] 添加 users.email_verified 字段",
|
||||
)
|
||||
or changed
|
||||
)
|
||||
changed = (
|
||||
_add_column_if_missing(
|
||||
cursor,
|
||||
"users",
|
||||
columns,
|
||||
"email_notify_enabled",
|
||||
"INTEGER DEFAULT 1",
|
||||
ok_message=" [OK] 添加 users.email_notify_enabled 字段",
|
||||
)
|
||||
or changed
|
||||
)
|
||||
|
||||
if changed:
|
||||
conn.commit()
|
||||
@@ -657,19 +746,24 @@ def _migrate_to_v15(conn):
|
||||
"""迁移到版本15 - 邮件设置:新设备登录提醒全局开关"""
|
||||
cursor = conn.cursor()
|
||||
|
||||
cursor.execute("SELECT name FROM sqlite_master WHERE type='table' AND name='email_settings'")
|
||||
if not cursor.fetchone():
|
||||
if not _table_exists(cursor, "email_settings"):
|
||||
# 邮件表由 email_service.init_email_tables 创建;此处仅做增量字段迁移
|
||||
return
|
||||
|
||||
cursor.execute("PRAGMA table_info(email_settings)")
|
||||
columns = [col[1] for col in cursor.fetchall()]
|
||||
columns = _get_table_columns(cursor, "email_settings")
|
||||
|
||||
changed = False
|
||||
if "login_alert_enabled" not in columns:
|
||||
cursor.execute("ALTER TABLE email_settings ADD COLUMN login_alert_enabled INTEGER DEFAULT 1")
|
||||
print(" [OK] 添加 email_settings.login_alert_enabled 字段")
|
||||
changed = True
|
||||
changed = (
|
||||
_add_column_if_missing(
|
||||
cursor,
|
||||
"email_settings",
|
||||
columns,
|
||||
"login_alert_enabled",
|
||||
"INTEGER DEFAULT 1",
|
||||
ok_message=" [OK] 添加 email_settings.login_alert_enabled 字段",
|
||||
)
|
||||
or changed
|
||||
)
|
||||
|
||||
try:
|
||||
cursor.execute("UPDATE email_settings SET login_alert_enabled = 1 WHERE login_alert_enabled IS NULL")
|
||||
@@ -686,22 +780,24 @@ def _migrate_to_v15(conn):
|
||||
def _migrate_to_v16(conn):
|
||||
"""迁移到版本16 - 公告支持图片字段"""
|
||||
cursor = conn.cursor()
|
||||
cursor.execute("PRAGMA table_info(announcements)")
|
||||
columns = [col[1] for col in cursor.fetchall()]
|
||||
columns = _get_table_columns(cursor, "announcements")
|
||||
|
||||
if "image_url" not in columns:
|
||||
cursor.execute("ALTER TABLE announcements ADD COLUMN image_url TEXT")
|
||||
if _add_column_if_missing(
|
||||
cursor,
|
||||
"announcements",
|
||||
columns,
|
||||
"image_url",
|
||||
"TEXT",
|
||||
ok_message=" [OK] 添加 announcements.image_url 字段",
|
||||
):
|
||||
conn.commit()
|
||||
print(" [OK] 添加 announcements.image_url 字段")
|
||||
|
||||
|
||||
def _migrate_to_v17(conn):
|
||||
"""迁移到版本17 - 金山文档上传配置与用户开关"""
|
||||
cursor = conn.cursor()
|
||||
|
||||
cursor.execute("PRAGMA table_info(system_config)")
|
||||
columns = [col[1] for col in cursor.fetchall()]
|
||||
|
||||
system_columns = _get_table_columns(cursor, "system_config")
|
||||
system_fields = [
|
||||
("kdocs_enabled", "INTEGER DEFAULT 0"),
|
||||
("kdocs_doc_url", "TEXT DEFAULT ''"),
|
||||
@@ -714,21 +810,29 @@ def _migrate_to_v17(conn):
|
||||
("kdocs_admin_notify_email", "TEXT DEFAULT ''"),
|
||||
]
|
||||
for field, ddl in system_fields:
|
||||
if field not in columns:
|
||||
cursor.execute(f"ALTER TABLE system_config ADD COLUMN {field} {ddl}")
|
||||
print(f" [OK] 添加 system_config.{field} 字段")
|
||||
|
||||
cursor.execute("PRAGMA table_info(users)")
|
||||
columns = [col[1] for col in cursor.fetchall()]
|
||||
_add_column_if_missing(
|
||||
cursor,
|
||||
"system_config",
|
||||
system_columns,
|
||||
field,
|
||||
ddl,
|
||||
ok_message=f" [OK] 添加 system_config.{field} 字段",
|
||||
)
|
||||
|
||||
user_columns = _get_table_columns(cursor, "users")
|
||||
user_fields = [
|
||||
("kdocs_unit", "TEXT DEFAULT ''"),
|
||||
("kdocs_auto_upload", "INTEGER DEFAULT 0"),
|
||||
]
|
||||
for field, ddl in user_fields:
|
||||
if field not in columns:
|
||||
cursor.execute(f"ALTER TABLE users ADD COLUMN {field} {ddl}")
|
||||
print(f" [OK] 添加 users.{field} 字段")
|
||||
_add_column_if_missing(
|
||||
cursor,
|
||||
"users",
|
||||
user_columns,
|
||||
field,
|
||||
ddl,
|
||||
ok_message=f" [OK] 添加 users.{field} 字段",
|
||||
)
|
||||
|
||||
conn.commit()
|
||||
|
||||
@@ -737,15 +841,22 @@ def _migrate_to_v18(conn):
|
||||
"""迁移到版本18 - 金山文档上传:有效行范围配置"""
|
||||
cursor = conn.cursor()
|
||||
|
||||
cursor.execute("PRAGMA table_info(system_config)")
|
||||
columns = [col[1] for col in cursor.fetchall()]
|
||||
|
||||
if "kdocs_row_start" not in columns:
|
||||
cursor.execute("ALTER TABLE system_config ADD COLUMN kdocs_row_start INTEGER DEFAULT 0")
|
||||
print(" [OK] 添加 system_config.kdocs_row_start 字段")
|
||||
|
||||
if "kdocs_row_end" not in columns:
|
||||
cursor.execute("ALTER TABLE system_config ADD COLUMN kdocs_row_end INTEGER DEFAULT 0")
|
||||
print(" [OK] 添加 system_config.kdocs_row_end 字段")
|
||||
columns = _get_table_columns(cursor, "system_config")
|
||||
_add_column_if_missing(
|
||||
cursor,
|
||||
"system_config",
|
||||
columns,
|
||||
"kdocs_row_start",
|
||||
"INTEGER DEFAULT 0",
|
||||
ok_message=" [OK] 添加 system_config.kdocs_row_start 字段",
|
||||
)
|
||||
_add_column_if_missing(
|
||||
cursor,
|
||||
"system_config",
|
||||
columns,
|
||||
"kdocs_row_end",
|
||||
"INTEGER DEFAULT 0",
|
||||
ok_message=" [OK] 添加 system_config.kdocs_row_end 字段",
|
||||
)
|
||||
|
||||
conn.commit()
|
||||
|
||||
292
db/schedules.py
292
db/schedules.py
@@ -2,12 +2,93 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
from __future__ import annotations
|
||||
|
||||
from datetime import datetime
|
||||
import json
|
||||
from datetime import datetime, timedelta
|
||||
|
||||
import db_pool
|
||||
from services.schedule_utils import compute_next_run_at, format_cst
|
||||
from services.time_utils import get_beijing_now
|
||||
|
||||
_SCHEDULE_DEFAULT_TIME = "08:00"
|
||||
_SCHEDULE_DEFAULT_WEEKDAYS = "1,2,3,4,5"
|
||||
|
||||
_ALLOWED_SCHEDULE_UPDATE_FIELDS = (
|
||||
"name",
|
||||
"enabled",
|
||||
"schedule_time",
|
||||
"weekdays",
|
||||
"browse_type",
|
||||
"enable_screenshot",
|
||||
"random_delay",
|
||||
"account_ids",
|
||||
)
|
||||
|
||||
_ALLOWED_EXEC_LOG_UPDATE_FIELDS = (
|
||||
"total_accounts",
|
||||
"success_accounts",
|
||||
"failed_accounts",
|
||||
"total_items",
|
||||
"total_attachments",
|
||||
"total_screenshots",
|
||||
"duration_seconds",
|
||||
"status",
|
||||
"error_message",
|
||||
)
|
||||
|
||||
|
||||
def _normalize_limit(limit, default: int, *, minimum: int = 1) -> int:
|
||||
try:
|
||||
parsed = int(limit)
|
||||
except Exception:
|
||||
parsed = default
|
||||
if parsed < minimum:
|
||||
return minimum
|
||||
return parsed
|
||||
|
||||
|
||||
def _to_int(value, default: int = 0) -> int:
|
||||
try:
|
||||
return int(value)
|
||||
except Exception:
|
||||
return default
|
||||
|
||||
|
||||
def _format_optional_datetime(dt: datetime | None) -> str | None:
|
||||
if dt is None:
|
||||
return None
|
||||
return format_cst(dt)
|
||||
|
||||
|
||||
def _serialize_account_ids(account_ids) -> str:
|
||||
return json.dumps(account_ids) if account_ids else "[]"
|
||||
|
||||
|
||||
def _compute_schedule_next_run_str(
|
||||
*,
|
||||
now_dt,
|
||||
schedule_time,
|
||||
weekdays,
|
||||
random_delay,
|
||||
last_run_at,
|
||||
) -> str:
|
||||
next_dt = compute_next_run_at(
|
||||
now=now_dt,
|
||||
schedule_time=str(schedule_time or _SCHEDULE_DEFAULT_TIME),
|
||||
weekdays=str(weekdays or _SCHEDULE_DEFAULT_WEEKDAYS),
|
||||
random_delay=_to_int(random_delay, 0),
|
||||
last_run_at=str(last_run_at or "") if last_run_at else None,
|
||||
)
|
||||
return format_cst(next_dt)
|
||||
|
||||
|
||||
def _map_schedule_log_row(row) -> dict:
|
||||
log = dict(row)
|
||||
log["created_at"] = log.get("execute_time")
|
||||
log["success_count"] = log.get("success_accounts", 0)
|
||||
log["failed_count"] = log.get("failed_accounts", 0)
|
||||
log["duration"] = log.get("duration_seconds", 0)
|
||||
return log
|
||||
|
||||
|
||||
def get_user_schedules(user_id):
|
||||
"""获取用户的所有定时任务"""
|
||||
@@ -44,14 +125,10 @@ def create_user_schedule(
|
||||
account_ids=None,
|
||||
):
|
||||
"""创建用户定时任务"""
|
||||
import json
|
||||
|
||||
with db_pool.get_db() as conn:
|
||||
cursor = conn.cursor()
|
||||
cst_time = format_cst(get_beijing_now())
|
||||
|
||||
account_ids_str = json.dumps(account_ids) if account_ids else "[]"
|
||||
|
||||
cursor.execute(
|
||||
"""
|
||||
INSERT INTO user_schedules (
|
||||
@@ -66,8 +143,8 @@ def create_user_schedule(
|
||||
weekdays,
|
||||
browse_type,
|
||||
enable_screenshot,
|
||||
int(random_delay or 0),
|
||||
account_ids_str,
|
||||
_to_int(random_delay, 0),
|
||||
_serialize_account_ids(account_ids),
|
||||
cst_time,
|
||||
cst_time,
|
||||
),
|
||||
@@ -79,28 +156,11 @@ def create_user_schedule(
|
||||
|
||||
def update_user_schedule(schedule_id, **kwargs):
|
||||
"""更新用户定时任务"""
|
||||
import json
|
||||
|
||||
with db_pool.get_db() as conn:
|
||||
cursor = conn.cursor()
|
||||
now_dt = get_beijing_now()
|
||||
now_str = format_cst(now_dt)
|
||||
|
||||
updates = []
|
||||
params = []
|
||||
|
||||
allowed_fields = [
|
||||
"name",
|
||||
"enabled",
|
||||
"schedule_time",
|
||||
"weekdays",
|
||||
"browse_type",
|
||||
"enable_screenshot",
|
||||
"random_delay",
|
||||
"account_ids",
|
||||
]
|
||||
|
||||
# 读取旧值,用于决定是否需要重算 next_run_at
|
||||
cursor.execute(
|
||||
"""
|
||||
SELECT enabled, schedule_time, weekdays, random_delay, last_run_at
|
||||
@@ -112,10 +172,11 @@ def update_user_schedule(schedule_id, **kwargs):
|
||||
current = cursor.fetchone()
|
||||
if not current:
|
||||
return False
|
||||
current_enabled = int(current[0] or 0)
|
||||
|
||||
current_enabled = _to_int(current[0], 0)
|
||||
current_time = current[1]
|
||||
current_weekdays = current[2]
|
||||
current_random_delay = int(current[3] or 0)
|
||||
current_random_delay = _to_int(current[3], 0)
|
||||
current_last_run_at = current[4]
|
||||
|
||||
will_enabled = current_enabled
|
||||
@@ -123,21 +184,28 @@ def update_user_schedule(schedule_id, **kwargs):
|
||||
next_weekdays = current_weekdays
|
||||
next_random_delay = current_random_delay
|
||||
|
||||
for field in allowed_fields:
|
||||
if field in kwargs:
|
||||
value = kwargs[field]
|
||||
if field == "account_ids" and isinstance(value, list):
|
||||
value = json.dumps(value)
|
||||
if field == "enabled":
|
||||
will_enabled = 1 if value else 0
|
||||
if field == "schedule_time":
|
||||
next_time = value
|
||||
if field == "weekdays":
|
||||
next_weekdays = value
|
||||
if field == "random_delay":
|
||||
next_random_delay = int(value or 0)
|
||||
updates.append(f"{field} = ?")
|
||||
params.append(value)
|
||||
updates = []
|
||||
params = []
|
||||
|
||||
for field in _ALLOWED_SCHEDULE_UPDATE_FIELDS:
|
||||
if field not in kwargs:
|
||||
continue
|
||||
|
||||
value = kwargs[field]
|
||||
if field == "account_ids" and isinstance(value, list):
|
||||
value = json.dumps(value)
|
||||
|
||||
if field == "enabled":
|
||||
will_enabled = 1 if value else 0
|
||||
if field == "schedule_time":
|
||||
next_time = value
|
||||
if field == "weekdays":
|
||||
next_weekdays = value
|
||||
if field == "random_delay":
|
||||
next_random_delay = int(value or 0)
|
||||
|
||||
updates.append(f"{field} = ?")
|
||||
params.append(value)
|
||||
|
||||
if not updates:
|
||||
return False
|
||||
@@ -145,30 +213,26 @@ def update_user_schedule(schedule_id, **kwargs):
|
||||
updates.append("updated_at = ?")
|
||||
params.append(now_str)
|
||||
|
||||
# 关键字段变更后重算 next_run_at,确保索引驱动不会跑偏
|
||||
#
|
||||
# 需求:当用户修改“执行时间/执行日期/随机±15分钟”后,即使今天已经执行过,也允许按新配置在今天再次触发。
|
||||
# 做法:这些关键字段发生变更时,重算 next_run_at 时忽略 last_run_at 的“同日仅一次”限制。
|
||||
config_changed = any(key in kwargs for key in ["schedule_time", "weekdays", "random_delay"])
|
||||
config_changed = any(key in kwargs for key in ("schedule_time", "weekdays", "random_delay"))
|
||||
enabled_toggled = "enabled" in kwargs
|
||||
should_recompute_next = config_changed or (enabled_toggled and will_enabled == 1)
|
||||
|
||||
if should_recompute_next:
|
||||
next_dt = compute_next_run_at(
|
||||
now=now_dt,
|
||||
schedule_time=str(next_time or "08:00"),
|
||||
weekdays=str(next_weekdays or "1,2,3,4,5"),
|
||||
random_delay=int(next_random_delay or 0),
|
||||
last_run_at=None if config_changed else (str(current_last_run_at or "") if current_last_run_at else None),
|
||||
next_run_at = _compute_schedule_next_run_str(
|
||||
now_dt=now_dt,
|
||||
schedule_time=next_time,
|
||||
weekdays=next_weekdays,
|
||||
random_delay=next_random_delay,
|
||||
last_run_at=None if config_changed else current_last_run_at,
|
||||
)
|
||||
updates.append("next_run_at = ?")
|
||||
params.append(format_cst(next_dt))
|
||||
params.append(next_run_at)
|
||||
|
||||
# 若本次显式禁用任务,则 next_run_at 清空(与 toggle 行为保持一致)
|
||||
if enabled_toggled and will_enabled == 0:
|
||||
updates.append("next_run_at = ?")
|
||||
params.append(None)
|
||||
params.append(schedule_id)
|
||||
|
||||
params.append(schedule_id)
|
||||
sql = f"UPDATE user_schedules SET {', '.join(updates)} WHERE id = ?"
|
||||
cursor.execute(sql, params)
|
||||
conn.commit()
|
||||
@@ -203,28 +267,19 @@ def toggle_user_schedule(schedule_id, enabled):
|
||||
)
|
||||
row = cursor.fetchone()
|
||||
if row:
|
||||
schedule_time, weekdays, random_delay, last_run_at, existing_next_run_at = (
|
||||
row[0],
|
||||
row[1],
|
||||
row[2],
|
||||
row[3],
|
||||
row[4],
|
||||
)
|
||||
|
||||
schedule_time, weekdays, random_delay, last_run_at, existing_next_run_at = row
|
||||
existing_next_run_at = str(existing_next_run_at or "").strip() or None
|
||||
# 若 next_run_at 已经被“修改配置”逻辑预先计算好且仍在未来,则优先沿用,
|
||||
# 避免 last_run_at 的“同日仅一次”限制阻塞用户把任务调整到今天再次触发。
|
||||
|
||||
if existing_next_run_at and existing_next_run_at > now_str:
|
||||
next_run_at = existing_next_run_at
|
||||
else:
|
||||
next_dt = compute_next_run_at(
|
||||
now=now_dt,
|
||||
schedule_time=str(schedule_time or "08:00"),
|
||||
weekdays=str(weekdays or "1,2,3,4,5"),
|
||||
random_delay=int(random_delay or 0),
|
||||
last_run_at=str(last_run_at or "") if last_run_at else None,
|
||||
next_run_at = _compute_schedule_next_run_str(
|
||||
now_dt=now_dt,
|
||||
schedule_time=schedule_time,
|
||||
weekdays=weekdays,
|
||||
random_delay=random_delay,
|
||||
last_run_at=last_run_at,
|
||||
)
|
||||
next_run_at = format_cst(next_dt)
|
||||
|
||||
cursor.execute(
|
||||
"""
|
||||
@@ -272,16 +327,15 @@ def update_schedule_last_run(schedule_id):
|
||||
row = cursor.fetchone()
|
||||
if not row:
|
||||
return False
|
||||
schedule_time, weekdays, random_delay = row[0], row[1], row[2]
|
||||
|
||||
next_dt = compute_next_run_at(
|
||||
now=now_dt,
|
||||
schedule_time=str(schedule_time or "08:00"),
|
||||
weekdays=str(weekdays or "1,2,3,4,5"),
|
||||
random_delay=int(random_delay or 0),
|
||||
schedule_time, weekdays, random_delay = row
|
||||
next_run_at = _compute_schedule_next_run_str(
|
||||
now_dt=now_dt,
|
||||
schedule_time=schedule_time,
|
||||
weekdays=weekdays,
|
||||
random_delay=random_delay,
|
||||
last_run_at=now_str,
|
||||
)
|
||||
next_run_at = format_cst(next_dt)
|
||||
|
||||
cursor.execute(
|
||||
"""
|
||||
@@ -305,7 +359,11 @@ def update_schedule_next_run(schedule_id: int, next_run_at: str) -> bool:
|
||||
SET next_run_at = ?, updated_at = ?
|
||||
WHERE id = ?
|
||||
""",
|
||||
(str(next_run_at or "").strip() or None, format_cst(get_beijing_now()), int(schedule_id)),
|
||||
(
|
||||
str(next_run_at or "").strip() or None,
|
||||
format_cst(get_beijing_now()),
|
||||
int(schedule_id),
|
||||
),
|
||||
)
|
||||
conn.commit()
|
||||
return cursor.rowcount > 0
|
||||
@@ -328,15 +386,15 @@ def recompute_schedule_next_run(schedule_id: int, *, now_dt=None) -> bool:
|
||||
if not row:
|
||||
return False
|
||||
|
||||
schedule_time, weekdays, random_delay, last_run_at = row[0], row[1], row[2], row[3]
|
||||
next_dt = compute_next_run_at(
|
||||
now=now_dt,
|
||||
schedule_time=str(schedule_time or "08:00"),
|
||||
weekdays=str(weekdays or "1,2,3,4,5"),
|
||||
random_delay=int(random_delay or 0),
|
||||
last_run_at=str(last_run_at or "") if last_run_at else None,
|
||||
schedule_time, weekdays, random_delay, last_run_at = row
|
||||
next_run_at = _compute_schedule_next_run_str(
|
||||
now_dt=now_dt,
|
||||
schedule_time=schedule_time,
|
||||
weekdays=weekdays,
|
||||
random_delay=random_delay,
|
||||
last_run_at=last_run_at,
|
||||
)
|
||||
return update_schedule_next_run(int(schedule_id), format_cst(next_dt))
|
||||
return update_schedule_next_run(int(schedule_id), next_run_at)
|
||||
|
||||
|
||||
def get_due_user_schedules(now_cst: str, limit: int = 50):
|
||||
@@ -345,6 +403,8 @@ def get_due_user_schedules(now_cst: str, limit: int = 50):
|
||||
if not now_cst:
|
||||
now_cst = format_cst(get_beijing_now())
|
||||
|
||||
safe_limit = _normalize_limit(limit, 50, minimum=1)
|
||||
|
||||
with db_pool.get_db() as conn:
|
||||
cursor = conn.cursor()
|
||||
cursor.execute(
|
||||
@@ -358,7 +418,7 @@ def get_due_user_schedules(now_cst: str, limit: int = 50):
|
||||
ORDER BY us.next_run_at ASC
|
||||
LIMIT ?
|
||||
""",
|
||||
(now_cst, int(limit)),
|
||||
(now_cst, safe_limit),
|
||||
)
|
||||
return [dict(row) for row in cursor.fetchall()]
|
||||
|
||||
@@ -370,15 +430,13 @@ def create_schedule_execution_log(schedule_id, user_id, schedule_name):
|
||||
"""创建定时任务执行日志"""
|
||||
with db_pool.get_db() as conn:
|
||||
cursor = conn.cursor()
|
||||
execute_time = format_cst(get_beijing_now())
|
||||
|
||||
cursor.execute(
|
||||
"""
|
||||
INSERT INTO schedule_execution_logs (
|
||||
schedule_id, user_id, schedule_name, execute_time, status
|
||||
) VALUES (?, ?, ?, ?, 'running')
|
||||
""",
|
||||
(schedule_id, user_id, schedule_name, execute_time),
|
||||
(schedule_id, user_id, schedule_name, format_cst(get_beijing_now())),
|
||||
)
|
||||
|
||||
conn.commit()
|
||||
@@ -393,22 +451,11 @@ def update_schedule_execution_log(log_id, **kwargs):
|
||||
updates = []
|
||||
params = []
|
||||
|
||||
allowed_fields = [
|
||||
"total_accounts",
|
||||
"success_accounts",
|
||||
"failed_accounts",
|
||||
"total_items",
|
||||
"total_attachments",
|
||||
"total_screenshots",
|
||||
"duration_seconds",
|
||||
"status",
|
||||
"error_message",
|
||||
]
|
||||
|
||||
for field in allowed_fields:
|
||||
if field in kwargs:
|
||||
updates.append(f"{field} = ?")
|
||||
params.append(kwargs[field])
|
||||
for field in _ALLOWED_EXEC_LOG_UPDATE_FIELDS:
|
||||
if field not in kwargs:
|
||||
continue
|
||||
updates.append(f"{field} = ?")
|
||||
params.append(kwargs[field])
|
||||
|
||||
if not updates:
|
||||
return False
|
||||
@@ -424,6 +471,7 @@ def update_schedule_execution_log(log_id, **kwargs):
|
||||
def get_schedule_execution_logs(schedule_id, limit=10):
|
||||
"""获取定时任务执行日志"""
|
||||
try:
|
||||
safe_limit = _normalize_limit(limit, 10, minimum=1)
|
||||
with db_pool.get_db() as conn:
|
||||
cursor = conn.cursor()
|
||||
cursor.execute(
|
||||
@@ -433,24 +481,16 @@ def get_schedule_execution_logs(schedule_id, limit=10):
|
||||
ORDER BY execute_time DESC
|
||||
LIMIT ?
|
||||
""",
|
||||
(schedule_id, limit),
|
||||
(schedule_id, safe_limit),
|
||||
)
|
||||
|
||||
logs = []
|
||||
rows = cursor.fetchall()
|
||||
|
||||
for row in rows:
|
||||
for row in cursor.fetchall():
|
||||
try:
|
||||
log = dict(row)
|
||||
log["created_at"] = log.get("execute_time")
|
||||
log["success_count"] = log.get("success_accounts", 0)
|
||||
log["failed_count"] = log.get("failed_accounts", 0)
|
||||
log["duration"] = log.get("duration_seconds", 0)
|
||||
logs.append(log)
|
||||
logs.append(_map_schedule_log_row(row))
|
||||
except Exception as e:
|
||||
print(f"[数据库] 处理日志行时出错: {e}")
|
||||
continue
|
||||
|
||||
return logs
|
||||
except Exception as e:
|
||||
print(f"[数据库] 查询定时任务日志时出错: {e}")
|
||||
@@ -462,6 +502,7 @@ def get_schedule_execution_logs(schedule_id, limit=10):
|
||||
|
||||
def get_user_all_schedule_logs(user_id, limit=50):
|
||||
"""获取用户所有定时任务的执行日志"""
|
||||
safe_limit = _normalize_limit(limit, 50, minimum=1)
|
||||
with db_pool.get_db() as conn:
|
||||
cursor = conn.cursor()
|
||||
cursor.execute(
|
||||
@@ -471,7 +512,7 @@ def get_user_all_schedule_logs(user_id, limit=50):
|
||||
ORDER BY execute_time DESC
|
||||
LIMIT ?
|
||||
""",
|
||||
(user_id, limit),
|
||||
(user_id, safe_limit),
|
||||
)
|
||||
return [dict(row) for row in cursor.fetchall()]
|
||||
|
||||
@@ -493,14 +534,21 @@ def delete_schedule_logs(schedule_id, user_id):
|
||||
|
||||
def clean_old_schedule_logs(days=30):
|
||||
"""清理指定天数前的定时任务执行日志"""
|
||||
safe_days = _to_int(days, 30)
|
||||
if safe_days < 0:
|
||||
safe_days = 0
|
||||
|
||||
cutoff_dt = get_beijing_now() - timedelta(days=safe_days)
|
||||
cutoff_str = format_cst(cutoff_dt)
|
||||
|
||||
with db_pool.get_db() as conn:
|
||||
cursor = conn.cursor()
|
||||
cursor.execute(
|
||||
"""
|
||||
DELETE FROM schedule_execution_logs
|
||||
WHERE execute_time < datetime('now', 'localtime', '-' || ? || ' days')
|
||||
WHERE execute_time < ?
|
||||
""",
|
||||
(days,),
|
||||
(cutoff_str,),
|
||||
)
|
||||
conn.commit()
|
||||
return cursor.rowcount
|
||||
|
||||
@@ -362,6 +362,8 @@ def ensure_schema(conn) -> None:
|
||||
cursor.execute("CREATE INDEX IF NOT EXISTS idx_users_username ON users(username)")
|
||||
cursor.execute("CREATE INDEX IF NOT EXISTS idx_users_status ON users(status)")
|
||||
cursor.execute("CREATE INDEX IF NOT EXISTS idx_users_vip_expire ON users(vip_expire_time)")
|
||||
cursor.execute("CREATE INDEX IF NOT EXISTS idx_users_email ON users(email)")
|
||||
cursor.execute("CREATE INDEX IF NOT EXISTS idx_users_created_at ON users(created_at)")
|
||||
cursor.execute("CREATE INDEX IF NOT EXISTS idx_login_fingerprints_user ON login_fingerprints(user_id)")
|
||||
cursor.execute("CREATE INDEX IF NOT EXISTS idx_login_ips_user ON login_ips(user_id)")
|
||||
|
||||
@@ -391,6 +393,8 @@ def ensure_schema(conn) -> None:
|
||||
cursor.execute("CREATE INDEX IF NOT EXISTS idx_task_logs_user_id ON task_logs(user_id)")
|
||||
cursor.execute("CREATE INDEX IF NOT EXISTS idx_task_logs_status ON task_logs(status)")
|
||||
cursor.execute("CREATE INDEX IF NOT EXISTS idx_task_logs_created_at ON task_logs(created_at)")
|
||||
cursor.execute("CREATE INDEX IF NOT EXISTS idx_task_logs_source ON task_logs(source)")
|
||||
cursor.execute("CREATE INDEX IF NOT EXISTS idx_task_logs_source_created_at ON task_logs(source, created_at)")
|
||||
cursor.execute("CREATE INDEX IF NOT EXISTS idx_task_logs_user_date ON task_logs(user_id, created_at)")
|
||||
|
||||
cursor.execute("CREATE INDEX IF NOT EXISTS idx_bug_feedbacks_user_id ON bug_feedbacks(user_id)")
|
||||
@@ -409,6 +413,9 @@ def ensure_schema(conn) -> None:
|
||||
cursor.execute("CREATE INDEX IF NOT EXISTS idx_schedule_execution_logs_schedule_id ON schedule_execution_logs(schedule_id)")
|
||||
cursor.execute("CREATE INDEX IF NOT EXISTS idx_schedule_execution_logs_user_id ON schedule_execution_logs(user_id)")
|
||||
cursor.execute("CREATE INDEX IF NOT EXISTS idx_schedule_execution_logs_status ON schedule_execution_logs(status)")
|
||||
cursor.execute("CREATE INDEX IF NOT EXISTS idx_schedule_execution_logs_execute_time ON schedule_execution_logs(execute_time)")
|
||||
cursor.execute("CREATE INDEX IF NOT EXISTS idx_schedule_execution_logs_schedule_time ON schedule_execution_logs(schedule_id, execute_time)")
|
||||
cursor.execute("CREATE INDEX IF NOT EXISTS idx_schedule_execution_logs_user_time ON schedule_execution_logs(user_id, execute_time)")
|
||||
|
||||
# 初始化VIP配置(幂等)
|
||||
try:
|
||||
|
||||
167
db/security.py
167
db/security.py
@@ -3,13 +3,82 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from datetime import timedelta
|
||||
from typing import Any, Optional
|
||||
from typing import Dict
|
||||
from typing import Any, Dict, Optional
|
||||
|
||||
import db_pool
|
||||
from db.utils import get_cst_now, get_cst_now_str
|
||||
|
||||
|
||||
_THREAT_EVENT_SELECT_COLUMNS = """
|
||||
id,
|
||||
threat_type,
|
||||
score,
|
||||
rule,
|
||||
field_name,
|
||||
matched,
|
||||
value_preview,
|
||||
ip,
|
||||
user_id,
|
||||
request_method,
|
||||
request_path,
|
||||
user_agent,
|
||||
created_at
|
||||
"""
|
||||
|
||||
|
||||
def _normalize_page(page: int) -> int:
|
||||
try:
|
||||
page_i = int(page)
|
||||
except Exception:
|
||||
page_i = 1
|
||||
return max(1, page_i)
|
||||
|
||||
|
||||
def _normalize_per_page(per_page: int, default: int = 20) -> int:
|
||||
try:
|
||||
value = int(per_page)
|
||||
except Exception:
|
||||
value = default
|
||||
return max(1, min(200, value))
|
||||
|
||||
|
||||
def _normalize_limit(limit: int, default: int = 50) -> int:
|
||||
try:
|
||||
value = int(limit)
|
||||
except Exception:
|
||||
value = default
|
||||
return max(1, min(200, value))
|
||||
|
||||
|
||||
def _row_value(row, key: str, index: int = 0, default=None):
|
||||
if row is None:
|
||||
return default
|
||||
try:
|
||||
return row[key]
|
||||
except Exception:
|
||||
try:
|
||||
return row[index]
|
||||
except Exception:
|
||||
return default
|
||||
|
||||
|
||||
def _fetch_threat_events_history(where_clause: str, params: tuple[Any, ...], limit_i: int) -> list[dict]:
|
||||
with db_pool.get_db() as conn:
|
||||
cursor = conn.cursor()
|
||||
cursor.execute(
|
||||
f"""
|
||||
SELECT
|
||||
{_THREAT_EVENT_SELECT_COLUMNS}
|
||||
FROM threat_events
|
||||
WHERE {where_clause}
|
||||
ORDER BY created_at DESC, id DESC
|
||||
LIMIT ?
|
||||
""",
|
||||
tuple(params) + (limit_i,),
|
||||
)
|
||||
return [dict(r) for r in cursor.fetchall()]
|
||||
|
||||
|
||||
def record_login_context(user_id: int, ip_address: str, user_agent: str) -> Dict[str, bool]:
|
||||
"""记录登录环境信息,返回是否新设备/新IP。"""
|
||||
user_id = int(user_id)
|
||||
@@ -36,7 +105,7 @@ def record_login_context(user_id: int, ip_address: str, user_agent: str) -> Dict
|
||||
SET last_seen = ?, last_ip = ?
|
||||
WHERE id = ?
|
||||
""",
|
||||
(now_str, ip_text, row["id"] if isinstance(row, dict) else row[0]),
|
||||
(now_str, ip_text, _row_value(row, "id", 0)),
|
||||
)
|
||||
else:
|
||||
cursor.execute(
|
||||
@@ -61,7 +130,7 @@ def record_login_context(user_id: int, ip_address: str, user_agent: str) -> Dict
|
||||
SET last_seen = ?
|
||||
WHERE id = ?
|
||||
""",
|
||||
(now_str, row["id"] if isinstance(row, dict) else row[0]),
|
||||
(now_str, _row_value(row, "id", 0)),
|
||||
)
|
||||
else:
|
||||
cursor.execute(
|
||||
@@ -166,15 +235,8 @@ def _build_threat_events_where_clause(filters: Optional[dict]) -> tuple[str, lis
|
||||
|
||||
def get_threat_events_list(page: int, per_page: int, filters: Optional[dict] = None) -> dict:
|
||||
"""分页获取威胁事件。"""
|
||||
try:
|
||||
page_i = max(1, int(page))
|
||||
except Exception:
|
||||
page_i = 1
|
||||
try:
|
||||
per_page_i = int(per_page)
|
||||
except Exception:
|
||||
per_page_i = 20
|
||||
per_page_i = max(1, min(200, per_page_i))
|
||||
page_i = _normalize_page(page)
|
||||
per_page_i = _normalize_per_page(per_page, default=20)
|
||||
|
||||
where_sql, params = _build_threat_events_where_clause(filters)
|
||||
offset = (page_i - 1) * per_page_i
|
||||
@@ -188,19 +250,7 @@ def get_threat_events_list(page: int, per_page: int, filters: Optional[dict] = N
|
||||
cursor.execute(
|
||||
f"""
|
||||
SELECT
|
||||
id,
|
||||
threat_type,
|
||||
score,
|
||||
rule,
|
||||
field_name,
|
||||
matched,
|
||||
value_preview,
|
||||
ip,
|
||||
user_id,
|
||||
request_method,
|
||||
request_path,
|
||||
user_agent,
|
||||
created_at
|
||||
{_THREAT_EVENT_SELECT_COLUMNS}
|
||||
FROM threat_events
|
||||
{where_sql}
|
||||
ORDER BY created_at DESC, id DESC
|
||||
@@ -218,75 +268,20 @@ def get_ip_threat_history(ip: str, limit: int = 50) -> list[dict]:
|
||||
ip_text = str(ip or "").strip()[:64]
|
||||
if not ip_text:
|
||||
return []
|
||||
try:
|
||||
limit_i = max(1, min(200, int(limit)))
|
||||
except Exception:
|
||||
limit_i = 50
|
||||
|
||||
with db_pool.get_db() as conn:
|
||||
cursor = conn.cursor()
|
||||
cursor.execute(
|
||||
"""
|
||||
SELECT
|
||||
id,
|
||||
threat_type,
|
||||
score,
|
||||
rule,
|
||||
field_name,
|
||||
matched,
|
||||
value_preview,
|
||||
ip,
|
||||
user_id,
|
||||
request_method,
|
||||
request_path,
|
||||
user_agent,
|
||||
created_at
|
||||
FROM threat_events
|
||||
WHERE ip = ?
|
||||
ORDER BY created_at DESC, id DESC
|
||||
LIMIT ?
|
||||
""",
|
||||
(ip_text, limit_i),
|
||||
)
|
||||
return [dict(r) for r in cursor.fetchall()]
|
||||
limit_i = _normalize_limit(limit, default=50)
|
||||
return _fetch_threat_events_history("ip = ?", (ip_text,), limit_i)
|
||||
|
||||
|
||||
def get_user_threat_history(user_id: int, limit: int = 50) -> list[dict]:
|
||||
"""获取用户的威胁历史(最近limit条)。"""
|
||||
if user_id is None:
|
||||
return []
|
||||
|
||||
try:
|
||||
user_id_int = int(user_id)
|
||||
except Exception:
|
||||
return []
|
||||
try:
|
||||
limit_i = max(1, min(200, int(limit)))
|
||||
except Exception:
|
||||
limit_i = 50
|
||||
|
||||
with db_pool.get_db() as conn:
|
||||
cursor = conn.cursor()
|
||||
cursor.execute(
|
||||
"""
|
||||
SELECT
|
||||
id,
|
||||
threat_type,
|
||||
score,
|
||||
rule,
|
||||
field_name,
|
||||
matched,
|
||||
value_preview,
|
||||
ip,
|
||||
user_id,
|
||||
request_method,
|
||||
request_path,
|
||||
user_agent,
|
||||
created_at
|
||||
FROM threat_events
|
||||
WHERE user_id = ?
|
||||
ORDER BY created_at DESC, id DESC
|
||||
LIMIT ?
|
||||
""",
|
||||
(user_id_int, limit_i),
|
||||
)
|
||||
return [dict(r) for r in cursor.fetchall()]
|
||||
limit_i = _normalize_limit(limit, default=50)
|
||||
return _fetch_threat_events_history("user_id = ?", (user_id_int,), limit_i)
|
||||
|
||||
303
db/tasks.py
303
db/tasks.py
@@ -2,12 +2,135 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
from __future__ import annotations
|
||||
|
||||
from datetime import datetime
|
||||
|
||||
import pytz
|
||||
from datetime import datetime, timedelta
|
||||
|
||||
import db_pool
|
||||
from db.utils import sanitize_sql_like_pattern
|
||||
from db.utils import get_cst_now, get_cst_now_str, sanitize_sql_like_pattern
|
||||
|
||||
_TASK_STATS_SELECT_SQL = """
|
||||
SELECT
|
||||
COUNT(*) as total_tasks,
|
||||
SUM(CASE WHEN status = 'success' THEN 1 ELSE 0 END) as success_tasks,
|
||||
SUM(CASE WHEN status = 'failed' THEN 1 ELSE 0 END) as failed_tasks,
|
||||
SUM(total_items) as total_items,
|
||||
SUM(total_attachments) as total_attachments
|
||||
FROM task_logs
|
||||
"""
|
||||
|
||||
_USER_RUN_STATS_SELECT_SQL = """
|
||||
SELECT
|
||||
SUM(CASE WHEN status = 'success' THEN 1 ELSE 0 END) as completed,
|
||||
SUM(CASE WHEN status = 'failed' THEN 1 ELSE 0 END) as failed,
|
||||
SUM(total_items) as total_items,
|
||||
SUM(total_attachments) as total_attachments
|
||||
FROM task_logs
|
||||
"""
|
||||
|
||||
|
||||
def _build_day_bounds(date_filter: str) -> tuple[str | None, str | None]:
|
||||
"""将 YYYY-MM-DD 转换为 [day_start, day_end) 区间。"""
|
||||
try:
|
||||
day_start = datetime.strptime(str(date_filter), "%Y-%m-%d")
|
||||
except Exception:
|
||||
return None, None
|
||||
|
||||
day_end = day_start + timedelta(days=1)
|
||||
return day_start.strftime("%Y-%m-%d %H:%M:%S"), day_end.strftime("%Y-%m-%d %H:%M:%S")
|
||||
|
||||
|
||||
def _normalize_int(value, default: int, *, minimum: int | None = None) -> int:
|
||||
try:
|
||||
parsed = int(value)
|
||||
except Exception:
|
||||
parsed = default
|
||||
if minimum is not None and parsed < minimum:
|
||||
return minimum
|
||||
return parsed
|
||||
|
||||
|
||||
def _stat_value(row, key: str) -> int:
|
||||
try:
|
||||
value = row[key] if row else 0
|
||||
except Exception:
|
||||
value = 0
|
||||
return int(value or 0)
|
||||
|
||||
|
||||
def _build_task_logs_where_sql(
|
||||
*,
|
||||
date_filter=None,
|
||||
status_filter=None,
|
||||
source_filter=None,
|
||||
user_id_filter=None,
|
||||
account_filter=None,
|
||||
) -> tuple[str, list]:
|
||||
where_clauses = ["1=1"]
|
||||
params = []
|
||||
|
||||
if date_filter:
|
||||
day_start, day_end = _build_day_bounds(date_filter)
|
||||
if day_start and day_end:
|
||||
where_clauses.append("tl.created_at >= ? AND tl.created_at < ?")
|
||||
params.extend([day_start, day_end])
|
||||
else:
|
||||
where_clauses.append("date(tl.created_at) = ?")
|
||||
params.append(date_filter)
|
||||
|
||||
if status_filter:
|
||||
where_clauses.append("tl.status = ?")
|
||||
params.append(status_filter)
|
||||
|
||||
if source_filter:
|
||||
source_filter = str(source_filter or "").strip()
|
||||
if source_filter == "user_scheduled":
|
||||
where_clauses.append("tl.source LIKE ? ESCAPE '\\\\'")
|
||||
params.append("user_scheduled:%")
|
||||
elif source_filter.endswith("*"):
|
||||
prefix = source_filter[:-1]
|
||||
safe_prefix = sanitize_sql_like_pattern(prefix)
|
||||
where_clauses.append("tl.source LIKE ? ESCAPE '\\\\'")
|
||||
params.append(f"{safe_prefix}%")
|
||||
else:
|
||||
where_clauses.append("tl.source = ?")
|
||||
params.append(source_filter)
|
||||
|
||||
if user_id_filter:
|
||||
where_clauses.append("tl.user_id = ?")
|
||||
params.append(user_id_filter)
|
||||
|
||||
if account_filter:
|
||||
safe_filter = sanitize_sql_like_pattern(account_filter)
|
||||
where_clauses.append("tl.username LIKE ? ESCAPE '\\\\'")
|
||||
params.append(f"%{safe_filter}%")
|
||||
|
||||
return " AND ".join(where_clauses), params
|
||||
|
||||
|
||||
def _fetch_task_stats_row(cursor, *, where_clause: str = "", params: tuple | list = ()) -> dict:
|
||||
sql = _TASK_STATS_SELECT_SQL
|
||||
if where_clause:
|
||||
sql = f"{sql}\nWHERE {where_clause}"
|
||||
cursor.execute(sql, params)
|
||||
row = cursor.fetchone()
|
||||
return {
|
||||
"total_tasks": _stat_value(row, "total_tasks"),
|
||||
"success_tasks": _stat_value(row, "success_tasks"),
|
||||
"failed_tasks": _stat_value(row, "failed_tasks"),
|
||||
"total_items": _stat_value(row, "total_items"),
|
||||
"total_attachments": _stat_value(row, "total_attachments"),
|
||||
}
|
||||
|
||||
|
||||
def _fetch_user_run_stats_row(cursor, *, where_clause: str, params: tuple | list) -> dict:
|
||||
sql = f"{_USER_RUN_STATS_SELECT_SQL}\nWHERE {where_clause}"
|
||||
cursor.execute(sql, params)
|
||||
row = cursor.fetchone()
|
||||
return {
|
||||
"completed": _stat_value(row, "completed"),
|
||||
"failed": _stat_value(row, "failed"),
|
||||
"total_items": _stat_value(row, "total_items"),
|
||||
"total_attachments": _stat_value(row, "total_attachments"),
|
||||
}
|
||||
|
||||
|
||||
def create_task_log(
|
||||
@@ -25,8 +148,6 @@ def create_task_log(
|
||||
"""创建任务日志记录"""
|
||||
with db_pool.get_db() as conn:
|
||||
cursor = conn.cursor()
|
||||
cst_tz = pytz.timezone("Asia/Shanghai")
|
||||
cst_time = datetime.now(cst_tz).strftime("%Y-%m-%d %H:%M:%S")
|
||||
|
||||
cursor.execute(
|
||||
"""
|
||||
@@ -45,7 +166,7 @@ def create_task_log(
|
||||
total_attachments,
|
||||
error_message,
|
||||
duration,
|
||||
cst_time,
|
||||
get_cst_now_str(),
|
||||
source,
|
||||
),
|
||||
)
|
||||
@@ -64,54 +185,27 @@ def get_task_logs(
|
||||
account_filter=None,
|
||||
):
|
||||
"""获取任务日志列表(支持分页和多种筛选)"""
|
||||
limit = _normalize_int(limit, 100, minimum=1)
|
||||
offset = _normalize_int(offset, 0, minimum=0)
|
||||
|
||||
with db_pool.get_db() as conn:
|
||||
cursor = conn.cursor()
|
||||
|
||||
where_clauses = ["1=1"]
|
||||
params = []
|
||||
|
||||
if date_filter:
|
||||
where_clauses.append("date(tl.created_at) = ?")
|
||||
params.append(date_filter)
|
||||
|
||||
if status_filter:
|
||||
where_clauses.append("tl.status = ?")
|
||||
params.append(status_filter)
|
||||
|
||||
if source_filter:
|
||||
source_filter = str(source_filter or "").strip()
|
||||
# 兼容“虚拟来源”:用于筛选 user_scheduled:batch_xxx 这类动态值
|
||||
if source_filter == "user_scheduled":
|
||||
where_clauses.append("tl.source LIKE ? ESCAPE '\\\\'")
|
||||
params.append("user_scheduled:%")
|
||||
elif source_filter.endswith("*"):
|
||||
prefix = source_filter[:-1]
|
||||
safe_prefix = sanitize_sql_like_pattern(prefix)
|
||||
where_clauses.append("tl.source LIKE ? ESCAPE '\\\\'")
|
||||
params.append(f"{safe_prefix}%")
|
||||
else:
|
||||
where_clauses.append("tl.source = ?")
|
||||
params.append(source_filter)
|
||||
|
||||
if user_id_filter:
|
||||
where_clauses.append("tl.user_id = ?")
|
||||
params.append(user_id_filter)
|
||||
|
||||
if account_filter:
|
||||
safe_filter = sanitize_sql_like_pattern(account_filter)
|
||||
where_clauses.append("tl.username LIKE ? ESCAPE '\\\\'")
|
||||
params.append(f"%{safe_filter}%")
|
||||
|
||||
where_sql = " AND ".join(where_clauses)
|
||||
where_sql, params = _build_task_logs_where_sql(
|
||||
date_filter=date_filter,
|
||||
status_filter=status_filter,
|
||||
source_filter=source_filter,
|
||||
user_id_filter=user_id_filter,
|
||||
account_filter=account_filter,
|
||||
)
|
||||
|
||||
count_sql = f"""
|
||||
SELECT COUNT(*) as total
|
||||
FROM task_logs tl
|
||||
LEFT JOIN users u ON tl.user_id = u.id
|
||||
WHERE {where_sql}
|
||||
"""
|
||||
cursor.execute(count_sql, params)
|
||||
total = cursor.fetchone()["total"]
|
||||
total = _stat_value(cursor.fetchone(), "total")
|
||||
|
||||
data_sql = f"""
|
||||
SELECT
|
||||
@@ -123,9 +217,10 @@ def get_task_logs(
|
||||
ORDER BY tl.created_at DESC
|
||||
LIMIT ? OFFSET ?
|
||||
"""
|
||||
params.extend([limit, offset])
|
||||
data_params = list(params)
|
||||
data_params.extend([limit, offset])
|
||||
|
||||
cursor.execute(data_sql, params)
|
||||
cursor.execute(data_sql, data_params)
|
||||
logs = [dict(row) for row in cursor.fetchall()]
|
||||
|
||||
return {"logs": logs, "total": total}
|
||||
@@ -133,61 +228,39 @@ def get_task_logs(
|
||||
|
||||
def get_task_stats(date_filter=None):
|
||||
"""获取任务统计信息"""
|
||||
if date_filter is None:
|
||||
date_filter = get_cst_now().strftime("%Y-%m-%d")
|
||||
|
||||
day_start, day_end = _build_day_bounds(date_filter)
|
||||
|
||||
with db_pool.get_db() as conn:
|
||||
cursor = conn.cursor()
|
||||
cst_tz = pytz.timezone("Asia/Shanghai")
|
||||
|
||||
if date_filter is None:
|
||||
date_filter = datetime.now(cst_tz).strftime("%Y-%m-%d")
|
||||
if day_start and day_end:
|
||||
today_stats = _fetch_task_stats_row(
|
||||
cursor,
|
||||
where_clause="created_at >= ? AND created_at < ?",
|
||||
params=(day_start, day_end),
|
||||
)
|
||||
else:
|
||||
today_stats = _fetch_task_stats_row(
|
||||
cursor,
|
||||
where_clause="date(created_at) = ?",
|
||||
params=(date_filter,),
|
||||
)
|
||||
|
||||
cursor.execute(
|
||||
"""
|
||||
SELECT
|
||||
COUNT(*) as total_tasks,
|
||||
SUM(CASE WHEN status = 'success' THEN 1 ELSE 0 END) as success_tasks,
|
||||
SUM(CASE WHEN status = 'failed' THEN 1 ELSE 0 END) as failed_tasks,
|
||||
SUM(total_items) as total_items,
|
||||
SUM(total_attachments) as total_attachments
|
||||
FROM task_logs
|
||||
WHERE date(created_at) = ?
|
||||
""",
|
||||
(date_filter,),
|
||||
)
|
||||
today_stats = cursor.fetchone()
|
||||
total_stats = _fetch_task_stats_row(cursor)
|
||||
|
||||
cursor.execute(
|
||||
"""
|
||||
SELECT
|
||||
COUNT(*) as total_tasks,
|
||||
SUM(CASE WHEN status = 'success' THEN 1 ELSE 0 END) as success_tasks,
|
||||
SUM(CASE WHEN status = 'failed' THEN 1 ELSE 0 END) as failed_tasks,
|
||||
SUM(total_items) as total_items,
|
||||
SUM(total_attachments) as total_attachments
|
||||
FROM task_logs
|
||||
"""
|
||||
)
|
||||
total_stats = cursor.fetchone()
|
||||
|
||||
return {
|
||||
"today": {
|
||||
"total_tasks": today_stats["total_tasks"] or 0,
|
||||
"success_tasks": today_stats["success_tasks"] or 0,
|
||||
"failed_tasks": today_stats["failed_tasks"] or 0,
|
||||
"total_items": today_stats["total_items"] or 0,
|
||||
"total_attachments": today_stats["total_attachments"] or 0,
|
||||
},
|
||||
"total": {
|
||||
"total_tasks": total_stats["total_tasks"] or 0,
|
||||
"success_tasks": total_stats["success_tasks"] or 0,
|
||||
"failed_tasks": total_stats["failed_tasks"] or 0,
|
||||
"total_items": total_stats["total_items"] or 0,
|
||||
"total_attachments": total_stats["total_attachments"] or 0,
|
||||
},
|
||||
}
|
||||
return {"today": today_stats, "total": total_stats}
|
||||
|
||||
|
||||
def delete_old_task_logs(days=30, batch_size=1000):
|
||||
"""删除N天前的任务日志(分批删除,避免长时间锁表)"""
|
||||
days = _normalize_int(days, 30, minimum=0)
|
||||
batch_size = _normalize_int(batch_size, 1000, minimum=1)
|
||||
|
||||
cutoff = (get_cst_now() - timedelta(days=days)).strftime("%Y-%m-%d %H:%M:%S")
|
||||
|
||||
total_deleted = 0
|
||||
while True:
|
||||
with db_pool.get_db() as conn:
|
||||
@@ -197,16 +270,16 @@ def delete_old_task_logs(days=30, batch_size=1000):
|
||||
DELETE FROM task_logs
|
||||
WHERE rowid IN (
|
||||
SELECT rowid FROM task_logs
|
||||
WHERE created_at < datetime('now', 'localtime', '-' || ? || ' days')
|
||||
WHERE created_at < ?
|
||||
LIMIT ?
|
||||
)
|
||||
""",
|
||||
(days, batch_size),
|
||||
(cutoff, batch_size),
|
||||
)
|
||||
deleted = cursor.rowcount
|
||||
conn.commit()
|
||||
|
||||
if deleted == 0:
|
||||
if deleted <= 0:
|
||||
break
|
||||
total_deleted += deleted
|
||||
|
||||
@@ -215,31 +288,23 @@ def delete_old_task_logs(days=30, batch_size=1000):
|
||||
|
||||
def get_user_run_stats(user_id, date_filter=None):
|
||||
"""获取用户的运行统计信息"""
|
||||
if date_filter is None:
|
||||
date_filter = get_cst_now().strftime("%Y-%m-%d")
|
||||
|
||||
day_start, day_end = _build_day_bounds(date_filter)
|
||||
|
||||
with db_pool.get_db() as conn:
|
||||
cst_tz = pytz.timezone("Asia/Shanghai")
|
||||
cursor = conn.cursor()
|
||||
|
||||
if date_filter is None:
|
||||
date_filter = datetime.now(cst_tz).strftime("%Y-%m-%d")
|
||||
if day_start and day_end:
|
||||
return _fetch_user_run_stats_row(
|
||||
cursor,
|
||||
where_clause="user_id = ? AND created_at >= ? AND created_at < ?",
|
||||
params=(user_id, day_start, day_end),
|
||||
)
|
||||
|
||||
cursor.execute(
|
||||
"""
|
||||
SELECT
|
||||
SUM(CASE WHEN status = 'success' THEN 1 ELSE 0 END) as completed,
|
||||
SUM(CASE WHEN status = 'failed' THEN 1 ELSE 0 END) as failed,
|
||||
SUM(total_items) as total_items,
|
||||
SUM(total_attachments) as total_attachments
|
||||
FROM task_logs
|
||||
WHERE user_id = ? AND date(created_at) = ?
|
||||
""",
|
||||
(user_id, date_filter),
|
||||
return _fetch_user_run_stats_row(
|
||||
cursor,
|
||||
where_clause="user_id = ? AND date(created_at) = ?",
|
||||
params=(user_id, date_filter),
|
||||
)
|
||||
|
||||
stats = cursor.fetchone()
|
||||
|
||||
return {
|
||||
"completed": stats["completed"] or 0,
|
||||
"failed": stats["failed"] or 0,
|
||||
"total_items": stats["total_items"] or 0,
|
||||
"total_attachments": stats["total_attachments"] or 0,
|
||||
}
|
||||
|
||||
174
db/users.py
174
db/users.py
@@ -16,8 +16,41 @@ from password_utils import (
|
||||
verify_password_bcrypt,
|
||||
verify_password_sha256,
|
||||
)
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
_CST_TZ = pytz.timezone("Asia/Shanghai")
|
||||
_PERMANENT_VIP_EXPIRE = "2099-12-31 23:59:59"
|
||||
|
||||
|
||||
def _row_to_dict(row):
|
||||
return dict(row) if row else None
|
||||
|
||||
|
||||
def _get_user_by_field(field_name: str, field_value):
|
||||
with db_pool.get_db() as conn:
|
||||
cursor = conn.cursor()
|
||||
cursor.execute(f"SELECT * FROM users WHERE {field_name} = ?", (field_value,))
|
||||
return _row_to_dict(cursor.fetchone())
|
||||
|
||||
|
||||
def _parse_cst_datetime(datetime_str: str | None):
|
||||
if not datetime_str:
|
||||
return None
|
||||
try:
|
||||
naive_dt = datetime.strptime(str(datetime_str), "%Y-%m-%d %H:%M:%S")
|
||||
return _CST_TZ.localize(naive_dt)
|
||||
except Exception:
|
||||
return None
|
||||
|
||||
|
||||
def _format_vip_expire(days: int, *, base_dt: datetime | None = None) -> str:
|
||||
if int(days) == 999999:
|
||||
return _PERMANENT_VIP_EXPIRE
|
||||
if base_dt is None:
|
||||
base_dt = datetime.now(_CST_TZ)
|
||||
return (base_dt + timedelta(days=int(days))).strftime("%Y-%m-%d %H:%M:%S")
|
||||
|
||||
|
||||
def get_vip_config():
|
||||
"""获取VIP配置"""
|
||||
@@ -32,13 +65,12 @@ def set_default_vip_days(days):
|
||||
"""设置默认VIP天数"""
|
||||
with db_pool.get_db() as conn:
|
||||
cursor = conn.cursor()
|
||||
cst_time = get_cst_now_str()
|
||||
cursor.execute(
|
||||
"""
|
||||
INSERT OR REPLACE INTO vip_config (id, default_vip_days, updated_at)
|
||||
VALUES (1, ?, ?)
|
||||
""",
|
||||
(days, cst_time),
|
||||
(days, get_cst_now_str()),
|
||||
)
|
||||
conn.commit()
|
||||
return True
|
||||
@@ -47,14 +79,8 @@ def set_default_vip_days(days):
|
||||
def set_user_vip(user_id, days):
|
||||
"""设置用户VIP - days: 7=一周, 30=一个月, 365=一年, 999999=永久"""
|
||||
with db_pool.get_db() as conn:
|
||||
cst_tz = pytz.timezone("Asia/Shanghai")
|
||||
cursor = conn.cursor()
|
||||
|
||||
if days == 999999:
|
||||
expire_time = "2099-12-31 23:59:59"
|
||||
else:
|
||||
expire_time = (datetime.now(cst_tz) + timedelta(days=days)).strftime("%Y-%m-%d %H:%M:%S")
|
||||
|
||||
expire_time = _format_vip_expire(days)
|
||||
cursor.execute("UPDATE users SET vip_expire_time = ? WHERE id = ?", (expire_time, user_id))
|
||||
conn.commit()
|
||||
return cursor.rowcount > 0
|
||||
@@ -63,29 +89,26 @@ def set_user_vip(user_id, days):
|
||||
def extend_user_vip(user_id, days):
|
||||
"""延长用户VIP时间"""
|
||||
user = get_user_by_id(user_id)
|
||||
cst_tz = pytz.timezone("Asia/Shanghai")
|
||||
|
||||
if not user:
|
||||
return False
|
||||
|
||||
current_expire = user.get("vip_expire_time")
|
||||
now_dt = datetime.now(_CST_TZ)
|
||||
|
||||
if current_expire and current_expire != _PERMANENT_VIP_EXPIRE:
|
||||
expire_time = _parse_cst_datetime(current_expire)
|
||||
if expire_time is not None:
|
||||
if expire_time < now_dt:
|
||||
expire_time = now_dt
|
||||
new_expire = _format_vip_expire(days, base_dt=expire_time)
|
||||
else:
|
||||
logger.warning("解析VIP过期时间失败,使用当前时间")
|
||||
new_expire = _format_vip_expire(days, base_dt=now_dt)
|
||||
else:
|
||||
new_expire = _format_vip_expire(days, base_dt=now_dt)
|
||||
|
||||
with db_pool.get_db() as conn:
|
||||
cursor = conn.cursor()
|
||||
current_expire = user.get("vip_expire_time")
|
||||
|
||||
if current_expire and current_expire != "2099-12-31 23:59:59":
|
||||
try:
|
||||
expire_time_naive = datetime.strptime(current_expire, "%Y-%m-%d %H:%M:%S")
|
||||
expire_time = cst_tz.localize(expire_time_naive)
|
||||
now = datetime.now(cst_tz)
|
||||
if expire_time < now:
|
||||
expire_time = now
|
||||
new_expire = (expire_time + timedelta(days=days)).strftime("%Y-%m-%d %H:%M:%S")
|
||||
except (ValueError, AttributeError) as e:
|
||||
logger.warning(f"解析VIP过期时间失败: {e}, 使用当前时间")
|
||||
new_expire = (datetime.now(cst_tz) + timedelta(days=days)).strftime("%Y-%m-%d %H:%M:%S")
|
||||
else:
|
||||
new_expire = (datetime.now(cst_tz) + timedelta(days=days)).strftime("%Y-%m-%d %H:%M:%S")
|
||||
|
||||
cursor.execute("UPDATE users SET vip_expire_time = ? WHERE id = ?", (new_expire, user_id))
|
||||
conn.commit()
|
||||
return cursor.rowcount > 0
|
||||
@@ -105,45 +128,49 @@ def is_user_vip(user_id):
|
||||
|
||||
注意:数据库中存储的时间统一使用CST(Asia/Shanghai)时区
|
||||
"""
|
||||
cst_tz = pytz.timezone("Asia/Shanghai")
|
||||
user = get_user_by_id(user_id)
|
||||
|
||||
if not user or not user.get("vip_expire_time"):
|
||||
if not user:
|
||||
return False
|
||||
|
||||
try:
|
||||
expire_time_naive = datetime.strptime(user["vip_expire_time"], "%Y-%m-%d %H:%M:%S")
|
||||
expire_time = cst_tz.localize(expire_time_naive)
|
||||
now = datetime.now(cst_tz)
|
||||
return now < expire_time
|
||||
except (ValueError, AttributeError) as e:
|
||||
logger.warning(f"检查VIP状态失败 (user_id={user_id}): {e}")
|
||||
vip_expire_time = user.get("vip_expire_time")
|
||||
if not vip_expire_time:
|
||||
return False
|
||||
|
||||
expire_time = _parse_cst_datetime(vip_expire_time)
|
||||
if expire_time is None:
|
||||
logger.warning(f"检查VIP状态失败 (user_id={user_id}): 无法解析时间")
|
||||
return False
|
||||
|
||||
return datetime.now(_CST_TZ) < expire_time
|
||||
|
||||
|
||||
def get_user_vip_info(user_id):
|
||||
"""获取用户VIP信息"""
|
||||
cst_tz = pytz.timezone("Asia/Shanghai")
|
||||
user = get_user_by_id(user_id)
|
||||
|
||||
if not user:
|
||||
return {"is_vip": False, "expire_time": None, "days_left": 0, "username": ""}
|
||||
|
||||
vip_expire_time = user.get("vip_expire_time")
|
||||
username = user.get("username", "")
|
||||
|
||||
if not vip_expire_time:
|
||||
return {"is_vip": False, "expire_time": None, "days_left": 0, "username": user.get("username", "")}
|
||||
return {"is_vip": False, "expire_time": None, "days_left": 0, "username": username}
|
||||
|
||||
try:
|
||||
expire_time_naive = datetime.strptime(vip_expire_time, "%Y-%m-%d %H:%M:%S")
|
||||
expire_time = cst_tz.localize(expire_time_naive)
|
||||
now = datetime.now(cst_tz)
|
||||
is_vip = now < expire_time
|
||||
days_left = (expire_time - now).days if is_vip else 0
|
||||
expire_time = _parse_cst_datetime(vip_expire_time)
|
||||
if expire_time is None:
|
||||
logger.warning("VIP信息获取错误: 无法解析过期时间")
|
||||
return {"is_vip": False, "expire_time": None, "days_left": 0, "username": username}
|
||||
|
||||
return {"username": user.get("username", ""), "is_vip": is_vip, "expire_time": vip_expire_time, "days_left": max(0, days_left)}
|
||||
except Exception as e:
|
||||
logger.warning(f"VIP信息获取错误: {e}")
|
||||
return {"is_vip": False, "expire_time": None, "days_left": 0, "username": user.get("username", "")}
|
||||
now_dt = datetime.now(_CST_TZ)
|
||||
is_vip = now_dt < expire_time
|
||||
days_left = (expire_time - now_dt).days if is_vip else 0
|
||||
|
||||
return {
|
||||
"username": username,
|
||||
"is_vip": is_vip,
|
||||
"expire_time": vip_expire_time,
|
||||
"days_left": max(0, days_left),
|
||||
}
|
||||
|
||||
|
||||
# ==================== 用户相关 ====================
|
||||
@@ -151,8 +178,6 @@ def get_user_vip_info(user_id):
|
||||
|
||||
def create_user(username, password, email=""):
|
||||
"""创建新用户(默认直接通过,赠送默认VIP)"""
|
||||
cst_tz = pytz.timezone("Asia/Shanghai")
|
||||
|
||||
with db_pool.get_db() as conn:
|
||||
cursor = conn.cursor()
|
||||
password_hash = hash_password_bcrypt(password)
|
||||
@@ -160,12 +185,8 @@ def create_user(username, password, email=""):
|
||||
|
||||
default_vip_days = get_vip_config()["default_vip_days"]
|
||||
vip_expire_time = None
|
||||
|
||||
if default_vip_days > 0:
|
||||
if default_vip_days == 999999:
|
||||
vip_expire_time = "2099-12-31 23:59:59"
|
||||
else:
|
||||
vip_expire_time = (datetime.now(cst_tz) + timedelta(days=default_vip_days)).strftime("%Y-%m-%d %H:%M:%S")
|
||||
if int(default_vip_days or 0) > 0:
|
||||
vip_expire_time = _format_vip_expire(int(default_vip_days))
|
||||
|
||||
try:
|
||||
cursor.execute(
|
||||
@@ -210,28 +231,28 @@ def verify_user(username, password):
|
||||
|
||||
def get_user_by_id(user_id):
|
||||
"""根据ID获取用户"""
|
||||
with db_pool.get_db() as conn:
|
||||
cursor = conn.cursor()
|
||||
cursor.execute("SELECT * FROM users WHERE id = ?", (user_id,))
|
||||
user = cursor.fetchone()
|
||||
return dict(user) if user else None
|
||||
return _get_user_by_field("id", user_id)
|
||||
|
||||
|
||||
def get_user_kdocs_settings(user_id):
|
||||
"""获取用户的金山文档配置"""
|
||||
user = get_user_by_id(user_id)
|
||||
if not user:
|
||||
return None
|
||||
return {
|
||||
"kdocs_unit": user.get("kdocs_unit") or "",
|
||||
"kdocs_auto_upload": 1 if user.get("kdocs_auto_upload") else 0,
|
||||
}
|
||||
with db_pool.get_db() as conn:
|
||||
cursor = conn.cursor()
|
||||
cursor.execute("SELECT kdocs_unit, kdocs_auto_upload FROM users WHERE id = ?", (user_id,))
|
||||
row = cursor.fetchone()
|
||||
if not row:
|
||||
return None
|
||||
return {
|
||||
"kdocs_unit": (row["kdocs_unit"] or "") if isinstance(row, dict) else (row[0] or ""),
|
||||
"kdocs_auto_upload": 1 if ((row["kdocs_auto_upload"] if isinstance(row, dict) else row[1]) or 0) else 0,
|
||||
}
|
||||
|
||||
|
||||
def update_user_kdocs_settings(user_id, *, kdocs_unit=None, kdocs_auto_upload=None) -> bool:
|
||||
"""更新用户的金山文档配置"""
|
||||
updates = []
|
||||
params = []
|
||||
|
||||
if kdocs_unit is not None:
|
||||
updates.append("kdocs_unit = ?")
|
||||
params.append(kdocs_unit)
|
||||
@@ -252,11 +273,7 @@ def update_user_kdocs_settings(user_id, *, kdocs_unit=None, kdocs_auto_upload=No
|
||||
|
||||
def get_user_by_username(username):
|
||||
"""根据用户名获取用户"""
|
||||
with db_pool.get_db() as conn:
|
||||
cursor = conn.cursor()
|
||||
cursor.execute("SELECT * FROM users WHERE username = ?", (username,))
|
||||
user = cursor.fetchone()
|
||||
return dict(user) if user else None
|
||||
return _get_user_by_field("username", username)
|
||||
|
||||
|
||||
def get_all_users():
|
||||
@@ -279,14 +296,13 @@ def approve_user(user_id):
|
||||
"""审核通过用户"""
|
||||
with db_pool.get_db() as conn:
|
||||
cursor = conn.cursor()
|
||||
cst_time = get_cst_now_str()
|
||||
cursor.execute(
|
||||
"""
|
||||
UPDATE users
|
||||
SET status = 'approved', approved_at = ?
|
||||
WHERE id = ?
|
||||
""",
|
||||
(cst_time, user_id),
|
||||
(get_cst_now_str(), user_id),
|
||||
)
|
||||
conn.commit()
|
||||
return cursor.rowcount > 0
|
||||
@@ -315,5 +331,5 @@ def get_user_stats(user_id):
|
||||
with db_pool.get_db() as conn:
|
||||
cursor = conn.cursor()
|
||||
cursor.execute("SELECT COUNT(*) as count FROM accounts WHERE user_id = ?", (user_id,))
|
||||
account_count = cursor.fetchone()["count"]
|
||||
return {"account_count": account_count}
|
||||
row = cursor.fetchone()
|
||||
return {"account_count": int((row["count"] if row else 0) or 0)}
|
||||
|
||||
119
db_pool.py
119
db_pool.py
@@ -7,8 +7,12 @@
|
||||
|
||||
import sqlite3
|
||||
import threading
|
||||
from queue import Queue, Empty
|
||||
import time
|
||||
from queue import Empty, Full, Queue
|
||||
|
||||
from app_logger import get_logger
|
||||
|
||||
|
||||
logger = get_logger("database")
|
||||
|
||||
|
||||
class ConnectionPool:
|
||||
@@ -44,12 +48,55 @@ class ConnectionPool:
|
||||
"""创建新的数据库连接"""
|
||||
conn = sqlite3.connect(self.database, check_same_thread=False)
|
||||
conn.row_factory = sqlite3.Row
|
||||
# 启用外键约束,确保 ON DELETE CASCADE 等约束生效
|
||||
conn.execute("PRAGMA foreign_keys=ON")
|
||||
# 设置WAL模式提高并发性能
|
||||
conn.execute("PRAGMA journal_mode=WAL")
|
||||
# 在WAL模式下使用NORMAL同步,兼顾性能与可靠性
|
||||
conn.execute("PRAGMA synchronous=NORMAL")
|
||||
# 设置合理的超时时间
|
||||
conn.execute("PRAGMA busy_timeout=5000")
|
||||
return conn
|
||||
|
||||
def _close_connection(self, conn) -> None:
|
||||
if conn is None:
|
||||
return
|
||||
try:
|
||||
conn.close()
|
||||
except Exception as e:
|
||||
logger.warning(f"关闭连接失败: {e}")
|
||||
|
||||
def _is_connection_healthy(self, conn) -> bool:
|
||||
if conn is None:
|
||||
return False
|
||||
try:
|
||||
conn.rollback()
|
||||
conn.execute("SELECT 1")
|
||||
return True
|
||||
except sqlite3.Error as e:
|
||||
logger.warning(f"连接健康检查失败(数据库错误): {e}")
|
||||
except Exception as e:
|
||||
logger.warning(f"连接健康检查失败(未知错误): {e}")
|
||||
return False
|
||||
|
||||
def _replenish_pool_if_needed(self) -> None:
|
||||
with self._lock:
|
||||
if self._pool.qsize() >= self.pool_size:
|
||||
return
|
||||
|
||||
new_conn = None
|
||||
try:
|
||||
new_conn = self._create_connection()
|
||||
self._pool.put(new_conn, block=False)
|
||||
self._created_connections += 1
|
||||
except Full:
|
||||
if new_conn:
|
||||
self._close_connection(new_conn)
|
||||
except Exception as e:
|
||||
if new_conn:
|
||||
self._close_connection(new_conn)
|
||||
logger.warning(f"重建连接失败: {e}")
|
||||
|
||||
def get_connection(self):
|
||||
"""
|
||||
从连接池获取连接
|
||||
@@ -70,66 +117,20 @@ class ConnectionPool:
|
||||
Args:
|
||||
conn: 要归还的连接
|
||||
"""
|
||||
import sqlite3
|
||||
from queue import Full
|
||||
|
||||
if conn is None:
|
||||
return
|
||||
|
||||
connection_healthy = False
|
||||
try:
|
||||
# 回滚任何未提交的事务
|
||||
conn.rollback()
|
||||
# 安全修复:验证连接是否健康,防止损坏的连接污染连接池
|
||||
conn.execute("SELECT 1")
|
||||
connection_healthy = True
|
||||
except sqlite3.Error as e:
|
||||
# 数据库相关错误,连接可能损坏
|
||||
print(f"连接健康检查失败(数据库错误): {e}")
|
||||
except Exception as e:
|
||||
print(f"连接健康检查失败(未知错误): {e}")
|
||||
|
||||
if connection_healthy:
|
||||
if self._is_connection_healthy(conn):
|
||||
try:
|
||||
self._pool.put(conn, block=False)
|
||||
return # 成功归还
|
||||
return
|
||||
except Full:
|
||||
# 队列已满(不应该发生,但处理它)
|
||||
print(f"警告: 连接池已满,关闭多余连接")
|
||||
connection_healthy = False # 标记为需要关闭
|
||||
logger.warning("连接池已满,关闭多余连接")
|
||||
self._close_connection(conn)
|
||||
return
|
||||
|
||||
# 连接不健康或队列已满,关闭它
|
||||
try:
|
||||
conn.close()
|
||||
except Exception as close_error:
|
||||
print(f"关闭连接失败: {close_error}")
|
||||
|
||||
# 如果连接不健康,尝试创建新连接补充池
|
||||
if not connection_healthy:
|
||||
with self._lock:
|
||||
# 双重检查:确保池确实需要补充
|
||||
if self._pool.qsize() < self.pool_size:
|
||||
new_conn = None
|
||||
try:
|
||||
new_conn = self._create_connection()
|
||||
self._pool.put(new_conn, block=False)
|
||||
# 只有成功放入池后才增加计数
|
||||
self._created_connections += 1
|
||||
except Full:
|
||||
# 在获取锁期间池被填满了,关闭新建的连接
|
||||
if new_conn:
|
||||
try:
|
||||
new_conn.close()
|
||||
except Exception:
|
||||
pass
|
||||
except Exception as create_error:
|
||||
# 创建连接失败,确保关闭已创建的连接
|
||||
if new_conn:
|
||||
try:
|
||||
new_conn.close()
|
||||
except Exception:
|
||||
pass
|
||||
print(f"重建连接失败: {create_error}")
|
||||
self._close_connection(conn)
|
||||
self._replenish_pool_if_needed()
|
||||
|
||||
def close_all(self):
|
||||
"""关闭所有连接"""
|
||||
@@ -138,7 +139,7 @@ class ConnectionPool:
|
||||
conn = self._pool.get(block=False)
|
||||
conn.close()
|
||||
except Exception as e:
|
||||
print(f"关闭连接失败: {e}")
|
||||
logger.warning(f"关闭连接失败: {e}")
|
||||
|
||||
def get_stats(self):
|
||||
"""获取连接池统计信息"""
|
||||
@@ -175,14 +176,14 @@ class PooledConnection:
|
||||
if exc_type is not None:
|
||||
# 发生异常,回滚事务
|
||||
self._conn.rollback()
|
||||
print(f"数据库事务已回滚: {exc_type.__name__}")
|
||||
logger.warning(f"数据库事务已回滚: {exc_type.__name__}")
|
||||
# 注意: 不自动commit,要求用户显式调用conn.commit()
|
||||
|
||||
if self._cursor:
|
||||
self._cursor.close()
|
||||
self._cursor = None
|
||||
except Exception as e:
|
||||
print(f"关闭游标失败: {e}")
|
||||
logger.warning(f"关闭游标失败: {e}")
|
||||
finally:
|
||||
# 归还连接
|
||||
self._pool.return_connection(self._conn)
|
||||
@@ -254,7 +255,7 @@ def init_pool(database, pool_size=5):
|
||||
with _pool_lock:
|
||||
if _pool is None:
|
||||
_pool = ConnectionPool(database, pool_size)
|
||||
print(f"[OK] 数据库连接池已初始化 (大小: {pool_size})")
|
||||
logger.info(f"[OK] 数据库连接池已初始化 (大小: {pool_size})")
|
||||
|
||||
|
||||
def get_db():
|
||||
|
||||
1338
email_service.py
1338
email_service.py
File diff suppressed because it is too large
Load Diff
@@ -2,6 +2,7 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import os
|
||||
import time
|
||||
|
||||
@@ -9,8 +10,40 @@ from services.runtime import get_logger, get_socketio
|
||||
from services.state import safe_get_account, safe_iter_task_status_items
|
||||
|
||||
|
||||
def _to_int(value, default: int = 0) -> int:
|
||||
try:
|
||||
return int(value)
|
||||
except Exception:
|
||||
return int(default)
|
||||
|
||||
|
||||
def _payload_signature(payload: dict) -> str:
|
||||
try:
|
||||
return json.dumps(payload, ensure_ascii=False, sort_keys=True, separators=(",", ":"), default=str)
|
||||
except Exception:
|
||||
return repr(payload)
|
||||
|
||||
|
||||
def _should_emit(
|
||||
*,
|
||||
last_sig: str | None,
|
||||
last_ts: float,
|
||||
new_sig: str,
|
||||
now_ts: float,
|
||||
min_interval: float,
|
||||
force_interval: float,
|
||||
) -> bool:
|
||||
if last_sig is None:
|
||||
return True
|
||||
if (now_ts - last_ts) >= force_interval:
|
||||
return True
|
||||
if new_sig != last_sig and (now_ts - last_ts) >= min_interval:
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
def status_push_worker() -> None:
|
||||
"""后台线程:按间隔推送排队/运行中任务的状态更新(可节流)。"""
|
||||
"""后台线程:按间隔推送排队/运行中任务状态(变更驱动+心跳兜底)。"""
|
||||
logger = get_logger()
|
||||
try:
|
||||
push_interval = float(os.environ.get("STATUS_PUSH_INTERVAL_SECONDS", "1"))
|
||||
@@ -18,18 +51,41 @@ def status_push_worker() -> None:
|
||||
push_interval = 1.0
|
||||
push_interval = max(0.5, push_interval)
|
||||
|
||||
try:
|
||||
queue_min_interval = float(os.environ.get("STATUS_PUSH_MIN_QUEUE_INTERVAL_SECONDS", str(push_interval)))
|
||||
except Exception:
|
||||
queue_min_interval = push_interval
|
||||
queue_min_interval = max(push_interval, queue_min_interval)
|
||||
|
||||
try:
|
||||
progress_min_interval = float(
|
||||
os.environ.get("STATUS_PUSH_MIN_PROGRESS_INTERVAL_SECONDS", str(push_interval))
|
||||
)
|
||||
except Exception:
|
||||
progress_min_interval = push_interval
|
||||
progress_min_interval = max(push_interval, progress_min_interval)
|
||||
|
||||
try:
|
||||
force_interval = float(os.environ.get("STATUS_PUSH_FORCE_INTERVAL_SECONDS", "10"))
|
||||
except Exception:
|
||||
force_interval = 10.0
|
||||
force_interval = max(push_interval, force_interval)
|
||||
|
||||
socketio = get_socketio()
|
||||
from services.tasks import get_task_scheduler
|
||||
|
||||
scheduler = get_task_scheduler()
|
||||
emitted_state: dict[str, dict] = {}
|
||||
|
||||
while True:
|
||||
try:
|
||||
now_ts = time.time()
|
||||
queue_snapshot = scheduler.get_queue_state_snapshot()
|
||||
pending_total = int(queue_snapshot.get("pending_total", 0) or 0)
|
||||
running_total = int(queue_snapshot.get("running_total", 0) or 0)
|
||||
running_by_user = queue_snapshot.get("running_by_user") or {}
|
||||
positions = queue_snapshot.get("positions") or {}
|
||||
active_account_ids = set()
|
||||
|
||||
status_items = safe_iter_task_status_items()
|
||||
for account_id, status_info in status_items:
|
||||
@@ -39,11 +95,15 @@ def status_push_worker() -> None:
|
||||
user_id = status_info.get("user_id")
|
||||
if not user_id:
|
||||
continue
|
||||
|
||||
active_account_ids.add(str(account_id))
|
||||
account = safe_get_account(user_id, account_id)
|
||||
if not account:
|
||||
continue
|
||||
|
||||
user_id_int = _to_int(user_id)
|
||||
account_data = account.to_dict()
|
||||
pos = positions.get(account_id) or {}
|
||||
pos = positions.get(account_id) or positions.get(str(account_id)) or {}
|
||||
account_data.update(
|
||||
{
|
||||
"queue_pending_total": pending_total,
|
||||
@@ -51,10 +111,23 @@ def status_push_worker() -> None:
|
||||
"queue_ahead": pos.get("queue_ahead"),
|
||||
"queue_position": pos.get("queue_position"),
|
||||
"queue_is_vip": pos.get("is_vip"),
|
||||
"queue_running_user": int(running_by_user.get(int(user_id), 0) or 0),
|
||||
"queue_running_user": _to_int(running_by_user.get(user_id_int, running_by_user.get(str(user_id_int), 0))),
|
||||
}
|
||||
)
|
||||
socketio.emit("account_update", account_data, room=f"user_{user_id}")
|
||||
|
||||
cache_entry = emitted_state.setdefault(str(account_id), {})
|
||||
account_sig = _payload_signature(account_data)
|
||||
if _should_emit(
|
||||
last_sig=cache_entry.get("account_sig"),
|
||||
last_ts=float(cache_entry.get("account_ts", 0) or 0),
|
||||
new_sig=account_sig,
|
||||
now_ts=now_ts,
|
||||
min_interval=queue_min_interval,
|
||||
force_interval=force_interval,
|
||||
):
|
||||
socketio.emit("account_update", account_data, room=f"user_{user_id}")
|
||||
cache_entry["account_sig"] = account_sig
|
||||
cache_entry["account_ts"] = now_ts
|
||||
|
||||
if status != "运行中":
|
||||
continue
|
||||
@@ -74,9 +147,26 @@ def status_push_worker() -> None:
|
||||
"queue_running_total": running_total,
|
||||
"queue_ahead": pos.get("queue_ahead"),
|
||||
"queue_position": pos.get("queue_position"),
|
||||
"queue_running_user": int(running_by_user.get(int(user_id), 0) or 0),
|
||||
"queue_running_user": _to_int(running_by_user.get(user_id_int, running_by_user.get(str(user_id_int), 0))),
|
||||
}
|
||||
socketio.emit("task_progress", progress_data, room=f"user_{user_id}")
|
||||
|
||||
progress_sig = _payload_signature(progress_data)
|
||||
if _should_emit(
|
||||
last_sig=cache_entry.get("progress_sig"),
|
||||
last_ts=float(cache_entry.get("progress_ts", 0) or 0),
|
||||
new_sig=progress_sig,
|
||||
now_ts=now_ts,
|
||||
min_interval=progress_min_interval,
|
||||
force_interval=force_interval,
|
||||
):
|
||||
socketio.emit("task_progress", progress_data, room=f"user_{user_id}")
|
||||
cache_entry["progress_sig"] = progress_sig
|
||||
cache_entry["progress_ts"] = now_ts
|
||||
|
||||
if emitted_state:
|
||||
stale_ids = [account_id for account_id in emitted_state.keys() if account_id not in active_account_ids]
|
||||
for account_id in stale_ids:
|
||||
emitted_state.pop(account_id, None)
|
||||
|
||||
time.sleep(push_interval)
|
||||
except Exception as e:
|
||||
|
||||
@@ -8,6 +8,15 @@ admin_api_bp = Blueprint("admin_api", __name__, url_prefix="/yuyx/api")
|
||||
|
||||
# Import side effects: register routes on blueprint
|
||||
from routes.admin_api import core as _core # noqa: F401
|
||||
from routes.admin_api import system_config_api as _system_config_api # noqa: F401
|
||||
from routes.admin_api import operations_api as _operations_api # noqa: F401
|
||||
from routes.admin_api import announcements_api as _announcements_api # noqa: F401
|
||||
from routes.admin_api import users_api as _users_api # noqa: F401
|
||||
from routes.admin_api import account_api as _account_api # noqa: F401
|
||||
from routes.admin_api import feedback_api as _feedback_api # noqa: F401
|
||||
from routes.admin_api import infra_api as _infra_api # noqa: F401
|
||||
from routes.admin_api import tasks_api as _tasks_api # noqa: F401
|
||||
from routes.admin_api import email_api as _email_api # noqa: F401
|
||||
|
||||
# Export security blueprint for app registration
|
||||
from routes.admin_api.security import security_bp # noqa: F401
|
||||
|
||||
63
routes/admin_api/account_api.py
Normal file
63
routes/admin_api/account_api.py
Normal file
@@ -0,0 +1,63 @@
|
||||
#!/usr/bin/env python3
|
||||
# -*- coding: utf-8 -*-
|
||||
from __future__ import annotations
|
||||
|
||||
import database
|
||||
from app_security import validate_password
|
||||
from flask import jsonify, request, session
|
||||
from routes.admin_api import admin_api_bp
|
||||
from routes.decorators import admin_required
|
||||
|
||||
# ==================== 密码重置 / 反馈(管理员) ====================
|
||||
|
||||
|
||||
@admin_api_bp.route("/admin/password", methods=["PUT"])
|
||||
@admin_required
|
||||
def update_admin_password():
|
||||
"""修改管理员密码"""
|
||||
data = request.json or {}
|
||||
new_password = (data.get("new_password") or "").strip()
|
||||
|
||||
if not new_password:
|
||||
return jsonify({"error": "密码不能为空"}), 400
|
||||
|
||||
username = session.get("admin_username")
|
||||
if database.update_admin_password(username, new_password):
|
||||
return jsonify({"success": True})
|
||||
return jsonify({"error": "修改失败"}), 400
|
||||
|
||||
|
||||
@admin_api_bp.route("/admin/username", methods=["PUT"])
|
||||
@admin_required
|
||||
def update_admin_username():
|
||||
"""修改管理员用户名"""
|
||||
data = request.json or {}
|
||||
new_username = (data.get("new_username") or "").strip()
|
||||
|
||||
if not new_username:
|
||||
return jsonify({"error": "用户名不能为空"}), 400
|
||||
|
||||
old_username = session.get("admin_username")
|
||||
if database.update_admin_username(old_username, new_username):
|
||||
session["admin_username"] = new_username
|
||||
return jsonify({"success": True})
|
||||
return jsonify({"error": "修改失败,用户名可能已存在"}), 400
|
||||
|
||||
|
||||
@admin_api_bp.route("/users/<int:user_id>/reset_password", methods=["POST"])
|
||||
@admin_required
|
||||
def admin_reset_password_route(user_id):
|
||||
"""管理员直接重置用户密码(无需审核)"""
|
||||
data = request.json or {}
|
||||
new_password = (data.get("new_password") or "").strip()
|
||||
|
||||
if not new_password:
|
||||
return jsonify({"error": "新密码不能为空"}), 400
|
||||
|
||||
is_valid, error_msg = validate_password(new_password)
|
||||
if not is_valid:
|
||||
return jsonify({"error": error_msg}), 400
|
||||
|
||||
if database.admin_reset_user_password(user_id, new_password):
|
||||
return jsonify({"message": "密码重置成功"})
|
||||
return jsonify({"error": "重置失败,用户不存在"}), 400
|
||||
144
routes/admin_api/announcements_api.py
Normal file
144
routes/admin_api/announcements_api.py
Normal file
@@ -0,0 +1,144 @@
|
||||
#!/usr/bin/env python3
|
||||
# -*- coding: utf-8 -*-
|
||||
from __future__ import annotations
|
||||
|
||||
import os
|
||||
import posixpath
|
||||
import secrets
|
||||
import time
|
||||
|
||||
import database
|
||||
from app_config import get_config
|
||||
from app_logger import get_logger
|
||||
from app_security import is_safe_path, sanitize_filename
|
||||
from flask import current_app, jsonify, request, url_for
|
||||
from routes.admin_api import admin_api_bp
|
||||
from routes.decorators import admin_required
|
||||
|
||||
logger = get_logger("app")
|
||||
config = get_config()
|
||||
|
||||
|
||||
def _get_upload_dir():
|
||||
rel_dir = getattr(config, "ANNOUNCEMENT_IMAGE_DIR", "static/announcements")
|
||||
if not is_safe_path(current_app.root_path, rel_dir):
|
||||
rel_dir = "static/announcements"
|
||||
abs_dir = os.path.join(current_app.root_path, rel_dir)
|
||||
os.makedirs(abs_dir, exist_ok=True)
|
||||
return abs_dir, rel_dir
|
||||
|
||||
|
||||
def _get_file_size(file_storage):
|
||||
try:
|
||||
file_storage.stream.seek(0, os.SEEK_END)
|
||||
size = file_storage.stream.tell()
|
||||
file_storage.stream.seek(0)
|
||||
return size
|
||||
except Exception:
|
||||
return None
|
||||
|
||||
|
||||
# ==================== 公告管理API(管理员) ====================
|
||||
|
||||
|
||||
@admin_api_bp.route("/announcements/upload_image", methods=["POST"])
|
||||
@admin_required
|
||||
def admin_upload_announcement_image():
|
||||
"""上传公告图片(返回可访问URL)"""
|
||||
file = request.files.get("file")
|
||||
if not file or not file.filename:
|
||||
return jsonify({"error": "请选择图片"}), 400
|
||||
|
||||
filename = sanitize_filename(file.filename)
|
||||
ext = os.path.splitext(filename)[1].lower()
|
||||
allowed_exts = getattr(config, "ALLOWED_ANNOUNCEMENT_IMAGE_EXTENSIONS", {".png", ".jpg", ".jpeg"})
|
||||
if not ext or ext not in allowed_exts:
|
||||
return jsonify({"error": "不支持的图片格式"}), 400
|
||||
if file.mimetype and not str(file.mimetype).startswith("image/"):
|
||||
return jsonify({"error": "文件类型无效"}), 400
|
||||
|
||||
size = _get_file_size(file)
|
||||
max_size = int(getattr(config, "MAX_ANNOUNCEMENT_IMAGE_SIZE", 5 * 1024 * 1024))
|
||||
if size is not None and size > max_size:
|
||||
max_mb = max_size // 1024 // 1024
|
||||
return jsonify({"error": f"图片大小不能超过{max_mb}MB"}), 400
|
||||
|
||||
abs_dir, rel_dir = _get_upload_dir()
|
||||
token = secrets.token_hex(6)
|
||||
name = f"announcement_{int(time.time())}_{token}{ext}"
|
||||
save_path = os.path.join(abs_dir, name)
|
||||
file.save(save_path)
|
||||
|
||||
static_root = os.path.join(current_app.root_path, "static")
|
||||
rel_to_static = os.path.relpath(abs_dir, static_root)
|
||||
if rel_to_static.startswith(".."):
|
||||
rel_to_static = "announcements"
|
||||
url_path = posixpath.join(rel_to_static.replace(os.sep, "/"), name)
|
||||
return jsonify({"success": True, "url": url_for("serve_static", filename=url_path)})
|
||||
|
||||
|
||||
@admin_api_bp.route("/announcements", methods=["GET"])
|
||||
@admin_required
|
||||
def admin_get_announcements():
|
||||
"""获取公告列表"""
|
||||
try:
|
||||
limit = int(request.args.get("limit", 50))
|
||||
offset = int(request.args.get("offset", 0))
|
||||
except (TypeError, ValueError):
|
||||
limit, offset = 50, 0
|
||||
|
||||
limit = max(1, min(200, limit))
|
||||
offset = max(0, offset)
|
||||
return jsonify(database.get_announcements(limit=limit, offset=offset))
|
||||
|
||||
|
||||
@admin_api_bp.route("/announcements", methods=["POST"])
|
||||
@admin_required
|
||||
def admin_create_announcement():
|
||||
"""创建公告(默认启用并替换旧公告)"""
|
||||
data = request.json or {}
|
||||
title = (data.get("title") or "").strip()
|
||||
content = (data.get("content") or "").strip()
|
||||
image_url = (data.get("image_url") or "").strip()
|
||||
is_active = bool(data.get("is_active", True))
|
||||
|
||||
if image_url and len(image_url) > 1000:
|
||||
return jsonify({"error": "图片地址过长"}), 400
|
||||
|
||||
announcement_id = database.create_announcement(title, content, image_url=image_url, is_active=is_active)
|
||||
if not announcement_id:
|
||||
return jsonify({"error": "标题和内容不能为空"}), 400
|
||||
|
||||
return jsonify({"success": True, "id": announcement_id})
|
||||
|
||||
|
||||
@admin_api_bp.route("/announcements/<int:announcement_id>/activate", methods=["POST"])
|
||||
@admin_required
|
||||
def admin_activate_announcement(announcement_id):
|
||||
"""启用公告(会自动停用其他公告)"""
|
||||
if not database.get_announcement_by_id(announcement_id):
|
||||
return jsonify({"error": "公告不存在"}), 404
|
||||
ok = database.set_announcement_active(announcement_id, True)
|
||||
return jsonify({"success": ok})
|
||||
|
||||
|
||||
@admin_api_bp.route("/announcements/<int:announcement_id>/deactivate", methods=["POST"])
|
||||
@admin_required
|
||||
def admin_deactivate_announcement(announcement_id):
|
||||
"""停用公告"""
|
||||
if not database.get_announcement_by_id(announcement_id):
|
||||
return jsonify({"error": "公告不存在"}), 404
|
||||
ok = database.set_announcement_active(announcement_id, False)
|
||||
return jsonify({"success": ok})
|
||||
|
||||
|
||||
@admin_api_bp.route("/announcements/<int:announcement_id>", methods=["DELETE"])
|
||||
@admin_required
|
||||
def admin_delete_announcement(announcement_id):
|
||||
"""删除公告"""
|
||||
if not database.get_announcement_by_id(announcement_id):
|
||||
return jsonify({"error": "公告不存在"}), 404
|
||||
ok = database.delete_announcement(announcement_id)
|
||||
return jsonify({"success": ok})
|
||||
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
214
routes/admin_api/email_api.py
Normal file
214
routes/admin_api/email_api.py
Normal file
@@ -0,0 +1,214 @@
|
||||
#!/usr/bin/env python3
|
||||
# -*- coding: utf-8 -*-
|
||||
from __future__ import annotations
|
||||
|
||||
import email_service
|
||||
from app_logger import get_logger
|
||||
from app_security import validate_email
|
||||
from flask import jsonify, request
|
||||
from routes.admin_api import admin_api_bp
|
||||
from routes.decorators import admin_required
|
||||
|
||||
logger = get_logger("app")
|
||||
|
||||
|
||||
@admin_api_bp.route("/email/settings", methods=["GET"])
|
||||
@admin_required
|
||||
def get_email_settings_api():
|
||||
"""获取全局邮件设置"""
|
||||
try:
|
||||
settings = email_service.get_email_settings()
|
||||
return jsonify(settings)
|
||||
except Exception as e:
|
||||
logger.error(f"获取邮件设置失败: {e}")
|
||||
return jsonify({"error": str(e)}), 500
|
||||
|
||||
|
||||
@admin_api_bp.route("/email/settings", methods=["POST"])
|
||||
@admin_required
|
||||
def update_email_settings_api():
|
||||
"""更新全局邮件设置"""
|
||||
try:
|
||||
data = request.json or {}
|
||||
enabled = data.get("enabled", False)
|
||||
failover_enabled = data.get("failover_enabled", True)
|
||||
register_verify_enabled = data.get("register_verify_enabled")
|
||||
login_alert_enabled = data.get("login_alert_enabled")
|
||||
base_url = data.get("base_url")
|
||||
task_notify_enabled = data.get("task_notify_enabled")
|
||||
|
||||
email_service.update_email_settings(
|
||||
enabled=enabled,
|
||||
failover_enabled=failover_enabled,
|
||||
register_verify_enabled=register_verify_enabled,
|
||||
login_alert_enabled=login_alert_enabled,
|
||||
base_url=base_url,
|
||||
task_notify_enabled=task_notify_enabled,
|
||||
)
|
||||
return jsonify({"success": True})
|
||||
except Exception as e:
|
||||
logger.error(f"更新邮件设置失败: {e}")
|
||||
return jsonify({"error": str(e)}), 500
|
||||
|
||||
|
||||
@admin_api_bp.route("/smtp/configs", methods=["GET"])
|
||||
@admin_required
|
||||
def get_smtp_configs_api():
|
||||
"""获取所有SMTP配置列表"""
|
||||
try:
|
||||
configs = email_service.get_smtp_configs(include_password=False)
|
||||
return jsonify(configs)
|
||||
except Exception as e:
|
||||
logger.error(f"获取SMTP配置失败: {e}")
|
||||
return jsonify({"error": str(e)}), 500
|
||||
|
||||
|
||||
@admin_api_bp.route("/smtp/configs", methods=["POST"])
|
||||
@admin_required
|
||||
def create_smtp_config_api():
|
||||
"""创建SMTP配置"""
|
||||
try:
|
||||
data = request.json or {}
|
||||
if not data.get("host"):
|
||||
return jsonify({"error": "SMTP服务器地址不能为空"}), 400
|
||||
if not data.get("username"):
|
||||
return jsonify({"error": "SMTP用户名不能为空"}), 400
|
||||
|
||||
config_id = email_service.create_smtp_config(data)
|
||||
return jsonify({"success": True, "id": config_id})
|
||||
except Exception as e:
|
||||
logger.error(f"创建SMTP配置失败: {e}")
|
||||
return jsonify({"error": str(e)}), 500
|
||||
|
||||
|
||||
@admin_api_bp.route("/smtp/configs/<int:config_id>", methods=["GET"])
|
||||
@admin_required
|
||||
def get_smtp_config_api(config_id):
|
||||
"""获取单个SMTP配置详情"""
|
||||
try:
|
||||
config_data = email_service.get_smtp_config(config_id, include_password=False)
|
||||
if not config_data:
|
||||
return jsonify({"error": "配置不存在"}), 404
|
||||
return jsonify(config_data)
|
||||
except Exception as e:
|
||||
logger.error(f"获取SMTP配置失败: {e}")
|
||||
return jsonify({"error": str(e)}), 500
|
||||
|
||||
|
||||
@admin_api_bp.route("/smtp/configs/<int:config_id>", methods=["PUT"])
|
||||
@admin_required
|
||||
def update_smtp_config_api(config_id):
|
||||
"""更新SMTP配置"""
|
||||
try:
|
||||
data = request.json or {}
|
||||
if email_service.update_smtp_config(config_id, data):
|
||||
return jsonify({"success": True})
|
||||
return jsonify({"error": "更新失败"}), 400
|
||||
except Exception as e:
|
||||
logger.error(f"更新SMTP配置失败: {e}")
|
||||
return jsonify({"error": str(e)}), 500
|
||||
|
||||
|
||||
@admin_api_bp.route("/smtp/configs/<int:config_id>", methods=["DELETE"])
|
||||
@admin_required
|
||||
def delete_smtp_config_api(config_id):
|
||||
"""删除SMTP配置"""
|
||||
try:
|
||||
if email_service.delete_smtp_config(config_id):
|
||||
return jsonify({"success": True})
|
||||
return jsonify({"error": "删除失败"}), 400
|
||||
except Exception as e:
|
||||
logger.error(f"删除SMTP配置失败: {e}")
|
||||
return jsonify({"error": str(e)}), 500
|
||||
|
||||
|
||||
@admin_api_bp.route("/smtp/configs/<int:config_id>/test", methods=["POST"])
|
||||
@admin_required
|
||||
def test_smtp_config_api(config_id):
|
||||
"""测试SMTP配置"""
|
||||
try:
|
||||
data = request.json or {}
|
||||
test_email = str(data.get("email", "") or "").strip()
|
||||
if not test_email:
|
||||
return jsonify({"error": "请提供测试邮箱"}), 400
|
||||
|
||||
is_valid, error_msg = validate_email(test_email)
|
||||
if not is_valid:
|
||||
return jsonify({"error": error_msg}), 400
|
||||
|
||||
result = email_service.test_smtp_config(config_id, test_email)
|
||||
return jsonify(result)
|
||||
except Exception as e:
|
||||
logger.error(f"测试SMTP配置失败: {e}")
|
||||
return jsonify({"success": False, "error": str(e)}), 500
|
||||
|
||||
|
||||
@admin_api_bp.route("/smtp/configs/<int:config_id>/primary", methods=["POST"])
|
||||
@admin_required
|
||||
def set_primary_smtp_config_api(config_id):
|
||||
"""设置主SMTP配置"""
|
||||
try:
|
||||
if email_service.set_primary_smtp_config(config_id):
|
||||
return jsonify({"success": True})
|
||||
return jsonify({"error": "设置失败"}), 400
|
||||
except Exception as e:
|
||||
logger.error(f"设置主SMTP配置失败: {e}")
|
||||
return jsonify({"error": str(e)}), 500
|
||||
|
||||
|
||||
@admin_api_bp.route("/smtp/configs/primary/clear", methods=["POST"])
|
||||
@admin_required
|
||||
def clear_primary_smtp_config_api():
|
||||
"""取消主SMTP配置"""
|
||||
try:
|
||||
email_service.clear_primary_smtp_config()
|
||||
return jsonify({"success": True})
|
||||
except Exception as e:
|
||||
logger.error(f"取消主SMTP配置失败: {e}")
|
||||
return jsonify({"error": str(e)}), 500
|
||||
|
||||
|
||||
@admin_api_bp.route("/email/stats", methods=["GET"])
|
||||
@admin_required
|
||||
def get_email_stats_api():
|
||||
"""获取邮件发送统计"""
|
||||
try:
|
||||
stats = email_service.get_email_stats()
|
||||
return jsonify(stats)
|
||||
except Exception as e:
|
||||
logger.error(f"获取邮件统计失败: {e}")
|
||||
return jsonify({"error": str(e)}), 500
|
||||
|
||||
|
||||
@admin_api_bp.route("/email/logs", methods=["GET"])
|
||||
@admin_required
|
||||
def get_email_logs_api():
|
||||
"""获取邮件发送日志"""
|
||||
try:
|
||||
page = request.args.get("page", 1, type=int)
|
||||
page_size = request.args.get("page_size", 20, type=int)
|
||||
email_type = request.args.get("type", None)
|
||||
status = request.args.get("status", None)
|
||||
|
||||
page_size = min(max(page_size, 10), 100)
|
||||
result = email_service.get_email_logs(page, page_size, email_type, status)
|
||||
return jsonify(result)
|
||||
except Exception as e:
|
||||
logger.error(f"获取邮件日志失败: {e}")
|
||||
return jsonify({"error": str(e)}), 500
|
||||
|
||||
|
||||
@admin_api_bp.route("/email/logs/cleanup", methods=["POST"])
|
||||
@admin_required
|
||||
def cleanup_email_logs_api():
|
||||
"""清理过期邮件日志"""
|
||||
try:
|
||||
data = request.json or {}
|
||||
days = data.get("days", 30)
|
||||
days = min(max(days, 7), 365)
|
||||
|
||||
deleted = email_service.cleanup_email_logs(days)
|
||||
return jsonify({"success": True, "deleted": deleted})
|
||||
except Exception as e:
|
||||
logger.error(f"清理邮件日志失败: {e}")
|
||||
return jsonify({"error": str(e)}), 500
|
||||
58
routes/admin_api/feedback_api.py
Normal file
58
routes/admin_api/feedback_api.py
Normal file
@@ -0,0 +1,58 @@
|
||||
#!/usr/bin/env python3
|
||||
# -*- coding: utf-8 -*-
|
||||
from __future__ import annotations
|
||||
|
||||
import database
|
||||
from flask import jsonify, request
|
||||
from routes.admin_api import admin_api_bp
|
||||
from routes.decorators import admin_required
|
||||
|
||||
@admin_api_bp.route("/feedbacks", methods=["GET"])
|
||||
@admin_required
|
||||
def get_all_feedbacks():
|
||||
"""管理员获取所有反馈"""
|
||||
status = request.args.get("status")
|
||||
try:
|
||||
limit = int(request.args.get("limit", 100))
|
||||
offset = int(request.args.get("offset", 0))
|
||||
limit = min(max(1, limit), 1000)
|
||||
offset = max(0, offset)
|
||||
except (ValueError, TypeError):
|
||||
return jsonify({"error": "无效的分页参数"}), 400
|
||||
|
||||
feedbacks = database.get_bug_feedbacks(limit=limit, offset=offset, status_filter=status)
|
||||
stats = database.get_feedback_stats()
|
||||
return jsonify({"feedbacks": feedbacks, "stats": stats})
|
||||
|
||||
|
||||
@admin_api_bp.route("/feedbacks/<int:feedback_id>/reply", methods=["POST"])
|
||||
@admin_required
|
||||
def reply_to_feedback(feedback_id):
|
||||
"""管理员回复反馈"""
|
||||
data = request.get_json() or {}
|
||||
reply = (data.get("reply") or "").strip()
|
||||
|
||||
if not reply:
|
||||
return jsonify({"error": "回复内容不能为空"}), 400
|
||||
|
||||
if database.reply_feedback(feedback_id, reply):
|
||||
return jsonify({"message": "回复成功"})
|
||||
return jsonify({"error": "反馈不存在"}), 404
|
||||
|
||||
|
||||
@admin_api_bp.route("/feedbacks/<int:feedback_id>/close", methods=["POST"])
|
||||
@admin_required
|
||||
def close_feedback_api(feedback_id):
|
||||
"""管理员关闭反馈"""
|
||||
if database.close_feedback(feedback_id):
|
||||
return jsonify({"message": "已关闭"})
|
||||
return jsonify({"error": "反馈不存在"}), 404
|
||||
|
||||
|
||||
@admin_api_bp.route("/feedbacks/<int:feedback_id>", methods=["DELETE"])
|
||||
@admin_required
|
||||
def delete_feedback_api(feedback_id):
|
||||
"""管理员删除反馈"""
|
||||
if database.delete_feedback(feedback_id):
|
||||
return jsonify({"message": "已删除"})
|
||||
return jsonify({"error": "反馈不存在"}), 404
|
||||
226
routes/admin_api/infra_api.py
Normal file
226
routes/admin_api/infra_api.py
Normal file
@@ -0,0 +1,226 @@
|
||||
#!/usr/bin/env python3
|
||||
# -*- coding: utf-8 -*-
|
||||
from __future__ import annotations
|
||||
|
||||
import os
|
||||
import time
|
||||
from datetime import datetime
|
||||
|
||||
import database
|
||||
from app_logger import get_logger
|
||||
from flask import jsonify, session
|
||||
from routes.admin_api import admin_api_bp
|
||||
from routes.decorators import admin_required
|
||||
from services.time_utils import BEIJING_TZ, get_beijing_now
|
||||
|
||||
logger = get_logger("app")
|
||||
|
||||
|
||||
@admin_api_bp.route("/stats", methods=["GET"])
|
||||
@admin_required
|
||||
def get_system_stats():
|
||||
"""获取系统统计"""
|
||||
stats = database.get_system_stats()
|
||||
stats["admin_username"] = session.get("admin_username", "admin")
|
||||
return jsonify(stats)
|
||||
|
||||
|
||||
@admin_api_bp.route("/browser_pool/stats", methods=["GET"])
|
||||
@admin_required
|
||||
def get_browser_pool_stats():
|
||||
"""获取截图线程池状态"""
|
||||
try:
|
||||
from browser_pool_worker import get_browser_worker_pool
|
||||
|
||||
pool = get_browser_worker_pool()
|
||||
stats = pool.get_stats() or {}
|
||||
|
||||
worker_details = []
|
||||
for w in stats.get("workers") or []:
|
||||
last_ts = float(w.get("last_active_ts") or 0)
|
||||
last_active_at = None
|
||||
if last_ts > 0:
|
||||
try:
|
||||
last_active_at = datetime.fromtimestamp(last_ts, tz=BEIJING_TZ).strftime("%Y-%m-%d %H:%M:%S")
|
||||
except Exception:
|
||||
last_active_at = None
|
||||
|
||||
created_ts = w.get("browser_created_at")
|
||||
created_at = None
|
||||
if created_ts:
|
||||
try:
|
||||
created_at = datetime.fromtimestamp(float(created_ts), tz=BEIJING_TZ).strftime("%Y-%m-%d %H:%M:%S")
|
||||
except Exception:
|
||||
created_at = None
|
||||
|
||||
worker_details.append(
|
||||
{
|
||||
"worker_id": w.get("worker_id"),
|
||||
"idle": bool(w.get("idle")),
|
||||
"has_browser": bool(w.get("has_browser")),
|
||||
"total_tasks": int(w.get("total_tasks") or 0),
|
||||
"failed_tasks": int(w.get("failed_tasks") or 0),
|
||||
"browser_use_count": int(w.get("browser_use_count") or 0),
|
||||
"browser_created_at": created_at,
|
||||
"browser_created_ts": created_ts,
|
||||
"last_active_at": last_active_at,
|
||||
"last_active_ts": last_ts,
|
||||
"thread_alive": bool(w.get("thread_alive")),
|
||||
}
|
||||
)
|
||||
|
||||
total_workers = len(worker_details) if worker_details else int(stats.get("pool_size") or 0)
|
||||
return jsonify(
|
||||
{
|
||||
"total_workers": total_workers,
|
||||
"active_workers": int(stats.get("busy_workers") or 0),
|
||||
"idle_workers": int(stats.get("idle_workers") or 0),
|
||||
"queue_size": int(stats.get("queue_size") or 0),
|
||||
"workers": worker_details,
|
||||
"summary": {
|
||||
"total_tasks": int(stats.get("total_tasks") or 0),
|
||||
"failed_tasks": int(stats.get("failed_tasks") or 0),
|
||||
"success_rate": stats.get("success_rate"),
|
||||
},
|
||||
"server_time_cst": get_beijing_now().strftime("%Y-%m-%d %H:%M:%S"),
|
||||
}
|
||||
)
|
||||
except Exception as e:
|
||||
logger.exception(f"[AdminAPI] 获取截图线程池状态失败: {e}")
|
||||
return jsonify({"error": "获取截图线程池状态失败"}), 500
|
||||
|
||||
|
||||
@admin_api_bp.route("/docker_stats", methods=["GET"])
|
||||
@admin_required
|
||||
def get_docker_stats():
|
||||
"""获取Docker容器运行状态"""
|
||||
import subprocess
|
||||
|
||||
docker_status = {
|
||||
"running": False,
|
||||
"container_name": "N/A",
|
||||
"uptime": "N/A",
|
||||
"memory_usage": "N/A",
|
||||
"memory_limit": "N/A",
|
||||
"memory_percent": "N/A",
|
||||
"cpu_percent": "N/A",
|
||||
"status": "Unknown",
|
||||
}
|
||||
|
||||
try:
|
||||
if os.path.exists("/.dockerenv"):
|
||||
docker_status["running"] = True
|
||||
|
||||
try:
|
||||
with open("/etc/hostname", "r") as f:
|
||||
docker_status["container_name"] = f.read().strip()
|
||||
except Exception as e:
|
||||
logger.debug(f"读取容器名称失败: {e}")
|
||||
|
||||
try:
|
||||
if os.path.exists("/sys/fs/cgroup/memory.current"):
|
||||
with open("/sys/fs/cgroup/memory.current", "r") as f:
|
||||
mem_total = int(f.read().strip())
|
||||
|
||||
cache = 0
|
||||
if os.path.exists("/sys/fs/cgroup/memory.stat"):
|
||||
with open("/sys/fs/cgroup/memory.stat", "r") as f:
|
||||
for line in f:
|
||||
if line.startswith("inactive_file "):
|
||||
cache = int(line.split()[1])
|
||||
break
|
||||
|
||||
mem_bytes = mem_total - cache
|
||||
docker_status["memory_usage"] = "{:.2f} MB".format(mem_bytes / 1024 / 1024)
|
||||
|
||||
if os.path.exists("/sys/fs/cgroup/memory.max"):
|
||||
with open("/sys/fs/cgroup/memory.max", "r") as f:
|
||||
limit_str = f.read().strip()
|
||||
if limit_str != "max":
|
||||
limit_bytes = int(limit_str)
|
||||
docker_status["memory_limit"] = "{:.2f} GB".format(limit_bytes / 1024 / 1024 / 1024)
|
||||
docker_status["memory_percent"] = "{:.2f}%".format(mem_bytes / limit_bytes * 100)
|
||||
elif os.path.exists("/sys/fs/cgroup/memory/memory.usage_in_bytes"):
|
||||
with open("/sys/fs/cgroup/memory/memory.usage_in_bytes", "r") as f:
|
||||
mem_bytes = int(f.read().strip())
|
||||
docker_status["memory_usage"] = "{:.2f} MB".format(mem_bytes / 1024 / 1024)
|
||||
|
||||
with open("/sys/fs/cgroup/memory/memory.limit_in_bytes", "r") as f:
|
||||
limit_bytes = int(f.read().strip())
|
||||
if limit_bytes < 1e18:
|
||||
docker_status["memory_limit"] = "{:.2f} GB".format(limit_bytes / 1024 / 1024 / 1024)
|
||||
docker_status["memory_percent"] = "{:.2f}%".format(mem_bytes / limit_bytes * 100)
|
||||
except Exception as e:
|
||||
logger.debug(f"读取内存信息失败: {e}")
|
||||
|
||||
try:
|
||||
if os.path.exists("/sys/fs/cgroup/cpu.stat"):
|
||||
cpu_usage = 0
|
||||
with open("/sys/fs/cgroup/cpu.stat", "r") as f:
|
||||
for line in f:
|
||||
if line.startswith("usage_usec"):
|
||||
cpu_usage = int(line.split()[1])
|
||||
break
|
||||
|
||||
time.sleep(0.1)
|
||||
cpu_usage2 = 0
|
||||
with open("/sys/fs/cgroup/cpu.stat", "r") as f:
|
||||
for line in f:
|
||||
if line.startswith("usage_usec"):
|
||||
cpu_usage2 = int(line.split()[1])
|
||||
break
|
||||
|
||||
cpu_percent = (cpu_usage2 - cpu_usage) / 0.1 / 1e6 * 100
|
||||
docker_status["cpu_percent"] = "{:.2f}%".format(cpu_percent)
|
||||
elif os.path.exists("/sys/fs/cgroup/cpu/cpuacct.usage"):
|
||||
with open("/sys/fs/cgroup/cpu/cpuacct.usage", "r") as f:
|
||||
cpu_usage = int(f.read().strip())
|
||||
|
||||
time.sleep(0.1)
|
||||
with open("/sys/fs/cgroup/cpu/cpuacct.usage", "r") as f:
|
||||
cpu_usage2 = int(f.read().strip())
|
||||
|
||||
cpu_percent = (cpu_usage2 - cpu_usage) / 0.1 / 1e9 * 100
|
||||
docker_status["cpu_percent"] = "{:.2f}%".format(cpu_percent)
|
||||
except Exception as e:
|
||||
logger.debug(f"读取CPU信息失败: {e}")
|
||||
|
||||
try:
|
||||
# 读取系统运行时间
|
||||
with open('/proc/uptime', 'r') as f:
|
||||
system_uptime = float(f.read().split()[0])
|
||||
|
||||
# 读取 PID 1 的启动时间 (jiffies)
|
||||
with open('/proc/1/stat', 'r') as f:
|
||||
stat = f.read().split()
|
||||
starttime_jiffies = int(stat[21])
|
||||
|
||||
# 获取 CLK_TCK (通常是 100)
|
||||
clk_tck = os.sysconf(os.sysconf_names['SC_CLK_TCK'])
|
||||
|
||||
# 计算容器运行时长(秒)
|
||||
container_uptime_seconds = system_uptime - (starttime_jiffies / clk_tck)
|
||||
|
||||
# 格式化为可读字符串
|
||||
days = int(container_uptime_seconds // 86400)
|
||||
hours = int((container_uptime_seconds % 86400) // 3600)
|
||||
minutes = int((container_uptime_seconds % 3600) // 60)
|
||||
|
||||
if days > 0:
|
||||
docker_status["uptime"] = f"{days}天{hours}小时{minutes}分钟"
|
||||
elif hours > 0:
|
||||
docker_status["uptime"] = f"{hours}小时{minutes}分钟"
|
||||
else:
|
||||
docker_status["uptime"] = f"{minutes}分钟"
|
||||
except Exception as e:
|
||||
logger.debug(f"获取容器运行时间失败: {e}")
|
||||
|
||||
docker_status["status"] = "Running"
|
||||
|
||||
else:
|
||||
docker_status["status"] = "Not in Docker"
|
||||
except Exception as e:
|
||||
docker_status["status"] = f"Error: {str(e)}"
|
||||
|
||||
return jsonify(docker_status)
|
||||
|
||||
228
routes/admin_api/operations_api.py
Normal file
228
routes/admin_api/operations_api.py
Normal file
@@ -0,0 +1,228 @@
|
||||
#!/usr/bin/env python3
|
||||
# -*- coding: utf-8 -*-
|
||||
from __future__ import annotations
|
||||
|
||||
import threading
|
||||
import time
|
||||
from datetime import datetime
|
||||
|
||||
import database
|
||||
import requests
|
||||
from app_logger import get_logger
|
||||
from app_security import is_safe_outbound_url
|
||||
from flask import jsonify, request
|
||||
from routes.admin_api import admin_api_bp
|
||||
from routes.decorators import admin_required
|
||||
from services.scheduler import run_scheduled_task
|
||||
from services.time_utils import BEIJING_TZ, get_beijing_now
|
||||
|
||||
logger = get_logger("app")
|
||||
|
||||
_server_cpu_percent_lock = threading.Lock()
|
||||
_server_cpu_percent_last: float | None = None
|
||||
_server_cpu_percent_last_ts = 0.0
|
||||
|
||||
|
||||
def _get_server_cpu_percent() -> float:
|
||||
import psutil
|
||||
|
||||
global _server_cpu_percent_last, _server_cpu_percent_last_ts
|
||||
|
||||
now = time.time()
|
||||
with _server_cpu_percent_lock:
|
||||
if _server_cpu_percent_last is not None and (now - _server_cpu_percent_last_ts) < 0.5:
|
||||
return _server_cpu_percent_last
|
||||
|
||||
try:
|
||||
if _server_cpu_percent_last is None:
|
||||
cpu_percent = float(psutil.cpu_percent(interval=0.1))
|
||||
else:
|
||||
cpu_percent = float(psutil.cpu_percent(interval=None))
|
||||
except Exception:
|
||||
cpu_percent = float(_server_cpu_percent_last or 0.0)
|
||||
|
||||
if cpu_percent < 0:
|
||||
cpu_percent = 0.0
|
||||
|
||||
_server_cpu_percent_last = cpu_percent
|
||||
_server_cpu_percent_last_ts = now
|
||||
return cpu_percent
|
||||
|
||||
|
||||
@admin_api_bp.route("/kdocs/status", methods=["GET"])
|
||||
@admin_required
|
||||
def get_kdocs_status_api():
|
||||
"""获取金山文档上传状态"""
|
||||
try:
|
||||
from services.kdocs_uploader import get_kdocs_uploader
|
||||
|
||||
uploader = get_kdocs_uploader()
|
||||
status = uploader.get_status()
|
||||
live = str(request.args.get("live", "")).lower() in ("1", "true", "yes")
|
||||
if live:
|
||||
live_status = uploader.refresh_login_status()
|
||||
if live_status.get("success"):
|
||||
logged_in = bool(live_status.get("logged_in"))
|
||||
status["logged_in"] = logged_in
|
||||
status["last_login_ok"] = logged_in
|
||||
status["login_required"] = not logged_in
|
||||
if live_status.get("error"):
|
||||
status["last_error"] = live_status.get("error")
|
||||
else:
|
||||
status["logged_in"] = True if status.get("last_login_ok") else False if status.get("last_login_ok") is False else None
|
||||
if status.get("last_login_ok") is True and status.get("last_error") == "操作超时":
|
||||
status["last_error"] = None
|
||||
return jsonify(status)
|
||||
except Exception as e:
|
||||
return jsonify({"error": f"获取状态失败: {e}"}), 500
|
||||
|
||||
|
||||
@admin_api_bp.route("/kdocs/qr", methods=["POST"])
|
||||
@admin_required
|
||||
def get_kdocs_qr_api():
|
||||
"""获取金山文档登录二维码"""
|
||||
try:
|
||||
from services.kdocs_uploader import get_kdocs_uploader
|
||||
|
||||
uploader = get_kdocs_uploader()
|
||||
data = request.get_json(silent=True) or {}
|
||||
force = bool(data.get("force"))
|
||||
if not force:
|
||||
force = str(request.args.get("force", "")).lower() in ("1", "true", "yes")
|
||||
result = uploader.request_qr(force=force)
|
||||
if not result.get("success"):
|
||||
return jsonify({"error": result.get("error", "获取二维码失败")}), 400
|
||||
return jsonify(result)
|
||||
except Exception as e:
|
||||
return jsonify({"error": f"获取二维码失败: {e}"}), 500
|
||||
|
||||
|
||||
@admin_api_bp.route("/kdocs/clear-login", methods=["POST"])
|
||||
@admin_required
|
||||
def clear_kdocs_login_api():
|
||||
"""清除金山文档登录态"""
|
||||
try:
|
||||
from services.kdocs_uploader import get_kdocs_uploader
|
||||
|
||||
uploader = get_kdocs_uploader()
|
||||
result = uploader.clear_login()
|
||||
if not result.get("success"):
|
||||
return jsonify({"error": result.get("error", "清除失败")}), 400
|
||||
return jsonify({"success": True})
|
||||
except Exception as e:
|
||||
return jsonify({"error": f"清除失败: {e}"}), 500
|
||||
|
||||
|
||||
@admin_api_bp.route("/schedule/execute", methods=["POST"])
|
||||
@admin_required
|
||||
def execute_schedule_now():
|
||||
"""立即执行定时任务(无视定时时间和星期限制)"""
|
||||
try:
|
||||
threading.Thread(target=run_scheduled_task, args=(True,), daemon=True).start()
|
||||
logger.info("[立即执行定时任务] 管理员手动触发定时任务执行(跳过星期检查)")
|
||||
return jsonify({"message": "定时任务已开始执行,请查看任务列表获取进度"})
|
||||
except Exception as e:
|
||||
logger.error(f"[立即执行定时任务] 启动失败: {str(e)}")
|
||||
return jsonify({"error": f"启动失败: {str(e)}"}), 500
|
||||
|
||||
|
||||
@admin_api_bp.route("/proxy/config", methods=["GET"])
|
||||
@admin_required
|
||||
def get_proxy_config_api():
|
||||
"""获取代理配置"""
|
||||
config_data = database.get_system_config()
|
||||
return jsonify(
|
||||
{
|
||||
"proxy_enabled": config_data.get("proxy_enabled", 0),
|
||||
"proxy_api_url": config_data.get("proxy_api_url", ""),
|
||||
"proxy_expire_minutes": config_data.get("proxy_expire_minutes", 3),
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
@admin_api_bp.route("/proxy/config", methods=["POST"])
|
||||
@admin_required
|
||||
def update_proxy_config_api():
|
||||
"""更新代理配置"""
|
||||
data = request.json or {}
|
||||
proxy_enabled = data.get("proxy_enabled")
|
||||
proxy_api_url = (data.get("proxy_api_url", "") or "").strip()
|
||||
proxy_expire_minutes = data.get("proxy_expire_minutes")
|
||||
|
||||
if proxy_enabled is not None and proxy_enabled not in [0, 1]:
|
||||
return jsonify({"error": "proxy_enabled必须是0或1"}), 400
|
||||
|
||||
if proxy_expire_minutes is not None:
|
||||
if not isinstance(proxy_expire_minutes, int) or proxy_expire_minutes < 1:
|
||||
return jsonify({"error": "代理有效期必须是大于0的整数"}), 400
|
||||
|
||||
if database.update_system_config(
|
||||
proxy_enabled=proxy_enabled,
|
||||
proxy_api_url=proxy_api_url,
|
||||
proxy_expire_minutes=proxy_expire_minutes,
|
||||
):
|
||||
return jsonify({"message": "代理配置已更新"})
|
||||
return jsonify({"error": "更新失败"}), 400
|
||||
|
||||
|
||||
@admin_api_bp.route("/proxy/test", methods=["POST"])
|
||||
@admin_required
|
||||
def test_proxy_api():
|
||||
"""测试代理连接"""
|
||||
data = request.json or {}
|
||||
api_url = (data.get("api_url") or "").strip()
|
||||
|
||||
if not api_url:
|
||||
return jsonify({"error": "请提供API地址"}), 400
|
||||
|
||||
if not is_safe_outbound_url(api_url):
|
||||
return jsonify({"error": "API地址不可用或不安全"}), 400
|
||||
|
||||
try:
|
||||
response = requests.get(api_url, timeout=10)
|
||||
if response.status_code == 200:
|
||||
ip_port = response.text.strip()
|
||||
if ip_port and ":" in ip_port:
|
||||
return jsonify({"success": True, "proxy": ip_port, "message": f"代理获取成功: {ip_port}"})
|
||||
return jsonify({"success": False, "message": f"代理格式错误: {ip_port}"}), 400
|
||||
return jsonify({"success": False, "message": f"HTTP错误: {response.status_code}"}), 400
|
||||
except Exception as e:
|
||||
return jsonify({"success": False, "message": f"连接失败: {str(e)}"}), 500
|
||||
|
||||
|
||||
@admin_api_bp.route("/server/info", methods=["GET"])
|
||||
@admin_required
|
||||
def get_server_info_api():
|
||||
"""获取服务器信息"""
|
||||
import psutil
|
||||
|
||||
cpu_percent = _get_server_cpu_percent()
|
||||
|
||||
memory = psutil.virtual_memory()
|
||||
memory_total = f"{memory.total / (1024**3):.1f}GB"
|
||||
memory_used = f"{memory.used / (1024**3):.1f}GB"
|
||||
memory_percent = memory.percent
|
||||
|
||||
disk = psutil.disk_usage("/")
|
||||
disk_total = f"{disk.total / (1024**3):.1f}GB"
|
||||
disk_used = f"{disk.used / (1024**3):.1f}GB"
|
||||
disk_percent = disk.percent
|
||||
|
||||
boot_time = datetime.fromtimestamp(psutil.boot_time(), tz=BEIJING_TZ)
|
||||
uptime_delta = get_beijing_now() - boot_time
|
||||
days = uptime_delta.days
|
||||
hours = uptime_delta.seconds // 3600
|
||||
uptime = f"{days}天{hours}小时"
|
||||
|
||||
return jsonify(
|
||||
{
|
||||
"cpu_percent": cpu_percent,
|
||||
"memory_total": memory_total,
|
||||
"memory_used": memory_used,
|
||||
"memory_percent": memory_percent,
|
||||
"disk_total": disk_total,
|
||||
"disk_used": disk_used,
|
||||
"disk_percent": disk_percent,
|
||||
"uptime": uptime,
|
||||
}
|
||||
)
|
||||
@@ -62,6 +62,19 @@ def _parse_bool(value: Any) -> bool:
|
||||
return text in {"1", "true", "yes", "y", "on"}
|
||||
|
||||
|
||||
def _parse_int(value: Any, *, default: int | None = None, min_value: int | None = None) -> int | None:
|
||||
try:
|
||||
parsed = int(value)
|
||||
except Exception:
|
||||
parsed = default
|
||||
|
||||
if parsed is None:
|
||||
return None
|
||||
if min_value is not None:
|
||||
parsed = max(int(min_value), parsed)
|
||||
return parsed
|
||||
|
||||
|
||||
def _sanitize_threat_event(event: dict) -> dict:
|
||||
return {
|
||||
"id": event.get("id"),
|
||||
@@ -199,10 +212,7 @@ def ban_ip():
|
||||
if not reason:
|
||||
return jsonify({"error": "reason不能为空"}), 400
|
||||
|
||||
try:
|
||||
duration_hours = max(1, int(duration_hours_raw))
|
||||
except Exception:
|
||||
duration_hours = 24
|
||||
duration_hours = _parse_int(duration_hours_raw, default=24, min_value=1) or 24
|
||||
|
||||
ok = blacklist.ban_ip(ip, reason, duration_hours=duration_hours, permanent=permanent)
|
||||
if not ok:
|
||||
@@ -235,20 +245,14 @@ def ban_user():
|
||||
duration_hours_raw = data.get("duration_hours", 24)
|
||||
permanent = _parse_bool(data.get("permanent", False))
|
||||
|
||||
try:
|
||||
user_id = int(user_id_raw)
|
||||
except Exception:
|
||||
user_id = None
|
||||
user_id = _parse_int(user_id_raw)
|
||||
|
||||
if user_id is None:
|
||||
return jsonify({"error": "user_id不能为空"}), 400
|
||||
if not reason:
|
||||
return jsonify({"error": "reason不能为空"}), 400
|
||||
|
||||
try:
|
||||
duration_hours = max(1, int(duration_hours_raw))
|
||||
except Exception:
|
||||
duration_hours = 24
|
||||
duration_hours = _parse_int(duration_hours_raw, default=24, min_value=1) or 24
|
||||
|
||||
ok = blacklist._ban_user_internal(user_id, reason=reason, duration_hours=duration_hours, permanent=permanent)
|
||||
if not ok:
|
||||
@@ -262,10 +266,7 @@ def unban_user():
|
||||
"""解除用户封禁"""
|
||||
data = _parse_json()
|
||||
user_id_raw = data.get("user_id")
|
||||
try:
|
||||
user_id = int(user_id_raw)
|
||||
except Exception:
|
||||
user_id = None
|
||||
user_id = _parse_int(user_id_raw)
|
||||
|
||||
if user_id is None:
|
||||
return jsonify({"error": "user_id不能为空"}), 400
|
||||
|
||||
228
routes/admin_api/system_config_api.py
Normal file
228
routes/admin_api/system_config_api.py
Normal file
@@ -0,0 +1,228 @@
|
||||
#!/usr/bin/env python3
|
||||
# -*- coding: utf-8 -*-
|
||||
from __future__ import annotations
|
||||
|
||||
import database
|
||||
from app_logger import get_logger
|
||||
from app_security import is_safe_outbound_url, validate_email
|
||||
from flask import jsonify, request
|
||||
from routes.admin_api import admin_api_bp
|
||||
from routes.decorators import admin_required
|
||||
from services.browse_types import BROWSE_TYPE_SHOULD_READ, validate_browse_type
|
||||
from services.tasks import get_task_scheduler
|
||||
|
||||
logger = get_logger("app")
|
||||
|
||||
|
||||
@admin_api_bp.route("/system/config", methods=["GET"])
|
||||
@admin_required
|
||||
def get_system_config_api():
|
||||
"""获取系统配置"""
|
||||
return jsonify(database.get_system_config())
|
||||
|
||||
|
||||
@admin_api_bp.route("/system/config", methods=["POST"])
|
||||
@admin_required
|
||||
def update_system_config_api():
|
||||
"""更新系统配置"""
|
||||
data = request.json or {}
|
||||
|
||||
max_concurrent = data.get("max_concurrent_global")
|
||||
schedule_enabled = data.get("schedule_enabled")
|
||||
schedule_time = data.get("schedule_time")
|
||||
schedule_browse_type = data.get("schedule_browse_type")
|
||||
schedule_weekdays = data.get("schedule_weekdays")
|
||||
new_max_concurrent_per_account = data.get("max_concurrent_per_account")
|
||||
new_max_screenshot_concurrent = data.get("max_screenshot_concurrent")
|
||||
enable_screenshot = data.get("enable_screenshot")
|
||||
auto_approve_enabled = data.get("auto_approve_enabled")
|
||||
auto_approve_hourly_limit = data.get("auto_approve_hourly_limit")
|
||||
auto_approve_vip_days = data.get("auto_approve_vip_days")
|
||||
kdocs_enabled = data.get("kdocs_enabled")
|
||||
kdocs_doc_url = data.get("kdocs_doc_url")
|
||||
kdocs_default_unit = data.get("kdocs_default_unit")
|
||||
kdocs_sheet_name = data.get("kdocs_sheet_name")
|
||||
kdocs_sheet_index = data.get("kdocs_sheet_index")
|
||||
kdocs_unit_column = data.get("kdocs_unit_column")
|
||||
kdocs_image_column = data.get("kdocs_image_column")
|
||||
kdocs_admin_notify_enabled = data.get("kdocs_admin_notify_enabled")
|
||||
kdocs_admin_notify_email = data.get("kdocs_admin_notify_email")
|
||||
kdocs_row_start = data.get("kdocs_row_start")
|
||||
kdocs_row_end = data.get("kdocs_row_end")
|
||||
|
||||
if max_concurrent is not None:
|
||||
if not isinstance(max_concurrent, int) or max_concurrent < 1:
|
||||
return jsonify({"error": "全局并发数必须大于0(建议:小型服务器2-5,中型5-10,大型10-20)"}), 400
|
||||
|
||||
if new_max_concurrent_per_account is not None:
|
||||
if not isinstance(new_max_concurrent_per_account, int) or new_max_concurrent_per_account < 1:
|
||||
return jsonify({"error": "单账号并发数必须大于0(建议设为1,避免同一用户任务相互影响)"}), 400
|
||||
|
||||
if new_max_screenshot_concurrent is not None:
|
||||
if not isinstance(new_max_screenshot_concurrent, int) or new_max_screenshot_concurrent < 1:
|
||||
return jsonify({"error": "截图并发数必须大于0(建议根据服务器配置设置,wkhtmltoimage 资源占用较低)"}), 400
|
||||
|
||||
if enable_screenshot is not None:
|
||||
if isinstance(enable_screenshot, bool):
|
||||
enable_screenshot = 1 if enable_screenshot else 0
|
||||
if enable_screenshot not in (0, 1):
|
||||
return jsonify({"error": "截图开关必须是0或1"}), 400
|
||||
|
||||
if schedule_time is not None:
|
||||
import re
|
||||
|
||||
if not re.match(r"^([01]\\d|2[0-3]):([0-5]\\d)$", schedule_time):
|
||||
return jsonify({"error": "时间格式错误,应为 HH:MM"}), 400
|
||||
|
||||
if schedule_browse_type is not None:
|
||||
normalized = validate_browse_type(schedule_browse_type, default=BROWSE_TYPE_SHOULD_READ)
|
||||
if not normalized:
|
||||
return jsonify({"error": "浏览类型无效"}), 400
|
||||
schedule_browse_type = normalized
|
||||
|
||||
if schedule_weekdays is not None:
|
||||
try:
|
||||
days = [int(d.strip()) for d in schedule_weekdays.split(",") if d.strip()]
|
||||
if not all(1 <= d <= 7 for d in days):
|
||||
return jsonify({"error": "星期数字必须在1-7之间"}), 400
|
||||
except (ValueError, AttributeError):
|
||||
return jsonify({"error": "星期格式错误"}), 400
|
||||
|
||||
if auto_approve_hourly_limit is not None:
|
||||
if not isinstance(auto_approve_hourly_limit, int) or auto_approve_hourly_limit < 1:
|
||||
return jsonify({"error": "每小时注册限制必须大于0"}), 400
|
||||
|
||||
if auto_approve_vip_days is not None:
|
||||
if not isinstance(auto_approve_vip_days, int) or auto_approve_vip_days < 0:
|
||||
return jsonify({"error": "注册赠送VIP天数不能为负数"}), 400
|
||||
|
||||
if kdocs_enabled is not None:
|
||||
if isinstance(kdocs_enabled, bool):
|
||||
kdocs_enabled = 1 if kdocs_enabled else 0
|
||||
if kdocs_enabled not in (0, 1):
|
||||
return jsonify({"error": "表格上传开关必须是0或1"}), 400
|
||||
|
||||
if kdocs_doc_url is not None:
|
||||
kdocs_doc_url = str(kdocs_doc_url or "").strip()
|
||||
if kdocs_doc_url and not is_safe_outbound_url(kdocs_doc_url):
|
||||
return jsonify({"error": "文档链接格式不正确"}), 400
|
||||
|
||||
if kdocs_default_unit is not None:
|
||||
kdocs_default_unit = str(kdocs_default_unit or "").strip()
|
||||
if len(kdocs_default_unit) > 50:
|
||||
return jsonify({"error": "默认县区长度不能超过50"}), 400
|
||||
|
||||
if kdocs_sheet_name is not None:
|
||||
kdocs_sheet_name = str(kdocs_sheet_name or "").strip()
|
||||
if len(kdocs_sheet_name) > 50:
|
||||
return jsonify({"error": "Sheet名称长度不能超过50"}), 400
|
||||
|
||||
if kdocs_sheet_index is not None:
|
||||
try:
|
||||
kdocs_sheet_index = int(kdocs_sheet_index)
|
||||
except Exception:
|
||||
return jsonify({"error": "Sheet序号必须是数字"}), 400
|
||||
if kdocs_sheet_index < 0:
|
||||
return jsonify({"error": "Sheet序号不能为负数"}), 400
|
||||
|
||||
if kdocs_unit_column is not None:
|
||||
kdocs_unit_column = str(kdocs_unit_column or "").strip().upper()
|
||||
if not kdocs_unit_column:
|
||||
return jsonify({"error": "县区列不能为空"}), 400
|
||||
import re
|
||||
|
||||
if not re.match(r"^[A-Z]{1,3}$", kdocs_unit_column):
|
||||
return jsonify({"error": "县区列格式错误"}), 400
|
||||
|
||||
if kdocs_image_column is not None:
|
||||
kdocs_image_column = str(kdocs_image_column or "").strip().upper()
|
||||
if not kdocs_image_column:
|
||||
return jsonify({"error": "图片列不能为空"}), 400
|
||||
import re
|
||||
|
||||
if not re.match(r"^[A-Z]{1,3}$", kdocs_image_column):
|
||||
return jsonify({"error": "图片列格式错误"}), 400
|
||||
|
||||
if kdocs_admin_notify_enabled is not None:
|
||||
if isinstance(kdocs_admin_notify_enabled, bool):
|
||||
kdocs_admin_notify_enabled = 1 if kdocs_admin_notify_enabled else 0
|
||||
if kdocs_admin_notify_enabled not in (0, 1):
|
||||
return jsonify({"error": "管理员通知开关必须是0或1"}), 400
|
||||
|
||||
if kdocs_admin_notify_email is not None:
|
||||
kdocs_admin_notify_email = str(kdocs_admin_notify_email or "").strip()
|
||||
if kdocs_admin_notify_email:
|
||||
is_valid, error_msg = validate_email(kdocs_admin_notify_email)
|
||||
if not is_valid:
|
||||
return jsonify({"error": error_msg}), 400
|
||||
|
||||
if kdocs_row_start is not None:
|
||||
try:
|
||||
kdocs_row_start = int(kdocs_row_start)
|
||||
except (ValueError, TypeError):
|
||||
return jsonify({"error": "起始行必须是数字"}), 400
|
||||
if kdocs_row_start < 0:
|
||||
return jsonify({"error": "起始行不能为负数"}), 400
|
||||
|
||||
if kdocs_row_end is not None:
|
||||
try:
|
||||
kdocs_row_end = int(kdocs_row_end)
|
||||
except (ValueError, TypeError):
|
||||
return jsonify({"error": "结束行必须是数字"}), 400
|
||||
if kdocs_row_end < 0:
|
||||
return jsonify({"error": "结束行不能为负数"}), 400
|
||||
|
||||
old_config = database.get_system_config() or {}
|
||||
|
||||
if not database.update_system_config(
|
||||
max_concurrent=max_concurrent,
|
||||
schedule_enabled=schedule_enabled,
|
||||
schedule_time=schedule_time,
|
||||
schedule_browse_type=schedule_browse_type,
|
||||
schedule_weekdays=schedule_weekdays,
|
||||
max_concurrent_per_account=new_max_concurrent_per_account,
|
||||
max_screenshot_concurrent=new_max_screenshot_concurrent,
|
||||
enable_screenshot=enable_screenshot,
|
||||
auto_approve_enabled=auto_approve_enabled,
|
||||
auto_approve_hourly_limit=auto_approve_hourly_limit,
|
||||
auto_approve_vip_days=auto_approve_vip_days,
|
||||
kdocs_enabled=kdocs_enabled,
|
||||
kdocs_doc_url=kdocs_doc_url,
|
||||
kdocs_default_unit=kdocs_default_unit,
|
||||
kdocs_sheet_name=kdocs_sheet_name,
|
||||
kdocs_sheet_index=kdocs_sheet_index,
|
||||
kdocs_unit_column=kdocs_unit_column,
|
||||
kdocs_image_column=kdocs_image_column,
|
||||
kdocs_admin_notify_enabled=kdocs_admin_notify_enabled,
|
||||
kdocs_admin_notify_email=kdocs_admin_notify_email,
|
||||
kdocs_row_start=kdocs_row_start,
|
||||
kdocs_row_end=kdocs_row_end,
|
||||
):
|
||||
return jsonify({"error": "更新失败"}), 400
|
||||
|
||||
try:
|
||||
new_config = database.get_system_config() or {}
|
||||
scheduler = get_task_scheduler()
|
||||
scheduler.update_limits(
|
||||
max_global=int(new_config.get("max_concurrent_global", old_config.get("max_concurrent_global", 2))),
|
||||
max_per_user=int(new_config.get("max_concurrent_per_account", old_config.get("max_concurrent_per_account", 1))),
|
||||
)
|
||||
if new_max_screenshot_concurrent is not None:
|
||||
try:
|
||||
from browser_pool_worker import resize_browser_worker_pool
|
||||
|
||||
if resize_browser_worker_pool(int(new_config.get("max_screenshot_concurrent", new_max_screenshot_concurrent))):
|
||||
logger.info(f"截图线程池并发已更新为: {new_config.get('max_screenshot_concurrent')}")
|
||||
except Exception as pool_error:
|
||||
logger.warning(f"截图线程池并发更新失败: {pool_error}")
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
if max_concurrent is not None and max_concurrent != old_config.get("max_concurrent_global"):
|
||||
logger.info(f"全局并发数已更新为: {max_concurrent}")
|
||||
if new_max_concurrent_per_account is not None and new_max_concurrent_per_account != old_config.get("max_concurrent_per_account"):
|
||||
logger.info(f"单用户并发数已更新为: {new_max_concurrent_per_account}")
|
||||
if new_max_screenshot_concurrent is not None:
|
||||
logger.info(f"截图并发数已更新为: {new_max_screenshot_concurrent}")
|
||||
|
||||
return jsonify({"message": "系统配置已更新"})
|
||||
138
routes/admin_api/tasks_api.py
Normal file
138
routes/admin_api/tasks_api.py
Normal file
@@ -0,0 +1,138 @@
|
||||
#!/usr/bin/env python3
|
||||
# -*- coding: utf-8 -*-
|
||||
from __future__ import annotations
|
||||
|
||||
import database
|
||||
from app_logger import get_logger
|
||||
from flask import jsonify, request
|
||||
from routes.admin_api import admin_api_bp
|
||||
from routes.decorators import admin_required
|
||||
from services.state import safe_iter_task_status_items
|
||||
from services.tasks import get_task_scheduler
|
||||
|
||||
logger = get_logger("app")
|
||||
|
||||
|
||||
def _parse_page_int(name: str, default: int, *, minimum: int, maximum: int) -> int:
|
||||
try:
|
||||
value = int(request.args.get(name, default))
|
||||
return max(minimum, min(value, maximum))
|
||||
except (ValueError, TypeError):
|
||||
return default
|
||||
|
||||
|
||||
@admin_api_bp.route("/task/stats", methods=["GET"])
|
||||
@admin_required
|
||||
def get_task_stats_api():
|
||||
"""获取任务统计数据"""
|
||||
date_filter = request.args.get("date")
|
||||
stats = database.get_task_stats(date_filter)
|
||||
return jsonify(stats)
|
||||
|
||||
|
||||
@admin_api_bp.route("/task/running", methods=["GET"])
|
||||
@admin_required
|
||||
def get_running_tasks_api():
|
||||
"""获取当前运行中和排队中的任务"""
|
||||
import time as time_mod
|
||||
|
||||
current_time = time_mod.time()
|
||||
running = []
|
||||
queuing = []
|
||||
user_cache = {}
|
||||
|
||||
for account_id, info in safe_iter_task_status_items():
|
||||
elapsed = int(current_time - info.get("start_time", current_time))
|
||||
|
||||
info_user_id = info.get("user_id")
|
||||
if info_user_id not in user_cache:
|
||||
user_cache[info_user_id] = database.get_user_by_id(info_user_id)
|
||||
user = user_cache.get(info_user_id)
|
||||
user_username = user["username"] if user else "N/A"
|
||||
|
||||
progress = info.get("progress", {"items": 0, "attachments": 0})
|
||||
task_info = {
|
||||
"account_id": account_id,
|
||||
"user_id": info.get("user_id"),
|
||||
"user_username": user_username,
|
||||
"username": info.get("username"),
|
||||
"browse_type": info.get("browse_type"),
|
||||
"source": info.get("source", "manual"),
|
||||
"detail_status": info.get("detail_status", "未知"),
|
||||
"progress_items": progress.get("items", 0),
|
||||
"progress_attachments": progress.get("attachments", 0),
|
||||
"elapsed_seconds": elapsed,
|
||||
"elapsed_display": f"{elapsed // 60}分{elapsed % 60}秒" if elapsed >= 60 else f"{elapsed}秒",
|
||||
}
|
||||
|
||||
if info.get("status") == "运行中":
|
||||
running.append(task_info)
|
||||
else:
|
||||
queuing.append(task_info)
|
||||
|
||||
running.sort(key=lambda x: x["elapsed_seconds"], reverse=True)
|
||||
queuing.sort(key=lambda x: x["elapsed_seconds"], reverse=True)
|
||||
|
||||
try:
|
||||
max_concurrent = int(get_task_scheduler().max_global)
|
||||
except Exception:
|
||||
max_concurrent = int((database.get_system_config() or {}).get("max_concurrent_global", 2))
|
||||
|
||||
return jsonify(
|
||||
{
|
||||
"running": running,
|
||||
"queuing": queuing,
|
||||
"running_count": len(running),
|
||||
"queuing_count": len(queuing),
|
||||
"max_concurrent": max_concurrent,
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
@admin_api_bp.route("/task/logs", methods=["GET"])
|
||||
@admin_required
|
||||
def get_task_logs_api():
|
||||
"""获取任务日志列表(支持分页和多种筛选)"""
|
||||
limit = _parse_page_int("limit", 20, minimum=1, maximum=200)
|
||||
offset = _parse_page_int("offset", 0, minimum=0, maximum=10**9)
|
||||
|
||||
date_filter = request.args.get("date")
|
||||
status_filter = request.args.get("status")
|
||||
source_filter = request.args.get("source")
|
||||
user_id_filter = request.args.get("user_id")
|
||||
account_filter = (request.args.get("account") or "").strip()
|
||||
|
||||
if user_id_filter:
|
||||
try:
|
||||
user_id_filter = int(user_id_filter)
|
||||
except (ValueError, TypeError):
|
||||
user_id_filter = None
|
||||
|
||||
try:
|
||||
result = database.get_task_logs(
|
||||
limit=limit,
|
||||
offset=offset,
|
||||
date_filter=date_filter,
|
||||
status_filter=status_filter,
|
||||
source_filter=source_filter,
|
||||
user_id_filter=user_id_filter,
|
||||
account_filter=account_filter if account_filter else None,
|
||||
)
|
||||
return jsonify(result)
|
||||
except Exception as e:
|
||||
logger.error(f"获取任务日志失败: {e}")
|
||||
return jsonify({"logs": [], "total": 0, "error": "查询失败"})
|
||||
|
||||
|
||||
@admin_api_bp.route("/task/logs/clear", methods=["POST"])
|
||||
@admin_required
|
||||
def clear_old_task_logs_api():
|
||||
"""清理旧的任务日志"""
|
||||
data = request.json or {}
|
||||
days = data.get("days", 30)
|
||||
|
||||
if not isinstance(days, int) or days < 1:
|
||||
return jsonify({"error": "天数必须是大于0的整数"}), 400
|
||||
|
||||
deleted_count = database.delete_old_task_logs(days)
|
||||
return jsonify({"message": f"已删除{days}天前的{deleted_count}条日志"})
|
||||
117
routes/admin_api/users_api.py
Normal file
117
routes/admin_api/users_api.py
Normal file
@@ -0,0 +1,117 @@
|
||||
#!/usr/bin/env python3
|
||||
# -*- coding: utf-8 -*-
|
||||
from __future__ import annotations
|
||||
|
||||
|
||||
import database
|
||||
from flask import jsonify, request
|
||||
from routes.admin_api import admin_api_bp
|
||||
from routes.decorators import admin_required
|
||||
from services.state import safe_clear_user_logs, safe_remove_user_accounts
|
||||
|
||||
|
||||
# ==================== 用户管理/统计(管理员) ====================
|
||||
|
||||
|
||||
@admin_api_bp.route("/users", methods=["GET"])
|
||||
@admin_required
|
||||
def get_all_users():
|
||||
"""获取所有用户"""
|
||||
users = database.get_all_users()
|
||||
return jsonify(users)
|
||||
|
||||
|
||||
@admin_api_bp.route("/users/pending", methods=["GET"])
|
||||
@admin_required
|
||||
def get_pending_users():
|
||||
"""获取待审核用户"""
|
||||
users = database.get_pending_users()
|
||||
return jsonify(users)
|
||||
|
||||
|
||||
@admin_api_bp.route("/users/<int:user_id>/approve", methods=["POST"])
|
||||
@admin_required
|
||||
def approve_user_route(user_id):
|
||||
"""审核通过用户"""
|
||||
if database.approve_user(user_id):
|
||||
return jsonify({"success": True})
|
||||
return jsonify({"error": "审核失败"}), 400
|
||||
|
||||
|
||||
@admin_api_bp.route("/users/<int:user_id>/reject", methods=["POST"])
|
||||
@admin_required
|
||||
def reject_user_route(user_id):
|
||||
"""拒绝用户"""
|
||||
if database.reject_user(user_id):
|
||||
return jsonify({"success": True})
|
||||
return jsonify({"error": "拒绝失败"}), 400
|
||||
|
||||
|
||||
@admin_api_bp.route("/users/<int:user_id>", methods=["DELETE"])
|
||||
@admin_required
|
||||
def delete_user_route(user_id):
|
||||
"""删除用户"""
|
||||
if database.delete_user(user_id):
|
||||
safe_remove_user_accounts(user_id)
|
||||
safe_clear_user_logs(user_id)
|
||||
return jsonify({"success": True})
|
||||
return jsonify({"error": "删除失败"}), 400
|
||||
|
||||
|
||||
# ==================== VIP 管理(管理员) ====================
|
||||
|
||||
|
||||
@admin_api_bp.route("/vip/config", methods=["GET"])
|
||||
@admin_required
|
||||
def get_vip_config_api():
|
||||
"""获取VIP配置"""
|
||||
config = database.get_vip_config()
|
||||
return jsonify(config)
|
||||
|
||||
|
||||
@admin_api_bp.route("/vip/config", methods=["POST"])
|
||||
@admin_required
|
||||
def set_vip_config_api():
|
||||
"""设置默认VIP天数"""
|
||||
data = request.json or {}
|
||||
days = data.get("default_vip_days", 0)
|
||||
|
||||
if not isinstance(days, int) or days < 0:
|
||||
return jsonify({"error": "VIP天数必须是非负整数"}), 400
|
||||
|
||||
database.set_default_vip_days(days)
|
||||
return jsonify({"message": "VIP配置已更新", "default_vip_days": days})
|
||||
|
||||
|
||||
@admin_api_bp.route("/users/<int:user_id>/vip", methods=["POST"])
|
||||
@admin_required
|
||||
def set_user_vip_api(user_id):
|
||||
"""设置用户VIP"""
|
||||
data = request.json or {}
|
||||
days = data.get("days", 30)
|
||||
|
||||
valid_days = [7, 30, 365, 999999]
|
||||
if days not in valid_days:
|
||||
return jsonify({"error": "VIP天数必须是 7/30/365/999999 之一"}), 400
|
||||
|
||||
if database.set_user_vip(user_id, days):
|
||||
vip_type = {7: "一周", 30: "一个月", 365: "一年", 999999: "永久"}[days]
|
||||
return jsonify({"message": f"VIP设置成功: {vip_type}"})
|
||||
return jsonify({"error": "设置失败,用户不存在"}), 400
|
||||
|
||||
|
||||
@admin_api_bp.route("/users/<int:user_id>/vip", methods=["DELETE"])
|
||||
@admin_required
|
||||
def remove_user_vip_api(user_id):
|
||||
"""移除用户VIP"""
|
||||
if database.remove_user_vip(user_id):
|
||||
return jsonify({"message": "VIP已移除"})
|
||||
return jsonify({"error": "移除失败"}), 400
|
||||
|
||||
|
||||
@admin_api_bp.route("/users/<int:user_id>/vip", methods=["GET"])
|
||||
@admin_required
|
||||
def get_user_vip_info_api(user_id):
|
||||
"""获取用户VIP信息(管理员)"""
|
||||
vip_info = database.get_user_vip_info(user_id)
|
||||
return jsonify(vip_info)
|
||||
@@ -40,6 +40,48 @@ def _emit(event: str, data: object, *, room: str | None = None) -> None:
|
||||
pass
|
||||
|
||||
|
||||
def _emit_account_update(user_id: int, account) -> None:
|
||||
_emit("account_update", account.to_dict(), room=f"user_{user_id}")
|
||||
|
||||
|
||||
def _request_json(default=None):
|
||||
if default is None:
|
||||
default = {}
|
||||
data = request.get_json(silent=True)
|
||||
return data if isinstance(data, dict) else default
|
||||
|
||||
|
||||
def _ensure_accounts_loaded(user_id: int) -> dict:
|
||||
accounts = safe_get_user_accounts_snapshot(user_id)
|
||||
if accounts:
|
||||
return accounts
|
||||
load_user_accounts(user_id)
|
||||
return safe_get_user_accounts_snapshot(user_id)
|
||||
|
||||
|
||||
def _get_user_account(user_id: int, account_id: str, *, refresh_if_missing: bool = False):
|
||||
account = safe_get_account(user_id, account_id)
|
||||
if account or (not refresh_if_missing):
|
||||
return account
|
||||
load_user_accounts(user_id)
|
||||
return safe_get_account(user_id, account_id)
|
||||
|
||||
|
||||
def _validate_browse_type_input(raw_browse_type, *, default=BROWSE_TYPE_SHOULD_READ):
|
||||
browse_type = validate_browse_type(raw_browse_type, default=default)
|
||||
if not browse_type:
|
||||
return None, (jsonify({"error": "浏览类型无效"}), 400)
|
||||
return browse_type, None
|
||||
|
||||
|
||||
def _cancel_pending_account_task(user_id: int, account_id: str) -> bool:
|
||||
try:
|
||||
scheduler = get_task_scheduler()
|
||||
return bool(scheduler.cancel_pending_task(user_id=user_id, account_id=account_id))
|
||||
except Exception:
|
||||
return False
|
||||
|
||||
|
||||
@api_accounts_bp.route("/api/accounts", methods=["GET"])
|
||||
@login_required
|
||||
def get_accounts():
|
||||
@@ -49,8 +91,7 @@ def get_accounts():
|
||||
|
||||
accounts = safe_get_user_accounts_snapshot(user_id)
|
||||
if refresh or not accounts:
|
||||
load_user_accounts(user_id)
|
||||
accounts = safe_get_user_accounts_snapshot(user_id)
|
||||
accounts = _ensure_accounts_loaded(user_id)
|
||||
|
||||
return jsonify([acc.to_dict() for acc in accounts.values()])
|
||||
|
||||
@@ -63,20 +104,18 @@ def add_account():
|
||||
|
||||
current_count = len(database.get_user_accounts(user_id))
|
||||
is_vip = database.is_user_vip(user_id)
|
||||
if not is_vip and current_count >= 3:
|
||||
if (not is_vip) and current_count >= 3:
|
||||
return jsonify({"error": "普通用户最多添加3个账号,升级VIP可无限添加"}), 403
|
||||
data = request.json
|
||||
username = data.get("username", "").strip()
|
||||
password = data.get("password", "").strip()
|
||||
remark = data.get("remark", "").strip()[:200]
|
||||
|
||||
data = _request_json()
|
||||
username = str(data.get("username", "")).strip()
|
||||
password = str(data.get("password", "")).strip()
|
||||
remark = str(data.get("remark", "")).strip()[:200]
|
||||
|
||||
if not username or not password:
|
||||
return jsonify({"error": "用户名和密码不能为空"}), 400
|
||||
|
||||
accounts = safe_get_user_accounts_snapshot(user_id)
|
||||
if not accounts:
|
||||
load_user_accounts(user_id)
|
||||
accounts = safe_get_user_accounts_snapshot(user_id)
|
||||
accounts = _ensure_accounts_loaded(user_id)
|
||||
for acc in accounts.values():
|
||||
if acc.username == username:
|
||||
return jsonify({"error": f"账号 '{username}' 已存在"}), 400
|
||||
@@ -92,7 +131,7 @@ def add_account():
|
||||
safe_set_account(user_id, account_id, account)
|
||||
|
||||
log_to_client(f"添加账号: {username}", user_id)
|
||||
_emit("account_update", account.to_dict(), room=f"user_{user_id}")
|
||||
_emit_account_update(user_id, account)
|
||||
|
||||
return jsonify(account.to_dict())
|
||||
|
||||
@@ -103,15 +142,15 @@ def update_account(account_id):
|
||||
"""更新账号信息(密码等)"""
|
||||
user_id = current_user.id
|
||||
|
||||
account = safe_get_account(user_id, account_id)
|
||||
account = _get_user_account(user_id, account_id)
|
||||
if not account:
|
||||
return jsonify({"error": "账号不存在"}), 404
|
||||
|
||||
if account.is_running:
|
||||
return jsonify({"error": "账号正在运行中,请先停止"}), 400
|
||||
|
||||
data = request.json
|
||||
new_password = data.get("password", "").strip()
|
||||
data = _request_json()
|
||||
new_password = str(data.get("password", "")).strip()
|
||||
new_remember = data.get("remember", account.remember)
|
||||
|
||||
if not new_password:
|
||||
@@ -147,7 +186,7 @@ def delete_account(account_id):
|
||||
"""删除账号"""
|
||||
user_id = current_user.id
|
||||
|
||||
account = safe_get_account(user_id, account_id)
|
||||
account = _get_user_account(user_id, account_id)
|
||||
if not account:
|
||||
return jsonify({"error": "账号不存在"}), 404
|
||||
|
||||
@@ -159,7 +198,6 @@ def delete_account(account_id):
|
||||
username = account.username
|
||||
|
||||
database.delete_account(account_id)
|
||||
|
||||
safe_remove_account(user_id, account_id)
|
||||
|
||||
log_to_client(f"删除账号: {username}", user_id)
|
||||
@@ -196,12 +234,12 @@ def update_remark(account_id):
|
||||
"""更新备注"""
|
||||
user_id = current_user.id
|
||||
|
||||
account = safe_get_account(user_id, account_id)
|
||||
account = _get_user_account(user_id, account_id)
|
||||
if not account:
|
||||
return jsonify({"error": "账号不存在"}), 404
|
||||
|
||||
data = request.json
|
||||
remark = data.get("remark", "").strip()[:200]
|
||||
data = _request_json()
|
||||
remark = str(data.get("remark", "")).strip()[:200]
|
||||
|
||||
database.update_account_remark(account_id, remark)
|
||||
|
||||
@@ -217,17 +255,18 @@ def start_account(account_id):
|
||||
"""启动账号任务"""
|
||||
user_id = current_user.id
|
||||
|
||||
account = safe_get_account(user_id, account_id)
|
||||
account = _get_user_account(user_id, account_id)
|
||||
if not account:
|
||||
return jsonify({"error": "账号不存在"}), 404
|
||||
|
||||
if account.is_running:
|
||||
return jsonify({"error": "任务已在运行中"}), 400
|
||||
|
||||
data = request.json or {}
|
||||
browse_type = validate_browse_type(data.get("browse_type"), default=BROWSE_TYPE_SHOULD_READ)
|
||||
if not browse_type:
|
||||
return jsonify({"error": "浏览类型无效"}), 400
|
||||
data = _request_json()
|
||||
browse_type, browse_error = _validate_browse_type_input(data.get("browse_type"), default=BROWSE_TYPE_SHOULD_READ)
|
||||
if browse_error:
|
||||
return browse_error
|
||||
|
||||
enable_screenshot = data.get("enable_screenshot", True)
|
||||
ok, message = submit_account_task(
|
||||
user_id=user_id,
|
||||
@@ -249,7 +288,7 @@ def stop_account(account_id):
|
||||
"""停止账号任务"""
|
||||
user_id = current_user.id
|
||||
|
||||
account = safe_get_account(user_id, account_id)
|
||||
account = _get_user_account(user_id, account_id)
|
||||
if not account:
|
||||
return jsonify({"error": "账号不存在"}), 404
|
||||
|
||||
@@ -259,20 +298,16 @@ def stop_account(account_id):
|
||||
account.should_stop = True
|
||||
account.status = "正在停止"
|
||||
|
||||
try:
|
||||
scheduler = get_task_scheduler()
|
||||
if scheduler.cancel_pending_task(user_id=user_id, account_id=account_id):
|
||||
account.status = "已停止"
|
||||
account.is_running = False
|
||||
safe_remove_task_status(account_id)
|
||||
_emit("account_update", account.to_dict(), room=f"user_{user_id}")
|
||||
log_to_client(f"任务已取消: {account.username}", user_id)
|
||||
return jsonify({"success": True, "canceled": True})
|
||||
except Exception:
|
||||
pass
|
||||
if _cancel_pending_account_task(user_id, account_id):
|
||||
account.status = "已停止"
|
||||
account.is_running = False
|
||||
safe_remove_task_status(account_id)
|
||||
_emit_account_update(user_id, account)
|
||||
log_to_client(f"任务已取消: {account.username}", user_id)
|
||||
return jsonify({"success": True, "canceled": True})
|
||||
|
||||
log_to_client(f"停止任务: {account.username}", user_id)
|
||||
_emit("account_update", account.to_dict(), room=f"user_{user_id}")
|
||||
_emit_account_update(user_id, account)
|
||||
|
||||
return jsonify({"success": True})
|
||||
|
||||
@@ -283,23 +318,20 @@ def manual_screenshot(account_id):
|
||||
"""手动为指定账号截图"""
|
||||
user_id = current_user.id
|
||||
|
||||
account = safe_get_account(user_id, account_id)
|
||||
if not account:
|
||||
load_user_accounts(user_id)
|
||||
account = safe_get_account(user_id, account_id)
|
||||
account = _get_user_account(user_id, account_id, refresh_if_missing=True)
|
||||
if not account:
|
||||
return jsonify({"error": "账号不存在"}), 404
|
||||
if account.is_running:
|
||||
return jsonify({"error": "任务运行中,无法截图"}), 400
|
||||
|
||||
data = request.json or {}
|
||||
data = _request_json()
|
||||
requested_browse_type = data.get("browse_type", None)
|
||||
if requested_browse_type is None:
|
||||
browse_type = normalize_browse_type(account.last_browse_type)
|
||||
else:
|
||||
browse_type = validate_browse_type(requested_browse_type, default=BROWSE_TYPE_SHOULD_READ)
|
||||
if not browse_type:
|
||||
return jsonify({"error": "浏览类型无效"}), 400
|
||||
browse_type, browse_error = _validate_browse_type_input(requested_browse_type, default=BROWSE_TYPE_SHOULD_READ)
|
||||
if browse_error:
|
||||
return browse_error
|
||||
|
||||
account.last_browse_type = browse_type
|
||||
|
||||
@@ -317,12 +349,16 @@ def manual_screenshot(account_id):
|
||||
def batch_start_accounts():
|
||||
"""批量启动账号"""
|
||||
user_id = current_user.id
|
||||
data = request.json or {}
|
||||
data = _request_json()
|
||||
|
||||
account_ids = data.get("account_ids", [])
|
||||
browse_type = validate_browse_type(data.get("browse_type", BROWSE_TYPE_SHOULD_READ), default=BROWSE_TYPE_SHOULD_READ)
|
||||
if not browse_type:
|
||||
return jsonify({"error": "浏览类型无效"}), 400
|
||||
browse_type, browse_error = _validate_browse_type_input(
|
||||
data.get("browse_type", BROWSE_TYPE_SHOULD_READ),
|
||||
default=BROWSE_TYPE_SHOULD_READ,
|
||||
)
|
||||
if browse_error:
|
||||
return browse_error
|
||||
|
||||
enable_screenshot = data.get("enable_screenshot", True)
|
||||
|
||||
if not account_ids:
|
||||
@@ -331,11 +367,10 @@ def batch_start_accounts():
|
||||
started = []
|
||||
failed = []
|
||||
|
||||
if not safe_get_user_accounts_snapshot(user_id):
|
||||
load_user_accounts(user_id)
|
||||
_ensure_accounts_loaded(user_id)
|
||||
|
||||
for account_id in account_ids:
|
||||
account = safe_get_account(user_id, account_id)
|
||||
account = _get_user_account(user_id, account_id)
|
||||
if not account:
|
||||
failed.append({"id": account_id, "reason": "账号不存在"})
|
||||
continue
|
||||
@@ -357,7 +392,13 @@ def batch_start_accounts():
|
||||
failed.append({"id": account_id, "reason": msg})
|
||||
|
||||
return jsonify(
|
||||
{"success": True, "started_count": len(started), "failed_count": len(failed), "started": started, "failed": failed}
|
||||
{
|
||||
"success": True,
|
||||
"started_count": len(started),
|
||||
"failed_count": len(failed),
|
||||
"started": started,
|
||||
"failed": failed,
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
@@ -366,39 +407,29 @@ def batch_start_accounts():
|
||||
def batch_stop_accounts():
|
||||
"""批量停止账号"""
|
||||
user_id = current_user.id
|
||||
data = request.json
|
||||
data = _request_json()
|
||||
|
||||
account_ids = data.get("account_ids", [])
|
||||
|
||||
if not account_ids:
|
||||
return jsonify({"error": "请选择要停止的账号"}), 400
|
||||
|
||||
stopped = []
|
||||
|
||||
if not safe_get_user_accounts_snapshot(user_id):
|
||||
load_user_accounts(user_id)
|
||||
_ensure_accounts_loaded(user_id)
|
||||
|
||||
for account_id in account_ids:
|
||||
account = safe_get_account(user_id, account_id)
|
||||
if not account:
|
||||
continue
|
||||
|
||||
if not account.is_running:
|
||||
account = _get_user_account(user_id, account_id)
|
||||
if (not account) or (not account.is_running):
|
||||
continue
|
||||
|
||||
account.should_stop = True
|
||||
account.status = "正在停止"
|
||||
stopped.append(account_id)
|
||||
|
||||
try:
|
||||
scheduler = get_task_scheduler()
|
||||
if scheduler.cancel_pending_task(user_id=user_id, account_id=account_id):
|
||||
account.status = "已停止"
|
||||
account.is_running = False
|
||||
safe_remove_task_status(account_id)
|
||||
except Exception:
|
||||
pass
|
||||
if _cancel_pending_account_task(user_id, account_id):
|
||||
account.status = "已停止"
|
||||
account.is_running = False
|
||||
safe_remove_task_status(account_id)
|
||||
|
||||
_emit("account_update", account.to_dict(), room=f"user_{user_id}")
|
||||
_emit_account_update(user_id, account)
|
||||
|
||||
return jsonify({"success": True, "stopped_count": len(stopped), "stopped": stopped})
|
||||
|
||||
@@ -2,16 +2,20 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
from __future__ import annotations
|
||||
|
||||
import base64
|
||||
import random
|
||||
import secrets
|
||||
import threading
|
||||
import time
|
||||
import uuid
|
||||
from io import BytesIO
|
||||
|
||||
import database
|
||||
import email_service
|
||||
from app_config import get_config
|
||||
from app_logger import get_logger
|
||||
from app_security import get_rate_limit_ip, require_ip_not_locked, validate_email, validate_password, validate_username
|
||||
from flask import Blueprint, jsonify, redirect, render_template, request, url_for
|
||||
from flask import Blueprint, jsonify, request
|
||||
from flask_login import login_required, login_user, logout_user
|
||||
from routes.pages import render_app_spa_or_legacy
|
||||
from services.accounts_service import load_user_accounts
|
||||
@@ -39,12 +43,162 @@ config = get_config()
|
||||
|
||||
api_auth_bp = Blueprint("api_auth", __name__)
|
||||
|
||||
_CAPTCHA_FONT_PATHS = [
|
||||
"/usr/share/fonts/truetype/liberation/LiberationSans-Bold.ttf",
|
||||
"/usr/share/fonts/truetype/dejavu/DejaVuSans-Bold.ttf",
|
||||
"/usr/share/fonts/truetype/freefont/FreeSansBold.ttf",
|
||||
]
|
||||
_CAPTCHA_FONT = None
|
||||
_CAPTCHA_FONT_LOCK = threading.Lock()
|
||||
|
||||
|
||||
def _get_json_payload() -> dict:
|
||||
data = request.get_json(silent=True)
|
||||
return data if isinstance(data, dict) else {}
|
||||
|
||||
|
||||
def _load_captcha_font(image_font_module):
|
||||
global _CAPTCHA_FONT
|
||||
|
||||
if _CAPTCHA_FONT is not None:
|
||||
return _CAPTCHA_FONT
|
||||
|
||||
with _CAPTCHA_FONT_LOCK:
|
||||
if _CAPTCHA_FONT is not None:
|
||||
return _CAPTCHA_FONT
|
||||
|
||||
for font_path in _CAPTCHA_FONT_PATHS:
|
||||
try:
|
||||
_CAPTCHA_FONT = image_font_module.truetype(font_path, 42)
|
||||
break
|
||||
except Exception:
|
||||
continue
|
||||
|
||||
if _CAPTCHA_FONT is None:
|
||||
_CAPTCHA_FONT = image_font_module.load_default()
|
||||
|
||||
return _CAPTCHA_FONT
|
||||
|
||||
|
||||
def _generate_captcha_image_data_uri(code: str) -> str:
|
||||
from PIL import Image, ImageDraw, ImageFont
|
||||
|
||||
width, height = 160, 60
|
||||
image = Image.new("RGB", (width, height), color=(255, 255, 255))
|
||||
draw = ImageDraw.Draw(image)
|
||||
|
||||
for _ in range(6):
|
||||
x1 = random.randint(0, width)
|
||||
y1 = random.randint(0, height)
|
||||
x2 = random.randint(0, width)
|
||||
y2 = random.randint(0, height)
|
||||
draw.line(
|
||||
[(x1, y1), (x2, y2)],
|
||||
fill=(random.randint(0, 200), random.randint(0, 200), random.randint(0, 200)),
|
||||
width=1,
|
||||
)
|
||||
|
||||
for _ in range(80):
|
||||
x = random.randint(0, width)
|
||||
y = random.randint(0, height)
|
||||
draw.point((x, y), fill=(random.randint(0, 200), random.randint(0, 200), random.randint(0, 200)))
|
||||
|
||||
font = _load_captcha_font(ImageFont)
|
||||
for i, char in enumerate(code):
|
||||
x = 12 + i * 35 + random.randint(-3, 3)
|
||||
y = random.randint(5, 12)
|
||||
color = (random.randint(0, 150), random.randint(0, 150), random.randint(0, 150))
|
||||
draw.text((x, y), char, font=font, fill=color)
|
||||
|
||||
buffer = BytesIO()
|
||||
image.save(buffer, format="PNG")
|
||||
img_base64 = base64.b64encode(buffer.getvalue()).decode("utf-8")
|
||||
return f"data:image/png;base64,{img_base64}"
|
||||
|
||||
|
||||
def _with_vip_suffix(message: str, auto_approve_enabled: bool, auto_approve_vip_days: int) -> str:
|
||||
if auto_approve_enabled and auto_approve_vip_days > 0:
|
||||
return f"{message},赠送{auto_approve_vip_days}天VIP"
|
||||
return message
|
||||
|
||||
|
||||
def _verify_common_captcha(client_ip: str, captcha_session: str, captcha_code: str):
|
||||
success, message = safe_verify_and_consume_captcha(captcha_session, captcha_code)
|
||||
if success:
|
||||
return True, None
|
||||
|
||||
is_locked = record_failed_captcha(client_ip)
|
||||
if is_locked:
|
||||
return False, (jsonify({"error": "验证码错误次数过多,IP已被锁定1小时"}), 429)
|
||||
return False, (jsonify({"error": message}), 400)
|
||||
|
||||
|
||||
def _verify_login_captcha_if_needed(
|
||||
*,
|
||||
captcha_required: bool,
|
||||
captcha_session: str,
|
||||
captcha_code: str,
|
||||
client_ip: str,
|
||||
username_key: str,
|
||||
):
|
||||
if not captcha_required:
|
||||
return True, None
|
||||
|
||||
if not captcha_session or not captcha_code:
|
||||
return False, (jsonify({"error": "请填写验证码", "need_captcha": True}), 400)
|
||||
|
||||
success, message = safe_verify_and_consume_captcha(captcha_session, captcha_code)
|
||||
if success:
|
||||
return True, None
|
||||
|
||||
record_login_failure(client_ip, username_key)
|
||||
return False, (jsonify({"error": message, "need_captcha": True}), 400)
|
||||
|
||||
|
||||
def _send_password_reset_email_if_possible(email: str, username: str, user_id: int) -> None:
|
||||
result = email_service.send_password_reset_email(email=email, username=username, user_id=user_id)
|
||||
if not result["success"]:
|
||||
logger.error(f"密码重置邮件发送失败: {result['error']}")
|
||||
|
||||
|
||||
def _send_login_security_alert_if_needed(user: dict, username: str, client_ip: str) -> None:
|
||||
try:
|
||||
user_agent = request.headers.get("User-Agent", "")
|
||||
context = database.record_login_context(user["id"], client_ip, user_agent)
|
||||
if not context or (not context.get("new_ip") and not context.get("new_device")):
|
||||
return
|
||||
|
||||
if not config.LOGIN_ALERT_ENABLED:
|
||||
return
|
||||
if not should_send_login_alert(user["id"], client_ip):
|
||||
return
|
||||
if not email_service.get_email_settings().get("login_alert_enabled", True):
|
||||
return
|
||||
|
||||
user_info = database.get_user_by_id(user["id"]) or {}
|
||||
if (not user_info.get("email")) or (not user_info.get("email_verified")):
|
||||
return
|
||||
if not database.get_user_email_notify(user["id"]):
|
||||
return
|
||||
|
||||
email_service.send_security_alert_email(
|
||||
email=user_info.get("email"),
|
||||
username=user_info.get("username") or username,
|
||||
ip_address=client_ip,
|
||||
user_agent=user_agent,
|
||||
new_ip=context.get("new_ip", False),
|
||||
new_device=context.get("new_device", False),
|
||||
user_id=user["id"],
|
||||
)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
|
||||
@api_auth_bp.route("/api/register", methods=["POST"])
|
||||
@require_ip_not_locked
|
||||
def register():
|
||||
"""用户注册"""
|
||||
data = request.json or {}
|
||||
data = _get_json_payload()
|
||||
username = data.get("username", "").strip()
|
||||
password = data.get("password", "").strip()
|
||||
email = data.get("email", "").strip().lower()
|
||||
@@ -67,12 +221,9 @@ def register():
|
||||
if not allowed:
|
||||
return jsonify({"error": error_msg}), 429
|
||||
|
||||
success, message = safe_verify_and_consume_captcha(captcha_session, captcha_code)
|
||||
if not success:
|
||||
is_locked = record_failed_captcha(client_ip)
|
||||
if is_locked:
|
||||
return jsonify({"error": "验证码错误次数过多,IP已被锁定1小时"}), 429
|
||||
return jsonify({"error": message}), 400
|
||||
captcha_ok, captcha_error_response = _verify_common_captcha(client_ip, captcha_session, captcha_code)
|
||||
if not captcha_ok:
|
||||
return captcha_error_response
|
||||
|
||||
email_settings = email_service.get_email_settings()
|
||||
email_verify_enabled = email_settings.get("register_verify_enabled", False) and email_settings.get("enabled", False)
|
||||
@@ -105,20 +256,22 @@ def register():
|
||||
if email_verify_enabled and email:
|
||||
result = email_service.send_register_verification_email(email=email, username=username, user_id=user_id)
|
||||
if result["success"]:
|
||||
message = "注册成功!验证邮件已发送(可直接登录,建议完成邮箱验证)"
|
||||
if auto_approve_enabled and auto_approve_vip_days > 0:
|
||||
message += f",赠送{auto_approve_vip_days}天VIP"
|
||||
message = _with_vip_suffix(
|
||||
"注册成功!验证邮件已发送(可直接登录,建议完成邮箱验证)",
|
||||
auto_approve_enabled,
|
||||
auto_approve_vip_days,
|
||||
)
|
||||
return jsonify({"success": True, "message": message, "need_verify": True})
|
||||
|
||||
logger.error(f"注册验证邮件发送失败: {result['error']}")
|
||||
message = f"注册成功,但验证邮件发送失败({result['error']})。你仍可直接登录"
|
||||
if auto_approve_enabled and auto_approve_vip_days > 0:
|
||||
message += f",赠送{auto_approve_vip_days}天VIP"
|
||||
message = _with_vip_suffix(
|
||||
f"注册成功,但验证邮件发送失败({result['error']})。你仍可直接登录",
|
||||
auto_approve_enabled,
|
||||
auto_approve_vip_days,
|
||||
)
|
||||
return jsonify({"success": True, "message": message, "need_verify": True})
|
||||
|
||||
message = "注册成功!可直接登录"
|
||||
if auto_approve_enabled and auto_approve_vip_days > 0:
|
||||
message += f",赠送{auto_approve_vip_days}天VIP"
|
||||
message = _with_vip_suffix("注册成功!可直接登录", auto_approve_enabled, auto_approve_vip_days)
|
||||
return jsonify({"success": True, "message": message})
|
||||
return jsonify({"error": "用户名已存在"}), 400
|
||||
|
||||
@@ -175,7 +328,7 @@ def verify_email(token):
|
||||
@require_ip_not_locked
|
||||
def resend_verify_email():
|
||||
"""重发验证邮件"""
|
||||
data = request.json or {}
|
||||
data = _get_json_payload()
|
||||
email = data.get("email", "").strip().lower()
|
||||
captcha_session = data.get("captcha_session", "")
|
||||
captcha_code = data.get("captcha", "").strip()
|
||||
@@ -195,12 +348,9 @@ def resend_verify_email():
|
||||
if not allowed:
|
||||
return jsonify({"error": error_msg}), 429
|
||||
|
||||
success, message = safe_verify_and_consume_captcha(captcha_session, captcha_code)
|
||||
if not success:
|
||||
is_locked = record_failed_captcha(client_ip)
|
||||
if is_locked:
|
||||
return jsonify({"error": "验证码错误次数过多,IP已被锁定1小时"}), 429
|
||||
return jsonify({"error": message}), 400
|
||||
captcha_ok, captcha_error_response = _verify_common_captcha(client_ip, captcha_session, captcha_code)
|
||||
if not captcha_ok:
|
||||
return captcha_error_response
|
||||
|
||||
user = database.get_user_by_email(email)
|
||||
if not user:
|
||||
@@ -235,7 +385,7 @@ def get_email_verify_status():
|
||||
@require_ip_not_locked
|
||||
def forgot_password():
|
||||
"""发送密码重置邮件"""
|
||||
data = request.json or {}
|
||||
data = _get_json_payload()
|
||||
email = data.get("email", "").strip().lower()
|
||||
username = data.get("username", "").strip()
|
||||
captcha_session = data.get("captcha_session", "")
|
||||
@@ -263,12 +413,9 @@ def forgot_password():
|
||||
if not allowed:
|
||||
return jsonify({"error": error_msg}), 429
|
||||
|
||||
success, message = safe_verify_and_consume_captcha(captcha_session, captcha_code)
|
||||
if not success:
|
||||
is_locked = record_failed_captcha(client_ip)
|
||||
if is_locked:
|
||||
return jsonify({"error": "验证码错误次数过多,IP已被锁定1小时"}), 429
|
||||
return jsonify({"error": message}), 400
|
||||
captcha_ok, captcha_error_response = _verify_common_captcha(client_ip, captcha_session, captcha_code)
|
||||
if not captcha_ok:
|
||||
return captcha_error_response
|
||||
|
||||
email_settings = email_service.get_email_settings()
|
||||
if not email_settings.get("enabled", False):
|
||||
@@ -293,20 +440,16 @@ def forgot_password():
|
||||
if not allowed:
|
||||
return jsonify({"error": error_msg}), 429
|
||||
|
||||
result = email_service.send_password_reset_email(
|
||||
_send_password_reset_email_if_possible(
|
||||
email=bound_email,
|
||||
username=user["username"],
|
||||
user_id=user["id"],
|
||||
)
|
||||
if not result["success"]:
|
||||
logger.error(f"密码重置邮件发送失败: {result['error']}")
|
||||
return jsonify({"success": True, "message": "如果该账号已绑定邮箱,您将收到密码重置邮件"})
|
||||
|
||||
user = database.get_user_by_email(email)
|
||||
if user and user.get("status") == "approved":
|
||||
result = email_service.send_password_reset_email(email=email, username=user["username"], user_id=user["id"])
|
||||
if not result["success"]:
|
||||
logger.error(f"密码重置邮件发送失败: {result['error']}")
|
||||
_send_password_reset_email_if_possible(email=email, username=user["username"], user_id=user["id"])
|
||||
|
||||
return jsonify({"success": True, "message": "如果该邮箱已注册,您将收到密码重置邮件"})
|
||||
|
||||
@@ -331,7 +474,7 @@ def reset_password_page(token):
|
||||
@api_auth_bp.route("/api/reset-password-confirm", methods=["POST"])
|
||||
def reset_password_confirm():
|
||||
"""确认密码重置"""
|
||||
data = request.json or {}
|
||||
data = _get_json_payload()
|
||||
token = data.get("token", "").strip()
|
||||
new_password = data.get("new_password", "").strip()
|
||||
|
||||
@@ -356,67 +499,15 @@ def reset_password_confirm():
|
||||
@api_auth_bp.route("/api/generate_captcha", methods=["POST"])
|
||||
def generate_captcha():
|
||||
"""生成4位数字验证码图片"""
|
||||
import base64
|
||||
import uuid
|
||||
from io import BytesIO
|
||||
|
||||
session_id = str(uuid.uuid4())
|
||||
|
||||
code = "".join([str(secrets.randbelow(10)) for _ in range(4)])
|
||||
code = "".join(str(secrets.randbelow(10)) for _ in range(4))
|
||||
|
||||
safe_set_captcha(session_id, {"code": code, "expire_time": time.time() + 300, "failed_attempts": 0})
|
||||
safe_cleanup_expired_captcha()
|
||||
|
||||
try:
|
||||
from PIL import Image, ImageDraw, ImageFont
|
||||
import io
|
||||
|
||||
width, height = 160, 60
|
||||
image = Image.new("RGB", (width, height), color=(255, 255, 255))
|
||||
draw = ImageDraw.Draw(image)
|
||||
|
||||
for _ in range(6):
|
||||
x1 = random.randint(0, width)
|
||||
y1 = random.randint(0, height)
|
||||
x2 = random.randint(0, width)
|
||||
y2 = random.randint(0, height)
|
||||
draw.line(
|
||||
[(x1, y1), (x2, y2)],
|
||||
fill=(random.randint(0, 200), random.randint(0, 200), random.randint(0, 200)),
|
||||
width=1,
|
||||
)
|
||||
|
||||
for _ in range(80):
|
||||
x = random.randint(0, width)
|
||||
y = random.randint(0, height)
|
||||
draw.point((x, y), fill=(random.randint(0, 200), random.randint(0, 200), random.randint(0, 200)))
|
||||
|
||||
font = None
|
||||
font_paths = [
|
||||
"/usr/share/fonts/truetype/liberation/LiberationSans-Bold.ttf",
|
||||
"/usr/share/fonts/truetype/dejavu/DejaVuSans-Bold.ttf",
|
||||
"/usr/share/fonts/truetype/freefont/FreeSansBold.ttf",
|
||||
]
|
||||
for font_path in font_paths:
|
||||
try:
|
||||
font = ImageFont.truetype(font_path, 42)
|
||||
break
|
||||
except Exception:
|
||||
continue
|
||||
if font is None:
|
||||
font = ImageFont.load_default()
|
||||
|
||||
for i, char in enumerate(code):
|
||||
x = 12 + i * 35 + random.randint(-3, 3)
|
||||
y = random.randint(5, 12)
|
||||
color = (random.randint(0, 150), random.randint(0, 150), random.randint(0, 150))
|
||||
draw.text((x, y), char, font=font, fill=color)
|
||||
|
||||
buffer = io.BytesIO()
|
||||
image.save(buffer, format="PNG")
|
||||
img_base64 = base64.b64encode(buffer.getvalue()).decode("utf-8")
|
||||
|
||||
return jsonify({"session_id": session_id, "captcha_image": f"data:image/png;base64,{img_base64}"})
|
||||
captcha_image = _generate_captcha_image_data_uri(code)
|
||||
return jsonify({"session_id": session_id, "captcha_image": captcha_image})
|
||||
except ImportError as e:
|
||||
logger.error(f"PIL库未安装,验证码功能不可用: {e}")
|
||||
safe_delete_captcha(session_id)
|
||||
@@ -427,7 +518,7 @@ def generate_captcha():
|
||||
@require_ip_not_locked
|
||||
def login():
|
||||
"""用户登录"""
|
||||
data = request.json or {}
|
||||
data = _get_json_payload()
|
||||
username = data.get("username", "").strip()
|
||||
password = data.get("password", "").strip()
|
||||
captcha_session = data.get("captcha_session", "")
|
||||
@@ -452,13 +543,15 @@ def login():
|
||||
return jsonify({"error": error_msg, "need_captcha": True}), 429
|
||||
|
||||
captcha_required = check_login_captcha_required(client_ip, username_key) or scan_locked or bool(need_captcha)
|
||||
if captcha_required:
|
||||
if not captcha_session or not captcha_code:
|
||||
return jsonify({"error": "请填写验证码", "need_captcha": True}), 400
|
||||
success, message = safe_verify_and_consume_captcha(captcha_session, captcha_code)
|
||||
if not success:
|
||||
record_login_failure(client_ip, username_key)
|
||||
return jsonify({"error": message, "need_captcha": True}), 400
|
||||
captcha_ok, captcha_error_response = _verify_login_captcha_if_needed(
|
||||
captcha_required=captcha_required,
|
||||
captcha_session=captcha_session,
|
||||
captcha_code=captcha_code,
|
||||
client_ip=client_ip,
|
||||
username_key=username_key,
|
||||
)
|
||||
if not captcha_ok:
|
||||
return captcha_error_response
|
||||
|
||||
user = database.verify_user(username, password)
|
||||
if not user:
|
||||
@@ -476,29 +569,7 @@ def login():
|
||||
login_user(user_obj)
|
||||
load_user_accounts(user["id"])
|
||||
|
||||
try:
|
||||
user_agent = request.headers.get("User-Agent", "")
|
||||
context = database.record_login_context(user["id"], client_ip, user_agent)
|
||||
if context and (context.get("new_ip") or context.get("new_device")):
|
||||
if (
|
||||
config.LOGIN_ALERT_ENABLED
|
||||
and should_send_login_alert(user["id"], client_ip)
|
||||
and email_service.get_email_settings().get("login_alert_enabled", True)
|
||||
):
|
||||
user_info = database.get_user_by_id(user["id"]) or {}
|
||||
if user_info.get("email") and user_info.get("email_verified"):
|
||||
if database.get_user_email_notify(user["id"]):
|
||||
email_service.send_security_alert_email(
|
||||
email=user_info.get("email"),
|
||||
username=user_info.get("username") or username,
|
||||
ip_address=client_ip,
|
||||
user_agent=user_agent,
|
||||
new_ip=context.get("new_ip", False),
|
||||
new_device=context.get("new_device", False),
|
||||
user_id=user["id"],
|
||||
)
|
||||
except Exception:
|
||||
pass
|
||||
_send_login_security_alert_if_needed(user=user, username=username, client_ip=client_ip)
|
||||
return jsonify({"success": True})
|
||||
|
||||
|
||||
|
||||
@@ -2,7 +2,11 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import re
|
||||
import threading
|
||||
import time as time_mod
|
||||
import uuid
|
||||
|
||||
import database
|
||||
from flask import Blueprint, jsonify, request
|
||||
@@ -17,6 +21,13 @@ api_schedules_bp = Blueprint("api_schedules", __name__)
|
||||
_HHMM_RE = re.compile(r"^(\d{1,2}):(\d{2})$")
|
||||
|
||||
|
||||
def _request_json(default=None):
|
||||
if default is None:
|
||||
default = {}
|
||||
data = request.get_json(silent=True)
|
||||
return data if isinstance(data, dict) else default
|
||||
|
||||
|
||||
def _normalize_hhmm(value: object) -> str | None:
|
||||
match = _HHMM_RE.match(str(value or "").strip())
|
||||
if not match:
|
||||
@@ -28,18 +39,53 @@ def _normalize_hhmm(value: object) -> str | None:
|
||||
return f"{hour:02d}:{minute:02d}"
|
||||
|
||||
|
||||
def _normalize_random_delay(value) -> tuple[int | None, str | None]:
|
||||
try:
|
||||
normalized = int(value or 0)
|
||||
except Exception:
|
||||
return None, "random_delay必须是0或1"
|
||||
if normalized not in (0, 1):
|
||||
return None, "random_delay必须是0或1"
|
||||
return normalized, None
|
||||
|
||||
|
||||
def _parse_schedule_account_ids(raw_value) -> list:
|
||||
try:
|
||||
parsed = json.loads(raw_value or "[]")
|
||||
except (json.JSONDecodeError, TypeError):
|
||||
return []
|
||||
return parsed if isinstance(parsed, list) else []
|
||||
|
||||
|
||||
def _get_owned_schedule_or_error(schedule_id: int):
|
||||
schedule = database.get_schedule_by_id(schedule_id)
|
||||
if not schedule:
|
||||
return None, (jsonify({"error": "定时任务不存在"}), 404)
|
||||
if schedule.get("user_id") != current_user.id:
|
||||
return None, (jsonify({"error": "无权访问"}), 403)
|
||||
return schedule, None
|
||||
|
||||
|
||||
def _ensure_user_accounts_loaded(user_id: int) -> None:
|
||||
if safe_get_user_accounts_snapshot(user_id):
|
||||
return
|
||||
load_user_accounts(user_id)
|
||||
|
||||
|
||||
def _parse_browse_type_or_error(raw_value, *, default=BROWSE_TYPE_SHOULD_READ):
|
||||
browse_type = validate_browse_type(raw_value, default=default)
|
||||
if not browse_type:
|
||||
return None, (jsonify({"error": "浏览类型无效"}), 400)
|
||||
return browse_type, None
|
||||
|
||||
|
||||
@api_schedules_bp.route("/api/schedules", methods=["GET"])
|
||||
@login_required
|
||||
def get_user_schedules_api():
|
||||
"""获取当前用户的所有定时任务"""
|
||||
schedules = database.get_user_schedules(current_user.id)
|
||||
import json
|
||||
|
||||
for s in schedules:
|
||||
try:
|
||||
s["account_ids"] = json.loads(s.get("account_ids", "[]") or "[]")
|
||||
except (json.JSONDecodeError, TypeError):
|
||||
s["account_ids"] = []
|
||||
for schedule in schedules:
|
||||
schedule["account_ids"] = _parse_schedule_account_ids(schedule.get("account_ids"))
|
||||
return jsonify(schedules)
|
||||
|
||||
|
||||
@@ -47,23 +93,26 @@ def get_user_schedules_api():
|
||||
@login_required
|
||||
def create_user_schedule_api():
|
||||
"""创建用户定时任务"""
|
||||
data = request.json or {}
|
||||
data = _request_json()
|
||||
|
||||
name = data.get("name", "我的定时任务")
|
||||
schedule_time = data.get("schedule_time", "08:00")
|
||||
weekdays = data.get("weekdays", "1,2,3,4,5")
|
||||
browse_type = validate_browse_type(data.get("browse_type", BROWSE_TYPE_SHOULD_READ), default=BROWSE_TYPE_SHOULD_READ)
|
||||
if not browse_type:
|
||||
return jsonify({"error": "浏览类型无效"}), 400
|
||||
|
||||
browse_type, browse_error = _parse_browse_type_or_error(data.get("browse_type", BROWSE_TYPE_SHOULD_READ))
|
||||
if browse_error:
|
||||
return browse_error
|
||||
|
||||
enable_screenshot = data.get("enable_screenshot", 1)
|
||||
random_delay = int(data.get("random_delay", 0) or 0)
|
||||
random_delay, delay_error = _normalize_random_delay(data.get("random_delay", 0))
|
||||
if delay_error:
|
||||
return jsonify({"error": delay_error}), 400
|
||||
|
||||
account_ids = data.get("account_ids", [])
|
||||
|
||||
normalized_time = _normalize_hhmm(schedule_time)
|
||||
if not normalized_time:
|
||||
return jsonify({"error": "时间格式不正确,应为 HH:MM"}), 400
|
||||
if random_delay not in (0, 1):
|
||||
return jsonify({"error": "random_delay必须是0或1"}), 400
|
||||
|
||||
schedule_id = database.create_user_schedule(
|
||||
user_id=current_user.id,
|
||||
@@ -85,18 +134,11 @@ def create_user_schedule_api():
|
||||
@login_required
|
||||
def get_schedule_detail_api(schedule_id):
|
||||
"""获取定时任务详情"""
|
||||
schedule = database.get_schedule_by_id(schedule_id)
|
||||
if not schedule:
|
||||
return jsonify({"error": "定时任务不存在"}), 404
|
||||
if schedule["user_id"] != current_user.id:
|
||||
return jsonify({"error": "无权访问"}), 403
|
||||
schedule, error_response = _get_owned_schedule_or_error(schedule_id)
|
||||
if error_response:
|
||||
return error_response
|
||||
|
||||
import json
|
||||
|
||||
try:
|
||||
schedule["account_ids"] = json.loads(schedule.get("account_ids", "[]") or "[]")
|
||||
except (json.JSONDecodeError, TypeError):
|
||||
schedule["account_ids"] = []
|
||||
schedule["account_ids"] = _parse_schedule_account_ids(schedule.get("account_ids"))
|
||||
return jsonify(schedule)
|
||||
|
||||
|
||||
@@ -104,14 +146,12 @@ def get_schedule_detail_api(schedule_id):
|
||||
@login_required
|
||||
def update_schedule_api(schedule_id):
|
||||
"""更新定时任务"""
|
||||
schedule = database.get_schedule_by_id(schedule_id)
|
||||
if not schedule:
|
||||
return jsonify({"error": "定时任务不存在"}), 404
|
||||
if schedule["user_id"] != current_user.id:
|
||||
return jsonify({"error": "无权访问"}), 403
|
||||
_, error_response = _get_owned_schedule_or_error(schedule_id)
|
||||
if error_response:
|
||||
return error_response
|
||||
|
||||
data = request.json or {}
|
||||
allowed_fields = [
|
||||
data = _request_json()
|
||||
allowed_fields = {
|
||||
"name",
|
||||
"schedule_time",
|
||||
"weekdays",
|
||||
@@ -120,27 +160,26 @@ def update_schedule_api(schedule_id):
|
||||
"random_delay",
|
||||
"account_ids",
|
||||
"enabled",
|
||||
]
|
||||
|
||||
update_data = {k: v for k, v in data.items() if k in allowed_fields}
|
||||
}
|
||||
update_data = {key: value for key, value in data.items() if key in allowed_fields}
|
||||
|
||||
if "schedule_time" in update_data:
|
||||
normalized_time = _normalize_hhmm(update_data["schedule_time"])
|
||||
if not normalized_time:
|
||||
return jsonify({"error": "时间格式不正确,应为 HH:MM"}), 400
|
||||
update_data["schedule_time"] = normalized_time
|
||||
|
||||
if "random_delay" in update_data:
|
||||
try:
|
||||
update_data["random_delay"] = int(update_data.get("random_delay") or 0)
|
||||
except Exception:
|
||||
return jsonify({"error": "random_delay必须是0或1"}), 400
|
||||
if update_data["random_delay"] not in (0, 1):
|
||||
return jsonify({"error": "random_delay必须是0或1"}), 400
|
||||
random_delay, delay_error = _normalize_random_delay(update_data.get("random_delay"))
|
||||
if delay_error:
|
||||
return jsonify({"error": delay_error}), 400
|
||||
update_data["random_delay"] = random_delay
|
||||
|
||||
if "browse_type" in update_data:
|
||||
normalized = validate_browse_type(update_data.get("browse_type"), default=BROWSE_TYPE_SHOULD_READ)
|
||||
if not normalized:
|
||||
return jsonify({"error": "浏览类型无效"}), 400
|
||||
update_data["browse_type"] = normalized
|
||||
normalized_browse_type, browse_error = _parse_browse_type_or_error(update_data.get("browse_type"))
|
||||
if browse_error:
|
||||
return browse_error
|
||||
update_data["browse_type"] = normalized_browse_type
|
||||
|
||||
success = database.update_user_schedule(schedule_id, **update_data)
|
||||
if success:
|
||||
@@ -152,11 +191,9 @@ def update_schedule_api(schedule_id):
|
||||
@login_required
|
||||
def delete_schedule_api(schedule_id):
|
||||
"""删除定时任务"""
|
||||
schedule = database.get_schedule_by_id(schedule_id)
|
||||
if not schedule:
|
||||
return jsonify({"error": "定时任务不存在"}), 404
|
||||
if schedule["user_id"] != current_user.id:
|
||||
return jsonify({"error": "无权访问"}), 403
|
||||
_, error_response = _get_owned_schedule_or_error(schedule_id)
|
||||
if error_response:
|
||||
return error_response
|
||||
|
||||
success = database.delete_user_schedule(schedule_id)
|
||||
if success:
|
||||
@@ -168,13 +205,11 @@ def delete_schedule_api(schedule_id):
|
||||
@login_required
|
||||
def toggle_schedule_api(schedule_id):
|
||||
"""启用/禁用定时任务"""
|
||||
schedule = database.get_schedule_by_id(schedule_id)
|
||||
if not schedule:
|
||||
return jsonify({"error": "定时任务不存在"}), 404
|
||||
if schedule["user_id"] != current_user.id:
|
||||
return jsonify({"error": "无权访问"}), 403
|
||||
schedule, error_response = _get_owned_schedule_or_error(schedule_id)
|
||||
if error_response:
|
||||
return error_response
|
||||
|
||||
data = request.json
|
||||
data = _request_json()
|
||||
enabled = data.get("enabled", not schedule["enabled"])
|
||||
|
||||
success = database.toggle_user_schedule(schedule_id, enabled)
|
||||
@@ -187,22 +222,11 @@ def toggle_schedule_api(schedule_id):
|
||||
@login_required
|
||||
def run_schedule_now_api(schedule_id):
|
||||
"""立即执行定时任务"""
|
||||
import json
|
||||
import threading
|
||||
import time as time_mod
|
||||
import uuid
|
||||
|
||||
schedule = database.get_schedule_by_id(schedule_id)
|
||||
if not schedule:
|
||||
return jsonify({"error": "定时任务不存在"}), 404
|
||||
if schedule["user_id"] != current_user.id:
|
||||
return jsonify({"error": "无权访问"}), 403
|
||||
|
||||
try:
|
||||
account_ids = json.loads(schedule.get("account_ids", "[]") or "[]")
|
||||
except (json.JSONDecodeError, TypeError):
|
||||
account_ids = []
|
||||
schedule, error_response = _get_owned_schedule_or_error(schedule_id)
|
||||
if error_response:
|
||||
return error_response
|
||||
|
||||
account_ids = _parse_schedule_account_ids(schedule.get("account_ids"))
|
||||
if not account_ids:
|
||||
return jsonify({"error": "没有配置账号"}), 400
|
||||
|
||||
@@ -210,8 +234,7 @@ def run_schedule_now_api(schedule_id):
|
||||
browse_type = normalize_browse_type(schedule.get("browse_type", BROWSE_TYPE_SHOULD_READ))
|
||||
enable_screenshot = schedule["enable_screenshot"]
|
||||
|
||||
if not safe_get_user_accounts_snapshot(user_id):
|
||||
load_user_accounts(user_id)
|
||||
_ensure_user_accounts_loaded(user_id)
|
||||
|
||||
from services.state import safe_create_batch, safe_finalize_batch_after_dispatch
|
||||
from services.task_batches import _send_batch_task_email_if_configured
|
||||
@@ -250,6 +273,7 @@ def run_schedule_now_api(schedule_id):
|
||||
if remaining["done"] or remaining["count"] > 0:
|
||||
return
|
||||
remaining["done"] = True
|
||||
|
||||
execution_duration = int(time_mod.time() - execution_start_time)
|
||||
database.update_schedule_execution_log(
|
||||
log_id,
|
||||
@@ -260,19 +284,17 @@ def run_schedule_now_api(schedule_id):
|
||||
status="completed",
|
||||
)
|
||||
|
||||
task_source = f"user_scheduled:{batch_id}"
|
||||
for account_id in account_ids:
|
||||
account = safe_get_account(user_id, account_id)
|
||||
if not account:
|
||||
skipped_count += 1
|
||||
continue
|
||||
if account.is_running:
|
||||
if (not account) or account.is_running:
|
||||
skipped_count += 1
|
||||
continue
|
||||
|
||||
task_source = f"user_scheduled:{batch_id}"
|
||||
with completion_lock:
|
||||
remaining["count"] += 1
|
||||
ok, msg = submit_account_task(
|
||||
|
||||
ok, _ = submit_account_task(
|
||||
user_id=user_id,
|
||||
account_id=account_id,
|
||||
browse_type=browse_type,
|
||||
|
||||
@@ -4,6 +4,7 @@ from __future__ import annotations
|
||||
|
||||
import os
|
||||
from datetime import datetime
|
||||
from typing import Iterator
|
||||
|
||||
import database
|
||||
from app_config import get_config
|
||||
@@ -15,41 +16,67 @@ from services.time_utils import BEIJING_TZ
|
||||
|
||||
config = get_config()
|
||||
SCREENSHOTS_DIR = config.SCREENSHOTS_DIR
|
||||
_IMAGE_EXTENSIONS = (".png", ".jpg", ".jpeg")
|
||||
|
||||
api_screenshots_bp = Blueprint("api_screenshots", __name__)
|
||||
|
||||
|
||||
def _get_user_prefix(user_id: int) -> str:
|
||||
user_info = database.get_user_by_id(user_id)
|
||||
return user_info["username"] if user_info else f"user{user_id}"
|
||||
|
||||
|
||||
def _is_user_screenshot(filename: str, username_prefix: str) -> bool:
|
||||
return filename.startswith(username_prefix + "_") and filename.lower().endswith(_IMAGE_EXTENSIONS)
|
||||
|
||||
|
||||
def _iter_user_screenshot_entries(username_prefix: str) -> Iterator[os.DirEntry]:
|
||||
if not os.path.exists(SCREENSHOTS_DIR):
|
||||
return
|
||||
|
||||
with os.scandir(SCREENSHOTS_DIR) as entries:
|
||||
for entry in entries:
|
||||
if (not entry.is_file()) or (not _is_user_screenshot(entry.name, username_prefix)):
|
||||
continue
|
||||
yield entry
|
||||
|
||||
|
||||
def _build_display_name(filename: str) -> str:
|
||||
base_name, ext = filename.rsplit(".", 1)
|
||||
parts = base_name.split("_", 1)
|
||||
if len(parts) > 1:
|
||||
return f"{parts[1]}.{ext}"
|
||||
return filename
|
||||
|
||||
|
||||
@api_screenshots_bp.route("/api/screenshots", methods=["GET"])
|
||||
@login_required
|
||||
def get_screenshots():
|
||||
"""获取当前用户的截图列表"""
|
||||
user_id = current_user.id
|
||||
user_info = database.get_user_by_id(user_id)
|
||||
username_prefix = user_info["username"] if user_info else f"user{user_id}"
|
||||
username_prefix = _get_user_prefix(user_id)
|
||||
|
||||
try:
|
||||
screenshots = []
|
||||
if os.path.exists(SCREENSHOTS_DIR):
|
||||
for filename in os.listdir(SCREENSHOTS_DIR):
|
||||
if filename.lower().endswith((".png", ".jpg", ".jpeg")) and filename.startswith(username_prefix + "_"):
|
||||
filepath = os.path.join(SCREENSHOTS_DIR, filename)
|
||||
stat = os.stat(filepath)
|
||||
created_time = datetime.fromtimestamp(stat.st_mtime, tz=BEIJING_TZ)
|
||||
parts = filename.rsplit(".", 1)[0].split("_", 1)
|
||||
if len(parts) > 1:
|
||||
display_name = parts[1] + "." + filename.rsplit(".", 1)[1]
|
||||
else:
|
||||
display_name = filename
|
||||
for entry in _iter_user_screenshot_entries(username_prefix):
|
||||
filename = entry.name
|
||||
stat = entry.stat()
|
||||
created_time = datetime.fromtimestamp(stat.st_mtime, tz=BEIJING_TZ)
|
||||
|
||||
screenshots.append(
|
||||
{
|
||||
"filename": filename,
|
||||
"display_name": _build_display_name(filename),
|
||||
"size": stat.st_size,
|
||||
"created": created_time.strftime("%Y-%m-%d %H:%M:%S"),
|
||||
"_created_ts": stat.st_mtime,
|
||||
}
|
||||
)
|
||||
|
||||
screenshots.sort(key=lambda item: item.get("_created_ts", 0), reverse=True)
|
||||
for item in screenshots:
|
||||
item.pop("_created_ts", None)
|
||||
|
||||
screenshots.append(
|
||||
{
|
||||
"filename": filename,
|
||||
"display_name": display_name,
|
||||
"size": stat.st_size,
|
||||
"created": created_time.strftime("%Y-%m-%d %H:%M:%S"),
|
||||
}
|
||||
)
|
||||
screenshots.sort(key=lambda x: x["created"], reverse=True)
|
||||
return jsonify(screenshots)
|
||||
except Exception as e:
|
||||
return jsonify({"error": str(e)}), 500
|
||||
@@ -60,10 +87,9 @@ def get_screenshots():
|
||||
def serve_screenshot(filename):
|
||||
"""提供截图文件访问"""
|
||||
user_id = current_user.id
|
||||
user_info = database.get_user_by_id(user_id)
|
||||
username_prefix = user_info["username"] if user_info else f"user{user_id}"
|
||||
username_prefix = _get_user_prefix(user_id)
|
||||
|
||||
if not filename.startswith(username_prefix + "_"):
|
||||
if not _is_user_screenshot(filename, username_prefix):
|
||||
return jsonify({"error": "无权访问"}), 403
|
||||
|
||||
if not is_safe_path(SCREENSHOTS_DIR, filename):
|
||||
@@ -77,12 +103,14 @@ def serve_screenshot(filename):
|
||||
def delete_screenshot(filename):
|
||||
"""删除指定截图"""
|
||||
user_id = current_user.id
|
||||
user_info = database.get_user_by_id(user_id)
|
||||
username_prefix = user_info["username"] if user_info else f"user{user_id}"
|
||||
username_prefix = _get_user_prefix(user_id)
|
||||
|
||||
if not filename.startswith(username_prefix + "_"):
|
||||
if not _is_user_screenshot(filename, username_prefix):
|
||||
return jsonify({"error": "无权删除"}), 403
|
||||
|
||||
if not is_safe_path(SCREENSHOTS_DIR, filename):
|
||||
return jsonify({"error": "非法路径"}), 403
|
||||
|
||||
try:
|
||||
filepath = os.path.join(SCREENSHOTS_DIR, filename)
|
||||
if os.path.exists(filepath):
|
||||
@@ -99,19 +127,15 @@ def delete_screenshot(filename):
|
||||
def clear_all_screenshots():
|
||||
"""清空当前用户的所有截图"""
|
||||
user_id = current_user.id
|
||||
user_info = database.get_user_by_id(user_id)
|
||||
username_prefix = user_info["username"] if user_info else f"user{user_id}"
|
||||
username_prefix = _get_user_prefix(user_id)
|
||||
|
||||
try:
|
||||
deleted_count = 0
|
||||
if os.path.exists(SCREENSHOTS_DIR):
|
||||
for filename in os.listdir(SCREENSHOTS_DIR):
|
||||
if filename.lower().endswith((".png", ".jpg", ".jpeg")) and filename.startswith(username_prefix + "_"):
|
||||
filepath = os.path.join(SCREENSHOTS_DIR, filename)
|
||||
os.remove(filepath)
|
||||
deleted_count += 1
|
||||
for entry in _iter_user_screenshot_entries(username_prefix):
|
||||
os.remove(entry.path)
|
||||
deleted_count += 1
|
||||
|
||||
log_to_client(f"清理了 {deleted_count} 个截图文件", user_id)
|
||||
return jsonify({"success": True, "deleted": deleted_count})
|
||||
except Exception as e:
|
||||
return jsonify({"error": str(e)}), 500
|
||||
|
||||
|
||||
@@ -10,12 +10,96 @@ from flask import Blueprint, jsonify, request
|
||||
from flask_login import current_user, login_required
|
||||
from routes.pages import render_app_spa_or_legacy
|
||||
from services.state import check_email_rate_limit, check_ip_request_rate, safe_iter_task_status_items
|
||||
from services.tasks import get_task_scheduler
|
||||
|
||||
logger = get_logger("app")
|
||||
|
||||
api_user_bp = Blueprint("api_user", __name__)
|
||||
|
||||
|
||||
def _get_current_user_record():
|
||||
return database.get_user_by_id(current_user.id)
|
||||
|
||||
|
||||
def _get_current_user_or_404():
|
||||
user = _get_current_user_record()
|
||||
if user:
|
||||
return user, None
|
||||
return None, (jsonify({"error": "用户不存在"}), 404)
|
||||
|
||||
|
||||
def _get_current_username(*, fallback: str) -> str:
|
||||
user = _get_current_user_record()
|
||||
username = (user or {}).get("username", "")
|
||||
return username if username else fallback
|
||||
|
||||
|
||||
def _coerce_binary_flag(value, *, field_label: str):
|
||||
if isinstance(value, bool):
|
||||
value = 1 if value else 0
|
||||
try:
|
||||
value = int(value)
|
||||
except Exception:
|
||||
return None, f"{field_label}必须是0或1"
|
||||
if value not in (0, 1):
|
||||
return None, f"{field_label}必须是0或1"
|
||||
return value, None
|
||||
|
||||
|
||||
def _check_bind_email_rate_limits(email: str):
|
||||
client_ip = get_rate_limit_ip()
|
||||
allowed, error_msg = check_ip_request_rate(client_ip, "email")
|
||||
if not allowed:
|
||||
return False, error_msg, 429
|
||||
allowed, error_msg = check_email_rate_limit(email, "bind_email")
|
||||
if not allowed:
|
||||
return False, error_msg, 429
|
||||
return True, "", 200
|
||||
|
||||
|
||||
def _render_verify_bind_failed(*, title: str, error_message: str):
|
||||
spa_initial_state = {
|
||||
"page": "verify_result",
|
||||
"success": False,
|
||||
"title": title,
|
||||
"error_message": error_message,
|
||||
"primary_label": "返回登录",
|
||||
"primary_url": "/login",
|
||||
}
|
||||
return render_app_spa_or_legacy(
|
||||
"verify_failed.html",
|
||||
legacy_context={"error_message": error_message},
|
||||
spa_initial_state=spa_initial_state,
|
||||
)
|
||||
|
||||
|
||||
def _render_verify_bind_success(email: str):
|
||||
spa_initial_state = {
|
||||
"page": "verify_result",
|
||||
"success": True,
|
||||
"title": "邮箱绑定成功",
|
||||
"message": f"邮箱 {email} 已成功绑定到您的账号!",
|
||||
"primary_label": "返回登录",
|
||||
"primary_url": "/login",
|
||||
"redirect_url": "/login",
|
||||
"redirect_seconds": 5,
|
||||
}
|
||||
return render_app_spa_or_legacy("verify_success.html", spa_initial_state=spa_initial_state)
|
||||
|
||||
|
||||
def _get_current_running_count(user_id: int) -> int:
|
||||
try:
|
||||
queue_snapshot = get_task_scheduler().get_queue_state_snapshot() or {}
|
||||
running_by_user = queue_snapshot.get("running_by_user") or {}
|
||||
return int(running_by_user.get(int(user_id), running_by_user.get(str(user_id), 0)) or 0)
|
||||
except Exception:
|
||||
current_running = 0
|
||||
for _, info in safe_iter_task_status_items():
|
||||
if info.get("user_id") == user_id and info.get("status") == "运行中":
|
||||
current_running += 1
|
||||
return current_running
|
||||
|
||||
|
||||
@api_user_bp.route("/api/announcements/active", methods=["GET"])
|
||||
@login_required
|
||||
def get_active_announcement():
|
||||
@@ -77,8 +161,7 @@ def submit_feedback():
|
||||
if len(description) > 2000:
|
||||
return jsonify({"error": "描述不能超过2000个字符"}), 400
|
||||
|
||||
user_info = database.get_user_by_id(current_user.id)
|
||||
username = user_info["username"] if user_info else f"用户{current_user.id}"
|
||||
username = _get_current_username(fallback=f"用户{current_user.id}")
|
||||
|
||||
feedback_id = database.create_bug_feedback(
|
||||
user_id=current_user.id,
|
||||
@@ -104,8 +187,7 @@ def get_my_feedbacks():
|
||||
def get_current_user_vip():
|
||||
"""获取当前用户VIP信息"""
|
||||
vip_info = database.get_user_vip_info(current_user.id)
|
||||
user_info = database.get_user_by_id(current_user.id)
|
||||
vip_info["username"] = user_info["username"] if user_info else "Unknown"
|
||||
vip_info["username"] = _get_current_username(fallback="Unknown")
|
||||
return jsonify(vip_info)
|
||||
|
||||
|
||||
@@ -124,9 +206,9 @@ def change_user_password():
|
||||
if not is_valid:
|
||||
return jsonify({"error": error_msg}), 400
|
||||
|
||||
user = database.get_user_by_id(current_user.id)
|
||||
if not user:
|
||||
return jsonify({"error": "用户不存在"}), 404
|
||||
user, error_response = _get_current_user_or_404()
|
||||
if error_response:
|
||||
return error_response
|
||||
|
||||
username = user.get("username", "")
|
||||
if not username or not database.verify_user(username, current_password):
|
||||
@@ -141,9 +223,9 @@ def change_user_password():
|
||||
@login_required
|
||||
def get_user_email():
|
||||
"""获取当前用户的邮箱信息"""
|
||||
user = database.get_user_by_id(current_user.id)
|
||||
if not user:
|
||||
return jsonify({"error": "用户不存在"}), 404
|
||||
user, error_response = _get_current_user_or_404()
|
||||
if error_response:
|
||||
return error_response
|
||||
|
||||
return jsonify({"email": user.get("email", ""), "email_verified": user.get("email_verified", False)})
|
||||
|
||||
@@ -172,14 +254,9 @@ def update_user_kdocs_settings():
|
||||
return jsonify({"error": "县区长度不能超过50"}), 400
|
||||
|
||||
if kdocs_auto_upload is not None:
|
||||
if isinstance(kdocs_auto_upload, bool):
|
||||
kdocs_auto_upload = 1 if kdocs_auto_upload else 0
|
||||
try:
|
||||
kdocs_auto_upload = int(kdocs_auto_upload)
|
||||
except Exception:
|
||||
return jsonify({"error": "自动上传开关必须是0或1"}), 400
|
||||
if kdocs_auto_upload not in (0, 1):
|
||||
return jsonify({"error": "自动上传开关必须是0或1"}), 400
|
||||
kdocs_auto_upload, parse_error = _coerce_binary_flag(kdocs_auto_upload, field_label="自动上传开关")
|
||||
if parse_error:
|
||||
return jsonify({"error": parse_error}), 400
|
||||
|
||||
if not database.update_user_kdocs_settings(
|
||||
current_user.id,
|
||||
@@ -207,13 +284,9 @@ def bind_user_email():
|
||||
if not is_valid:
|
||||
return jsonify({"error": error_msg}), 400
|
||||
|
||||
client_ip = get_rate_limit_ip()
|
||||
allowed, error_msg = check_ip_request_rate(client_ip, "email")
|
||||
allowed, error_msg, status_code = _check_bind_email_rate_limits(email)
|
||||
if not allowed:
|
||||
return jsonify({"error": error_msg}), 429
|
||||
allowed, error_msg = check_email_rate_limit(email, "bind_email")
|
||||
if not allowed:
|
||||
return jsonify({"error": error_msg}), 429
|
||||
return jsonify({"error": error_msg}), status_code
|
||||
|
||||
settings = email_service.get_email_settings()
|
||||
if not settings.get("enabled", False):
|
||||
@@ -223,9 +296,9 @@ def bind_user_email():
|
||||
if existing_user and existing_user["id"] != current_user.id:
|
||||
return jsonify({"error": "该邮箱已被其他用户绑定"}), 400
|
||||
|
||||
user = database.get_user_by_id(current_user.id)
|
||||
if not user:
|
||||
return jsonify({"error": "用户不存在"}), 404
|
||||
user, error_response = _get_current_user_or_404()
|
||||
if error_response:
|
||||
return error_response
|
||||
|
||||
if user.get("email") == email and user.get("email_verified"):
|
||||
return jsonify({"error": "该邮箱已绑定并验证"}), 400
|
||||
@@ -247,56 +320,20 @@ def verify_bind_email(token):
|
||||
email = result["email"]
|
||||
|
||||
if database.update_user_email(user_id, email, verified=True):
|
||||
spa_initial_state = {
|
||||
"page": "verify_result",
|
||||
"success": True,
|
||||
"title": "邮箱绑定成功",
|
||||
"message": f"邮箱 {email} 已成功绑定到您的账号!",
|
||||
"primary_label": "返回登录",
|
||||
"primary_url": "/login",
|
||||
"redirect_url": "/login",
|
||||
"redirect_seconds": 5,
|
||||
}
|
||||
return render_app_spa_or_legacy("verify_success.html", spa_initial_state=spa_initial_state)
|
||||
return _render_verify_bind_success(email)
|
||||
|
||||
error_message = "邮箱绑定失败,请重试"
|
||||
spa_initial_state = {
|
||||
"page": "verify_result",
|
||||
"success": False,
|
||||
"title": "绑定失败",
|
||||
"error_message": error_message,
|
||||
"primary_label": "返回登录",
|
||||
"primary_url": "/login",
|
||||
}
|
||||
return render_app_spa_or_legacy(
|
||||
"verify_failed.html",
|
||||
legacy_context={"error_message": error_message},
|
||||
spa_initial_state=spa_initial_state,
|
||||
)
|
||||
return _render_verify_bind_failed(title="绑定失败", error_message="邮箱绑定失败,请重试")
|
||||
|
||||
error_message = "验证链接已过期或无效,请重新发送验证邮件"
|
||||
spa_initial_state = {
|
||||
"page": "verify_result",
|
||||
"success": False,
|
||||
"title": "链接无效",
|
||||
"error_message": error_message,
|
||||
"primary_label": "返回登录",
|
||||
"primary_url": "/login",
|
||||
}
|
||||
return render_app_spa_or_legacy(
|
||||
"verify_failed.html",
|
||||
legacy_context={"error_message": error_message},
|
||||
spa_initial_state=spa_initial_state,
|
||||
)
|
||||
return _render_verify_bind_failed(title="链接无效", error_message="验证链接已过期或无效,请重新发送验证邮件")
|
||||
|
||||
|
||||
@api_user_bp.route("/api/user/unbind-email", methods=["POST"])
|
||||
@login_required
|
||||
def unbind_user_email():
|
||||
"""解绑用户邮箱"""
|
||||
user = database.get_user_by_id(current_user.id)
|
||||
if not user:
|
||||
return jsonify({"error": "用户不存在"}), 404
|
||||
user, error_response = _get_current_user_or_404()
|
||||
if error_response:
|
||||
return error_response
|
||||
|
||||
if not user.get("email"):
|
||||
return jsonify({"error": "当前未绑定邮箱"}), 400
|
||||
@@ -334,10 +371,7 @@ def get_run_stats():
|
||||
|
||||
stats = database.get_user_run_stats(user_id)
|
||||
|
||||
current_running = 0
|
||||
for _, info in safe_iter_task_status_items():
|
||||
if info.get("user_id") == user_id and info.get("status") == "运行中":
|
||||
current_running += 1
|
||||
current_running = _get_current_running_count(user_id)
|
||||
|
||||
return jsonify(
|
||||
{
|
||||
|
||||
@@ -31,7 +31,7 @@ def admin_required(f):
|
||||
if is_api:
|
||||
return jsonify({"error": "需要管理员权限"}), 403
|
||||
return redirect(url_for("pages.admin_login_page"))
|
||||
logger.info(f"[admin_required] 管理员 {session.get('admin_username')} 访问 {request.path}")
|
||||
logger.debug(f"[admin_required] 管理员 {session.get('admin_username')} 访问 {request.path}")
|
||||
return f(*args, **kwargs)
|
||||
|
||||
return decorated_function
|
||||
|
||||
@@ -2,12 +2,62 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
from __future__ import annotations
|
||||
|
||||
import os
|
||||
import time
|
||||
|
||||
from flask import Blueprint, jsonify
|
||||
|
||||
import database
|
||||
import db_pool
|
||||
from services.time_utils import get_beijing_now
|
||||
|
||||
health_bp = Blueprint("health", __name__)
|
||||
_PROCESS_START_TS = time.time()
|
||||
|
||||
|
||||
def _build_runtime_metrics() -> dict:
|
||||
metrics = {
|
||||
"uptime_seconds": max(0, int(time.time() - _PROCESS_START_TS)),
|
||||
}
|
||||
|
||||
try:
|
||||
pool_stats = db_pool.get_pool_stats() or {}
|
||||
metrics["db_pool"] = {
|
||||
"pool_size": int(pool_stats.get("pool_size", 0) or 0),
|
||||
"available": int(pool_stats.get("available", 0) or 0),
|
||||
"in_use": int(pool_stats.get("in_use", 0) or 0),
|
||||
}
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
try:
|
||||
import psutil
|
||||
|
||||
proc = psutil.Process(os.getpid())
|
||||
with proc.oneshot():
|
||||
mem_info = proc.memory_info()
|
||||
metrics["process"] = {
|
||||
"rss_mb": round(float(mem_info.rss) / 1024 / 1024, 2),
|
||||
"cpu_percent": round(float(proc.cpu_percent(interval=None)), 2),
|
||||
"threads": int(proc.num_threads()),
|
||||
}
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
try:
|
||||
from services import tasks as tasks_module
|
||||
|
||||
scheduler = getattr(tasks_module, "_task_scheduler", None)
|
||||
if scheduler is not None:
|
||||
queue_snapshot = scheduler.get_queue_state_snapshot() or {}
|
||||
metrics["task_queue"] = {
|
||||
"pending_total": int(queue_snapshot.get("pending_total", 0) or 0),
|
||||
"running_total": int(queue_snapshot.get("running_total", 0) or 0),
|
||||
}
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
return metrics
|
||||
|
||||
|
||||
@health_bp.route("/health", methods=["GET"])
|
||||
@@ -26,6 +76,6 @@ def health_check():
|
||||
"time": get_beijing_now().strftime("%Y-%m-%d %H:%M:%S"),
|
||||
"db_ok": db_ok,
|
||||
"db_error": db_error,
|
||||
"metrics": _build_runtime_metrics(),
|
||||
}
|
||||
return jsonify(payload), (200 if db_ok else 500)
|
||||
|
||||
|
||||
60
scripts/HEALTH_MONITOR_README.md
Normal file
60
scripts/HEALTH_MONITOR_README.md
Normal file
@@ -0,0 +1,60 @@
|
||||
# 健康监控(邮件版)
|
||||
|
||||
本目录提供 `health_email_monitor.py`,通过调用 `/health` 接口并使用**容器内已有邮件配置**发告警邮件。
|
||||
|
||||
## 1) 快速试跑
|
||||
|
||||
```bash
|
||||
cd /root/zsglpt
|
||||
python3 scripts/health_email_monitor.py \
|
||||
--to 你的告警邮箱@example.com \
|
||||
--container knowledge-automation-multiuser \
|
||||
--url http://127.0.0.1:51232/health \
|
||||
--dry-run
|
||||
```
|
||||
|
||||
去掉 `--dry-run` 即会实际发邮件。
|
||||
|
||||
## 2) 建议 cron(每分钟)
|
||||
|
||||
```bash
|
||||
* * * * * cd /root/zsglpt && /usr/bin/python3 scripts/health_email_monitor.py \
|
||||
--to 你的告警邮箱@example.com \
|
||||
--container knowledge-automation-multiuser \
|
||||
--url http://127.0.0.1:51232/health \
|
||||
>> /root/zsglpt/logs/health_monitor.log 2>&1
|
||||
```
|
||||
|
||||
## 3) 支持的规则
|
||||
|
||||
- `service_down`:健康接口请求失败(立即告警)
|
||||
- `health_fail`:返回 `ok/db_ok` 异常或 HTTP 5xx(立即告警)
|
||||
- `db_pool_exhausted`:连接池耗尽(默认连续 3 次才告警)
|
||||
- `queue_backlog_high`:任务堆积过高(默认 `pending_total >= 50` 且连续 5 次)
|
||||
|
||||
脚本支持恢复通知(规则恢复正常会发“恢复”邮件)。
|
||||
|
||||
## 4) 常用参数
|
||||
|
||||
- `--to`:收件人(必填)
|
||||
- `--container`:Docker 容器名(默认 `knowledge-automation-multiuser`)
|
||||
- `--url`:健康地址(默认 `http://127.0.0.1:51232/health`)
|
||||
- `--state-file`:状态文件路径(默认 `/tmp/zsglpt_health_monitor_state.json`)
|
||||
- `--remind-seconds`:重复告警间隔(默认 3600 秒)
|
||||
- `--queue-threshold`:队列告警阈值(默认 50)
|
||||
- `--queue-streak`:队列连续次数阈值(默认 5)
|
||||
- `--db-pool-streak`:连接池连续次数阈值(默认 3)
|
||||
|
||||
## 5) 环境变量方式(可选)
|
||||
|
||||
也可不用命令行参数,改用环境变量:
|
||||
|
||||
- `MONITOR_EMAIL_TO`
|
||||
- `MONITOR_DOCKER_CONTAINER`
|
||||
- `HEALTH_URL`
|
||||
- `MONITOR_STATE_FILE`
|
||||
- `MONITOR_REMIND_SECONDS`
|
||||
- `MONITOR_QUEUE_THRESHOLD`
|
||||
- `MONITOR_QUEUE_STREAK`
|
||||
- `MONITOR_DB_POOL_STREAK`
|
||||
|
||||
348
scripts/health_email_monitor.py
Normal file
348
scripts/health_email_monitor.py
Normal file
@@ -0,0 +1,348 @@
|
||||
#!/usr/bin/env python3
|
||||
# -*- coding: utf-8 -*-
|
||||
from __future__ import annotations
|
||||
|
||||
import argparse
|
||||
import json
|
||||
import os
|
||||
import subprocess
|
||||
import time
|
||||
from datetime import datetime
|
||||
from typing import Any, Dict, Tuple
|
||||
from urllib.error import HTTPError, URLError
|
||||
from urllib.request import Request, urlopen
|
||||
|
||||
DEFAULT_STATE_FILE = "/tmp/zsglpt_health_monitor_state.json"
|
||||
|
||||
|
||||
def _now_text() -> str:
|
||||
return datetime.now().strftime("%Y-%m-%d %H:%M:%S")
|
||||
|
||||
|
||||
def _safe_int(value: Any, default: int = 0) -> int:
|
||||
try:
|
||||
return int(value)
|
||||
except Exception:
|
||||
return int(default)
|
||||
|
||||
|
||||
def _safe_float(value: Any, default: float = 0.0) -> float:
|
||||
try:
|
||||
return float(value)
|
||||
except Exception:
|
||||
return float(default)
|
||||
|
||||
|
||||
def _load_state(path: str) -> Dict[str, Any]:
|
||||
if not path or not os.path.exists(path):
|
||||
return {
|
||||
"version": 1,
|
||||
"rules": {},
|
||||
"counters": {},
|
||||
}
|
||||
try:
|
||||
with open(path, "r", encoding="utf-8") as f:
|
||||
raw = json.load(f)
|
||||
if not isinstance(raw, dict):
|
||||
raise ValueError("state is not dict")
|
||||
raw.setdefault("version", 1)
|
||||
raw.setdefault("rules", {})
|
||||
raw.setdefault("counters", {})
|
||||
return raw
|
||||
except Exception:
|
||||
return {
|
||||
"version": 1,
|
||||
"rules": {},
|
||||
"counters": {},
|
||||
}
|
||||
|
||||
|
||||
def _save_state(path: str, state: Dict[str, Any]) -> None:
|
||||
if not path:
|
||||
return
|
||||
state_dir = os.path.dirname(path)
|
||||
if state_dir:
|
||||
os.makedirs(state_dir, exist_ok=True)
|
||||
tmp_path = f"{path}.tmp"
|
||||
with open(tmp_path, "w", encoding="utf-8") as f:
|
||||
json.dump(state, f, ensure_ascii=False, indent=2)
|
||||
os.replace(tmp_path, path)
|
||||
|
||||
|
||||
def _fetch_health(url: str, timeout: int) -> Tuple[int | None, Dict[str, Any], str | None]:
|
||||
req = Request(
|
||||
url,
|
||||
headers={
|
||||
"User-Agent": "zsglpt-health-email-monitor/1.0",
|
||||
"Accept": "application/json",
|
||||
},
|
||||
method="GET",
|
||||
)
|
||||
try:
|
||||
with urlopen(req, timeout=max(1, int(timeout))) as resp:
|
||||
status = int(resp.getcode())
|
||||
body = resp.read().decode("utf-8", errors="ignore")
|
||||
except HTTPError as e:
|
||||
status = int(getattr(e, "code", 0) or 0)
|
||||
body = ""
|
||||
try:
|
||||
body = e.read().decode("utf-8", errors="ignore")
|
||||
except Exception:
|
||||
pass
|
||||
data = {}
|
||||
if body:
|
||||
try:
|
||||
data = json.loads(body)
|
||||
if not isinstance(data, dict):
|
||||
data = {}
|
||||
except Exception:
|
||||
data = {}
|
||||
return status, data, f"HTTPError: {e}"
|
||||
except URLError as e:
|
||||
return None, {}, f"URLError: {e}"
|
||||
except Exception as e:
|
||||
return None, {}, f"RequestError: {e}"
|
||||
|
||||
data: Dict[str, Any] = {}
|
||||
if body:
|
||||
try:
|
||||
loaded = json.loads(body)
|
||||
if isinstance(loaded, dict):
|
||||
data = loaded
|
||||
except Exception:
|
||||
data = {}
|
||||
|
||||
return status, data, None
|
||||
|
||||
|
||||
def _inc_streak(state: Dict[str, Any], key: str, bad: bool) -> int:
|
||||
counters = state.setdefault("counters", {})
|
||||
current = _safe_int(counters.get(key), 0)
|
||||
current = (current + 1) if bad else 0
|
||||
counters[key] = current
|
||||
return current
|
||||
|
||||
|
||||
def _rule_transition(
|
||||
state: Dict[str, Any],
|
||||
*,
|
||||
rule_name: str,
|
||||
bad: bool,
|
||||
streak: int,
|
||||
threshold: int,
|
||||
remind_seconds: int,
|
||||
now_ts: float,
|
||||
) -> str | None:
|
||||
rules = state.setdefault("rules", {})
|
||||
rule_state = rules.setdefault(rule_name, {"active": False, "last_sent": 0})
|
||||
|
||||
is_active = bool(rule_state.get("active", False))
|
||||
last_sent = _safe_float(rule_state.get("last_sent", 0), 0.0)
|
||||
threshold = max(1, int(threshold))
|
||||
remind_seconds = max(60, int(remind_seconds))
|
||||
|
||||
if bad and streak >= threshold:
|
||||
if not is_active:
|
||||
rule_state["active"] = True
|
||||
rule_state["last_sent"] = now_ts
|
||||
return "alert"
|
||||
if (now_ts - last_sent) >= remind_seconds:
|
||||
rule_state["last_sent"] = now_ts
|
||||
return "alert"
|
||||
return None
|
||||
|
||||
if is_active and (not bad):
|
||||
rule_state["active"] = False
|
||||
rule_state["last_sent"] = now_ts
|
||||
return "recover"
|
||||
|
||||
return None
|
||||
|
||||
|
||||
def _send_email_via_container(
|
||||
*,
|
||||
container_name: str,
|
||||
to_email: str,
|
||||
subject: str,
|
||||
body: str,
|
||||
timeout_seconds: int = 45,
|
||||
) -> Tuple[bool, str]:
|
||||
code = (
|
||||
"import sys,email_service;"
|
||||
"res=email_service.send_email(to_email=sys.argv[1],subject=sys.argv[2],body=sys.argv[3],email_type='health_monitor');"
|
||||
"ok=bool(res.get('success'));"
|
||||
"print('ok' if ok else ('error:'+str(res.get('error'))));"
|
||||
"raise SystemExit(0 if ok else 2)"
|
||||
)
|
||||
cmd = [
|
||||
"docker",
|
||||
"exec",
|
||||
container_name,
|
||||
"python",
|
||||
"-c",
|
||||
code,
|
||||
to_email,
|
||||
subject,
|
||||
body,
|
||||
]
|
||||
try:
|
||||
proc = subprocess.run(
|
||||
cmd,
|
||||
capture_output=True,
|
||||
text=True,
|
||||
timeout=max(5, int(timeout_seconds)),
|
||||
check=False,
|
||||
)
|
||||
except Exception as e:
|
||||
return False, str(e)
|
||||
|
||||
output = (proc.stdout or "") + (proc.stderr or "")
|
||||
output = output.strip()
|
||||
return proc.returncode == 0, output
|
||||
|
||||
|
||||
def _build_common_lines(status: int | None, data: Dict[str, Any], fetch_error: str | None) -> list[str]:
|
||||
metrics = data.get("metrics") if isinstance(data.get("metrics"), dict) else {}
|
||||
db_pool = metrics.get("db_pool") if isinstance(metrics.get("db_pool"), dict) else {}
|
||||
process = metrics.get("process") if isinstance(metrics.get("process"), dict) else {}
|
||||
task_queue = metrics.get("task_queue") if isinstance(metrics.get("task_queue"), dict) else {}
|
||||
|
||||
lines = [
|
||||
f"时间: {_now_text()}",
|
||||
f"健康地址: {data.get('_monitor_url', '')}",
|
||||
f"HTTP状态: {status if status is not None else '请求失败'}",
|
||||
f"ok/db_ok: {data.get('ok')} / {data.get('db_ok')}",
|
||||
]
|
||||
if fetch_error:
|
||||
lines.append(f"请求错误: {fetch_error}")
|
||||
lines.extend(
|
||||
[
|
||||
f"队列: pending={task_queue.get('pending_total', 'N/A')}, running={task_queue.get('running_total', 'N/A')}",
|
||||
f"连接池: size={db_pool.get('pool_size', 'N/A')}, available={db_pool.get('available', 'N/A')}, in_use={db_pool.get('in_use', 'N/A')}",
|
||||
f"进程: rss_mb={process.get('rss_mb', 'N/A')}, cpu%={process.get('cpu_percent', 'N/A')}, threads={process.get('threads', 'N/A')}",
|
||||
f"运行时长: {metrics.get('uptime_seconds', 'N/A')} 秒",
|
||||
]
|
||||
)
|
||||
return lines
|
||||
|
||||
|
||||
def main() -> int:
|
||||
parser = argparse.ArgumentParser(description="zsglpt 邮件健康监控(基于 /health)")
|
||||
parser.add_argument("--url", default=os.environ.get("HEALTH_URL", "http://127.0.0.1:51232/health"))
|
||||
parser.add_argument("--to", default=os.environ.get("MONITOR_EMAIL_TO", ""))
|
||||
parser.add_argument(
|
||||
"--container",
|
||||
default=os.environ.get("MONITOR_DOCKER_CONTAINER", "knowledge-automation-multiuser"),
|
||||
)
|
||||
parser.add_argument("--state-file", default=os.environ.get("MONITOR_STATE_FILE", DEFAULT_STATE_FILE))
|
||||
parser.add_argument("--timeout", type=int, default=_safe_int(os.environ.get("MONITOR_TIMEOUT", 8), 8))
|
||||
parser.add_argument(
|
||||
"--remind-seconds",
|
||||
type=int,
|
||||
default=_safe_int(os.environ.get("MONITOR_REMIND_SECONDS", 3600), 3600),
|
||||
)
|
||||
parser.add_argument(
|
||||
"--queue-threshold",
|
||||
type=int,
|
||||
default=_safe_int(os.environ.get("MONITOR_QUEUE_THRESHOLD", 50), 50),
|
||||
)
|
||||
parser.add_argument(
|
||||
"--queue-streak",
|
||||
type=int,
|
||||
default=_safe_int(os.environ.get("MONITOR_QUEUE_STREAK", 5), 5),
|
||||
)
|
||||
parser.add_argument(
|
||||
"--db-pool-streak",
|
||||
type=int,
|
||||
default=_safe_int(os.environ.get("MONITOR_DB_POOL_STREAK", 3), 3),
|
||||
)
|
||||
parser.add_argument("--dry-run", action="store_true", help="仅打印,不实际发邮件")
|
||||
args = parser.parse_args()
|
||||
|
||||
if not args.to:
|
||||
print("[monitor] 缺少收件人,请设置 --to 或 MONITOR_EMAIL_TO", flush=True)
|
||||
return 2
|
||||
|
||||
state = _load_state(args.state_file)
|
||||
now_ts = time.time()
|
||||
|
||||
status, data, fetch_error = _fetch_health(args.url, args.timeout)
|
||||
if not isinstance(data, dict):
|
||||
data = {}
|
||||
data["_monitor_url"] = args.url
|
||||
|
||||
metrics = data.get("metrics") if isinstance(data.get("metrics"), dict) else {}
|
||||
db_pool = metrics.get("db_pool") if isinstance(metrics.get("db_pool"), dict) else {}
|
||||
task_queue = metrics.get("task_queue") if isinstance(metrics.get("task_queue"), dict) else {}
|
||||
|
||||
service_down = status is None
|
||||
health_fail = bool(status is not None and (status >= 500 or (not data.get("ok", False)) or (not data.get("db_ok", False))))
|
||||
db_pool_exhausted = (
|
||||
_safe_int(db_pool.get("pool_size"), 0) > 0
|
||||
and _safe_int(db_pool.get("available"), 0) <= 0
|
||||
and _safe_int(db_pool.get("in_use"), 0) >= _safe_int(db_pool.get("pool_size"), 0)
|
||||
)
|
||||
queue_backlog_high = _safe_int(task_queue.get("pending_total"), 0) >= max(1, int(args.queue_threshold))
|
||||
|
||||
rule_defs = [
|
||||
("service_down", service_down, 1),
|
||||
("health_fail", health_fail, 1),
|
||||
("db_pool_exhausted", db_pool_exhausted, max(1, int(args.db_pool_streak))),
|
||||
("queue_backlog_high", queue_backlog_high, max(1, int(args.queue_streak))),
|
||||
]
|
||||
|
||||
pending_notifications: list[tuple[str, str]] = []
|
||||
for rule_name, bad, threshold in rule_defs:
|
||||
streak = _inc_streak(state, rule_name, bad)
|
||||
action = _rule_transition(
|
||||
state,
|
||||
rule_name=rule_name,
|
||||
bad=bad,
|
||||
streak=streak,
|
||||
threshold=threshold,
|
||||
remind_seconds=args.remind_seconds,
|
||||
now_ts=now_ts,
|
||||
)
|
||||
if action:
|
||||
pending_notifications.append((rule_name, action))
|
||||
|
||||
_save_state(args.state_file, state)
|
||||
|
||||
if not pending_notifications:
|
||||
print(f"[monitor] {_now_text()} 正常,无需发送邮件")
|
||||
return 0
|
||||
|
||||
common_lines = _build_common_lines(status, data, fetch_error)
|
||||
|
||||
for rule_name, action in pending_notifications:
|
||||
level = "告警" if action == "alert" else "恢复"
|
||||
subject = f"[zsglpt健康监控][{level}] {rule_name}"
|
||||
body_lines = [
|
||||
f"规则: {rule_name}",
|
||||
f"状态: {level}",
|
||||
"",
|
||||
*common_lines,
|
||||
]
|
||||
body = "\n".join(body_lines)
|
||||
|
||||
if args.dry_run:
|
||||
print(f"[monitor][dry-run] subject={subject}\n{body}\n")
|
||||
continue
|
||||
|
||||
ok, msg = _send_email_via_container(
|
||||
container_name=args.container,
|
||||
to_email=args.to,
|
||||
subject=subject,
|
||||
body=body,
|
||||
)
|
||||
if ok:
|
||||
print(f"[monitor] 邮件已发送: {subject}")
|
||||
else:
|
||||
print(f"[monitor] 邮件发送失败: {subject} | {msg}")
|
||||
|
||||
return 0
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
raise SystemExit(main())
|
||||
|
||||
@@ -243,6 +243,35 @@ class KDocsUploader:
|
||||
except queue.Empty:
|
||||
return {"success": False, "error": "操作超时"}
|
||||
|
||||
def _put_task_response(self, task: Dict[str, Any], result: Dict[str, Any]) -> None:
|
||||
response_queue = task.get("response")
|
||||
if not response_queue:
|
||||
return
|
||||
try:
|
||||
response_queue.put(result)
|
||||
except Exception:
|
||||
return
|
||||
|
||||
def _process_task(self, task: Dict[str, Any]) -> bool:
|
||||
action = task.get("action")
|
||||
payload = task.get("payload") or {}
|
||||
|
||||
if action == "shutdown":
|
||||
return False
|
||||
if action == "upload":
|
||||
self._handle_upload(payload)
|
||||
return True
|
||||
if action == "qr":
|
||||
self._put_task_response(task, self._handle_qr(payload))
|
||||
return True
|
||||
if action == "clear_login":
|
||||
self._put_task_response(task, self._handle_clear_login())
|
||||
return True
|
||||
if action == "status":
|
||||
self._put_task_response(task, self._handle_status_check())
|
||||
return True
|
||||
return True
|
||||
|
||||
def _run(self) -> None:
|
||||
thread_id = self._thread_id
|
||||
logger.info(f"[KDocs] 上传线程启动 (ID={thread_id})")
|
||||
@@ -261,34 +290,17 @@ class KDocsUploader:
|
||||
# 更新最后活动时间
|
||||
self._last_activity = time.time()
|
||||
|
||||
action = task.get("action")
|
||||
if action == "shutdown":
|
||||
break
|
||||
|
||||
try:
|
||||
if action == "upload":
|
||||
self._handle_upload(task.get("payload") or {})
|
||||
elif action == "qr":
|
||||
result = self._handle_qr(task.get("payload") or {})
|
||||
task.get("response").put(result)
|
||||
elif action == "clear_login":
|
||||
result = self._handle_clear_login()
|
||||
task.get("response").put(result)
|
||||
elif action == "status":
|
||||
result = self._handle_status_check()
|
||||
task.get("response").put(result)
|
||||
should_continue = self._process_task(task)
|
||||
if not should_continue:
|
||||
break
|
||||
|
||||
# 任务处理完成后更新活动时间
|
||||
self._last_activity = time.time()
|
||||
|
||||
except Exception as e:
|
||||
logger.warning(f"[KDocs] 处理任务失败: {e}")
|
||||
# 如果有响应队列,返回错误
|
||||
if "response" in task and task.get("response"):
|
||||
try:
|
||||
task["response"].put({"success": False, "error": str(e)})
|
||||
except Exception:
|
||||
pass
|
||||
self._put_task_response(task, {"success": False, "error": str(e)})
|
||||
|
||||
except Exception as e:
|
||||
logger.warning(f"[KDocs] 线程主循环异常: {e}")
|
||||
@@ -830,18 +842,180 @@ class KDocsUploader:
|
||||
except Exception as e:
|
||||
logger.warning(f"[KDocs] 保存登录态失败: {e}")
|
||||
|
||||
def _resolve_doc_url(self, cfg: Dict[str, Any]) -> str:
|
||||
return (cfg.get("kdocs_doc_url") or "").strip()
|
||||
|
||||
def _ensure_doc_access(
|
||||
self,
|
||||
doc_url: str,
|
||||
*,
|
||||
fast: bool = False,
|
||||
use_storage_state: bool = True,
|
||||
) -> Optional[str]:
|
||||
if not self._ensure_playwright(use_storage_state=use_storage_state):
|
||||
return self._last_error or "浏览器不可用"
|
||||
if not self._open_document(doc_url, fast=fast):
|
||||
return self._last_error or "打开文档失败"
|
||||
return None
|
||||
|
||||
def _trigger_fast_login_dialog(self, timeout_ms: int) -> None:
|
||||
self._ensure_login_dialog(
|
||||
timeout_ms=timeout_ms,
|
||||
frame_timeout_ms=timeout_ms,
|
||||
quick=True,
|
||||
)
|
||||
|
||||
def _capture_qr_with_retry(self, fast_login_timeout: int) -> Tuple[Optional[bytes], Optional[bytes]]:
|
||||
qr_image = None
|
||||
invalid_qr = None
|
||||
for attempt in range(10):
|
||||
if attempt in (3, 7):
|
||||
self._trigger_fast_login_dialog(fast_login_timeout)
|
||||
candidate = self._capture_qr_image()
|
||||
if candidate and self._is_valid_qr_image(candidate):
|
||||
qr_image = candidate
|
||||
break
|
||||
if candidate:
|
||||
invalid_qr = candidate
|
||||
time.sleep(0.8) # 优化: 1 -> 0.8
|
||||
return qr_image, invalid_qr
|
||||
|
||||
def _save_qr_debug_artifacts(self, invalid_qr: Optional[bytes]) -> None:
|
||||
try:
|
||||
pages = self._iter_pages()
|
||||
page_urls = [getattr(p, "url", "") for p in pages]
|
||||
logger.warning(f"[KDocs] 二维码未捕获,页面: {page_urls}")
|
||||
|
||||
ts = int(time.time())
|
||||
saved = []
|
||||
for idx, page in enumerate(pages[:3]):
|
||||
try:
|
||||
path = f"data/kdocs_debug_{ts}_{idx}.png"
|
||||
page.screenshot(path=path, full_page=True)
|
||||
saved.append(path)
|
||||
except Exception:
|
||||
continue
|
||||
|
||||
if saved:
|
||||
logger.warning(f"[KDocs] 已保存调试截图: {saved}")
|
||||
|
||||
if invalid_qr:
|
||||
try:
|
||||
path = f"data/kdocs_invalid_qr_{ts}.png"
|
||||
with open(path, "wb") as handle:
|
||||
handle.write(invalid_qr)
|
||||
logger.warning(f"[KDocs] 已保存无效二维码截图: {path}")
|
||||
except Exception:
|
||||
pass
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
def _log_upload_failure(self, message: str, user_id: Any, account_id: Any) -> None:
|
||||
try:
|
||||
log_to_client(f"表格上传失败: {message}", user_id, account_id)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
def _mark_upload_tracking(self, user_id: Any, account_id: Any) -> Tuple[Any, Optional[str], bool]:
|
||||
account = None
|
||||
prev_status = None
|
||||
status_tracked = False
|
||||
try:
|
||||
account = safe_get_account(user_id, account_id)
|
||||
if account and self._should_mark_upload(account):
|
||||
prev_status = getattr(account, "status", None)
|
||||
account.status = "上传截图"
|
||||
self._emit_account_update(user_id, account)
|
||||
status_tracked = True
|
||||
except Exception:
|
||||
prev_status = None
|
||||
return account, prev_status, status_tracked
|
||||
|
||||
def _parse_upload_payload(self, payload: Dict[str, Any]) -> Optional[Dict[str, Any]]:
|
||||
unit = (payload.get("unit") or "").strip()
|
||||
name = (payload.get("name") or "").strip()
|
||||
image_path = payload.get("image_path")
|
||||
|
||||
if not unit or not name:
|
||||
return None
|
||||
if not image_path or not os.path.exists(image_path):
|
||||
return None
|
||||
|
||||
return {
|
||||
"unit": unit,
|
||||
"name": name,
|
||||
"image_path": image_path,
|
||||
"user_id": payload.get("user_id"),
|
||||
"account_id": payload.get("account_id"),
|
||||
}
|
||||
|
||||
def _resolve_upload_sheet_config(self, cfg: Dict[str, Any]) -> Dict[str, Any]:
|
||||
return {
|
||||
"sheet_name": (cfg.get("kdocs_sheet_name") or "").strip(),
|
||||
"sheet_index": int(cfg.get("kdocs_sheet_index") or 0),
|
||||
"unit_col": (cfg.get("kdocs_unit_column") or "A").strip().upper(),
|
||||
"image_col": (cfg.get("kdocs_image_column") or "D").strip().upper(),
|
||||
"row_start": int(cfg.get("kdocs_row_start") or 0),
|
||||
"row_end": int(cfg.get("kdocs_row_end") or 0),
|
||||
}
|
||||
|
||||
def _try_upload_to_sheet(self, cfg: Dict[str, Any], unit: str, name: str, image_path: str) -> Tuple[bool, str]:
|
||||
sheet_cfg = self._resolve_upload_sheet_config(cfg)
|
||||
success = False
|
||||
error_msg = ""
|
||||
|
||||
for _ in range(2):
|
||||
try:
|
||||
if sheet_cfg["sheet_name"] or sheet_cfg["sheet_index"]:
|
||||
self._select_sheet(sheet_cfg["sheet_name"], sheet_cfg["sheet_index"])
|
||||
|
||||
row_num = self._find_person_with_unit(
|
||||
unit,
|
||||
name,
|
||||
sheet_cfg["unit_col"],
|
||||
row_start=sheet_cfg["row_start"],
|
||||
row_end=sheet_cfg["row_end"],
|
||||
)
|
||||
if row_num < 0:
|
||||
error_msg = f"未找到人员: {unit}-{name}"
|
||||
break
|
||||
|
||||
success = self._upload_image_to_cell(row_num, image_path, sheet_cfg["image_col"])
|
||||
if success:
|
||||
break
|
||||
except Exception as e:
|
||||
error_msg = str(e)
|
||||
|
||||
return success, error_msg
|
||||
|
||||
def _handle_upload_login_invalid(
|
||||
self,
|
||||
*,
|
||||
unit: str,
|
||||
name: str,
|
||||
image_path: str,
|
||||
user_id: Any,
|
||||
account_id: Any,
|
||||
) -> None:
|
||||
error_msg = "登录已失效,请管理员重新扫码登录"
|
||||
self._login_required = True
|
||||
self._last_login_ok = False
|
||||
self._notify_admin(unit, name, image_path, error_msg)
|
||||
self._log_upload_failure(error_msg, user_id, account_id)
|
||||
|
||||
def _handle_qr(self, payload: Dict[str, Any]) -> Dict[str, Any]:
|
||||
cfg = self._load_system_config()
|
||||
doc_url = (cfg.get("kdocs_doc_url") or "").strip()
|
||||
doc_url = self._resolve_doc_url(cfg)
|
||||
if not doc_url:
|
||||
return {"success": False, "error": "未配置金山文档链接"}
|
||||
|
||||
force = bool(payload.get("force"))
|
||||
if force:
|
||||
self._handle_clear_login()
|
||||
if not self._ensure_playwright(use_storage_state=not force):
|
||||
return {"success": False, "error": self._last_error or "浏览器不可用"}
|
||||
if not self._open_document(doc_url, fast=True):
|
||||
return {"success": False, "error": self._last_error or "打开文档失败"}
|
||||
|
||||
doc_error = self._ensure_doc_access(doc_url, fast=True, use_storage_state=not force)
|
||||
if doc_error:
|
||||
return {"success": False, "error": doc_error}
|
||||
|
||||
if not force and self._has_saved_login_state() and self._is_logged_in():
|
||||
self._login_required = False
|
||||
@@ -850,54 +1024,12 @@ class KDocsUploader:
|
||||
return {"success": True, "logged_in": True, "qr_image": ""}
|
||||
|
||||
fast_login_timeout = int(os.environ.get("KDOCS_FAST_LOGIN_TIMEOUT_MS", "300"))
|
||||
self._ensure_login_dialog(
|
||||
timeout_ms=fast_login_timeout,
|
||||
frame_timeout_ms=fast_login_timeout,
|
||||
quick=True,
|
||||
)
|
||||
qr_image = None
|
||||
invalid_qr = None
|
||||
for attempt in range(10):
|
||||
if attempt in (3, 7):
|
||||
self._ensure_login_dialog(
|
||||
timeout_ms=fast_login_timeout,
|
||||
frame_timeout_ms=fast_login_timeout,
|
||||
quick=True,
|
||||
)
|
||||
candidate = self._capture_qr_image()
|
||||
if candidate and self._is_valid_qr_image(candidate):
|
||||
qr_image = candidate
|
||||
break
|
||||
if candidate:
|
||||
invalid_qr = candidate
|
||||
time.sleep(0.8) # 优化: 1 -> 0.8
|
||||
self._trigger_fast_login_dialog(fast_login_timeout)
|
||||
qr_image, invalid_qr = self._capture_qr_with_retry(fast_login_timeout)
|
||||
|
||||
if not qr_image:
|
||||
self._last_error = "二维码识别异常" if invalid_qr else "二维码获取失败"
|
||||
try:
|
||||
pages = self._iter_pages()
|
||||
page_urls = [getattr(p, "url", "") for p in pages]
|
||||
logger.warning(f"[KDocs] 二维码未捕获,页面: {page_urls}")
|
||||
ts = int(time.time())
|
||||
saved = []
|
||||
for idx, page in enumerate(pages[:3]):
|
||||
try:
|
||||
path = f"data/kdocs_debug_{ts}_{idx}.png"
|
||||
page.screenshot(path=path, full_page=True)
|
||||
saved.append(path)
|
||||
except Exception:
|
||||
continue
|
||||
if saved:
|
||||
logger.warning(f"[KDocs] 已保存调试截图: {saved}")
|
||||
if invalid_qr:
|
||||
try:
|
||||
path = f"data/kdocs_invalid_qr_{ts}.png"
|
||||
with open(path, "wb") as handle:
|
||||
handle.write(invalid_qr)
|
||||
logger.warning(f"[KDocs] 已保存无效二维码截图: {path}")
|
||||
except Exception:
|
||||
pass
|
||||
except Exception:
|
||||
pass
|
||||
self._save_qr_debug_artifacts(invalid_qr)
|
||||
return {"success": False, "error": self._last_error}
|
||||
|
||||
try:
|
||||
@@ -933,24 +1065,22 @@ class KDocsUploader:
|
||||
|
||||
def _handle_status_check(self) -> Dict[str, Any]:
|
||||
cfg = self._load_system_config()
|
||||
doc_url = (cfg.get("kdocs_doc_url") or "").strip()
|
||||
doc_url = self._resolve_doc_url(cfg)
|
||||
if not doc_url:
|
||||
return {"success": True, "logged_in": False, "error": "未配置文档链接"}
|
||||
if not self._ensure_playwright():
|
||||
return {"success": False, "logged_in": False, "error": self._last_error or "浏览器不可用"}
|
||||
if not self._open_document(doc_url, fast=True):
|
||||
return {"success": False, "logged_in": False, "error": self._last_error or "打开文档失败"}
|
||||
|
||||
doc_error = self._ensure_doc_access(doc_url, fast=True)
|
||||
if doc_error:
|
||||
return {"success": False, "logged_in": False, "error": doc_error}
|
||||
|
||||
fast_login_timeout = int(os.environ.get("KDOCS_FAST_LOGIN_TIMEOUT_MS", "300"))
|
||||
self._ensure_login_dialog(
|
||||
timeout_ms=fast_login_timeout,
|
||||
frame_timeout_ms=fast_login_timeout,
|
||||
quick=True,
|
||||
)
|
||||
self._trigger_fast_login_dialog(fast_login_timeout)
|
||||
self._try_confirm_login(
|
||||
timeout_ms=fast_login_timeout,
|
||||
frame_timeout_ms=fast_login_timeout,
|
||||
quick=True,
|
||||
)
|
||||
|
||||
logged_in = self._is_logged_in()
|
||||
self._last_login_ok = logged_in
|
||||
self._login_required = not logged_in
|
||||
@@ -962,79 +1092,43 @@ class KDocsUploader:
|
||||
cfg = self._load_system_config()
|
||||
if int(cfg.get("kdocs_enabled", 0) or 0) != 1:
|
||||
return
|
||||
doc_url = (cfg.get("kdocs_doc_url") or "").strip()
|
||||
|
||||
doc_url = self._resolve_doc_url(cfg)
|
||||
if not doc_url:
|
||||
return
|
||||
|
||||
unit = (payload.get("unit") or "").strip()
|
||||
name = (payload.get("name") or "").strip()
|
||||
image_path = payload.get("image_path")
|
||||
user_id = payload.get("user_id")
|
||||
account_id = payload.get("account_id")
|
||||
|
||||
if not unit or not name:
|
||||
return
|
||||
if not image_path or not os.path.exists(image_path):
|
||||
upload_data = self._parse_upload_payload(payload)
|
||||
if not upload_data:
|
||||
return
|
||||
|
||||
account = None
|
||||
prev_status = None
|
||||
status_tracked = False
|
||||
unit = upload_data["unit"]
|
||||
name = upload_data["name"]
|
||||
image_path = upload_data["image_path"]
|
||||
user_id = upload_data["user_id"]
|
||||
account_id = upload_data["account_id"]
|
||||
|
||||
account, prev_status, status_tracked = self._mark_upload_tracking(user_id, account_id)
|
||||
|
||||
try:
|
||||
try:
|
||||
account = safe_get_account(user_id, account_id)
|
||||
if account and self._should_mark_upload(account):
|
||||
prev_status = getattr(account, "status", None)
|
||||
account.status = "上传截图"
|
||||
self._emit_account_update(user_id, account)
|
||||
status_tracked = True
|
||||
except Exception:
|
||||
prev_status = None
|
||||
|
||||
if not self._ensure_playwright():
|
||||
self._notify_admin(unit, name, image_path, self._last_error or "浏览器不可用")
|
||||
return
|
||||
|
||||
if not self._open_document(doc_url):
|
||||
self._notify_admin(unit, name, image_path, self._last_error or "打开文档失败")
|
||||
doc_error = self._ensure_doc_access(doc_url)
|
||||
if doc_error:
|
||||
self._notify_admin(unit, name, image_path, doc_error)
|
||||
return
|
||||
|
||||
if not self._is_logged_in():
|
||||
self._login_required = True
|
||||
self._last_login_ok = False
|
||||
self._notify_admin(unit, name, image_path, "登录已失效,请管理员重新扫码登录")
|
||||
try:
|
||||
log_to_client("表格上传失败: 登录已失效,请管理员重新扫码登录", user_id, account_id)
|
||||
except Exception:
|
||||
pass
|
||||
self._handle_upload_login_invalid(
|
||||
unit=unit,
|
||||
name=name,
|
||||
image_path=image_path,
|
||||
user_id=user_id,
|
||||
account_id=account_id,
|
||||
)
|
||||
return
|
||||
|
||||
self._login_required = False
|
||||
self._last_login_ok = True
|
||||
|
||||
sheet_name = (cfg.get("kdocs_sheet_name") or "").strip()
|
||||
sheet_index = int(cfg.get("kdocs_sheet_index") or 0)
|
||||
unit_col = (cfg.get("kdocs_unit_column") or "A").strip().upper()
|
||||
image_col = (cfg.get("kdocs_image_column") or "D").strip().upper()
|
||||
row_start = int(cfg.get("kdocs_row_start") or 0)
|
||||
row_end = int(cfg.get("kdocs_row_end") or 0)
|
||||
|
||||
success = False
|
||||
error_msg = ""
|
||||
for attempt in range(2):
|
||||
try:
|
||||
if sheet_name or sheet_index:
|
||||
self._select_sheet(sheet_name, sheet_index)
|
||||
row_num = self._find_person_with_unit(unit, name, unit_col, row_start=row_start, row_end=row_end)
|
||||
if row_num < 0:
|
||||
error_msg = f"未找到人员: {unit}-{name}"
|
||||
break
|
||||
success = self._upload_image_to_cell(row_num, image_path, image_col)
|
||||
if success:
|
||||
break
|
||||
except Exception as e:
|
||||
error_msg = str(e)
|
||||
|
||||
success, error_msg = self._try_upload_to_sheet(cfg, unit, name, image_path)
|
||||
if success:
|
||||
self._last_success_at = time.time()
|
||||
self._last_error = None
|
||||
@@ -1048,10 +1142,7 @@ class KDocsUploader:
|
||||
error_msg = "上传失败"
|
||||
self._last_error = error_msg
|
||||
self._notify_admin(unit, name, image_path, error_msg)
|
||||
try:
|
||||
log_to_client(f"表格上传失败: {error_msg}", user_id, account_id)
|
||||
except Exception:
|
||||
pass
|
||||
self._log_upload_failure(error_msg, user_id, account_id)
|
||||
finally:
|
||||
if status_tracked:
|
||||
self._restore_account_status(user_id, account, prev_status)
|
||||
|
||||
@@ -2,6 +2,7 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
from __future__ import annotations
|
||||
|
||||
import os
|
||||
import threading
|
||||
import time
|
||||
from datetime import datetime
|
||||
@@ -10,6 +11,8 @@ from app_config import get_config
|
||||
from app_logger import get_logger
|
||||
from services.state import (
|
||||
cleanup_expired_ip_rate_limits,
|
||||
cleanup_expired_ip_request_rates,
|
||||
cleanup_expired_login_security_state,
|
||||
safe_cleanup_expired_batches,
|
||||
safe_cleanup_expired_captcha,
|
||||
safe_cleanup_expired_pending_random,
|
||||
@@ -31,6 +34,69 @@ PENDING_RANDOM_EXPIRE_SECONDS = int(getattr(config, "PENDING_RANDOM_EXPIRE_SECON
|
||||
_kdocs_offline_notified: bool = False
|
||||
|
||||
|
||||
def _to_int(value, default: int = 0) -> int:
|
||||
try:
|
||||
return int(value)
|
||||
except Exception:
|
||||
return int(default)
|
||||
|
||||
|
||||
def _collect_active_user_ids() -> set[int]:
|
||||
active_user_ids: set[int] = set()
|
||||
for _, info in safe_iter_task_status_items():
|
||||
user_id = info.get("user_id") if isinstance(info, dict) else None
|
||||
if user_id is None:
|
||||
continue
|
||||
try:
|
||||
active_user_ids.add(int(user_id))
|
||||
except Exception:
|
||||
continue
|
||||
return active_user_ids
|
||||
|
||||
|
||||
def _find_expired_user_cache_ids(current_time: float, active_user_ids: set[int]) -> list[int]:
|
||||
expired_users = []
|
||||
for user_id, last_access in (safe_get_user_accounts_last_access_items() or []):
|
||||
try:
|
||||
user_id_int = int(user_id)
|
||||
last_access_ts = float(last_access)
|
||||
except Exception:
|
||||
continue
|
||||
if (current_time - last_access_ts) <= USER_ACCOUNTS_EXPIRE_SECONDS:
|
||||
continue
|
||||
if user_id_int in active_user_ids:
|
||||
continue
|
||||
if safe_has_user(user_id_int):
|
||||
expired_users.append(user_id_int)
|
||||
return expired_users
|
||||
|
||||
|
||||
def _find_completed_task_status_ids(current_time: float) -> list[str]:
|
||||
completed_task_ids = []
|
||||
for account_id, status_data in safe_iter_task_status_items():
|
||||
status = status_data.get("status") if isinstance(status_data, dict) else None
|
||||
if status not in ["已完成", "失败", "已停止"]:
|
||||
continue
|
||||
|
||||
start_time = float(status_data.get("start_time", 0) or 0)
|
||||
if (current_time - start_time) > 600: # 10分钟
|
||||
completed_task_ids.append(account_id)
|
||||
return completed_task_ids
|
||||
|
||||
|
||||
def _reap_zombie_processes() -> None:
|
||||
while True:
|
||||
try:
|
||||
pid, _ = os.waitpid(-1, os.WNOHANG)
|
||||
if pid == 0:
|
||||
break
|
||||
logger.debug(f"已回收僵尸进程: PID={pid}")
|
||||
except ChildProcessError:
|
||||
break
|
||||
except Exception:
|
||||
break
|
||||
|
||||
|
||||
def cleanup_expired_data() -> None:
|
||||
"""定期清理过期数据,防止内存泄漏(逻辑保持不变)。"""
|
||||
current_time = time.time()
|
||||
@@ -43,48 +109,36 @@ def cleanup_expired_data() -> None:
|
||||
if deleted_ips:
|
||||
logger.debug(f"已清理 {deleted_ips} 个过期IP限流记录")
|
||||
|
||||
expired_users = []
|
||||
last_access_items = safe_get_user_accounts_last_access_items()
|
||||
if last_access_items:
|
||||
task_items = safe_iter_task_status_items()
|
||||
active_user_ids = {int(info.get("user_id")) for _, info in task_items if info.get("user_id")}
|
||||
for user_id, last_access in last_access_items:
|
||||
if (current_time - float(last_access)) <= USER_ACCOUNTS_EXPIRE_SECONDS:
|
||||
continue
|
||||
if int(user_id) in active_user_ids:
|
||||
continue
|
||||
if safe_has_user(user_id):
|
||||
expired_users.append(int(user_id))
|
||||
deleted_ip_requests = cleanup_expired_ip_request_rates(current_time)
|
||||
if deleted_ip_requests:
|
||||
logger.debug(f"已清理 {deleted_ip_requests} 个过期IP请求频率记录")
|
||||
|
||||
login_cleanup_stats = cleanup_expired_login_security_state(current_time)
|
||||
login_cleanup_total = sum(int(v or 0) for v in login_cleanup_stats.values())
|
||||
if login_cleanup_total:
|
||||
logger.debug(
|
||||
"已清理登录风控缓存: "
|
||||
f"失败计数={login_cleanup_stats.get('failures', 0)}, "
|
||||
f"限流桶={login_cleanup_stats.get('rate_limits', 0)}, "
|
||||
f"扫描状态={login_cleanup_stats.get('scan_states', 0)}, "
|
||||
f"短时锁={login_cleanup_stats.get('ip_user_locks', 0)}, "
|
||||
f"告警状态={login_cleanup_stats.get('alerts', 0)}"
|
||||
)
|
||||
|
||||
active_user_ids = _collect_active_user_ids()
|
||||
expired_users = _find_expired_user_cache_ids(current_time, active_user_ids)
|
||||
for user_id in expired_users:
|
||||
safe_remove_user_accounts(user_id)
|
||||
if expired_users:
|
||||
logger.debug(f"已清理 {len(expired_users)} 个过期用户账号缓存")
|
||||
|
||||
completed_tasks = []
|
||||
for account_id, status_data in safe_iter_task_status_items():
|
||||
if status_data.get("status") in ["已完成", "失败", "已停止"]:
|
||||
start_time = float(status_data.get("start_time", 0) or 0)
|
||||
if (current_time - start_time) > 600: # 10分钟
|
||||
completed_tasks.append(account_id)
|
||||
for account_id in completed_tasks:
|
||||
completed_task_ids = _find_completed_task_status_ids(current_time)
|
||||
for account_id in completed_task_ids:
|
||||
safe_remove_task_status(account_id)
|
||||
if completed_tasks:
|
||||
logger.debug(f"已清理 {len(completed_tasks)} 个已完成任务状态")
|
||||
if completed_task_ids:
|
||||
logger.debug(f"已清理 {len(completed_task_ids)} 个已完成任务状态")
|
||||
|
||||
try:
|
||||
import os
|
||||
|
||||
while True:
|
||||
try:
|
||||
pid, status = os.waitpid(-1, os.WNOHANG)
|
||||
if pid == 0:
|
||||
break
|
||||
logger.debug(f"已回收僵尸进程: PID={pid}")
|
||||
except ChildProcessError:
|
||||
break
|
||||
except Exception:
|
||||
pass
|
||||
_reap_zombie_processes()
|
||||
|
||||
deleted_batches = safe_cleanup_expired_batches(BATCH_TASK_EXPIRE_SECONDS, current_time)
|
||||
if deleted_batches:
|
||||
@@ -95,52 +149,39 @@ def cleanup_expired_data() -> None:
|
||||
logger.debug(f"已清理 {deleted_random} 个过期随机延迟任务")
|
||||
|
||||
|
||||
def check_kdocs_online_status() -> None:
|
||||
"""检测金山文档登录状态,如果离线则发送邮件通知管理员(每次掉线只通知一次)"""
|
||||
global _kdocs_offline_notified
|
||||
def _load_kdocs_monitor_config():
|
||||
import database
|
||||
|
||||
cfg = database.get_system_config()
|
||||
if not cfg:
|
||||
return None
|
||||
|
||||
kdocs_enabled = _to_int(cfg.get("kdocs_enabled"), 0)
|
||||
if not kdocs_enabled:
|
||||
return None
|
||||
|
||||
admin_notify_enabled = _to_int(cfg.get("kdocs_admin_notify_enabled"), 0)
|
||||
admin_notify_email = str(cfg.get("kdocs_admin_notify_email") or "").strip()
|
||||
if (not admin_notify_enabled) or (not admin_notify_email):
|
||||
return None
|
||||
|
||||
return admin_notify_email
|
||||
|
||||
|
||||
def _is_kdocs_offline(status: dict) -> tuple[bool, bool, bool | None]:
|
||||
login_required = bool(status.get("login_required", False))
|
||||
last_login_ok = status.get("last_login_ok")
|
||||
is_offline = login_required or (last_login_ok is False)
|
||||
return is_offline, login_required, last_login_ok
|
||||
|
||||
|
||||
def _send_kdocs_offline_alert(admin_notify_email: str, *, login_required: bool, last_login_ok) -> bool:
|
||||
try:
|
||||
import database
|
||||
from services.kdocs_uploader import get_kdocs_uploader
|
||||
import email_service
|
||||
|
||||
# 获取系统配置
|
||||
cfg = database.get_system_config()
|
||||
if not cfg:
|
||||
return
|
||||
|
||||
# 检查是否启用了金山文档功能
|
||||
kdocs_enabled = int(cfg.get("kdocs_enabled") or 0)
|
||||
if not kdocs_enabled:
|
||||
return
|
||||
|
||||
# 检查是否启用了管理员通知
|
||||
admin_notify_enabled = int(cfg.get("kdocs_admin_notify_enabled") or 0)
|
||||
admin_notify_email = (cfg.get("kdocs_admin_notify_email") or "").strip()
|
||||
if not admin_notify_enabled or not admin_notify_email:
|
||||
return
|
||||
|
||||
# 获取金山文档状态
|
||||
kdocs = get_kdocs_uploader()
|
||||
status = kdocs.get_status()
|
||||
login_required = status.get("login_required", False)
|
||||
last_login_ok = status.get("last_login_ok")
|
||||
|
||||
# 如果需要登录或最后登录状态不是成功
|
||||
is_offline = login_required or (last_login_ok is False)
|
||||
|
||||
if is_offline:
|
||||
# 已经通知过了,不再重复通知
|
||||
if _kdocs_offline_notified:
|
||||
logger.debug("[KDocs监控] 金山文档离线,已通知过,跳过重复通知")
|
||||
return
|
||||
|
||||
# 发送邮件通知
|
||||
try:
|
||||
import email_service
|
||||
|
||||
now_str = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
|
||||
subject = "【金山文档离线告警】需要重新登录"
|
||||
body = f"""
|
||||
now_str = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
|
||||
subject = "【金山文档离线告警】需要重新登录"
|
||||
body = f"""
|
||||
您好,
|
||||
|
||||
系统检测到金山文档上传功能已离线,需要重新扫码登录。
|
||||
@@ -155,58 +196,92 @@ def check_kdocs_online_status() -> None:
|
||||
---
|
||||
此邮件由系统自动发送,请勿直接回复。
|
||||
"""
|
||||
email_service.send_email_async(
|
||||
to_email=admin_notify_email,
|
||||
subject=subject,
|
||||
body=body,
|
||||
email_type="kdocs_offline_alert",
|
||||
)
|
||||
_kdocs_offline_notified = True # 标记为已通知
|
||||
logger.warning(f"[KDocs监控] 金山文档离线,已发送通知邮件到 {admin_notify_email}")
|
||||
except Exception as e:
|
||||
logger.error(f"[KDocs监控] 发送离线通知邮件失败: {e}")
|
||||
else:
|
||||
# 恢复在线,重置通知状态
|
||||
email_service.send_email_async(
|
||||
to_email=admin_notify_email,
|
||||
subject=subject,
|
||||
body=body,
|
||||
email_type="kdocs_offline_alert",
|
||||
)
|
||||
logger.warning(f"[KDocs监控] 金山文档离线,已发送通知邮件到 {admin_notify_email}")
|
||||
return True
|
||||
except Exception as e:
|
||||
logger.error(f"[KDocs监控] 发送离线通知邮件失败: {e}")
|
||||
return False
|
||||
|
||||
|
||||
def check_kdocs_online_status() -> None:
|
||||
"""检测金山文档登录状态,如果离线则发送邮件通知管理员(每次掉线只通知一次)"""
|
||||
global _kdocs_offline_notified
|
||||
|
||||
try:
|
||||
admin_notify_email = _load_kdocs_monitor_config()
|
||||
if not admin_notify_email:
|
||||
return
|
||||
|
||||
from services.kdocs_uploader import get_kdocs_uploader
|
||||
|
||||
kdocs = get_kdocs_uploader()
|
||||
status = kdocs.get_status() or {}
|
||||
is_offline, login_required, last_login_ok = _is_kdocs_offline(status)
|
||||
|
||||
if is_offline:
|
||||
if _kdocs_offline_notified:
|
||||
logger.info("[KDocs监控] 金山文档已恢复在线,重置通知状态")
|
||||
_kdocs_offline_notified = False
|
||||
logger.debug("[KDocs监控] 金山文档状态正常")
|
||||
logger.debug("[KDocs监控] 金山文档离线,已通知过,跳过重复通知")
|
||||
return
|
||||
|
||||
if _send_kdocs_offline_alert(
|
||||
admin_notify_email,
|
||||
login_required=login_required,
|
||||
last_login_ok=last_login_ok,
|
||||
):
|
||||
_kdocs_offline_notified = True
|
||||
return
|
||||
|
||||
if _kdocs_offline_notified:
|
||||
logger.info("[KDocs监控] 金山文档已恢复在线,重置通知状态")
|
||||
_kdocs_offline_notified = False
|
||||
logger.debug("[KDocs监控] 金山文档状态正常")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"[KDocs监控] 检测失败: {e}")
|
||||
|
||||
|
||||
def start_cleanup_scheduler() -> None:
|
||||
"""启动定期清理调度器"""
|
||||
|
||||
def cleanup_loop():
|
||||
def _start_daemon_loop(name: str, *, startup_delay: float, interval_seconds: float, job, error_tag: str):
|
||||
def loop():
|
||||
if startup_delay > 0:
|
||||
time.sleep(startup_delay)
|
||||
while True:
|
||||
try:
|
||||
time.sleep(300) # 每5分钟执行一次清理
|
||||
cleanup_expired_data()
|
||||
job()
|
||||
time.sleep(interval_seconds)
|
||||
except Exception as e:
|
||||
logger.error(f"清理任务执行失败: {e}")
|
||||
logger.error(f"{error_tag}: {e}")
|
||||
time.sleep(min(60.0, max(1.0, interval_seconds / 5.0)))
|
||||
|
||||
cleanup_thread = threading.Thread(target=cleanup_loop, daemon=True, name="cleanup-scheduler")
|
||||
cleanup_thread.start()
|
||||
thread = threading.Thread(target=loop, daemon=True, name=name)
|
||||
thread.start()
|
||||
return thread
|
||||
|
||||
|
||||
def start_cleanup_scheduler() -> None:
|
||||
"""启动定期清理调度器"""
|
||||
_start_daemon_loop(
|
||||
"cleanup-scheduler",
|
||||
startup_delay=300,
|
||||
interval_seconds=300,
|
||||
job=cleanup_expired_data,
|
||||
error_tag="清理任务执行失败",
|
||||
)
|
||||
logger.info("内存清理调度器已启动")
|
||||
|
||||
|
||||
def start_kdocs_monitor() -> None:
|
||||
"""启动金山文档状态监控"""
|
||||
|
||||
def monitor_loop():
|
||||
# 启动后等待 60 秒再开始检测(给系统初始化的时间)
|
||||
time.sleep(60)
|
||||
while True:
|
||||
try:
|
||||
check_kdocs_online_status()
|
||||
time.sleep(300) # 每5分钟检测一次
|
||||
except Exception as e:
|
||||
logger.error(f"[KDocs监控] 监控任务执行失败: {e}")
|
||||
time.sleep(60)
|
||||
|
||||
monitor_thread = threading.Thread(target=monitor_loop, daemon=True, name="kdocs-monitor")
|
||||
monitor_thread.start()
|
||||
_start_daemon_loop(
|
||||
"kdocs-monitor",
|
||||
startup_delay=60,
|
||||
interval_seconds=300,
|
||||
job=check_kdocs_online_status,
|
||||
error_tag="[KDocs监控] 监控任务执行失败",
|
||||
)
|
||||
logger.info("[KDocs监控] 金山文档状态监控已启动(每5分钟检测一次)")
|
||||
|
||||
|
||||
@@ -27,6 +27,12 @@ from services.time_utils import get_beijing_now
|
||||
logger = get_logger("app")
|
||||
config = get_config()
|
||||
|
||||
try:
|
||||
_SCHEDULE_SUBMIT_DELAY_SECONDS = float(os.environ.get("SCHEDULE_SUBMIT_DELAY_SECONDS", "0.2"))
|
||||
except Exception:
|
||||
_SCHEDULE_SUBMIT_DELAY_SECONDS = 0.2
|
||||
_SCHEDULE_SUBMIT_DELAY_SECONDS = max(0.0, _SCHEDULE_SUBMIT_DELAY_SECONDS)
|
||||
|
||||
SCREENSHOTS_DIR = config.SCREENSHOTS_DIR
|
||||
os.makedirs(SCREENSHOTS_DIR, exist_ok=True)
|
||||
|
||||
@@ -55,6 +61,150 @@ def _normalize_hhmm(value: object, *, default: str) -> str:
|
||||
return f"{hour:02d}:{minute:02d}"
|
||||
|
||||
|
||||
def _safe_recompute_schedule_next_run(schedule_id: int) -> None:
|
||||
try:
|
||||
database.recompute_schedule_next_run(schedule_id)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
|
||||
def _load_accounts_for_users(approved_users: list[dict]) -> tuple[dict[int, dict], list[str]]:
|
||||
"""批量加载用户账号快照。"""
|
||||
user_accounts: dict[int, dict] = {}
|
||||
account_ids: list[str] = []
|
||||
for user in approved_users:
|
||||
user_id = user["id"]
|
||||
accounts = safe_get_user_accounts_snapshot(user_id)
|
||||
if not accounts:
|
||||
load_user_accounts(user_id)
|
||||
accounts = safe_get_user_accounts_snapshot(user_id)
|
||||
if accounts:
|
||||
user_accounts[user_id] = accounts
|
||||
account_ids.extend(list(accounts.keys()))
|
||||
return user_accounts, account_ids
|
||||
|
||||
|
||||
def _should_skip_suspended_account(account_status_info, account, username: str) -> bool:
|
||||
"""判断是否应跳过暂停账号,并输出日志。"""
|
||||
if not account_status_info:
|
||||
return False
|
||||
|
||||
status = account_status_info["status"] if "status" in account_status_info.keys() else "active"
|
||||
if status != "suspended":
|
||||
return False
|
||||
|
||||
fail_count = account_status_info["login_fail_count"] if "login_fail_count" in account_status_info.keys() else 0
|
||||
logger.info(
|
||||
f"[定时任务] 跳过暂停账号: {account.username} (用户:{username}) - 连续{fail_count}次密码错误,需修改密码"
|
||||
)
|
||||
return True
|
||||
|
||||
|
||||
def _parse_schedule_account_ids(schedule_config: dict, schedule_id: int):
|
||||
import json
|
||||
|
||||
try:
|
||||
account_ids_raw = schedule_config.get("account_ids", "[]") or "[]"
|
||||
account_ids = json.loads(account_ids_raw)
|
||||
except Exception as e:
|
||||
logger.warning(f"[定时任务] 任务#{schedule_id} 解析account_ids失败: {e}")
|
||||
return []
|
||||
if isinstance(account_ids, list):
|
||||
return account_ids
|
||||
return []
|
||||
|
||||
|
||||
def _create_user_schedule_batch(*, batch_id: str, user_id: int, browse_type: str, schedule_name: str, now_ts: float) -> None:
|
||||
safe_create_batch(
|
||||
batch_id,
|
||||
{
|
||||
"user_id": user_id,
|
||||
"browse_type": browse_type,
|
||||
"schedule_name": schedule_name,
|
||||
"screenshots": [],
|
||||
"total_accounts": 0,
|
||||
"completed": 0,
|
||||
"created_at": now_ts,
|
||||
"updated_at": now_ts,
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
def _build_user_schedule_done_callback(
|
||||
*,
|
||||
completion_lock: threading.Lock,
|
||||
remaining: dict,
|
||||
counters: dict,
|
||||
execution_start_time: float,
|
||||
log_id: int,
|
||||
schedule_id: int,
|
||||
total_accounts: int,
|
||||
):
|
||||
def on_browse_done():
|
||||
with completion_lock:
|
||||
remaining["count"] -= 1
|
||||
if remaining["done"] or remaining["count"] > 0:
|
||||
return
|
||||
remaining["done"] = True
|
||||
|
||||
execution_duration = int(time.time() - execution_start_time)
|
||||
started_count = int(counters.get("started", 0) or 0)
|
||||
database.update_schedule_execution_log(
|
||||
log_id,
|
||||
total_accounts=total_accounts,
|
||||
success_accounts=started_count,
|
||||
failed_accounts=total_accounts - started_count,
|
||||
duration_seconds=execution_duration,
|
||||
status="completed",
|
||||
)
|
||||
logger.info(f"[用户定时任务] 任务#{schedule_id}浏览阶段完成,耗时{execution_duration}秒,等待截图完成后发送邮件")
|
||||
|
||||
return on_browse_done
|
||||
|
||||
|
||||
def _submit_user_schedule_accounts(
|
||||
*,
|
||||
user_id: int,
|
||||
account_ids: list,
|
||||
browse_type: str,
|
||||
enable_screenshot,
|
||||
task_source: str,
|
||||
done_callback,
|
||||
completion_lock: threading.Lock,
|
||||
remaining: dict,
|
||||
counters: dict,
|
||||
) -> tuple[int, int]:
|
||||
started_count = 0
|
||||
skipped_count = 0
|
||||
|
||||
for account_id in account_ids:
|
||||
account = safe_get_account(user_id, account_id)
|
||||
if (not account) or account.is_running:
|
||||
skipped_count += 1
|
||||
continue
|
||||
|
||||
with completion_lock:
|
||||
remaining["count"] += 1
|
||||
ok, msg = submit_account_task(
|
||||
user_id=user_id,
|
||||
account_id=account_id,
|
||||
browse_type=browse_type,
|
||||
enable_screenshot=enable_screenshot,
|
||||
source=task_source,
|
||||
done_callback=done_callback,
|
||||
)
|
||||
if ok:
|
||||
started_count += 1
|
||||
counters["started"] = started_count
|
||||
else:
|
||||
with completion_lock:
|
||||
remaining["count"] -= 1
|
||||
skipped_count += 1
|
||||
logger.warning(f"[用户定时任务] 账号 {account.username} 启动失败: {msg}")
|
||||
|
||||
return started_count, skipped_count
|
||||
|
||||
|
||||
def run_scheduled_task(skip_weekday_check: bool = False) -> None:
|
||||
"""执行所有账号的浏览任务(可被手动调用,过滤重复账号)"""
|
||||
try:
|
||||
@@ -87,17 +237,7 @@ def run_scheduled_task(skip_weekday_check: bool = False) -> None:
|
||||
cfg = database.get_system_config()
|
||||
enable_screenshot_scheduled = cfg.get("enable_screenshot", 0) == 1
|
||||
|
||||
user_accounts = {}
|
||||
account_ids = []
|
||||
for user in approved_users:
|
||||
user_id = user["id"]
|
||||
accounts = safe_get_user_accounts_snapshot(user_id)
|
||||
if not accounts:
|
||||
load_user_accounts(user_id)
|
||||
accounts = safe_get_user_accounts_snapshot(user_id)
|
||||
if accounts:
|
||||
user_accounts[user_id] = accounts
|
||||
account_ids.extend(list(accounts.keys()))
|
||||
user_accounts, account_ids = _load_accounts_for_users(approved_users)
|
||||
|
||||
account_statuses = database.get_account_status_batch(account_ids)
|
||||
|
||||
@@ -113,18 +253,8 @@ def run_scheduled_task(skip_weekday_check: bool = False) -> None:
|
||||
continue
|
||||
|
||||
account_status_info = account_statuses.get(str(account_id))
|
||||
if account_status_info:
|
||||
status = account_status_info["status"] if "status" in account_status_info.keys() else "active"
|
||||
if status == "suspended":
|
||||
fail_count = (
|
||||
account_status_info["login_fail_count"]
|
||||
if "login_fail_count" in account_status_info.keys()
|
||||
else 0
|
||||
)
|
||||
logger.info(
|
||||
f"[定时任务] 跳过暂停账号: {account.username} (用户:{user['username']}) - 连续{fail_count}次密码错误,需修改密码"
|
||||
)
|
||||
continue
|
||||
if _should_skip_suspended_account(account_status_info, account, user["username"]):
|
||||
continue
|
||||
|
||||
if account.username in executed_usernames:
|
||||
skipped_duplicates += 1
|
||||
@@ -149,7 +279,8 @@ def run_scheduled_task(skip_weekday_check: bool = False) -> None:
|
||||
else:
|
||||
logger.warning(f"[定时任务] 启动失败({account.username}): {msg}")
|
||||
|
||||
time.sleep(2)
|
||||
if _SCHEDULE_SUBMIT_DELAY_SECONDS > 0:
|
||||
time.sleep(_SCHEDULE_SUBMIT_DELAY_SECONDS)
|
||||
|
||||
logger.info(
|
||||
f"[定时任务] 执行完成 - 总账号数:{total_accounts}, 已执行:{executed_accounts}, 跳过重复:{skipped_duplicates}"
|
||||
@@ -198,15 +329,16 @@ def scheduled_task_worker() -> None:
|
||||
deleted_screenshots = 0
|
||||
if os.path.exists(SCREENSHOTS_DIR):
|
||||
cutoff_time = time.time() - (7 * 24 * 60 * 60)
|
||||
for filename in os.listdir(SCREENSHOTS_DIR):
|
||||
if filename.lower().endswith((".png", ".jpg", ".jpeg")):
|
||||
filepath = os.path.join(SCREENSHOTS_DIR, filename)
|
||||
with os.scandir(SCREENSHOTS_DIR) as entries:
|
||||
for entry in entries:
|
||||
if (not entry.is_file()) or (not entry.name.lower().endswith((".png", ".jpg", ".jpeg"))):
|
||||
continue
|
||||
try:
|
||||
if os.path.getmtime(filepath) < cutoff_time:
|
||||
os.remove(filepath)
|
||||
if entry.stat().st_mtime < cutoff_time:
|
||||
os.remove(entry.path)
|
||||
deleted_screenshots += 1
|
||||
except Exception as e:
|
||||
logger.warning(f"[定时清理] 删除截图失败 {filename}: {str(e)}")
|
||||
logger.warning(f"[定时清理] 删除截图失败 {entry.name}: {str(e)}")
|
||||
|
||||
logger.info(f"[定时清理] 已删除 {deleted_screenshots} 个截图文件")
|
||||
logger.info("[定时清理] 清理完成!")
|
||||
@@ -214,10 +346,97 @@ def scheduled_task_worker() -> None:
|
||||
except Exception as e:
|
||||
logger.exception(f"[定时清理] 清理任务出错: {str(e)}")
|
||||
|
||||
def _parse_due_schedule_weekdays(schedule_config: dict, schedule_id: int):
|
||||
weekdays_str = schedule_config.get("weekdays", "1,2,3,4,5")
|
||||
try:
|
||||
return [int(d) for d in weekdays_str.split(",") if d.strip()]
|
||||
except Exception as e:
|
||||
logger.warning(f"[定时任务] 任务#{schedule_id} 解析weekdays失败: {e}")
|
||||
_safe_recompute_schedule_next_run(schedule_id)
|
||||
return None
|
||||
|
||||
def _execute_due_user_schedule(schedule_config: dict) -> None:
|
||||
schedule_name = schedule_config.get("name", "未命名任务")
|
||||
schedule_id = schedule_config["id"]
|
||||
user_id = schedule_config["user_id"]
|
||||
browse_type = normalize_browse_type(schedule_config.get("browse_type", BROWSE_TYPE_SHOULD_READ))
|
||||
enable_screenshot = schedule_config.get("enable_screenshot", 1)
|
||||
|
||||
account_ids = _parse_schedule_account_ids(schedule_config, schedule_id)
|
||||
if not account_ids:
|
||||
_safe_recompute_schedule_next_run(schedule_id)
|
||||
return
|
||||
|
||||
if not safe_get_user_accounts_snapshot(user_id):
|
||||
load_user_accounts(user_id)
|
||||
|
||||
import uuid
|
||||
|
||||
execution_start_time = time.time()
|
||||
log_id = database.create_schedule_execution_log(
|
||||
schedule_id=schedule_id,
|
||||
user_id=user_id,
|
||||
schedule_name=schedule_name,
|
||||
)
|
||||
|
||||
batch_id = f"batch_{uuid.uuid4().hex[:12]}"
|
||||
now_ts = time.time()
|
||||
_create_user_schedule_batch(
|
||||
batch_id=batch_id,
|
||||
user_id=user_id,
|
||||
browse_type=browse_type,
|
||||
schedule_name=schedule_name,
|
||||
now_ts=now_ts,
|
||||
)
|
||||
|
||||
completion_lock = threading.Lock()
|
||||
remaining = {"count": 0, "done": False}
|
||||
counters = {"started": 0}
|
||||
|
||||
on_browse_done = _build_user_schedule_done_callback(
|
||||
completion_lock=completion_lock,
|
||||
remaining=remaining,
|
||||
counters=counters,
|
||||
execution_start_time=execution_start_time,
|
||||
log_id=log_id,
|
||||
schedule_id=schedule_id,
|
||||
total_accounts=len(account_ids),
|
||||
)
|
||||
|
||||
task_source = f"user_scheduled:{batch_id}"
|
||||
started_count, skipped_count = _submit_user_schedule_accounts(
|
||||
user_id=user_id,
|
||||
account_ids=account_ids,
|
||||
browse_type=browse_type,
|
||||
enable_screenshot=enable_screenshot,
|
||||
task_source=task_source,
|
||||
done_callback=on_browse_done,
|
||||
completion_lock=completion_lock,
|
||||
remaining=remaining,
|
||||
counters=counters,
|
||||
)
|
||||
|
||||
batch_info = safe_finalize_batch_after_dispatch(batch_id, started_count, now_ts=time.time())
|
||||
if batch_info:
|
||||
_send_batch_task_email_if_configured(batch_info)
|
||||
|
||||
database.update_schedule_last_run(schedule_id)
|
||||
|
||||
logger.info(f"[用户定时任务] 已启动 {started_count} 个账号,跳过 {skipped_count} 个账号,批次ID: {batch_id}")
|
||||
if started_count <= 0:
|
||||
database.update_schedule_execution_log(
|
||||
log_id,
|
||||
total_accounts=len(account_ids),
|
||||
success_accounts=0,
|
||||
failed_accounts=len(account_ids),
|
||||
duration_seconds=0,
|
||||
status="completed",
|
||||
)
|
||||
if started_count == 0 and len(account_ids) > 0:
|
||||
logger.warning("[用户定时任务] ⚠️ 警告:所有账号都被跳过了!请检查user_accounts状态")
|
||||
|
||||
def check_user_schedules():
|
||||
"""检查并执行用户定时任务(O-08:next_run_at 索引驱动)。"""
|
||||
import json
|
||||
|
||||
try:
|
||||
now = get_beijing_now()
|
||||
now_str = now.strftime("%Y-%m-%d %H:%M:%S")
|
||||
@@ -226,145 +445,22 @@ def scheduled_task_worker() -> None:
|
||||
due_schedules = database.get_due_user_schedules(now_str, limit=50) or []
|
||||
|
||||
for schedule_config in due_schedules:
|
||||
schedule_name = schedule_config.get("name", "未命名任务")
|
||||
schedule_id = schedule_config["id"]
|
||||
schedule_name = schedule_config.get("name", "未命名任务")
|
||||
|
||||
weekdays_str = schedule_config.get("weekdays", "1,2,3,4,5")
|
||||
try:
|
||||
allowed_weekdays = [int(d) for d in weekdays_str.split(",") if d.strip()]
|
||||
except Exception as e:
|
||||
logger.warning(f"[定时任务] 任务#{schedule_id} 解析weekdays失败: {e}")
|
||||
try:
|
||||
database.recompute_schedule_next_run(schedule_id)
|
||||
except Exception:
|
||||
pass
|
||||
allowed_weekdays = _parse_due_schedule_weekdays(schedule_config, schedule_id)
|
||||
if allowed_weekdays is None:
|
||||
continue
|
||||
|
||||
if current_weekday not in allowed_weekdays:
|
||||
try:
|
||||
database.recompute_schedule_next_run(schedule_id)
|
||||
except Exception:
|
||||
pass
|
||||
_safe_recompute_schedule_next_run(schedule_id)
|
||||
continue
|
||||
|
||||
logger.info(f"[用户定时任务] 任务#{schedule_id} '{schedule_name}' 到期,开始执行 (next_run_at={schedule_config.get('next_run_at')})")
|
||||
|
||||
user_id = schedule_config["user_id"]
|
||||
schedule_id = schedule_config["id"]
|
||||
browse_type = normalize_browse_type(schedule_config.get("browse_type", BROWSE_TYPE_SHOULD_READ))
|
||||
enable_screenshot = schedule_config.get("enable_screenshot", 1)
|
||||
|
||||
try:
|
||||
account_ids_raw = schedule_config.get("account_ids", "[]") or "[]"
|
||||
account_ids = json.loads(account_ids_raw)
|
||||
except Exception as e:
|
||||
logger.warning(f"[定时任务] 任务#{schedule_id} 解析account_ids失败: {e}")
|
||||
account_ids = []
|
||||
|
||||
if not account_ids:
|
||||
try:
|
||||
database.recompute_schedule_next_run(schedule_id)
|
||||
except Exception:
|
||||
pass
|
||||
continue
|
||||
|
||||
if not safe_get_user_accounts_snapshot(user_id):
|
||||
load_user_accounts(user_id)
|
||||
|
||||
import time as time_mod
|
||||
import uuid
|
||||
|
||||
execution_start_time = time_mod.time()
|
||||
log_id = database.create_schedule_execution_log(
|
||||
schedule_id=schedule_id, user_id=user_id, schedule_name=schedule_config.get("name", "未命名任务")
|
||||
logger.info(
|
||||
f"[用户定时任务] 任务#{schedule_id} '{schedule_name}' 到期,开始执行 "
|
||||
f"(next_run_at={schedule_config.get('next_run_at')})"
|
||||
)
|
||||
|
||||
batch_id = f"batch_{uuid.uuid4().hex[:12]}"
|
||||
now_ts = time_mod.time()
|
||||
safe_create_batch(
|
||||
batch_id,
|
||||
{
|
||||
"user_id": user_id,
|
||||
"browse_type": browse_type,
|
||||
"schedule_name": schedule_config.get("name", "未命名任务"),
|
||||
"screenshots": [],
|
||||
"total_accounts": 0,
|
||||
"completed": 0,
|
||||
"created_at": now_ts,
|
||||
"updated_at": now_ts,
|
||||
},
|
||||
)
|
||||
|
||||
started_count = 0
|
||||
skipped_count = 0
|
||||
completion_lock = threading.Lock()
|
||||
remaining = {"count": 0, "done": False}
|
||||
|
||||
def on_browse_done():
|
||||
with completion_lock:
|
||||
remaining["count"] -= 1
|
||||
if remaining["done"] or remaining["count"] > 0:
|
||||
return
|
||||
remaining["done"] = True
|
||||
execution_duration = int(time_mod.time() - execution_start_time)
|
||||
database.update_schedule_execution_log(
|
||||
log_id,
|
||||
total_accounts=len(account_ids),
|
||||
success_accounts=started_count,
|
||||
failed_accounts=len(account_ids) - started_count,
|
||||
duration_seconds=execution_duration,
|
||||
status="completed",
|
||||
)
|
||||
logger.info(
|
||||
f"[用户定时任务] 任务#{schedule_id}浏览阶段完成,耗时{execution_duration}秒,等待截图完成后发送邮件"
|
||||
)
|
||||
|
||||
for account_id in account_ids:
|
||||
account = safe_get_account(user_id, account_id)
|
||||
if not account:
|
||||
skipped_count += 1
|
||||
continue
|
||||
if account.is_running:
|
||||
skipped_count += 1
|
||||
continue
|
||||
|
||||
task_source = f"user_scheduled:{batch_id}"
|
||||
with completion_lock:
|
||||
remaining["count"] += 1
|
||||
ok, msg = submit_account_task(
|
||||
user_id=user_id,
|
||||
account_id=account_id,
|
||||
browse_type=browse_type,
|
||||
enable_screenshot=enable_screenshot,
|
||||
source=task_source,
|
||||
done_callback=on_browse_done,
|
||||
)
|
||||
if ok:
|
||||
started_count += 1
|
||||
else:
|
||||
with completion_lock:
|
||||
remaining["count"] -= 1
|
||||
skipped_count += 1
|
||||
logger.warning(f"[用户定时任务] 账号 {account.username} 启动失败: {msg}")
|
||||
|
||||
batch_info = safe_finalize_batch_after_dispatch(batch_id, started_count, now_ts=time_mod.time())
|
||||
if batch_info:
|
||||
_send_batch_task_email_if_configured(batch_info)
|
||||
|
||||
database.update_schedule_last_run(schedule_id)
|
||||
|
||||
logger.info(f"[用户定时任务] 已启动 {started_count} 个账号,跳过 {skipped_count} 个账号,批次ID: {batch_id}")
|
||||
if started_count <= 0:
|
||||
database.update_schedule_execution_log(
|
||||
log_id,
|
||||
total_accounts=len(account_ids),
|
||||
success_accounts=0,
|
||||
failed_accounts=len(account_ids),
|
||||
duration_seconds=0,
|
||||
status="completed",
|
||||
)
|
||||
if started_count == 0 and len(account_ids) > 0:
|
||||
logger.warning("[用户定时任务] ⚠️ 警告:所有账号都被跳过了!请检查user_accounts状态")
|
||||
_execute_due_user_schedule(schedule_config)
|
||||
|
||||
except Exception as e:
|
||||
logger.exception(f"[用户定时任务] 检查出错: {str(e)}")
|
||||
|
||||
@@ -6,12 +6,14 @@ import os
|
||||
import shutil
|
||||
import subprocess
|
||||
import time
|
||||
from urllib.parse import urlsplit
|
||||
|
||||
import database
|
||||
import email_service
|
||||
from api_browser import APIBrowser, get_cookie_jar_path, is_cookie_jar_fresh
|
||||
from app_config import get_config
|
||||
from app_logger import get_logger
|
||||
from app_security import sanitize_filename
|
||||
from browser_pool_worker import get_browser_worker_pool
|
||||
from services.client_log import log_to_client
|
||||
from services.runtime import get_socketio
|
||||
@@ -194,6 +196,293 @@ def _emit(event: str, data: object, *, room: str | None = None) -> None:
|
||||
pass
|
||||
|
||||
|
||||
def _set_screenshot_running_status(user_id: int, account_id: str) -> None:
|
||||
"""更新账号状态为截图中。"""
|
||||
acc = safe_get_account(user_id, account_id)
|
||||
if not acc:
|
||||
return
|
||||
acc.status = "截图中"
|
||||
safe_update_task_status(account_id, {"status": "运行中", "detail_status": "正在截图"})
|
||||
_emit("account_update", acc.to_dict(), room=f"user_{user_id}")
|
||||
|
||||
|
||||
def _get_worker_display_info(browser_instance) -> tuple[str, int]:
|
||||
"""获取截图 worker 的展示信息。"""
|
||||
if isinstance(browser_instance, dict):
|
||||
return str(browser_instance.get("worker_id", "?")), int(browser_instance.get("use_count", 0) or 0)
|
||||
return "?", 0
|
||||
|
||||
|
||||
def _get_proxy_context(account) -> tuple[dict | None, str | None]:
|
||||
"""提取截图阶段代理配置。"""
|
||||
proxy_config = account.proxy_config if hasattr(account, "proxy_config") else None
|
||||
proxy_server = proxy_config.get("server") if proxy_config else None
|
||||
return proxy_config, proxy_server
|
||||
|
||||
|
||||
def _build_screenshot_targets(browse_type: str) -> tuple[str, str, str]:
|
||||
"""构建截图目标 URL 与页面脚本。"""
|
||||
parsed = urlsplit(config.ZSGL_LOGIN_URL)
|
||||
base = f"{parsed.scheme}://{parsed.netloc}"
|
||||
if "注册前" in str(browse_type):
|
||||
bz = 0
|
||||
else:
|
||||
bz = 0
|
||||
|
||||
target_url = f"{base}/admin/center.aspx?bz={bz}"
|
||||
index_url = config.ZSGL_INDEX_URL or f"{base}/admin/index.aspx"
|
||||
run_script = (
|
||||
"(function(){"
|
||||
"function done(){window.status='ready';}"
|
||||
"function ensureNav(){try{if(typeof loadMenuTree==='function'){loadMenuTree(true);}}catch(e){}}"
|
||||
"function expandMenu(){"
|
||||
"try{var body=document.body;if(body&&body.classList.contains('lay-mini')){body.classList.remove('lay-mini');}}catch(e){}"
|
||||
"try{if(typeof mainPageResize==='function'){mainPageResize();}}catch(e){}"
|
||||
"try{if(typeof toggleMainMenu==='function' && document.body && document.body.classList.contains('lay-mini')){toggleMainMenu();}}catch(e){}"
|
||||
"try{var navRight=document.querySelector('.nav-right');if(navRight){navRight.style.display='block';}}catch(e){}"
|
||||
"try{var mainNav=document.getElementById('main-nav');if(mainNav){mainNav.style.display='block';}}catch(e){}"
|
||||
"}"
|
||||
"function navReady(){"
|
||||
"try{var nav=document.getElementById('sidebar-nav');return nav && nav.querySelectorAll('a').length>0;}catch(e){return false;}"
|
||||
"}"
|
||||
"function frameReady(){"
|
||||
"try{var f=document.getElementById('mainframe');return f && f.contentDocument && f.contentDocument.readyState==='complete';}catch(e){return false;}"
|
||||
"}"
|
||||
"function check(){"
|
||||
"if(navReady() && frameReady()){done();return;}"
|
||||
"setTimeout(check,300);"
|
||||
"}"
|
||||
"var f=document.getElementById('mainframe');"
|
||||
"ensureNav();"
|
||||
"expandMenu();"
|
||||
"if(!f){done();return;}"
|
||||
f"f.src='{target_url}';"
|
||||
"f.onload=function(){ensureNav();expandMenu();setTimeout(check,300);};"
|
||||
"setTimeout(check,5000);"
|
||||
"})();"
|
||||
)
|
||||
return index_url, target_url, run_script
|
||||
|
||||
|
||||
def _build_screenshot_output_path(username_prefix: str, account, browse_type: str) -> tuple[str, str]:
|
||||
"""构建截图输出文件名与路径。"""
|
||||
timestamp = get_beijing_now().strftime("%Y%m%d_%H%M%S")
|
||||
login_account = account.remark if account.remark else account.username
|
||||
raw_filename = f"{username_prefix}_{login_account}_{browse_type}_{timestamp}.jpg"
|
||||
screenshot_filename = sanitize_filename(raw_filename)
|
||||
return screenshot_filename, os.path.join(SCREENSHOTS_DIR, screenshot_filename)
|
||||
|
||||
|
||||
def _ensure_screenshot_login_state(
|
||||
*,
|
||||
account,
|
||||
proxy_config,
|
||||
cookie_path: str,
|
||||
attempt: int,
|
||||
max_retries: int,
|
||||
user_id: int,
|
||||
account_id: str,
|
||||
custom_log,
|
||||
) -> str:
|
||||
"""确保截图前登录态有效。返回: ok/retry/fail。"""
|
||||
should_refresh_login = not is_cookie_jar_fresh(cookie_path)
|
||||
if not should_refresh_login:
|
||||
return "ok"
|
||||
|
||||
log_to_client("正在刷新登录态...", user_id, account_id)
|
||||
if _ensure_login_cookies(account, proxy_config, custom_log):
|
||||
return "ok"
|
||||
|
||||
if attempt > 1:
|
||||
log_to_client("截图登录失败", user_id, account_id)
|
||||
if attempt < max_retries:
|
||||
log_to_client("将重试...", user_id, account_id)
|
||||
time.sleep(2)
|
||||
return "retry"
|
||||
|
||||
log_to_client("❌ 截图失败: 登录失败", user_id, account_id)
|
||||
return "fail"
|
||||
|
||||
|
||||
def _take_screenshot_once(
|
||||
*,
|
||||
index_url: str,
|
||||
target_url: str,
|
||||
screenshot_path: str,
|
||||
cookie_path: str,
|
||||
proxy_server: str | None,
|
||||
run_script: str,
|
||||
log_callback,
|
||||
) -> str:
|
||||
"""执行一次截图尝试并验证输出文件。返回: success/invalid/failed。"""
|
||||
cookies_for_shot = cookie_path if is_cookie_jar_fresh(cookie_path) else None
|
||||
|
||||
attempts = [
|
||||
{
|
||||
"url": index_url,
|
||||
"run_script": run_script,
|
||||
"window_status": "ready",
|
||||
},
|
||||
{
|
||||
"url": target_url,
|
||||
"run_script": None,
|
||||
"window_status": None,
|
||||
},
|
||||
]
|
||||
|
||||
ok = False
|
||||
for shot in attempts:
|
||||
ok = take_screenshot_wkhtmltoimage(
|
||||
shot["url"],
|
||||
screenshot_path,
|
||||
cookies_path=cookies_for_shot,
|
||||
proxy_server=proxy_server,
|
||||
run_script=shot["run_script"],
|
||||
window_status=shot["window_status"],
|
||||
log_callback=log_callback,
|
||||
)
|
||||
if ok:
|
||||
break
|
||||
|
||||
if not ok:
|
||||
return "failed"
|
||||
|
||||
if os.path.exists(screenshot_path) and os.path.getsize(screenshot_path) > 1000:
|
||||
return "success"
|
||||
|
||||
if os.path.exists(screenshot_path):
|
||||
os.remove(screenshot_path)
|
||||
return "invalid"
|
||||
|
||||
|
||||
def _get_result_screenshot_path(result) -> str | None:
|
||||
"""从截图结果中提取截图文件绝对路径。"""
|
||||
if result and result.get("success") and result.get("filename"):
|
||||
return os.path.join(SCREENSHOTS_DIR, result["filename"])
|
||||
return None
|
||||
|
||||
|
||||
def _enqueue_kdocs_upload_if_needed(user_id: int, account_id: str, account, screenshot_path: str | None) -> None:
|
||||
"""按配置提交金山文档上传任务。"""
|
||||
if not screenshot_path:
|
||||
return
|
||||
|
||||
cfg = database.get_system_config() or {}
|
||||
if int(cfg.get("kdocs_enabled", 0) or 0) != 1:
|
||||
return
|
||||
|
||||
doc_url = (cfg.get("kdocs_doc_url") or "").strip()
|
||||
if not doc_url:
|
||||
return
|
||||
|
||||
user_cfg = database.get_user_kdocs_settings(user_id) or {}
|
||||
if int(user_cfg.get("kdocs_auto_upload", 0) or 0) != 1:
|
||||
return
|
||||
|
||||
unit = (user_cfg.get("kdocs_unit") or cfg.get("kdocs_default_unit") or "").strip()
|
||||
name = (account.remark or "").strip()
|
||||
if not unit:
|
||||
log_to_client("表格上传跳过: 未配置县区", user_id, account_id)
|
||||
return
|
||||
if not name:
|
||||
log_to_client("表格上传跳过: 账号备注为空", user_id, account_id)
|
||||
return
|
||||
|
||||
from services.kdocs_uploader import get_kdocs_uploader
|
||||
|
||||
ok = get_kdocs_uploader().enqueue_upload(
|
||||
user_id=user_id,
|
||||
account_id=account_id,
|
||||
unit=unit,
|
||||
name=name,
|
||||
image_path=screenshot_path,
|
||||
)
|
||||
if not ok:
|
||||
log_to_client("表格上传排队失败: 队列已满", user_id, account_id)
|
||||
|
||||
|
||||
def _dispatch_screenshot_result(
|
||||
*,
|
||||
user_id: int,
|
||||
account_id: str,
|
||||
source: str,
|
||||
browse_type: str,
|
||||
browse_result: dict,
|
||||
result,
|
||||
account,
|
||||
user_info,
|
||||
) -> None:
|
||||
"""将截图结果发送到批次统计/邮件通知链路。"""
|
||||
batch_id = _get_batch_id_from_source(source)
|
||||
screenshot_path = _get_result_screenshot_path(result)
|
||||
account_name = account.remark if account.remark else account.username
|
||||
|
||||
try:
|
||||
if result and result.get("success") and screenshot_path:
|
||||
_enqueue_kdocs_upload_if_needed(user_id, account_id, account, screenshot_path)
|
||||
except Exception as kdocs_error:
|
||||
logger.warning(f"表格上传任务提交失败: {kdocs_error}")
|
||||
|
||||
if batch_id:
|
||||
_batch_task_record_result(
|
||||
batch_id=batch_id,
|
||||
account_name=account_name,
|
||||
screenshot_path=screenshot_path,
|
||||
total_items=browse_result.get("total_items", 0),
|
||||
total_attachments=browse_result.get("total_attachments", 0),
|
||||
)
|
||||
return
|
||||
|
||||
if source and source.startswith("user_scheduled"):
|
||||
if user_info and user_info.get("email") and database.get_user_email_notify(user_id):
|
||||
email_service.send_task_complete_email_async(
|
||||
user_id=user_id,
|
||||
email=user_info["email"],
|
||||
username=user_info["username"],
|
||||
account_name=account_name,
|
||||
browse_type=browse_type,
|
||||
total_items=browse_result.get("total_items", 0),
|
||||
total_attachments=browse_result.get("total_attachments", 0),
|
||||
screenshot_path=screenshot_path,
|
||||
log_callback=lambda msg: log_to_client(msg, user_id, account_id),
|
||||
)
|
||||
|
||||
|
||||
def _finalize_screenshot_callback_state(user_id: int, account_id: str, account) -> None:
|
||||
"""截图回调的通用收尾状态变更。"""
|
||||
account.is_running = False
|
||||
account.status = "未开始"
|
||||
safe_remove_task_status(account_id)
|
||||
_emit("account_update", account.to_dict(), room=f"user_{user_id}")
|
||||
|
||||
|
||||
def _persist_browse_log_after_screenshot(
|
||||
*,
|
||||
user_id: int,
|
||||
account_id: str,
|
||||
account,
|
||||
browse_type: str,
|
||||
source: str,
|
||||
task_start_time,
|
||||
browse_result,
|
||||
) -> None:
|
||||
"""截图完成后写入任务日志(浏览完成日志)。"""
|
||||
import time as time_module
|
||||
|
||||
total_elapsed = int(time_module.time() - task_start_time)
|
||||
database.create_task_log(
|
||||
user_id=user_id,
|
||||
account_id=account_id,
|
||||
username=account.username,
|
||||
browse_type=browse_type,
|
||||
status="success",
|
||||
total_items=browse_result.get("total_items", 0),
|
||||
total_attachments=browse_result.get("total_attachments", 0),
|
||||
duration=total_elapsed,
|
||||
source=source,
|
||||
)
|
||||
|
||||
|
||||
def take_screenshot_for_account(
|
||||
user_id,
|
||||
account_id,
|
||||
@@ -213,21 +502,21 @@ def take_screenshot_for_account(
|
||||
# 标记账号正在截图(防止重复提交截图任务)
|
||||
account.is_running = True
|
||||
|
||||
|
||||
user_info = database.get_user_by_id(user_id)
|
||||
username_prefix = user_info["username"] if user_info else f"user{user_id}"
|
||||
|
||||
def screenshot_task(
|
||||
browser_instance, user_id, account_id, account, browse_type, source, task_start_time, browse_result
|
||||
):
|
||||
"""在worker线程中执行的截图任务"""
|
||||
# ✅ 获得worker后,立即更新状态为"截图中"
|
||||
acc = safe_get_account(user_id, account_id)
|
||||
if acc:
|
||||
acc.status = "截图中"
|
||||
safe_update_task_status(account_id, {"status": "运行中", "detail_status": "正在截图"})
|
||||
_emit("account_update", acc.to_dict(), room=f"user_{user_id}")
|
||||
_set_screenshot_running_status(user_id, account_id)
|
||||
|
||||
max_retries = 3
|
||||
proxy_config = account.proxy_config if hasattr(account, "proxy_config") else None
|
||||
proxy_server = proxy_config.get("server") if proxy_config else None
|
||||
proxy_config, proxy_server = _get_proxy_context(account)
|
||||
cookie_path = get_cookie_jar_path(account.username)
|
||||
index_url, target_url, run_script = _build_screenshot_targets(browse_type)
|
||||
|
||||
for attempt in range(1, max_retries + 1):
|
||||
try:
|
||||
@@ -239,8 +528,7 @@ def take_screenshot_for_account(
|
||||
if attempt > 1:
|
||||
log_to_client(f"🔄 第 {attempt} 次截图尝试...", user_id, account_id)
|
||||
|
||||
worker_id = browser_instance.get("worker_id", "?") if isinstance(browser_instance, dict) else "?"
|
||||
use_count = browser_instance.get("use_count", 0) if isinstance(browser_instance, dict) else 0
|
||||
worker_id, use_count = _get_worker_display_info(browser_instance)
|
||||
log_to_client(
|
||||
f"使用Worker-{worker_id}执行截图(已执行{use_count}次)",
|
||||
user_id,
|
||||
@@ -250,99 +538,39 @@ def take_screenshot_for_account(
|
||||
def custom_log(message: str):
|
||||
log_to_client(message, user_id, account_id)
|
||||
|
||||
# 智能登录状态检查:只在必要时才刷新登录
|
||||
should_refresh_login = not is_cookie_jar_fresh(cookie_path)
|
||||
if should_refresh_login and attempt > 1:
|
||||
# 重试时刷新登录(attempt > 1 表示第2次及以后的尝试)
|
||||
log_to_client("正在刷新登录态...", user_id, account_id)
|
||||
if not _ensure_login_cookies(account, proxy_config, custom_log):
|
||||
log_to_client("截图登录失败", user_id, account_id)
|
||||
if attempt < max_retries:
|
||||
log_to_client("将重试...", user_id, account_id)
|
||||
time.sleep(2)
|
||||
continue
|
||||
log_to_client("❌ 截图失败: 登录失败", user_id, account_id)
|
||||
return {"success": False, "error": "登录失败"}
|
||||
elif should_refresh_login:
|
||||
# 首次尝试时快速检查登录状态
|
||||
log_to_client("正在刷新登录态...", user_id, account_id)
|
||||
if not _ensure_login_cookies(account, proxy_config, custom_log):
|
||||
log_to_client("❌ 截图失败: 登录失败", user_id, account_id)
|
||||
return {"success": False, "error": "登录失败"}
|
||||
login_state = _ensure_screenshot_login_state(
|
||||
account=account,
|
||||
proxy_config=proxy_config,
|
||||
cookie_path=cookie_path,
|
||||
attempt=attempt,
|
||||
max_retries=max_retries,
|
||||
user_id=user_id,
|
||||
account_id=account_id,
|
||||
custom_log=custom_log,
|
||||
)
|
||||
if login_state == "retry":
|
||||
continue
|
||||
if login_state == "fail":
|
||||
return {"success": False, "error": "登录失败"}
|
||||
|
||||
log_to_client(f"导航到 '{browse_type}' 页面...", user_id, account_id)
|
||||
|
||||
from urllib.parse import urlsplit
|
||||
|
||||
parsed = urlsplit(config.ZSGL_LOGIN_URL)
|
||||
base = f"{parsed.scheme}://{parsed.netloc}"
|
||||
if "注册前" in str(browse_type):
|
||||
bz = 0
|
||||
else:
|
||||
bz = 0 # 应读(网站更新后 bz=0 为应读)
|
||||
target_url = f"{base}/admin/center.aspx?bz={bz}"
|
||||
index_url = config.ZSGL_INDEX_URL or f"{base}/admin/index.aspx"
|
||||
run_script = (
|
||||
"(function(){"
|
||||
"function done(){window.status='ready';}"
|
||||
"function ensureNav(){try{if(typeof loadMenuTree==='function'){loadMenuTree(true);}}catch(e){}}"
|
||||
"function expandMenu(){"
|
||||
"try{var body=document.body;if(body&&body.classList.contains('lay-mini')){body.classList.remove('lay-mini');}}catch(e){}"
|
||||
"try{if(typeof mainPageResize==='function'){mainPageResize();}}catch(e){}"
|
||||
"try{if(typeof toggleMainMenu==='function' && document.body && document.body.classList.contains('lay-mini')){toggleMainMenu();}}catch(e){}"
|
||||
"try{var navRight=document.querySelector('.nav-right');if(navRight){navRight.style.display='block';}}catch(e){}"
|
||||
"try{var mainNav=document.getElementById('main-nav');if(mainNav){mainNav.style.display='block';}}catch(e){}"
|
||||
"}"
|
||||
"function navReady(){"
|
||||
"try{var nav=document.getElementById('sidebar-nav');return nav && nav.querySelectorAll('a').length>0;}catch(e){return false;}"
|
||||
"}"
|
||||
"function frameReady(){"
|
||||
"try{var f=document.getElementById('mainframe');return f && f.contentDocument && f.contentDocument.readyState==='complete';}catch(e){return false;}"
|
||||
"}"
|
||||
"function check(){"
|
||||
"if(navReady() && frameReady()){done();return;}"
|
||||
"setTimeout(check,300);"
|
||||
"}"
|
||||
"var f=document.getElementById('mainframe');"
|
||||
"ensureNav();"
|
||||
"expandMenu();"
|
||||
"if(!f){done();return;}"
|
||||
f"f.src='{target_url}';"
|
||||
"f.onload=function(){ensureNav();expandMenu();setTimeout(check,300);};"
|
||||
"setTimeout(check,5000);"
|
||||
"})();"
|
||||
)
|
||||
|
||||
timestamp = get_beijing_now().strftime("%Y%m%d_%H%M%S")
|
||||
|
||||
user_info = database.get_user_by_id(user_id)
|
||||
username_prefix = user_info["username"] if user_info else f"user{user_id}"
|
||||
login_account = account.remark if account.remark else account.username
|
||||
screenshot_filename = f"{username_prefix}_{login_account}_{browse_type}_{timestamp}.jpg"
|
||||
screenshot_path = os.path.join(SCREENSHOTS_DIR, screenshot_filename)
|
||||
|
||||
cookies_for_shot = cookie_path if is_cookie_jar_fresh(cookie_path) else None
|
||||
if take_screenshot_wkhtmltoimage(
|
||||
index_url,
|
||||
screenshot_path,
|
||||
cookies_path=cookies_for_shot,
|
||||
screenshot_filename, screenshot_path = _build_screenshot_output_path(username_prefix, account, browse_type)
|
||||
shot_state = _take_screenshot_once(
|
||||
index_url=index_url,
|
||||
target_url=target_url,
|
||||
screenshot_path=screenshot_path,
|
||||
cookie_path=cookie_path,
|
||||
proxy_server=proxy_server,
|
||||
run_script=run_script,
|
||||
window_status="ready",
|
||||
log_callback=custom_log,
|
||||
) or take_screenshot_wkhtmltoimage(
|
||||
target_url,
|
||||
screenshot_path,
|
||||
cookies_path=cookies_for_shot,
|
||||
proxy_server=proxy_server,
|
||||
log_callback=custom_log,
|
||||
):
|
||||
if os.path.exists(screenshot_path) and os.path.getsize(screenshot_path) > 1000:
|
||||
log_to_client(f"[OK] 截图成功: {screenshot_filename}", user_id, account_id)
|
||||
return {"success": True, "filename": screenshot_filename}
|
||||
)
|
||||
if shot_state == "success":
|
||||
log_to_client(f"[OK] 截图成功: {screenshot_filename}", user_id, account_id)
|
||||
return {"success": True, "filename": screenshot_filename}
|
||||
|
||||
if shot_state == "invalid":
|
||||
log_to_client("截图文件异常,将重试", user_id, account_id)
|
||||
if os.path.exists(screenshot_path):
|
||||
os.remove(screenshot_path)
|
||||
else:
|
||||
log_to_client("截图保存失败", user_id, account_id)
|
||||
|
||||
@@ -361,12 +589,7 @@ def take_screenshot_for_account(
|
||||
def screenshot_callback(result, error):
|
||||
"""截图完成回调"""
|
||||
try:
|
||||
account.is_running = False
|
||||
account.status = "未开始"
|
||||
|
||||
safe_remove_task_status(account_id)
|
||||
|
||||
_emit("account_update", account.to_dict(), room=f"user_{user_id}")
|
||||
_finalize_screenshot_callback_state(user_id, account_id, account)
|
||||
|
||||
if error:
|
||||
log_to_client(f"❌ 截图失败: {error}", user_id, account_id)
|
||||
@@ -375,84 +598,27 @@ def take_screenshot_for_account(
|
||||
log_to_client(f"❌ 截图失败: {error_msg}", user_id, account_id)
|
||||
|
||||
if task_start_time and browse_result:
|
||||
import time as time_module
|
||||
|
||||
total_elapsed = int(time_module.time() - task_start_time)
|
||||
database.create_task_log(
|
||||
_persist_browse_log_after_screenshot(
|
||||
user_id=user_id,
|
||||
account_id=account_id,
|
||||
username=account.username,
|
||||
account=account,
|
||||
browse_type=browse_type,
|
||||
status="success",
|
||||
total_items=browse_result.get("total_items", 0),
|
||||
total_attachments=browse_result.get("total_attachments", 0),
|
||||
duration=total_elapsed,
|
||||
source=source,
|
||||
task_start_time=task_start_time,
|
||||
browse_result=browse_result,
|
||||
)
|
||||
|
||||
try:
|
||||
batch_id = _get_batch_id_from_source(source)
|
||||
|
||||
screenshot_path = None
|
||||
if result and result.get("success") and result.get("filename"):
|
||||
screenshot_path = os.path.join(SCREENSHOTS_DIR, result["filename"])
|
||||
|
||||
account_name = account.remark if account.remark else account.username
|
||||
|
||||
try:
|
||||
if screenshot_path and result and result.get("success"):
|
||||
cfg = database.get_system_config() or {}
|
||||
if int(cfg.get("kdocs_enabled", 0) or 0) == 1:
|
||||
doc_url = (cfg.get("kdocs_doc_url") or "").strip()
|
||||
if doc_url:
|
||||
user_cfg = database.get_user_kdocs_settings(user_id) or {}
|
||||
if int(user_cfg.get("kdocs_auto_upload", 0) or 0) == 1:
|
||||
unit = (
|
||||
user_cfg.get("kdocs_unit") or cfg.get("kdocs_default_unit") or ""
|
||||
).strip()
|
||||
name = (account.remark or "").strip()
|
||||
if unit and name:
|
||||
from services.kdocs_uploader import get_kdocs_uploader
|
||||
|
||||
ok = get_kdocs_uploader().enqueue_upload(
|
||||
user_id=user_id,
|
||||
account_id=account_id,
|
||||
unit=unit,
|
||||
name=name,
|
||||
image_path=screenshot_path,
|
||||
)
|
||||
if not ok:
|
||||
log_to_client("表格上传排队失败: 队列已满", user_id, account_id)
|
||||
else:
|
||||
if not unit:
|
||||
log_to_client("表格上传跳过: 未配置县区", user_id, account_id)
|
||||
if not name:
|
||||
log_to_client("表格上传跳过: 账号备注为空", user_id, account_id)
|
||||
except Exception as kdocs_error:
|
||||
logger.warning(f"表格上传任务提交失败: {kdocs_error}")
|
||||
|
||||
if batch_id:
|
||||
_batch_task_record_result(
|
||||
batch_id=batch_id,
|
||||
account_name=account_name,
|
||||
screenshot_path=screenshot_path,
|
||||
total_items=browse_result.get("total_items", 0),
|
||||
total_attachments=browse_result.get("total_attachments", 0),
|
||||
)
|
||||
elif source and source.startswith("user_scheduled"):
|
||||
user_info = database.get_user_by_id(user_id)
|
||||
if user_info and user_info.get("email") and database.get_user_email_notify(user_id):
|
||||
email_service.send_task_complete_email_async(
|
||||
user_id=user_id,
|
||||
email=user_info["email"],
|
||||
username=user_info["username"],
|
||||
account_name=account_name,
|
||||
browse_type=browse_type,
|
||||
total_items=browse_result.get("total_items", 0),
|
||||
total_attachments=browse_result.get("total_attachments", 0),
|
||||
screenshot_path=screenshot_path,
|
||||
log_callback=lambda msg: log_to_client(msg, user_id, account_id),
|
||||
)
|
||||
_dispatch_screenshot_result(
|
||||
user_id=user_id,
|
||||
account_id=account_id,
|
||||
source=source,
|
||||
browse_type=browse_type,
|
||||
browse_result=browse_result,
|
||||
result=result,
|
||||
account=account,
|
||||
user_info=user_info,
|
||||
)
|
||||
except Exception as email_error:
|
||||
logger.warning(f"发送任务完成邮件失败: {email_error}")
|
||||
except Exception as e:
|
||||
|
||||
@@ -13,7 +13,7 @@ from __future__ import annotations
|
||||
import threading
|
||||
import time
|
||||
import random
|
||||
from typing import Any, Dict, List, Optional, Tuple
|
||||
from typing import Any, Callable, Dict, List, Optional, Tuple
|
||||
|
||||
from app_config import get_config
|
||||
|
||||
@@ -161,6 +161,36 @@ _log_cache_lock = threading.RLock()
|
||||
_log_cache_total_count = 0
|
||||
|
||||
|
||||
def _pop_oldest_log_for_user(uid: int) -> bool:
|
||||
global _log_cache_total_count
|
||||
|
||||
logs = _log_cache.get(uid)
|
||||
if not logs:
|
||||
_log_cache.pop(uid, None)
|
||||
return False
|
||||
|
||||
logs.pop(0)
|
||||
_log_cache_total_count = max(0, _log_cache_total_count - 1)
|
||||
if not logs:
|
||||
_log_cache.pop(uid, None)
|
||||
return True
|
||||
|
||||
|
||||
def _pop_oldest_log_from_largest_user() -> bool:
|
||||
largest_uid = None
|
||||
largest_size = 0
|
||||
|
||||
for uid, logs in _log_cache.items():
|
||||
size = len(logs)
|
||||
if size > largest_size:
|
||||
largest_uid = uid
|
||||
largest_size = size
|
||||
|
||||
if largest_uid is None or largest_size <= 0:
|
||||
return False
|
||||
return _pop_oldest_log_for_user(int(largest_uid))
|
||||
|
||||
|
||||
def safe_add_log(
|
||||
user_id: int,
|
||||
log_entry: Dict[str, Any],
|
||||
@@ -175,24 +205,17 @@ def safe_add_log(
|
||||
max_total_logs = int(max_total_logs or config.MAX_TOTAL_LOGS)
|
||||
|
||||
with _log_cache_lock:
|
||||
if uid not in _log_cache:
|
||||
_log_cache[uid] = []
|
||||
logs = _log_cache.setdefault(uid, [])
|
||||
|
||||
if len(_log_cache[uid]) >= max_logs_per_user:
|
||||
_log_cache[uid].pop(0)
|
||||
_log_cache_total_count = max(0, _log_cache_total_count - 1)
|
||||
if len(logs) >= max_logs_per_user:
|
||||
_pop_oldest_log_for_user(uid)
|
||||
logs = _log_cache.setdefault(uid, [])
|
||||
|
||||
_log_cache[uid].append(dict(log_entry or {}))
|
||||
logs.append(dict(log_entry or {}))
|
||||
_log_cache_total_count += 1
|
||||
|
||||
while _log_cache_total_count > max_total_logs:
|
||||
if not _log_cache:
|
||||
break
|
||||
max_user = max(_log_cache.keys(), key=lambda u: len(_log_cache[u]))
|
||||
if _log_cache.get(max_user):
|
||||
_log_cache[max_user].pop(0)
|
||||
_log_cache_total_count -= 1
|
||||
else:
|
||||
if not _pop_oldest_log_from_largest_user():
|
||||
break
|
||||
|
||||
|
||||
@@ -378,6 +401,34 @@ def _get_action_rate_limit(action: str) -> Tuple[int, int]:
|
||||
return int(config.IP_RATE_LIMIT_LOGIN_MAX), int(config.IP_RATE_LIMIT_LOGIN_WINDOW_SECONDS)
|
||||
|
||||
|
||||
def _format_wait_hint(remaining_seconds: int) -> str:
|
||||
remaining = max(1, int(remaining_seconds or 0))
|
||||
if remaining >= 60:
|
||||
return f"{remaining // 60 + 1}分钟"
|
||||
return f"{remaining}秒"
|
||||
|
||||
|
||||
def _check_and_increment_rate_bucket(
|
||||
*,
|
||||
buckets: Dict[str, Dict[str, Any]],
|
||||
key: str,
|
||||
now_ts: float,
|
||||
max_requests: int,
|
||||
window_seconds: int,
|
||||
) -> Tuple[bool, Optional[str]]:
|
||||
if not key or int(max_requests) <= 0:
|
||||
return True, None
|
||||
|
||||
data = _get_or_reset_bucket(buckets.get(key), now_ts, window_seconds)
|
||||
if int(data.get("count", 0) or 0) >= int(max_requests):
|
||||
remaining = max(1, int(window_seconds - (now_ts - float(data.get("window_start", 0) or 0))))
|
||||
return False, f"请求过于频繁,请{_format_wait_hint(remaining)}后再试"
|
||||
|
||||
data["count"] = int(data.get("count", 0) or 0) + 1
|
||||
buckets[key] = data
|
||||
return True, None
|
||||
|
||||
|
||||
def check_ip_request_rate(
|
||||
ip_address: str,
|
||||
action: str,
|
||||
@@ -392,21 +443,13 @@ def check_ip_request_rate(
|
||||
|
||||
key = f"{action}:{ip_address}"
|
||||
with _ip_request_rate_lock:
|
||||
data = _ip_request_rate.get(key)
|
||||
if not data or (now_ts - float(data.get("window_start", 0) or 0)) >= window_seconds:
|
||||
data = {"window_start": now_ts, "count": 0}
|
||||
_ip_request_rate[key] = data
|
||||
|
||||
if int(data.get("count", 0) or 0) >= max_requests:
|
||||
remaining = max(1, int(window_seconds - (now_ts - float(data.get("window_start", 0) or 0))))
|
||||
if remaining >= 60:
|
||||
wait_hint = f"{remaining // 60 + 1}分钟"
|
||||
else:
|
||||
wait_hint = f"{remaining}秒"
|
||||
return False, f"请求过于频繁,请{wait_hint}后再试"
|
||||
|
||||
data["count"] = int(data.get("count", 0) or 0) + 1
|
||||
return True, None
|
||||
return _check_and_increment_rate_bucket(
|
||||
buckets=_ip_request_rate,
|
||||
key=key,
|
||||
now_ts=now_ts,
|
||||
max_requests=max_requests,
|
||||
window_seconds=window_seconds,
|
||||
)
|
||||
|
||||
|
||||
def cleanup_expired_ip_request_rates(now_ts: Optional[float] = None) -> int:
|
||||
@@ -417,8 +460,7 @@ def cleanup_expired_ip_request_rates(now_ts: Optional[float] = None) -> int:
|
||||
data = _ip_request_rate.get(key) or {}
|
||||
action = key.split(":", 1)[0]
|
||||
_, window_seconds = _get_action_rate_limit(action)
|
||||
window_start = float(data.get("window_start", 0) or 0)
|
||||
if now_ts - window_start >= window_seconds:
|
||||
if _is_bucket_expired(data, now_ts, window_seconds):
|
||||
_ip_request_rate.pop(key, None)
|
||||
removed += 1
|
||||
return removed
|
||||
@@ -487,6 +529,30 @@ def _get_or_reset_bucket(data: Optional[Dict[str, Any]], now_ts: float, window_s
|
||||
return data
|
||||
|
||||
|
||||
def _is_bucket_expired(
|
||||
data: Optional[Dict[str, Any]],
|
||||
now_ts: float,
|
||||
window_seconds: int,
|
||||
*,
|
||||
time_field: str = "window_start",
|
||||
) -> bool:
|
||||
start_ts = float((data or {}).get(time_field, 0) or 0)
|
||||
return (now_ts - start_ts) >= max(1, int(window_seconds))
|
||||
|
||||
|
||||
def _cleanup_map_entries(
|
||||
store: Dict[Any, Dict[str, Any]],
|
||||
should_remove: Callable[[Dict[str, Any]], bool],
|
||||
) -> int:
|
||||
removed = 0
|
||||
for key, value in list(store.items()):
|
||||
item = value if isinstance(value, dict) else {}
|
||||
if should_remove(item):
|
||||
store.pop(key, None)
|
||||
removed += 1
|
||||
return removed
|
||||
|
||||
|
||||
def record_login_username_attempt(ip_address: str, username: str) -> bool:
|
||||
now_ts = time.time()
|
||||
threshold, window_seconds, cooldown_seconds = _get_login_scan_config()
|
||||
@@ -527,26 +593,32 @@ def check_login_rate_limits(ip_address: str, username: str) -> Tuple[bool, Optio
|
||||
user_key = _normalize_login_key("user", "", username)
|
||||
ip_user_key = _normalize_login_key("ipuser", ip_address, username)
|
||||
|
||||
def _check(key: str, max_requests: int) -> Tuple[bool, Optional[str]]:
|
||||
if not key or max_requests <= 0:
|
||||
return True, None
|
||||
data = _get_or_reset_bucket(_login_rate_limits.get(key), now_ts, window_seconds)
|
||||
if int(data.get("count", 0) or 0) >= max_requests:
|
||||
remaining = max(1, int(window_seconds - (now_ts - float(data.get("window_start", 0) or 0))))
|
||||
wait_hint = f"{remaining // 60 + 1}分钟" if remaining >= 60 else f"{remaining}秒"
|
||||
return False, f"请求过于频繁,请{wait_hint}后再试"
|
||||
data["count"] = int(data.get("count", 0) or 0) + 1
|
||||
_login_rate_limits[key] = data
|
||||
return True, None
|
||||
|
||||
with _login_rate_limits_lock:
|
||||
allowed, msg = _check(ip_key, ip_max)
|
||||
allowed, msg = _check_and_increment_rate_bucket(
|
||||
buckets=_login_rate_limits,
|
||||
key=ip_key,
|
||||
now_ts=now_ts,
|
||||
max_requests=ip_max,
|
||||
window_seconds=window_seconds,
|
||||
)
|
||||
if not allowed:
|
||||
return False, msg
|
||||
allowed, msg = _check(ip_user_key, ip_user_max)
|
||||
allowed, msg = _check_and_increment_rate_bucket(
|
||||
buckets=_login_rate_limits,
|
||||
key=ip_user_key,
|
||||
now_ts=now_ts,
|
||||
max_requests=ip_user_max,
|
||||
window_seconds=window_seconds,
|
||||
)
|
||||
if not allowed:
|
||||
return False, msg
|
||||
allowed, msg = _check(user_key, user_max)
|
||||
allowed, msg = _check_and_increment_rate_bucket(
|
||||
buckets=_login_rate_limits,
|
||||
key=user_key,
|
||||
now_ts=now_ts,
|
||||
max_requests=user_max,
|
||||
window_seconds=window_seconds,
|
||||
)
|
||||
if not allowed:
|
||||
return False, msg
|
||||
|
||||
@@ -622,15 +694,18 @@ def check_login_captcha_required(ip_address: str, username: Optional[str] = None
|
||||
ip_key = _normalize_login_key("ip", ip_address)
|
||||
ip_user_key = _normalize_login_key("ipuser", ip_address, username or "")
|
||||
|
||||
def _is_over_threshold(data: Optional[Dict[str, Any]]) -> bool:
|
||||
if not data:
|
||||
return False
|
||||
if (now_ts - float(data.get("first_failed", 0) or 0)) > window_seconds:
|
||||
return False
|
||||
return int(data.get("count", 0) or 0) >= max_failures
|
||||
|
||||
with _login_failures_lock:
|
||||
ip_data = _login_failures.get(ip_key)
|
||||
if ip_data and (now_ts - float(ip_data.get("first_failed", 0) or 0)) <= window_seconds:
|
||||
if int(ip_data.get("count", 0) or 0) >= max_failures:
|
||||
return True
|
||||
ip_user_data = _login_failures.get(ip_user_key)
|
||||
if ip_user_data and (now_ts - float(ip_user_data.get("first_failed", 0) or 0)) <= window_seconds:
|
||||
if int(ip_user_data.get("count", 0) or 0) >= max_failures:
|
||||
return True
|
||||
if _is_over_threshold(_login_failures.get(ip_key)):
|
||||
return True
|
||||
if _is_over_threshold(_login_failures.get(ip_user_key)):
|
||||
return True
|
||||
|
||||
if is_login_scan_locked(ip_address):
|
||||
return True
|
||||
@@ -685,6 +760,56 @@ def should_send_login_alert(user_id: int, ip_address: str) -> bool:
|
||||
return False
|
||||
|
||||
|
||||
def cleanup_expired_login_security_state(now_ts: Optional[float] = None) -> Dict[str, int]:
|
||||
now_ts = float(now_ts if now_ts is not None else time.time())
|
||||
_, captcha_window = _get_login_captcha_config()
|
||||
_, _, _, rate_window = _get_login_rate_limit_config()
|
||||
_, lock_window, _ = _get_login_lock_config()
|
||||
_, scan_window, _ = _get_login_scan_config()
|
||||
alert_expire_seconds = max(3600, int(config.LOGIN_ALERT_MIN_INTERVAL_SECONDS) * 3)
|
||||
|
||||
with _login_failures_lock:
|
||||
failures_removed = _cleanup_map_entries(
|
||||
_login_failures,
|
||||
lambda data: (now_ts - float(data.get("first_failed", 0) or 0)) > max(captcha_window, lock_window),
|
||||
)
|
||||
|
||||
with _login_rate_limits_lock:
|
||||
rate_removed = _cleanup_map_entries(
|
||||
_login_rate_limits,
|
||||
lambda data: _is_bucket_expired(data, now_ts, rate_window),
|
||||
)
|
||||
|
||||
with _login_scan_lock:
|
||||
scan_removed = _cleanup_map_entries(
|
||||
_login_scan_state,
|
||||
lambda data: (
|
||||
(now_ts - float(data.get("first_seen", 0) or 0)) > scan_window
|
||||
and now_ts >= float(data.get("scan_until", 0) or 0)
|
||||
),
|
||||
)
|
||||
|
||||
with _login_ip_user_lock:
|
||||
ip_user_locks_removed = _cleanup_map_entries(
|
||||
_login_ip_user_locks,
|
||||
lambda data: now_ts >= float(data.get("lock_until", 0) or 0),
|
||||
)
|
||||
|
||||
with _login_alert_lock:
|
||||
alerts_removed = _cleanup_map_entries(
|
||||
_login_alert_state,
|
||||
lambda data: (now_ts - float(data.get("last_sent", 0) or 0)) > alert_expire_seconds,
|
||||
)
|
||||
|
||||
return {
|
||||
"failures": failures_removed,
|
||||
"rate_limits": rate_removed,
|
||||
"scan_states": scan_removed,
|
||||
"ip_user_locks": ip_user_locks_removed,
|
||||
"alerts": alerts_removed,
|
||||
}
|
||||
|
||||
|
||||
# ==================== 邮箱维度限流 ====================
|
||||
|
||||
_email_rate_limit: Dict[str, Dict[str, Any]] = {}
|
||||
@@ -701,14 +826,13 @@ def check_email_rate_limit(email: str, action: str) -> Tuple[bool, Optional[str]
|
||||
key = f"{action}:{email_key}"
|
||||
|
||||
with _email_rate_limit_lock:
|
||||
data = _get_or_reset_bucket(_email_rate_limit.get(key), now_ts, window_seconds)
|
||||
if int(data.get("count", 0) or 0) >= max_requests:
|
||||
remaining = max(1, int(window_seconds - (now_ts - float(data.get("window_start", 0) or 0))))
|
||||
wait_hint = f"{remaining // 60 + 1}分钟" if remaining >= 60 else f"{remaining}秒"
|
||||
return False, f"请求过于频繁,请{wait_hint}后再试"
|
||||
data["count"] = int(data.get("count", 0) or 0) + 1
|
||||
_email_rate_limit[key] = data
|
||||
return True, None
|
||||
return _check_and_increment_rate_bucket(
|
||||
buckets=_email_rate_limit,
|
||||
key=key,
|
||||
now_ts=now_ts,
|
||||
max_requests=max_requests,
|
||||
window_seconds=window_seconds,
|
||||
)
|
||||
|
||||
|
||||
# ==================== Batch screenshots(批次任务截图收集) ====================
|
||||
|
||||
365
services/task_scheduler.py
Normal file
365
services/task_scheduler.py
Normal file
@@ -0,0 +1,365 @@
|
||||
#!/usr/bin/env python3
|
||||
# -*- coding: utf-8 -*-
|
||||
from __future__ import annotations
|
||||
|
||||
import heapq
|
||||
import threading
|
||||
import time
|
||||
from concurrent.futures import ThreadPoolExecutor, wait
|
||||
from dataclasses import dataclass
|
||||
|
||||
import database
|
||||
from app_logger import get_logger
|
||||
from services.state import safe_get_account, safe_get_task, safe_remove_task, safe_set_task
|
||||
from services.task_batches import _batch_task_record_result, _get_batch_id_from_source
|
||||
|
||||
logger = get_logger("app")
|
||||
|
||||
# VIP优先级队列(仅用于可视化/调试)
|
||||
vip_task_queue = [] # VIP用户任务队列
|
||||
normal_task_queue = [] # 普通用户任务队列
|
||||
task_queue_lock = threading.Lock()
|
||||
|
||||
@dataclass
|
||||
class _TaskRequest:
|
||||
user_id: int
|
||||
account_id: str
|
||||
browse_type: str
|
||||
enable_screenshot: bool
|
||||
source: str
|
||||
retry_count: int
|
||||
submitted_at: float
|
||||
is_vip: bool
|
||||
seq: int
|
||||
canceled: bool = False
|
||||
done_callback: object = None
|
||||
|
||||
|
||||
class TaskScheduler:
|
||||
"""全局任务调度器:队列排队,不为每个任务单独创建线程。"""
|
||||
|
||||
def __init__(self, max_global: int, max_per_user: int, max_queue_size: int = 1000, run_task_fn=None):
|
||||
self.max_global = max(1, int(max_global))
|
||||
self.max_per_user = max(1, int(max_per_user))
|
||||
self.max_queue_size = max(1, int(max_queue_size))
|
||||
|
||||
self._cond = threading.Condition()
|
||||
self._pending = [] # heap: (priority, submitted_at, seq, task)
|
||||
self._pending_by_account = {} # {account_id: task}
|
||||
self._seq = 0
|
||||
self._known_account_ids = set()
|
||||
|
||||
self._running_global = 0
|
||||
self._running_by_user = {} # {user_id: running_count}
|
||||
|
||||
self._executor_max_workers = self.max_global
|
||||
self._executor = ThreadPoolExecutor(max_workers=self._executor_max_workers, thread_name_prefix="TaskWorker")
|
||||
|
||||
self._futures_lock = threading.Lock()
|
||||
self._active_futures = set()
|
||||
|
||||
self._running = True
|
||||
self._run_task_fn = run_task_fn
|
||||
self._dispatcher_thread = threading.Thread(target=self._dispatch_loop, daemon=True, name="TaskDispatcher")
|
||||
self._dispatcher_thread.start()
|
||||
|
||||
def _track_future(self, future) -> None:
|
||||
with self._futures_lock:
|
||||
self._active_futures.add(future)
|
||||
try:
|
||||
future.add_done_callback(self._untrack_future)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
def _untrack_future(self, future) -> None:
|
||||
with self._futures_lock:
|
||||
self._active_futures.discard(future)
|
||||
|
||||
def shutdown(self, timeout: float = 5.0):
|
||||
"""停止调度器(用于进程退出清理)"""
|
||||
with self._cond:
|
||||
self._running = False
|
||||
self._cond.notify_all()
|
||||
|
||||
try:
|
||||
self._dispatcher_thread.join(timeout=timeout)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
# 等待已提交的任务收尾(最多等待 timeout 秒),避免遗留 active_task 干扰后续调度/测试
|
||||
try:
|
||||
deadline = time.time() + max(0.0, float(timeout or 0))
|
||||
while True:
|
||||
with self._futures_lock:
|
||||
pending = [f for f in self._active_futures if not f.done()]
|
||||
if not pending:
|
||||
break
|
||||
remaining = deadline - time.time()
|
||||
if remaining <= 0:
|
||||
break
|
||||
wait(pending, timeout=remaining)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
try:
|
||||
self._executor.shutdown(wait=False)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
# 最后兜底:清理本调度器提交过的 active_task,避免测试/重启时被“任务已在运行中”误拦截
|
||||
try:
|
||||
with self._cond:
|
||||
known_ids = set(self._known_account_ids) | set(self._pending_by_account.keys())
|
||||
self._pending.clear()
|
||||
self._pending_by_account.clear()
|
||||
self._cond.notify_all()
|
||||
for account_id in known_ids:
|
||||
safe_remove_task(account_id)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
def update_limits(self, max_global: int = None, max_per_user: int = None, max_queue_size: int = None):
|
||||
"""动态更新并发/队列上限(不影响已在运行的任务)"""
|
||||
with self._cond:
|
||||
if max_per_user is not None:
|
||||
self.max_per_user = max(1, int(max_per_user))
|
||||
if max_queue_size is not None:
|
||||
self.max_queue_size = max(1, int(max_queue_size))
|
||||
|
||||
if max_global is not None:
|
||||
new_max_global = max(1, int(max_global))
|
||||
self.max_global = new_max_global
|
||||
if new_max_global > self._executor_max_workers:
|
||||
# 立即关闭旧线程池,防止资源泄漏
|
||||
old_executor = self._executor
|
||||
self._executor_max_workers = new_max_global
|
||||
self._executor = ThreadPoolExecutor(
|
||||
max_workers=self._executor_max_workers, thread_name_prefix="TaskWorker"
|
||||
)
|
||||
# 立即关闭旧线程池
|
||||
try:
|
||||
old_executor.shutdown(wait=False)
|
||||
logger.info(f"线程池已扩容:{old_executor._max_workers} -> {self._executor_max_workers}")
|
||||
except Exception as e:
|
||||
logger.warning(f"关闭旧线程池失败: {e}")
|
||||
|
||||
self._cond.notify_all()
|
||||
|
||||
def get_queue_state_snapshot(self) -> dict:
|
||||
"""获取调度器队列/运行状态快照(用于前端展示/监控)。"""
|
||||
with self._cond:
|
||||
pending_tasks = [t for t in self._pending_by_account.values() if t and not t.canceled]
|
||||
pending_tasks.sort(key=lambda t: (0 if t.is_vip else 1, t.submitted_at, t.seq))
|
||||
|
||||
positions = {}
|
||||
for idx, t in enumerate(pending_tasks):
|
||||
positions[t.account_id] = {"queue_position": idx + 1, "queue_ahead": idx, "is_vip": bool(t.is_vip)}
|
||||
|
||||
return {
|
||||
"pending_total": len(pending_tasks),
|
||||
"running_total": int(self._running_global),
|
||||
"running_by_user": dict(self._running_by_user),
|
||||
"positions": positions,
|
||||
}
|
||||
|
||||
def submit_task(
|
||||
self,
|
||||
user_id: int,
|
||||
account_id: str,
|
||||
browse_type: str,
|
||||
enable_screenshot: bool = True,
|
||||
source: str = "manual",
|
||||
retry_count: int = 0,
|
||||
is_vip: bool = None,
|
||||
done_callback=None,
|
||||
):
|
||||
"""提交任务进入队列(返回: (ok, message))"""
|
||||
if not user_id or not account_id:
|
||||
return False, "参数错误"
|
||||
|
||||
submitted_at = time.time()
|
||||
if is_vip is None:
|
||||
try:
|
||||
is_vip = bool(database.is_user_vip(user_id))
|
||||
except Exception:
|
||||
is_vip = False
|
||||
else:
|
||||
is_vip = bool(is_vip)
|
||||
|
||||
with self._cond:
|
||||
if not self._running:
|
||||
return False, "调度器未运行"
|
||||
if len(self._pending_by_account) >= self.max_queue_size:
|
||||
return False, "任务队列已满,请稍后再试"
|
||||
if account_id in self._pending_by_account:
|
||||
return False, "任务已在队列中"
|
||||
if safe_get_task(account_id) is not None:
|
||||
return False, "任务已在运行中"
|
||||
|
||||
self._seq += 1
|
||||
task = _TaskRequest(
|
||||
user_id=user_id,
|
||||
account_id=account_id,
|
||||
browse_type=browse_type,
|
||||
enable_screenshot=bool(enable_screenshot),
|
||||
source=source,
|
||||
retry_count=int(retry_count or 0),
|
||||
submitted_at=submitted_at,
|
||||
is_vip=is_vip,
|
||||
seq=self._seq,
|
||||
done_callback=done_callback,
|
||||
)
|
||||
self._pending_by_account[account_id] = task
|
||||
self._known_account_ids.add(account_id)
|
||||
priority = 0 if is_vip else 1
|
||||
heapq.heappush(self._pending, (priority, task.submitted_at, task.seq, task))
|
||||
self._cond.notify_all()
|
||||
|
||||
# 用于可视化/调试:记录队列
|
||||
with task_queue_lock:
|
||||
if is_vip:
|
||||
vip_task_queue.append(account_id)
|
||||
else:
|
||||
normal_task_queue.append(account_id)
|
||||
|
||||
return True, "已加入队列"
|
||||
|
||||
def cancel_pending_task(self, user_id: int, account_id: str) -> bool:
|
||||
"""取消尚未开始的排队任务(已运行的任务由 should_stop 控制)"""
|
||||
canceled_task = None
|
||||
with self._cond:
|
||||
task = self._pending_by_account.pop(account_id, None)
|
||||
if not task:
|
||||
return False
|
||||
task.canceled = True
|
||||
canceled_task = task
|
||||
self._cond.notify_all()
|
||||
|
||||
# 从可视化队列移除
|
||||
with task_queue_lock:
|
||||
if account_id in vip_task_queue:
|
||||
vip_task_queue.remove(account_id)
|
||||
if account_id in normal_task_queue:
|
||||
normal_task_queue.remove(account_id)
|
||||
|
||||
# 批次任务:取消也要推进完成计数,避免批次缓存常驻
|
||||
try:
|
||||
batch_id = _get_batch_id_from_source(canceled_task.source)
|
||||
if batch_id:
|
||||
acc = safe_get_account(user_id, account_id)
|
||||
if acc:
|
||||
account_name = acc.remark if acc.remark else acc.username
|
||||
else:
|
||||
account_name = account_id
|
||||
_batch_task_record_result(
|
||||
batch_id=batch_id,
|
||||
account_name=account_name,
|
||||
screenshot_path=None,
|
||||
total_items=0,
|
||||
total_attachments=0,
|
||||
)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
return True
|
||||
|
||||
def _dispatch_loop(self):
|
||||
while True:
|
||||
task = None
|
||||
with self._cond:
|
||||
if not self._running:
|
||||
return
|
||||
|
||||
if not self._pending or self._running_global >= self.max_global:
|
||||
self._cond.wait(timeout=0.5)
|
||||
continue
|
||||
|
||||
task = self._pop_next_runnable_locked()
|
||||
if task is None:
|
||||
self._cond.wait(timeout=0.5)
|
||||
continue
|
||||
|
||||
self._running_global += 1
|
||||
self._running_by_user[task.user_id] = self._running_by_user.get(task.user_id, 0) + 1
|
||||
|
||||
# 从队列移除(可视化)
|
||||
with task_queue_lock:
|
||||
if task.account_id in vip_task_queue:
|
||||
vip_task_queue.remove(task.account_id)
|
||||
if task.account_id in normal_task_queue:
|
||||
normal_task_queue.remove(task.account_id)
|
||||
|
||||
try:
|
||||
future = self._executor.submit(self._run_task_wrapper, task)
|
||||
self._track_future(future)
|
||||
safe_set_task(task.account_id, future)
|
||||
except Exception:
|
||||
with self._cond:
|
||||
self._running_global = max(0, self._running_global - 1)
|
||||
# 使用默认值 0 与增加时保持一致
|
||||
self._running_by_user[task.user_id] = max(0, self._running_by_user.get(task.user_id, 0) - 1)
|
||||
if self._running_by_user.get(task.user_id) == 0:
|
||||
self._running_by_user.pop(task.user_id, None)
|
||||
self._cond.notify_all()
|
||||
|
||||
def _pop_next_runnable_locked(self):
|
||||
"""在锁内从优先队列取出“可运行”的任务,避免VIP任务占位阻塞普通任务。"""
|
||||
if not self._pending:
|
||||
return None
|
||||
|
||||
skipped = []
|
||||
selected = None
|
||||
|
||||
while self._pending:
|
||||
_, _, _, task = heapq.heappop(self._pending)
|
||||
|
||||
if task.canceled:
|
||||
continue
|
||||
if self._pending_by_account.get(task.account_id) is not task:
|
||||
continue
|
||||
|
||||
running_for_user = self._running_by_user.get(task.user_id, 0)
|
||||
if running_for_user >= self.max_per_user:
|
||||
skipped.append(task)
|
||||
continue
|
||||
|
||||
selected = task
|
||||
break
|
||||
|
||||
for t in skipped:
|
||||
priority = 0 if t.is_vip else 1
|
||||
heapq.heappush(self._pending, (priority, t.submitted_at, t.seq, t))
|
||||
|
||||
if selected is None:
|
||||
return None
|
||||
|
||||
self._pending_by_account.pop(selected.account_id, None)
|
||||
return selected
|
||||
|
||||
def _run_task_wrapper(self, task: _TaskRequest):
|
||||
try:
|
||||
if callable(self._run_task_fn):
|
||||
self._run_task_fn(
|
||||
user_id=task.user_id,
|
||||
account_id=task.account_id,
|
||||
browse_type=task.browse_type,
|
||||
enable_screenshot=task.enable_screenshot,
|
||||
source=task.source,
|
||||
retry_count=task.retry_count,
|
||||
)
|
||||
finally:
|
||||
try:
|
||||
if callable(task.done_callback):
|
||||
task.done_callback()
|
||||
except Exception:
|
||||
pass
|
||||
safe_remove_task(task.account_id)
|
||||
with self._cond:
|
||||
self._running_global = max(0, self._running_global - 1)
|
||||
# 使用默认值 0 与增加时保持一致
|
||||
self._running_by_user[task.user_id] = max(0, self._running_by_user.get(task.user_id, 0) - 1)
|
||||
if self._running_by_user.get(task.user_id) == 0:
|
||||
self._running_by_user.pop(task.user_id, None)
|
||||
self._cond.notify_all()
|
||||
|
||||
|
||||
1346
services/tasks.py
1346
services/tasks.py
File diff suppressed because it is too large
Load Diff
Reference in New Issue
Block a user