{
}
})
+const kdocsAutoUpload = ref(false)
+const kdocsSettingsLoading = ref(false)
+
const addOpen = ref(false)
const editOpen = ref(false)
const upgradeOpen = ref(false)
@@ -189,6 +193,30 @@ async function refreshAccounts() {
}
}
+async function loadKdocsSettings() {
+ kdocsSettingsLoading.value = true
+ try {
+ const data = await fetchKdocsSettings()
+ kdocsAutoUpload.value = Number(data?.kdocs_auto_upload || 0) === 1
+ } catch {
+ kdocsAutoUpload.value = false
+ } finally {
+ kdocsSettingsLoading.value = false
+ }
+}
+
+async function onToggleKdocsAutoUpload(value) {
+ kdocsSettingsLoading.value = true
+ try {
+ await updateKdocsSettings({ kdocs_auto_upload: value ? 1 : 0 })
+ ElMessage.success(value ? '已开启自动上传(测试)' : '已关闭自动上传')
+ } catch (e) {
+ kdocsAutoUpload.value = !value
+ } finally {
+ kdocsSettingsLoading.value = false
+ }
+}
+
async function onStart(acc) {
try {
await startAccount(acc.id, { browse_type: browseTypeById[acc.id] || '应读', enable_screenshot: batchEnableScreenshot.value })
@@ -524,6 +552,7 @@ onMounted(async () => {
unbindSocket = bindSocket()
await refreshAccounts()
+ await loadKdocsSettings()
await refreshStats()
syncStatsPolling()
})
@@ -612,6 +641,15 @@ onBeforeUnmount(() => {
+
+ 表格(测试)
diff --git a/app-frontend/src/pages/LoginPage.vue b/app-frontend/src/pages/LoginPage.vue
index 6ba625f..4c0ea0c 100644
--- a/app-frontend/src/pages/LoginPage.vue
+++ b/app-frontend/src/pages/LoginPage.vue
@@ -8,10 +8,8 @@ import {
forgotPassword,
generateCaptcha,
login,
- requestPasswordReset,
resendVerifyEmail,
} from '../api/auth'
-import { validateStrongPassword } from '../utils/password'
const router = useRouter()
@@ -32,20 +30,14 @@ const registerVerifyEnabled = ref(false)
const forgotOpen = ref(false)
const resendOpen = ref(false)
-const emailResetForm = reactive({
- email: '',
+const forgotForm = reactive({
+ username: '',
captcha: '',
})
-const emailResetCaptchaImage = ref('')
-const emailResetCaptchaSession = ref('')
-const emailResetLoading = ref(false)
-
-const manualResetForm = reactive({
- username: '',
- email: '',
- new_password: '',
-})
-const manualResetLoading = ref(false)
+const forgotCaptchaImage = ref('')
+const forgotCaptchaSession = ref('')
+const forgotLoading = ref(false)
+const forgotHint = ref('')
const resendForm = reactive({
email: '',
@@ -72,12 +64,12 @@ async function refreshLoginCaptcha() {
async function refreshEmailResetCaptcha() {
try {
const data = await generateCaptcha()
- emailResetCaptchaSession.value = data?.session_id || ''
- emailResetCaptchaImage.value = data?.captcha_image || ''
- emailResetForm.captcha = ''
+ forgotCaptchaSession.value = data?.session_id || ''
+ forgotCaptchaImage.value = data?.captcha_image || ''
+ forgotForm.captcha = ''
} catch {
- emailResetCaptchaSession.value = ''
- emailResetCaptchaImage.value = ''
+ forgotCaptchaSession.value = ''
+ forgotCaptchaImage.value = ''
}
}
@@ -113,8 +105,14 @@ async function onSubmit() {
need_captcha: needCaptcha.value,
})
ElMessage.success('登录成功,正在跳转...')
+ const urlParams = new URLSearchParams(window.location.search || '')
+ const next = String(urlParams.get('next') || '').trim()
+ const safeNext = next && next.startsWith('/') && !next.startsWith('//') && !next.startsWith('/\\') ? next : ''
setTimeout(() => {
- window.location.href = '/app'
+ const target = safeNext || '/app'
+ router.push(target).catch(() => {
+ window.location.href = target
+ })
}, 300)
} catch (e) {
const status = e?.response?.status
@@ -136,80 +134,54 @@ async function onSubmit() {
async function openForgot() {
forgotOpen.value = true
-
+ forgotHint.value = ''
+ forgotForm.username = ''
+ forgotForm.captcha = ''
if (emailEnabled.value) {
- emailResetForm.email = ''
- emailResetForm.captcha = ''
await refreshEmailResetCaptcha()
- } else {
- manualResetForm.username = ''
- manualResetForm.email = ''
- manualResetForm.new_password = ''
}
}
async function submitForgot() {
- if (emailEnabled.value) {
- const email = emailResetForm.email.trim()
- if (!email) {
- ElMessage.error('请输入邮箱')
- return
- }
- if (!emailResetForm.captcha.trim()) {
- ElMessage.error('请输入验证码')
- return
- }
+ forgotHint.value = ''
- emailResetLoading.value = true
- try {
- const res = await forgotPassword({
- email,
- captcha_session: emailResetCaptchaSession.value,
- captcha: emailResetForm.captcha.trim(),
- })
- ElMessage.success(res?.message || '已发送重置邮件')
- setTimeout(() => {
- forgotOpen.value = false
- }, 800)
- } catch (e) {
- const data = e?.response?.data
- ElMessage.error(data?.error || '发送失败')
- await refreshEmailResetCaptcha()
- } finally {
- emailResetLoading.value = false
- }
+ if (!emailEnabled.value) {
+ ElMessage.warning('邮件功能未启用,请联系管理员重置密码。')
return
}
- const username = manualResetForm.username.trim()
- const newPassword = manualResetForm.new_password
- if (!username || !newPassword) {
- ElMessage.error('用户名和新密码不能为空')
+ const username = forgotForm.username.trim()
+ if (!username) {
+ ElMessage.error('请输入用户名')
+ return
+ }
+ if (!forgotForm.captcha.trim()) {
+ ElMessage.error('请输入验证码')
return
}
- const check = validateStrongPassword(newPassword)
- if (!check.ok) {
- ElMessage.error(check.message)
- return
- }
-
- manualResetLoading.value = true
+ forgotLoading.value = true
try {
- await requestPasswordReset({
+ const res = await forgotPassword({
username,
- email: manualResetForm.email.trim(),
- new_password: newPassword,
+ captcha_session: forgotCaptchaSession.value,
+ captcha: forgotForm.captcha.trim(),
})
- ElMessage.success('申请已提交,请等待审核')
+ ElMessage.success(res?.message || '已发送重置邮件')
setTimeout(() => {
forgotOpen.value = false
}, 800)
} catch (e) {
const data = e?.response?.data
- ElMessage.error(data?.error || '提交失败')
+ const message = data?.error || '发送失败'
+ if (data?.code === 'email_not_bound') {
+ forgotHint.value = message
+ } else {
+ ElMessage.error(message)
+ }
+ await refreshEmailResetCaptcha()
} finally {
- manualResetLoading.value = false
+ forgotLoading.value = false
}
}
@@ -320,51 +292,55 @@ onMounted(async () => {
-
-
-
-
-
-
-
-
-
-
![点击刷新 验证码]()
-
刷新
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
+
+
+
+
+
+
+
+
+
+
+
![点击刷新 验证码]()
+
刷新
+
+
+
取消
-
- {{ emailEnabled ? '发送重置邮件' : '提交申请' }}
+
+ 发送重置邮件
diff --git a/app-frontend/src/pages/RegisterPage.vue b/app-frontend/src/pages/RegisterPage.vue
index 083910c..6671a8a 100644
--- a/app-frontend/src/pages/RegisterPage.vue
+++ b/app-frontend/src/pages/RegisterPage.vue
@@ -4,6 +4,7 @@ import { useRouter } from 'vue-router'
import { ElMessage } from 'element-plus'
import { fetchEmailVerifyStatus, generateCaptcha, register } from '../api/auth'
+import { validateStrongPassword } from '../utils/password'
const router = useRouter()
@@ -68,8 +69,9 @@ async function onSubmit() {
ElMessage.error(errorText.value)
return
}
- if (password.length < 6) {
- errorText.value = '密码至少6个字符'
+ const passwordCheck = validateStrongPassword(password)
+ if (!passwordCheck.ok) {
+ errorText.value = passwordCheck.message || '密码格式不正确'
ElMessage.error(errorText.value)
return
}
@@ -166,10 +168,10 @@ onMounted(async () => {
v-model="form.password"
type="password"
show-password
- placeholder="至少6个字符"
+ placeholder="至少8位且包含字母和数字"
autocomplete="new-password"
/>
-
至少6个字符
+
至少8位且包含字母和数字
None:
- """按需应用 nest_asyncio,避免 import 时产生全局副作用。"""
- global _NEST_ASYNCIO_APPLIED
-
- if _NEST_ASYNCIO_APPLIED:
- return
- with _NEST_ASYNCIO_LOCK:
- if _NEST_ASYNCIO_APPLIED:
- return
- try:
- nest_asyncio.apply()
- except Exception:
- pass
- _NEST_ASYNCIO_APPLIED = True
# 安全修复: 将魔法数字提取为可配置常量
BROWSER_IDLE_TIMEOUT = int(os.environ.get('BROWSER_IDLE_TIMEOUT', '300')) # 空闲超时(秒),默认5分钟
TASK_QUEUE_TIMEOUT = int(os.environ.get('TASK_QUEUE_TIMEOUT', '10')) # 队列获取超时(秒)
TASK_QUEUE_MAXSIZE = int(os.environ.get('BROWSER_TASK_QUEUE_MAXSIZE', '200')) # 队列最大长度(0表示无限制)
-BROWSER_MAX_USE_COUNT = int(os.environ.get('BROWSER_MAX_USE_COUNT', '0')) # 每个浏览器最大复用次数(0表示不限制)
+BROWSER_MAX_USE_COUNT = int(os.environ.get('BROWSER_MAX_USE_COUNT', '0')) # 每个执行环境最大复用次数(0表示不限制)
class BrowserWorker(threading.Thread):
- """浏览器工作线程 - 每个worker维护自己的浏览器"""
+ """截图工作线程 - 每个worker维护自己的执行环境"""
def __init__(
self,
@@ -55,99 +35,61 @@ class BrowserWorker(threading.Thread):
self.total_tasks = 0
self.failed_tasks = 0
self.pre_warm = pre_warm
+ self.last_activity_ts = 0.0
def log(self, message: str):
"""日志输出"""
if self.log_callback:
self.log_callback(f"[Worker-{self.worker_id}] {message}")
else:
- print(f"[浏览器池][Worker-{self.worker_id}] {message}")
+ print(f"[截图池][Worker-{self.worker_id}] {message}")
- def _create_browser(self):
- """创建浏览器实例"""
- try:
- from playwright.sync_api import sync_playwright
-
- self.log("正在创建浏览器...")
- playwright = sync_playwright().start()
- browser = playwright.chromium.launch(
- headless=True,
- args=[
- '--no-sandbox',
- '--disable-setuid-sandbox',
- '--disable-dev-shm-usage',
- '--disable-gpu',
- ]
- )
-
- self.browser_instance = {
- 'playwright': playwright,
- 'browser': browser,
- 'created_at': time.time(),
- 'use_count': 0,
- 'worker_id': self.worker_id
- }
- self.log(f"浏览器创建成功")
- return True
-
- except Exception as e:
- self.log(f"创建浏览器失败: {e}")
- return False
-
- def _close_browser(self):
- """关闭浏览器"""
- if self.browser_instance:
- try:
- self.log("正在关闭浏览器...")
- if self.browser_instance['browser']:
- self.browser_instance['browser'].close()
- if self.browser_instance['playwright']:
- self.browser_instance['playwright'].stop()
- self.log(f"浏览器已关闭(共处理{self.browser_instance['use_count']}个任务)")
- except Exception as e:
- self.log(f"关闭浏览器时出错: {e}")
- finally:
- self.browser_instance = None
-
- def _check_browser_health(self) -> bool:
- """检查浏览器是否健康"""
- if not self.browser_instance:
- return False
-
- try:
- return self.browser_instance['browser'].is_connected()
- except:
- return False
-
- def _ensure_browser(self) -> bool:
- """确保浏览器可用(如果不可用则重新创建)"""
- if self._check_browser_health():
- return True
-
- # 浏览器不可用,尝试重新创建
- self.log("浏览器不可用,尝试重新创建...")
- self._close_browser()
- return self._create_browser()
+ def _create_browser(self):
+ """创建截图执行环境(逻辑占位,无需真实浏览器)"""
+ created_at = time.time()
+ self.browser_instance = {
+ 'created_at': created_at,
+ 'use_count': 0,
+ 'worker_id': self.worker_id,
+ }
+ self.last_activity_ts = created_at
+ self.log("截图执行环境就绪")
+ return True
+
+ def _close_browser(self):
+ """关闭截图执行环境"""
+ if self.browser_instance:
+ self.log(f"执行环境已释放(共处理{self.browser_instance.get('use_count', 0)}个任务)")
+ self.browser_instance = None
+
+ def _check_browser_health(self) -> bool:
+ """检查执行环境是否就绪"""
+ return bool(self.browser_instance)
+
+ def _ensure_browser(self) -> bool:
+ """确保执行环境可用"""
+ if self._check_browser_health():
+ return True
+ self.log("执行环境不可用,尝试重新创建...")
+ self._close_browser()
+ return self._create_browser()
def run(self):
- """工作线程主循环 - 按需启动浏览器模式"""
+ """工作线程主循环 - 按需启动执行环境模式"""
if self.pre_warm:
- self.log("Worker启动(预热模式,启动即创建浏览器)")
+ self.log("Worker启动(预热模式,启动即准备执行环境)")
else:
- self.log("Worker启动(按需模式,等待任务时不占用浏览器资源)")
+ self.log("Worker启动(按需模式,等待任务时不占用资源)")
- last_activity_time = 0
if self.pre_warm and not self.browser_instance:
- if self._create_browser():
- last_activity_time = time.time()
+ self._create_browser()
self.pre_warm = False
while self.running:
try:
# 允许运行中触发预热(例如池在初始化后调用 warmup)
if self.pre_warm and not self.browser_instance:
- if self._create_browser():
- last_activity_time = time.time()
+ self._create_browser()
self.pre_warm = False
# 从队列获取任务(带超时,以便能响应停止信号和空闲检查)
@@ -155,60 +97,87 @@ class BrowserWorker(threading.Thread):
try:
task = self.task_queue.get(timeout=TASK_QUEUE_TIMEOUT)
except queue.Empty:
- # 检查是否需要关闭空闲的浏览器
- if self.browser_instance and last_activity_time > 0:
- idle_time = time.time() - last_activity_time
+ # 检查是否需要释放空闲的执行环境
+ if self.browser_instance and self.last_activity_ts > 0:
+ idle_time = time.time() - self.last_activity_ts
if idle_time > BROWSER_IDLE_TIMEOUT:
- self.log(f"空闲{int(idle_time)}秒,关闭浏览器释放资源")
+ self.log(f"空闲{int(idle_time)}秒,释放执行环境")
self._close_browser()
continue
self.idle = False
- if task is None: # None作为停止信号
- self.log("收到停止信号")
- break
-
- # 按需创建或确保浏览器可用
- if not self._ensure_browser():
- self.log("浏览器不可用,任务失败")
- task['callback'](None, "浏览器不可用")
- self.failed_tasks += 1
- continue
-
- # 执行任务
- task_func = task.get('func')
- task_args = task.get('args', ())
- task_kwargs = task.get('kwargs', {})
+ if task is None: # None作为停止信号
+ self.log("收到停止信号")
+ break
+
+ # 按需创建或确保执行环境可用
+ browser_ready = False
+ for attempt in range(2):
+ if self._ensure_browser():
+ browser_ready = True
+ break
+ if attempt < 1:
+ self.log("执行环境创建失败,重试...")
+ time.sleep(0.5)
+
+ if not browser_ready:
+ retry_count = int(task.get("retry_count", 0) or 0) if isinstance(task, dict) else 0
+ if retry_count < 1 and isinstance(task, dict):
+ task["retry_count"] = retry_count + 1
+ try:
+ self.task_queue.put(task, timeout=1)
+ self.log("执行环境不可用,任务重新入队")
+ except queue.Full:
+ self.log("任务队列已满,无法重新入队,任务失败")
+ callback = task.get("callback")
+ if callable(callback):
+ callback(None, "执行环境不可用")
+ self.total_tasks += 1
+ self.failed_tasks += 1
+ continue
+
+ self.log("执行环境不可用,任务失败")
+ callback = task.get("callback") if isinstance(task, dict) else None
+ if callable(callback):
+ callback(None, "执行环境不可用")
+ self.total_tasks += 1
+ self.failed_tasks += 1
+ continue
+
+ # 执行任务
+ task_func = task.get('func')
+ task_args = task.get('args', ())
+ task_kwargs = task.get('kwargs', {})
callback = task.get('callback')
self.total_tasks += 1
self.browser_instance['use_count'] += 1
- self.log(f"开始执行任务(第{self.browser_instance['use_count']}次使用浏览器)")
+ self.log(f"开始执行任务(第{self.browser_instance['use_count']}次执行)")
try:
- # 将浏览器实例传递给任务函数
+ # 将执行环境实例传递给任务函数
result = task_func(self.browser_instance, *task_args, **task_kwargs)
callback(result, None)
self.log(f"任务执行成功")
- last_activity_time = time.time()
+ self.last_activity_ts = time.time()
except Exception as e:
self.log(f"任务执行失败: {e}")
callback(None, str(e))
self.failed_tasks += 1
- last_activity_time = time.time()
+ self.last_activity_ts = time.time()
- # 任务失败后,检查浏览器健康
+ # 任务失败后,检查执行环境健康
if not self._check_browser_health():
- self.log("任务失败导致浏览器异常,将在下次任务前重建")
+ self.log("任务失败导致执行环境异常,将在下次任务前重建")
self._close_browser()
- # 定期重启浏览器,释放Chromium可能累积的内存
+ # 定期重启执行环境,释放可能累积的资源
if self.browser_instance and BROWSER_MAX_USE_COUNT > 0:
if self.browser_instance.get('use_count', 0) >= BROWSER_MAX_USE_COUNT:
- self.log(f"浏览器已复用{self.browser_instance['use_count']}次,重启释放资源")
+ self.log(f"执行环境已复用{self.browser_instance['use_count']}次,重启释放资源")
self._close_browser()
except Exception as e:
@@ -225,7 +194,7 @@ class BrowserWorker(threading.Thread):
class BrowserWorkerPool:
- """浏览器工作线程池"""
+ """截图工作线程池"""
def __init__(self, pool_size: int = 3, log_callback: Optional[Callable] = None):
self.pool_size = pool_size
@@ -238,20 +207,18 @@ class BrowserWorkerPool:
def log(self, message: str):
"""日志输出"""
- if self.log_callback:
- self.log_callback(message)
- else:
- print(f"[浏览器池] {message}")
+ if self.log_callback:
+ self.log_callback(message)
+ else:
+ print(f"[截图池] {message}")
def initialize(self):
- """初始化工作线程池(按需模式,默认预热1个浏览器)"""
+ """初始化工作线程池(按需模式,默认预热1个执行环境)"""
with self.lock:
if self.initialized:
return
- _apply_nest_asyncio_once()
-
- self.log(f"正在初始化工作线程池({self.pool_size}个worker,按需启动浏览器)...")
+ self.log(f"正在初始化截图线程池({self.pool_size}个worker,按需启动执行环境)...")
for i in range(self.pool_size):
worker = BrowserWorker(
@@ -264,13 +231,13 @@ class BrowserWorkerPool:
self.workers.append(worker)
self.initialized = True
- self.log(f"✓ 工作线程池初始化完成({self.pool_size}个worker就绪,浏览器将在有任务时按需启动)")
+ self.log(f"✓ 截图线程池初始化完成({self.pool_size}个worker就绪,执行环境将在有任务时按需启动)")
- # 初始化完成后,默认预热1个浏览器,降低容器重启后前几批任务的冷启动开销
+ # 初始化完成后,默认预热1个执行环境,降低容器重启后前几批任务的冷启动开销
self.warmup(1)
def warmup(self, count: int = 1) -> int:
- """预热浏览器池 - 预创建指定数量的浏览器"""
+ """预热截图线程池 - 预创建指定数量的执行环境"""
if count <= 0:
return 0
@@ -281,7 +248,7 @@ class BrowserWorkerPool:
with self.lock:
target_workers = list(self.workers[: min(count, len(self.workers))])
- self.log(f"预热浏览器池(预创建{len(target_workers)}个浏览器)...")
+ self.log(f"预热截图线程池(预创建{len(target_workers)}个执行环境)...")
for worker in target_workers:
if not worker.browser_instance:
@@ -296,7 +263,7 @@ class BrowserWorkerPool:
time.sleep(0.1)
warmed = sum(1 for w in target_workers if w.browser_instance)
- self.log(f"✓ 浏览器池预热完成({warmed}个浏览器就绪)")
+ self.log(f"✓ 截图线程池预热完成({warmed}个执行环境就绪)")
return warmed
def submit_task(self, task_func: Callable, callback: Callable, *args, **kwargs) -> bool:
@@ -315,12 +282,13 @@ class BrowserWorkerPool:
self.log("警告:线程池未初始化")
return False
- task = {
- 'func': task_func,
- 'args': args,
- 'kwargs': kwargs,
- 'callback': callback
- }
+ task = {
+ 'func': task_func,
+ 'args': args,
+ 'kwargs': kwargs,
+ 'callback': callback,
+ 'retry_count': 0,
+ }
try:
self.task_queue.put(task, timeout=1)
@@ -329,21 +297,47 @@ class BrowserWorkerPool:
self.log(f"警告:任务队列已满(maxsize={self.task_queue.maxsize}),拒绝提交任务")
return False
- def get_stats(self) -> Dict[str, Any]:
- """获取线程池统计信息"""
- idle_count = sum(1 for w in self.workers if w.idle)
- total_tasks = sum(w.total_tasks for w in self.workers)
- failed_tasks = sum(w.failed_tasks for w in self.workers)
-
- return {
- 'pool_size': self.pool_size,
- 'idle_workers': idle_count,
- 'busy_workers': self.pool_size - idle_count,
- 'queue_size': self.task_queue.qsize(),
- 'total_tasks': total_tasks,
- 'failed_tasks': failed_tasks,
- 'success_rate': f"{(total_tasks - failed_tasks) / total_tasks * 100:.1f}%" if total_tasks > 0 else "N/A"
- }
+ def get_stats(self) -> Dict[str, Any]:
+ """获取线程池统计信息"""
+ workers = list(self.workers or [])
+ idle_count = sum(1 for w in workers if getattr(w, "idle", False))
+ total_tasks = sum(int(getattr(w, "total_tasks", 0) or 0) for w in workers)
+ failed_tasks = sum(int(getattr(w, "failed_tasks", 0) or 0) for w in workers)
+
+ worker_details = []
+ for w in workers:
+ browser_instance = getattr(w, "browser_instance", None)
+ browser_use_count = 0
+ browser_created_at = None
+ if isinstance(browser_instance, dict):
+ browser_use_count = int(browser_instance.get("use_count", 0) or 0)
+ browser_created_at = browser_instance.get("created_at")
+
+ worker_details.append(
+ {
+ "worker_id": getattr(w, "worker_id", None),
+ "idle": bool(getattr(w, "idle", False)),
+ "has_browser": bool(browser_instance),
+ "total_tasks": int(getattr(w, "total_tasks", 0) or 0),
+ "failed_tasks": int(getattr(w, "failed_tasks", 0) or 0),
+ "browser_use_count": browser_use_count,
+ "browser_created_at": browser_created_at,
+ "last_active_ts": float(getattr(w, "last_activity_ts", 0) or 0),
+ "thread_alive": bool(getattr(w, "is_alive", lambda: False)()),
+ }
+ )
+
+ return {
+ 'pool_size': self.pool_size,
+ 'idle_workers': idle_count,
+ 'busy_workers': max(0, len(workers) - idle_count),
+ 'queue_size': self.task_queue.qsize(),
+ 'total_tasks': total_tasks,
+ 'failed_tasks': failed_tasks,
+ 'success_rate': f"{(total_tasks - failed_tasks) / total_tasks * 100:.1f}%" if total_tasks > 0 else "N/A",
+ 'workers': worker_details,
+ 'timestamp': time.time(),
+ }
def wait_for_completion(self, timeout: Optional[float] = None):
"""等待所有任务完成"""
@@ -380,8 +374,8 @@ _global_pool: Optional[BrowserWorkerPool] = None
_pool_lock = threading.Lock()
-def get_browser_worker_pool(pool_size: int = 3, log_callback: Optional[Callable] = None) -> BrowserWorkerPool:
- """获取全局浏览器工作线程池(单例)"""
+def get_browser_worker_pool(pool_size: int = 3, log_callback: Optional[Callable] = None) -> BrowserWorkerPool:
+ """获取全局截图工作线程池(单例)"""
global _global_pool
with _pool_lock:
@@ -392,14 +386,48 @@ def get_browser_worker_pool(pool_size: int = 3, log_callback: Optional[Callable]
return _global_pool
-def init_browser_worker_pool(pool_size: int = 3, log_callback: Optional[Callable] = None):
- """初始化全局浏览器工作线程池"""
- get_browser_worker_pool(pool_size=pool_size, log_callback=log_callback)
-
-
-def shutdown_browser_worker_pool():
- """关闭全局浏览器工作线程池"""
- global _global_pool
+def init_browser_worker_pool(pool_size: int = 3, log_callback: Optional[Callable] = None):
+ """初始化全局截图工作线程池"""
+ get_browser_worker_pool(pool_size=pool_size, log_callback=log_callback)
+
+
+def _shutdown_pool_when_idle(pool: BrowserWorkerPool) -> None:
+ try:
+ pool.wait_for_completion(timeout=60)
+ except Exception:
+ pass
+ try:
+ pool.shutdown()
+ except Exception:
+ pass
+
+
+def resize_browser_worker_pool(pool_size: int, log_callback: Optional[Callable] = None) -> bool:
+ """调整截图线程池并发(新任务走新池,旧池空闲后自动关闭)"""
+ global _global_pool
+
+ try:
+ target_size = max(1, int(pool_size))
+ except Exception:
+ target_size = 1
+
+ with _pool_lock:
+ old_pool = _global_pool
+ if old_pool and int(getattr(old_pool, "pool_size", 0) or 0) == target_size:
+ return False
+ effective_log_callback = log_callback or (getattr(old_pool, "log_callback", None) if old_pool else None)
+ _global_pool = BrowserWorkerPool(pool_size=target_size, log_callback=effective_log_callback)
+ _global_pool.initialize()
+
+ if old_pool:
+ threading.Thread(target=_shutdown_pool_when_idle, args=(old_pool,), daemon=True).start()
+
+ return True
+
+
+def shutdown_browser_worker_pool():
+ """关闭全局截图工作线程池"""
+ global _global_pool
with _pool_lock:
if _global_pool:
@@ -407,9 +435,9 @@ def shutdown_browser_worker_pool():
_global_pool = None
-if __name__ == '__main__':
- # 测试代码
- print("测试浏览器工作线程池...")
+if __name__ == '__main__':
+ # 测试代码
+ print("测试截图工作线程池...")
def test_task(browser_instance, url: str, task_id: int):
"""测试任务:访问URL"""
@@ -424,8 +452,8 @@ if __name__ == '__main__':
else:
print(f"任务成功: {result}")
- # 创建线程池(2个worker)
- pool = BrowserWorkerPool(pool_size=2)
+ # 创建线程池(2个worker)
+ pool = BrowserWorkerPool(pool_size=2)
pool.initialize()
# 提交4个任务
diff --git a/database.py b/database.py
index 550b99a..d6a132a 100644
--- a/database.py
+++ b/database.py
@@ -24,15 +24,11 @@ from db.schema import ensure_schema
from db.migrations import migrate_database as _migrate_database
from db.admin import (
admin_reset_user_password,
- approve_password_reset,
clean_old_operation_logs,
- create_password_reset_request,
ensure_default_admin,
get_hourly_registration_count,
- get_pending_password_resets,
get_system_config_raw as _get_system_config_raw,
get_system_stats,
- reject_password_reset,
update_admin_password,
update_admin_username,
update_system_config as _update_system_config,
@@ -44,6 +40,7 @@ from db.accounts import (
delete_user_accounts,
get_account,
get_account_status,
+ get_account_status_batch,
get_user_accounts,
increment_account_login_fail,
reset_account_login_status,
@@ -103,6 +100,7 @@ from db.users import (
get_pending_users,
get_user_by_id,
get_user_by_username,
+ get_user_kdocs_settings,
get_user_stats,
get_user_vip_info,
get_vip_config,
@@ -111,6 +109,7 @@ from db.users import (
remove_user_vip,
set_default_vip_days,
set_user_vip,
+ update_user_kdocs_settings,
verify_user,
)
from db.security import record_login_context
@@ -121,7 +120,7 @@ config = get_config()
DB_FILE = config.DB_FILE
# 数据库版本 (用于迁移管理)
-DB_VERSION = 12
+DB_VERSION = 17
# ==================== 系统配置缓存(P1 / O-03) ====================
@@ -190,12 +189,24 @@ def update_system_config(
schedule_weekdays=None,
max_concurrent_per_account=None,
max_screenshot_concurrent=None,
+ enable_screenshot=None,
proxy_enabled=None,
proxy_api_url=None,
proxy_expire_minutes=None,
auto_approve_enabled=None,
auto_approve_hourly_limit=None,
auto_approve_vip_days=None,
+ kdocs_enabled=None,
+ kdocs_doc_url=None,
+ kdocs_default_unit=None,
+ kdocs_sheet_name=None,
+ kdocs_sheet_index=None,
+ kdocs_unit_column=None,
+ kdocs_image_column=None,
+ kdocs_admin_notify_enabled=None,
+ kdocs_admin_notify_email=None,
+ kdocs_row_start=None,
+ kdocs_row_end=None,
):
"""更新系统配置(写入后立即失效缓存)。"""
ok = _update_system_config(
@@ -206,12 +217,24 @@ def update_system_config(
schedule_weekdays=schedule_weekdays,
max_concurrent_per_account=max_concurrent_per_account,
max_screenshot_concurrent=max_screenshot_concurrent,
+ enable_screenshot=enable_screenshot,
proxy_enabled=proxy_enabled,
proxy_api_url=proxy_api_url,
proxy_expire_minutes=proxy_expire_minutes,
auto_approve_enabled=auto_approve_enabled,
auto_approve_hourly_limit=auto_approve_hourly_limit,
auto_approve_vip_days=auto_approve_vip_days,
+ kdocs_enabled=kdocs_enabled,
+ kdocs_doc_url=kdocs_doc_url,
+ kdocs_default_unit=kdocs_default_unit,
+ kdocs_sheet_name=kdocs_sheet_name,
+ kdocs_sheet_index=kdocs_sheet_index,
+ kdocs_unit_column=kdocs_unit_column,
+ kdocs_image_column=kdocs_image_column,
+ kdocs_admin_notify_enabled=kdocs_admin_notify_enabled,
+ kdocs_admin_notify_email=kdocs_admin_notify_email,
+ kdocs_row_start=kdocs_row_start,
+ kdocs_row_end=kdocs_row_end,
)
if ok:
invalidate_system_config_cache()
diff --git a/db/accounts.py b/db/accounts.py
index c1211ec..85e4132 100644
--- a/db/accounts.py
+++ b/db/accounts.py
@@ -140,6 +140,36 @@ def get_account_status(account_id):
return cursor.fetchone()
+def get_account_status_batch(account_ids):
+ """批量获取账号状态信息"""
+ account_ids = [str(account_id) for account_id in (account_ids or []) if account_id]
+ if not account_ids:
+ return {}
+
+ results = {}
+ chunk_size = 900 # 避免触发 SQLite 绑定参数上限
+ with db_pool.get_db() as conn:
+ cursor = conn.cursor()
+ for idx in range(0, len(account_ids), chunk_size):
+ chunk = account_ids[idx : idx + chunk_size]
+ placeholders = ",".join("?" for _ in chunk)
+ cursor.execute(
+ f"""
+ SELECT id, status, login_fail_count, last_login_error
+ FROM accounts
+ WHERE id IN ({placeholders})
+ """,
+ chunk,
+ )
+ for row in cursor.fetchall():
+ row_dict = dict(row)
+ account_id = str(row_dict.pop("id", ""))
+ if account_id:
+ results[account_id] = row_dict
+
+ return results
+
+
def delete_user_accounts(user_id):
"""删除用户的所有账号"""
with db_pool.get_db() as conn:
@@ -147,4 +177,3 @@ def delete_user_accounts(user_id):
cursor.execute("DELETE FROM accounts WHERE user_id = ?", (user_id,))
conn.commit()
return cursor.rowcount
-
diff --git a/db/admin.py b/db/admin.py
index a805b1e..b087ad3 100644
--- a/db/admin.py
+++ b/db/admin.py
@@ -172,6 +172,17 @@ def get_system_config_raw() -> dict:
"auto_approve_enabled": 0,
"auto_approve_hourly_limit": 10,
"auto_approve_vip_days": 7,
+ "kdocs_enabled": 0,
+ "kdocs_doc_url": "",
+ "kdocs_default_unit": "",
+ "kdocs_sheet_name": "",
+ "kdocs_sheet_index": 0,
+ "kdocs_unit_column": "A",
+ "kdocs_image_column": "D",
+ "kdocs_admin_notify_enabled": 0,
+ "kdocs_admin_notify_email": "",
+ "kdocs_row_start": 0,
+ "kdocs_row_end": 0,
}
@@ -184,12 +195,24 @@ def update_system_config(
schedule_weekdays=None,
max_concurrent_per_account=None,
max_screenshot_concurrent=None,
+ enable_screenshot=None,
proxy_enabled=None,
proxy_api_url=None,
proxy_expire_minutes=None,
auto_approve_enabled=None,
auto_approve_hourly_limit=None,
auto_approve_vip_days=None,
+ kdocs_enabled=None,
+ kdocs_doc_url=None,
+ kdocs_default_unit=None,
+ kdocs_sheet_name=None,
+ kdocs_sheet_index=None,
+ kdocs_unit_column=None,
+ kdocs_image_column=None,
+ kdocs_admin_notify_enabled=None,
+ kdocs_admin_notify_email=None,
+ kdocs_row_start=None,
+ kdocs_row_end=None,
) -> bool:
"""更新系统配置(仅更新DB,不做缓存处理)。"""
allowed_fields = {
@@ -200,12 +223,24 @@ def update_system_config(
"schedule_weekdays",
"max_concurrent_per_account",
"max_screenshot_concurrent",
+ "enable_screenshot",
"proxy_enabled",
"proxy_api_url",
"proxy_expire_minutes",
"auto_approve_enabled",
"auto_approve_hourly_limit",
"auto_approve_vip_days",
+ "kdocs_enabled",
+ "kdocs_doc_url",
+ "kdocs_default_unit",
+ "kdocs_sheet_name",
+ "kdocs_sheet_index",
+ "kdocs_unit_column",
+ "kdocs_image_column",
+ "kdocs_admin_notify_enabled",
+ "kdocs_admin_notify_email",
+ "kdocs_row_start",
+ "kdocs_row_end",
"updated_at",
}
@@ -232,6 +267,9 @@ def update_system_config(
if max_screenshot_concurrent is not None:
updates.append("max_screenshot_concurrent = ?")
params.append(max_screenshot_concurrent)
+ if enable_screenshot is not None:
+ updates.append("enable_screenshot = ?")
+ params.append(enable_screenshot)
if schedule_weekdays is not None:
updates.append("schedule_weekdays = ?")
params.append(schedule_weekdays)
@@ -253,6 +291,39 @@ def update_system_config(
if auto_approve_vip_days is not None:
updates.append("auto_approve_vip_days = ?")
params.append(auto_approve_vip_days)
+ if kdocs_enabled is not None:
+ updates.append("kdocs_enabled = ?")
+ params.append(kdocs_enabled)
+ if kdocs_doc_url is not None:
+ updates.append("kdocs_doc_url = ?")
+ params.append(kdocs_doc_url)
+ if kdocs_default_unit is not None:
+ updates.append("kdocs_default_unit = ?")
+ params.append(kdocs_default_unit)
+ if kdocs_sheet_name is not None:
+ updates.append("kdocs_sheet_name = ?")
+ params.append(kdocs_sheet_name)
+ if kdocs_sheet_index is not None:
+ updates.append("kdocs_sheet_index = ?")
+ params.append(kdocs_sheet_index)
+ if kdocs_unit_column is not None:
+ updates.append("kdocs_unit_column = ?")
+ params.append(kdocs_unit_column)
+ if kdocs_image_column is not None:
+ updates.append("kdocs_image_column = ?")
+ params.append(kdocs_image_column)
+ if kdocs_admin_notify_enabled is not None:
+ updates.append("kdocs_admin_notify_enabled = ?")
+ params.append(kdocs_admin_notify_enabled)
+ if kdocs_admin_notify_email is not None:
+ updates.append("kdocs_admin_notify_email = ?")
+ params.append(kdocs_admin_notify_email)
+ if kdocs_row_start is not None:
+ updates.append("kdocs_row_start = ?")
+ params.append(kdocs_row_start)
+ if kdocs_row_end is not None:
+ updates.append("kdocs_row_end = ?")
+ params.append(kdocs_row_end)
if not updates:
return False
@@ -287,108 +358,6 @@ def get_hourly_registration_count() -> int:
# ==================== 密码重置(管理员) ====================
-def create_password_reset_request(user_id: int, new_password: str):
- """创建密码重置申请(存储哈希)"""
- with db_pool.get_db() as conn:
- cursor = conn.cursor()
- password_hash = hash_password_bcrypt(new_password)
- cst_time = get_cst_now_str()
-
- try:
- cursor.execute(
- """
- INSERT INTO password_reset_requests (user_id, new_password_hash, status, created_at)
- VALUES (?, ?, 'pending', ?)
- """,
- (user_id, password_hash, cst_time),
- )
- conn.commit()
- return cursor.lastrowid
- except Exception as e:
- print(f"创建密码重置申请失败: {e}")
- return None
-
-
-def get_pending_password_resets():
- """获取待审核的密码重置申请列表"""
- with db_pool.get_db() as conn:
- cursor = conn.cursor()
- cursor.execute(
- """
- SELECT r.id, r.user_id, r.created_at, r.status,
- u.username, u.email
- FROM password_reset_requests r
- JOIN users u ON r.user_id = u.id
- WHERE r.status = 'pending'
- ORDER BY r.created_at DESC
- """
- )
- return [dict(row) for row in cursor.fetchall()]
-
-
-def approve_password_reset(request_id: int) -> bool:
- """批准密码重置申请"""
- with db_pool.get_db() as conn:
- cursor = conn.cursor()
- cst_time = get_cst_now_str()
-
- try:
- cursor.execute(
- """
- SELECT user_id, new_password_hash
- FROM password_reset_requests
- WHERE id = ? AND status = 'pending'
- """,
- (request_id,),
- )
-
- result = cursor.fetchone()
- if not result:
- return False
-
- user_id = result["user_id"]
- new_password_hash = result["new_password_hash"]
-
- cursor.execute("UPDATE users SET password_hash = ? WHERE id = ?", (new_password_hash, user_id))
-
- cursor.execute(
- """
- UPDATE password_reset_requests
- SET status = 'approved', processed_at = ?
- WHERE id = ?
- """,
- (cst_time, request_id),
- )
-
- conn.commit()
- return True
- except Exception as e:
- print(f"批准密码重置失败: {e}")
- return False
-
-
-def reject_password_reset(request_id: int) -> bool:
- """拒绝密码重置申请"""
- with db_pool.get_db() as conn:
- cursor = conn.cursor()
- cst_time = get_cst_now_str()
-
- try:
- cursor.execute(
- """
- UPDATE password_reset_requests
- SET status = 'rejected', processed_at = ?
- WHERE id = ? AND status = 'pending'
- """,
- (cst_time, request_id),
- )
- conn.commit()
- return cursor.rowcount > 0
- except Exception as e:
- print(f"拒绝密码重置失败: {e}")
- return False
-
-
def admin_reset_user_password(user_id: int, new_password: str) -> bool:
"""管理员直接重置用户密码"""
with db_pool.get_db() as conn:
diff --git a/db/announcements.py b/db/announcements.py
index 13e5303..c680816 100644
--- a/db/announcements.py
+++ b/db/announcements.py
@@ -6,10 +6,12 @@ import db_pool
from db.utils import get_cst_now_str
-def create_announcement(title, content, is_active=True):
+def create_announcement(title, content, image_url=None, is_active=True):
"""创建公告(默认启用;启用时会自动停用其他公告)"""
title = (title or "").strip()
content = (content or "").strip()
+ image_url = (image_url or "").strip()
+ image_url = image_url or None
if not title or not content:
return None
@@ -22,10 +24,10 @@ def create_announcement(title, content, is_active=True):
cursor.execute(
"""
- INSERT INTO announcements (title, content, is_active, created_at, updated_at)
- VALUES (?, ?, ?, ?, ?)
+ INSERT INTO announcements (title, content, image_url, is_active, created_at, updated_at)
+ VALUES (?, ?, ?, ?, ?, ?)
""",
- (title, content, 1 if is_active else 0, cst_time, cst_time),
+ (title, content, image_url, 1 if is_active else 0, cst_time, cst_time),
)
conn.commit()
return cursor.lastrowid
@@ -129,4 +131,3 @@ def dismiss_announcement_for_user(user_id, announcement_id):
)
conn.commit()
return cursor.rowcount >= 0
-
diff --git a/db/migrations.py b/db/migrations.py
index 91e588c..6933264 100644
--- a/db/migrations.py
+++ b/db/migrations.py
@@ -72,6 +72,24 @@ def migrate_database(conn, target_version: int) -> None:
if current_version < 12:
_migrate_to_v12(conn)
current_version = 12
+ if current_version < 13:
+ _migrate_to_v13(conn)
+ current_version = 13
+ if current_version < 14:
+ _migrate_to_v14(conn)
+ current_version = 14
+ if current_version < 15:
+ _migrate_to_v15(conn)
+ current_version = 15
+ if current_version < 16:
+ _migrate_to_v16(conn)
+ current_version = 16
+ if current_version < 17:
+ _migrate_to_v17(conn)
+ current_version = 17
+ if current_version < 18:
+ _migrate_to_v18(conn)
+ current_version = 18
if current_version != int(target_version):
set_current_version(conn, int(target_version))
@@ -519,3 +537,215 @@ def _migrate_to_v12(conn):
cursor.execute("CREATE INDEX IF NOT EXISTS idx_login_ips_user ON login_ips(user_id)")
conn.commit()
+
+
+def _migrate_to_v13(conn):
+ """迁移到版本13 - 安全防护:威胁检测相关表"""
+ cursor = conn.cursor()
+
+ cursor.execute(
+ """
+ CREATE TABLE IF NOT EXISTS threat_events (
+ id INTEGER PRIMARY KEY AUTOINCREMENT,
+ threat_type TEXT NOT NULL,
+ score INTEGER NOT NULL DEFAULT 0,
+ rule TEXT,
+ field_name TEXT,
+ matched TEXT,
+ value_preview TEXT,
+ ip TEXT,
+ user_id INTEGER,
+ request_method TEXT,
+ request_path TEXT,
+ user_agent TEXT,
+ created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
+ FOREIGN KEY (user_id) REFERENCES users (id) ON DELETE CASCADE
+ )
+ """
+ )
+ cursor.execute("CREATE INDEX IF NOT EXISTS idx_threat_events_created_at ON threat_events(created_at)")
+ cursor.execute("CREATE INDEX IF NOT EXISTS idx_threat_events_ip ON threat_events(ip)")
+ cursor.execute("CREATE INDEX IF NOT EXISTS idx_threat_events_user_id ON threat_events(user_id)")
+ cursor.execute("CREATE INDEX IF NOT EXISTS idx_threat_events_type ON threat_events(threat_type)")
+
+ cursor.execute(
+ """
+ CREATE TABLE IF NOT EXISTS ip_risk_scores (
+ ip TEXT PRIMARY KEY,
+ risk_score INTEGER NOT NULL DEFAULT 0,
+ last_seen TIMESTAMP,
+ created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
+ updated_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP
+ )
+ """
+ )
+ cursor.execute("CREATE INDEX IF NOT EXISTS idx_ip_risk_scores_score ON ip_risk_scores(risk_score)")
+ cursor.execute("CREATE INDEX IF NOT EXISTS idx_ip_risk_scores_updated_at ON ip_risk_scores(updated_at)")
+
+ cursor.execute(
+ """
+ CREATE TABLE IF NOT EXISTS user_risk_scores (
+ user_id INTEGER PRIMARY KEY,
+ risk_score INTEGER NOT NULL DEFAULT 0,
+ last_seen TIMESTAMP,
+ created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
+ updated_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
+ FOREIGN KEY (user_id) REFERENCES users (id) ON DELETE CASCADE
+ )
+ """
+ )
+ cursor.execute("CREATE INDEX IF NOT EXISTS idx_user_risk_scores_score ON user_risk_scores(risk_score)")
+ cursor.execute("CREATE INDEX IF NOT EXISTS idx_user_risk_scores_updated_at ON user_risk_scores(updated_at)")
+
+ cursor.execute(
+ """
+ CREATE TABLE IF NOT EXISTS ip_blacklist (
+ ip TEXT PRIMARY KEY,
+ reason TEXT,
+ is_active INTEGER DEFAULT 1,
+ added_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
+ expires_at TIMESTAMP
+ )
+ """
+ )
+ cursor.execute("CREATE INDEX IF NOT EXISTS idx_ip_blacklist_active ON ip_blacklist(is_active)")
+ cursor.execute("CREATE INDEX IF NOT EXISTS idx_ip_blacklist_expires ON ip_blacklist(expires_at)")
+
+ cursor.execute(
+ """
+ CREATE TABLE IF NOT EXISTS threat_signatures (
+ id INTEGER PRIMARY KEY AUTOINCREMENT,
+ name TEXT UNIQUE NOT NULL,
+ threat_type TEXT NOT NULL,
+ pattern TEXT NOT NULL,
+ pattern_type TEXT DEFAULT 'regex',
+ score INTEGER DEFAULT 0,
+ is_active INTEGER DEFAULT 1,
+ created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
+ updated_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP
+ )
+ """
+ )
+ cursor.execute("CREATE INDEX IF NOT EXISTS idx_threat_signatures_type ON threat_signatures(threat_type)")
+ cursor.execute("CREATE INDEX IF NOT EXISTS idx_threat_signatures_active ON threat_signatures(is_active)")
+
+ conn.commit()
+
+
+def _migrate_to_v14(conn):
+ """迁移到版本14 - 安全防护:用户黑名单表"""
+ cursor = conn.cursor()
+
+ cursor.execute(
+ """
+ CREATE TABLE IF NOT EXISTS user_blacklist (
+ user_id INTEGER PRIMARY KEY,
+ reason TEXT,
+ is_active INTEGER DEFAULT 1,
+ added_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
+ expires_at TIMESTAMP
+ )
+ """
+ )
+ cursor.execute("CREATE INDEX IF NOT EXISTS idx_user_blacklist_active ON user_blacklist(is_active)")
+ cursor.execute("CREATE INDEX IF NOT EXISTS idx_user_blacklist_expires ON user_blacklist(expires_at)")
+
+ conn.commit()
+
+
+def _migrate_to_v15(conn):
+ """迁移到版本15 - 邮件设置:新设备登录提醒全局开关"""
+ cursor = conn.cursor()
+
+ cursor.execute("SELECT name FROM sqlite_master WHERE type='table' AND name='email_settings'")
+ if not cursor.fetchone():
+ # 邮件表由 email_service.init_email_tables 创建;此处仅做增量字段迁移
+ return
+
+ cursor.execute("PRAGMA table_info(email_settings)")
+ columns = [col[1] for col in cursor.fetchall()]
+
+ changed = False
+ if "login_alert_enabled" not in columns:
+ cursor.execute("ALTER TABLE email_settings ADD COLUMN login_alert_enabled INTEGER DEFAULT 1")
+ print(" ✓ 添加 email_settings.login_alert_enabled 字段")
+ changed = True
+
+ try:
+ cursor.execute("UPDATE email_settings SET login_alert_enabled = 1 WHERE login_alert_enabled IS NULL")
+ if cursor.rowcount:
+ changed = True
+ except sqlite3.OperationalError:
+ # 列不存在等情况由上方迁移兜底;不阻断主流程
+ pass
+
+ if changed:
+ conn.commit()
+
+
+def _migrate_to_v16(conn):
+ """迁移到版本16 - 公告支持图片字段"""
+ cursor = conn.cursor()
+ cursor.execute("PRAGMA table_info(announcements)")
+ columns = [col[1] for col in cursor.fetchall()]
+
+ if "image_url" not in columns:
+ cursor.execute("ALTER TABLE announcements ADD COLUMN image_url TEXT")
+ conn.commit()
+ print(" ✓ 添加 announcements.image_url 字段")
+
+
+def _migrate_to_v17(conn):
+ """迁移到版本17 - 金山文档上传配置与用户开关"""
+ cursor = conn.cursor()
+
+ cursor.execute("PRAGMA table_info(system_config)")
+ columns = [col[1] for col in cursor.fetchall()]
+
+ system_fields = [
+ ("kdocs_enabled", "INTEGER DEFAULT 0"),
+ ("kdocs_doc_url", "TEXT DEFAULT ''"),
+ ("kdocs_default_unit", "TEXT DEFAULT ''"),
+ ("kdocs_sheet_name", "TEXT DEFAULT ''"),
+ ("kdocs_sheet_index", "INTEGER DEFAULT 0"),
+ ("kdocs_unit_column", "TEXT DEFAULT 'A'"),
+ ("kdocs_image_column", "TEXT DEFAULT 'D'"),
+ ("kdocs_admin_notify_enabled", "INTEGER DEFAULT 0"),
+ ("kdocs_admin_notify_email", "TEXT DEFAULT ''"),
+ ]
+ for field, ddl in system_fields:
+ if field not in columns:
+ cursor.execute(f"ALTER TABLE system_config ADD COLUMN {field} {ddl}")
+ print(f" ✓ 添加 system_config.{field} 字段")
+
+ cursor.execute("PRAGMA table_info(users)")
+ columns = [col[1] for col in cursor.fetchall()]
+
+ user_fields = [
+ ("kdocs_unit", "TEXT DEFAULT ''"),
+ ("kdocs_auto_upload", "INTEGER DEFAULT 0"),
+ ]
+ for field, ddl in user_fields:
+ if field not in columns:
+ cursor.execute(f"ALTER TABLE users ADD COLUMN {field} {ddl}")
+ print(f" ✓ 添加 users.{field} 字段")
+
+ conn.commit()
+
+
+def _migrate_to_v18(conn):
+ """迁移到版本18 - 金山文档上传:有效行范围配置"""
+ cursor = conn.cursor()
+
+ cursor.execute("PRAGMA table_info(system_config)")
+ columns = [col[1] for col in cursor.fetchall()]
+
+ if "kdocs_row_start" not in columns:
+ cursor.execute("ALTER TABLE system_config ADD COLUMN kdocs_row_start INTEGER DEFAULT 0")
+ print(" ✓ 添加 system_config.kdocs_row_start 字段")
+
+ if "kdocs_row_end" not in columns:
+ cursor.execute("ALTER TABLE system_config ADD COLUMN kdocs_row_end INTEGER DEFAULT 0")
+ print(" ✓ 添加 system_config.kdocs_row_end 字段")
+
+ conn.commit()
diff --git a/db/schema.py b/db/schema.py
index 73bd377..59108a7 100644
--- a/db/schema.py
+++ b/db/schema.py
@@ -33,6 +33,8 @@ def ensure_schema(conn) -> None:
email TEXT,
email_verified INTEGER DEFAULT 0,
email_notify_enabled INTEGER DEFAULT 1,
+ kdocs_unit TEXT DEFAULT '',
+ kdocs_auto_upload INTEGER DEFAULT 0,
status TEXT DEFAULT 'approved',
vip_expire_time TIMESTAMP,
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
@@ -72,6 +74,101 @@ def ensure_schema(conn) -> None:
"""
)
+ # ==================== 安全防护:威胁检测相关表 ====================
+
+ # 威胁事件日志表
+ cursor.execute(
+ """
+ CREATE TABLE IF NOT EXISTS threat_events (
+ id INTEGER PRIMARY KEY AUTOINCREMENT,
+ threat_type TEXT NOT NULL,
+ score INTEGER NOT NULL DEFAULT 0,
+ rule TEXT,
+ field_name TEXT,
+ matched TEXT,
+ value_preview TEXT,
+ ip TEXT,
+ user_id INTEGER,
+ request_method TEXT,
+ request_path TEXT,
+ user_agent TEXT,
+ created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
+ FOREIGN KEY (user_id) REFERENCES users (id) ON DELETE CASCADE
+ )
+ """
+ )
+
+ # IP风险评分表
+ cursor.execute(
+ """
+ CREATE TABLE IF NOT EXISTS ip_risk_scores (
+ ip TEXT PRIMARY KEY,
+ risk_score INTEGER NOT NULL DEFAULT 0,
+ last_seen TIMESTAMP,
+ created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
+ updated_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP
+ )
+ """
+ )
+
+ # 用户风险评分表
+ cursor.execute(
+ """
+ CREATE TABLE IF NOT EXISTS user_risk_scores (
+ user_id INTEGER PRIMARY KEY,
+ risk_score INTEGER NOT NULL DEFAULT 0,
+ last_seen TIMESTAMP,
+ created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
+ updated_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
+ FOREIGN KEY (user_id) REFERENCES users (id) ON DELETE CASCADE
+ )
+ """
+ )
+
+ # IP黑名单表
+ cursor.execute(
+ """
+ CREATE TABLE IF NOT EXISTS ip_blacklist (
+ ip TEXT PRIMARY KEY,
+ reason TEXT,
+ is_active INTEGER DEFAULT 1,
+ added_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
+ expires_at TIMESTAMP
+ )
+ """
+ )
+
+ # 用户黑名单表
+ cursor.execute(
+ """
+ CREATE TABLE IF NOT EXISTS user_blacklist (
+ user_id INTEGER PRIMARY KEY,
+ reason TEXT,
+ is_active INTEGER DEFAULT 1,
+ added_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
+ expires_at TIMESTAMP,
+ FOREIGN KEY (user_id) REFERENCES users (id) ON DELETE CASCADE
+ )
+ """
+ )
+
+ # 威胁特征库表
+ cursor.execute(
+ """
+ CREATE TABLE IF NOT EXISTS threat_signatures (
+ id INTEGER PRIMARY KEY AUTOINCREMENT,
+ name TEXT UNIQUE NOT NULL,
+ threat_type TEXT NOT NULL,
+ pattern TEXT NOT NULL,
+ pattern_type TEXT DEFAULT 'regex',
+ score INTEGER DEFAULT 0,
+ is_active INTEGER DEFAULT 1,
+ created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
+ updated_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP
+ )
+ """
+ )
+
# 账号表(关联用户)
cursor.execute(
"""
@@ -118,6 +215,17 @@ def ensure_schema(conn) -> None:
auto_approve_enabled INTEGER DEFAULT 0,
auto_approve_hourly_limit INTEGER DEFAULT 10,
auto_approve_vip_days INTEGER DEFAULT 7,
+ kdocs_enabled INTEGER DEFAULT 0,
+ kdocs_doc_url TEXT DEFAULT '',
+ kdocs_default_unit TEXT DEFAULT '',
+ kdocs_sheet_name TEXT DEFAULT '',
+ kdocs_sheet_index INTEGER DEFAULT 0,
+ kdocs_unit_column TEXT DEFAULT 'A',
+ kdocs_image_column TEXT DEFAULT 'D',
+ kdocs_admin_notify_enabled INTEGER DEFAULT 0,
+ kdocs_admin_notify_email TEXT DEFAULT '',
+ kdocs_row_start INTEGER DEFAULT 0,
+ kdocs_row_end INTEGER DEFAULT 0,
updated_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP
)
"""
@@ -144,21 +252,6 @@ def ensure_schema(conn) -> None:
"""
)
- # 密码重置申请表
- cursor.execute(
- """
- CREATE TABLE IF NOT EXISTS password_reset_requests (
- id INTEGER PRIMARY KEY AUTOINCREMENT,
- user_id INTEGER NOT NULL,
- new_password_hash TEXT NOT NULL,
- status TEXT NOT NULL DEFAULT 'pending',
- created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
- processed_at TIMESTAMP,
- FOREIGN KEY (user_id) REFERENCES users (id) ON DELETE CASCADE
- )
- """
- )
-
# 数据库版本表
cursor.execute(
"""
@@ -196,6 +289,7 @@ def ensure_schema(conn) -> None:
id INTEGER PRIMARY KEY AUTOINCREMENT,
title TEXT NOT NULL,
content TEXT NOT NULL,
+ image_url TEXT,
is_active INTEGER DEFAULT 1,
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
updated_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP
@@ -271,6 +365,26 @@ def ensure_schema(conn) -> None:
cursor.execute("CREATE INDEX IF NOT EXISTS idx_login_fingerprints_user ON login_fingerprints(user_id)")
cursor.execute("CREATE INDEX IF NOT EXISTS idx_login_ips_user ON login_ips(user_id)")
+ cursor.execute("CREATE INDEX IF NOT EXISTS idx_threat_events_created_at ON threat_events(created_at)")
+ cursor.execute("CREATE INDEX IF NOT EXISTS idx_threat_events_ip ON threat_events(ip)")
+ cursor.execute("CREATE INDEX IF NOT EXISTS idx_threat_events_user_id ON threat_events(user_id)")
+ cursor.execute("CREATE INDEX IF NOT EXISTS idx_threat_events_type ON threat_events(threat_type)")
+
+ cursor.execute("CREATE INDEX IF NOT EXISTS idx_ip_risk_scores_score ON ip_risk_scores(risk_score)")
+ cursor.execute("CREATE INDEX IF NOT EXISTS idx_ip_risk_scores_updated_at ON ip_risk_scores(updated_at)")
+
+ cursor.execute("CREATE INDEX IF NOT EXISTS idx_user_risk_scores_score ON user_risk_scores(risk_score)")
+ cursor.execute("CREATE INDEX IF NOT EXISTS idx_user_risk_scores_updated_at ON user_risk_scores(updated_at)")
+
+ cursor.execute("CREATE INDEX IF NOT EXISTS idx_ip_blacklist_active ON ip_blacklist(is_active)")
+ cursor.execute("CREATE INDEX IF NOT EXISTS idx_ip_blacklist_expires ON ip_blacklist(expires_at)")
+
+ cursor.execute("CREATE INDEX IF NOT EXISTS idx_user_blacklist_active ON user_blacklist(is_active)")
+ cursor.execute("CREATE INDEX IF NOT EXISTS idx_user_blacklist_expires ON user_blacklist(expires_at)")
+
+ cursor.execute("CREATE INDEX IF NOT EXISTS idx_threat_signatures_type ON threat_signatures(threat_type)")
+ cursor.execute("CREATE INDEX IF NOT EXISTS idx_threat_signatures_active ON threat_signatures(is_active)")
+
cursor.execute("CREATE INDEX IF NOT EXISTS idx_accounts_user_id ON accounts(user_id)")
cursor.execute("CREATE INDEX IF NOT EXISTS idx_accounts_username ON accounts(username)")
@@ -279,9 +393,6 @@ def ensure_schema(conn) -> None:
cursor.execute("CREATE INDEX IF NOT EXISTS idx_task_logs_created_at ON task_logs(created_at)")
cursor.execute("CREATE INDEX IF NOT EXISTS idx_task_logs_user_date ON task_logs(user_id, created_at)")
- cursor.execute("CREATE INDEX IF NOT EXISTS idx_password_reset_status ON password_reset_requests(status)")
- cursor.execute("CREATE INDEX IF NOT EXISTS idx_password_reset_user_id ON password_reset_requests(user_id)")
-
cursor.execute("CREATE INDEX IF NOT EXISTS idx_bug_feedbacks_user_id ON bug_feedbacks(user_id)")
cursor.execute("CREATE INDEX IF NOT EXISTS idx_bug_feedbacks_status ON bug_feedbacks(status)")
cursor.execute("CREATE INDEX IF NOT EXISTS idx_bug_feedbacks_created_at ON bug_feedbacks(created_at)")
diff --git a/db/security.py b/db/security.py
index a57b2d2..79ad0f3 100644
--- a/db/security.py
+++ b/db/security.py
@@ -2,10 +2,12 @@
# -*- coding: utf-8 -*-
from __future__ import annotations
+from datetime import timedelta
+from typing import Any, Optional
from typing import Dict
import db_pool
-from db.utils import get_cst_now_str
+from db.utils import get_cst_now, get_cst_now_str
def record_login_context(user_id: int, ip_address: str, user_agent: str) -> Dict[str, bool]:
@@ -74,3 +76,217 @@ def record_login_context(user_id: int, ip_address: str, user_agent: str) -> Dict
conn.commit()
return {"new_device": new_device, "new_ip": new_ip}
+
+
+def get_threat_events_count(hours: int = 24) -> int:
+ """获取指定时间内的威胁事件数。"""
+ try:
+ hours_int = max(0, int(hours))
+ except Exception:
+ hours_int = 24
+
+ if hours_int <= 0:
+ return 0
+
+ start_time = (get_cst_now() - timedelta(hours=hours_int)).strftime("%Y-%m-%d %H:%M:%S")
+ with db_pool.get_db() as conn:
+ cursor = conn.cursor()
+ cursor.execute("SELECT COUNT(*) AS cnt FROM threat_events WHERE created_at >= ?", (start_time,))
+ row = cursor.fetchone()
+ try:
+ return int(row["cnt"] if row else 0)
+ except Exception:
+ return 0
+
+
+def _build_threat_events_where_clause(filters: Optional[dict]) -> tuple[str, list[Any]]:
+ clauses: list[str] = []
+ params: list[Any] = []
+
+ if not isinstance(filters, dict):
+ return "", []
+
+ event_type = filters.get("event_type") or filters.get("threat_type")
+ if event_type:
+ raw = str(event_type).strip()
+ types = [t.strip()[:64] for t in raw.split(",") if t.strip()]
+ if len(types) == 1:
+ clauses.append("threat_type = ?")
+ params.append(types[0])
+ elif types:
+ placeholders = ", ".join(["?"] * len(types))
+ clauses.append(f"threat_type IN ({placeholders})")
+ params.extend(types)
+
+ severity = filters.get("severity")
+ if severity is not None and str(severity).strip():
+ sev = str(severity).strip().lower()
+ if "-" in sev:
+ parts = [p.strip() for p in sev.split("-", 1)]
+ try:
+ min_score = int(parts[0])
+ max_score = int(parts[1])
+ clauses.append("score >= ? AND score <= ?")
+ params.extend([min_score, max_score])
+ except Exception:
+ pass
+ elif sev.isdigit():
+ clauses.append("score >= ?")
+ params.append(int(sev))
+ elif sev in {"high", "critical"}:
+ clauses.append("score >= ?")
+ params.append(80)
+ elif sev in {"medium", "med"}:
+ clauses.append("score >= ? AND score < ?")
+ params.extend([50, 80])
+ elif sev in {"low", "info"}:
+ clauses.append("score < ?")
+ params.append(50)
+
+ ip = filters.get("ip")
+ if ip is not None and str(ip).strip():
+ ip_text = str(ip).strip()[:64]
+ clauses.append("ip = ?")
+ params.append(ip_text)
+
+ user_id = filters.get("user_id")
+ if user_id is not None and str(user_id).strip():
+ try:
+ user_id_int = int(user_id)
+ except Exception:
+ user_id_int = None
+ if user_id_int is not None:
+ clauses.append("user_id = ?")
+ params.append(user_id_int)
+
+ if not clauses:
+ return "", []
+ return " WHERE " + " AND ".join(clauses), params
+
+
+def get_threat_events_list(page: int, per_page: int, filters: Optional[dict] = None) -> dict:
+ """分页获取威胁事件。"""
+ try:
+ page_i = max(1, int(page))
+ except Exception:
+ page_i = 1
+ try:
+ per_page_i = int(per_page)
+ except Exception:
+ per_page_i = 20
+ per_page_i = max(1, min(200, per_page_i))
+
+ where_sql, params = _build_threat_events_where_clause(filters)
+ offset = (page_i - 1) * per_page_i
+
+ with db_pool.get_db() as conn:
+ cursor = conn.cursor()
+ cursor.execute(f"SELECT COUNT(*) AS cnt FROM threat_events{where_sql}", tuple(params))
+ row = cursor.fetchone()
+ total = int(row["cnt"]) if row else 0
+
+ cursor.execute(
+ f"""
+ SELECT
+ id,
+ threat_type,
+ score,
+ rule,
+ field_name,
+ matched,
+ value_preview,
+ ip,
+ user_id,
+ request_method,
+ request_path,
+ user_agent,
+ created_at
+ FROM threat_events
+ {where_sql}
+ ORDER BY created_at DESC, id DESC
+ LIMIT ? OFFSET ?
+ """,
+ tuple(params + [per_page_i, offset]),
+ )
+ items = [dict(r) for r in cursor.fetchall()]
+
+ return {"page": page_i, "per_page": per_page_i, "total": total, "items": items, "filters": filters or {}}
+
+
+def get_ip_threat_history(ip: str, limit: int = 50) -> list[dict]:
+ """获取IP的威胁历史(最近limit条)。"""
+ ip_text = str(ip or "").strip()[:64]
+ if not ip_text:
+ return []
+ try:
+ limit_i = max(1, min(200, int(limit)))
+ except Exception:
+ limit_i = 50
+
+ with db_pool.get_db() as conn:
+ cursor = conn.cursor()
+ cursor.execute(
+ """
+ SELECT
+ id,
+ threat_type,
+ score,
+ rule,
+ field_name,
+ matched,
+ value_preview,
+ ip,
+ user_id,
+ request_method,
+ request_path,
+ user_agent,
+ created_at
+ FROM threat_events
+ WHERE ip = ?
+ ORDER BY created_at DESC, id DESC
+ LIMIT ?
+ """,
+ (ip_text, limit_i),
+ )
+ return [dict(r) for r in cursor.fetchall()]
+
+
+def get_user_threat_history(user_id: int, limit: int = 50) -> list[dict]:
+ """获取用户的威胁历史(最近limit条)。"""
+ if user_id is None:
+ return []
+ try:
+ user_id_int = int(user_id)
+ except Exception:
+ return []
+ try:
+ limit_i = max(1, min(200, int(limit)))
+ except Exception:
+ limit_i = 50
+
+ with db_pool.get_db() as conn:
+ cursor = conn.cursor()
+ cursor.execute(
+ """
+ SELECT
+ id,
+ threat_type,
+ score,
+ rule,
+ field_name,
+ matched,
+ value_preview,
+ ip,
+ user_id,
+ request_method,
+ request_path,
+ user_agent,
+ created_at
+ FROM threat_events
+ WHERE user_id = ?
+ ORDER BY created_at DESC, id DESC
+ LIMIT ?
+ """,
+ (user_id_int, limit_i),
+ )
+ return [dict(r) for r in cursor.fetchall()]
diff --git a/db/users.py b/db/users.py
index 2df5183..42423a5 100644
--- a/db/users.py
+++ b/db/users.py
@@ -217,6 +217,39 @@ def get_user_by_id(user_id):
return dict(user) if user else None
+def get_user_kdocs_settings(user_id):
+ """获取用户的金山文档配置"""
+ user = get_user_by_id(user_id)
+ if not user:
+ return None
+ return {
+ "kdocs_unit": user.get("kdocs_unit") or "",
+ "kdocs_auto_upload": 1 if user.get("kdocs_auto_upload") else 0,
+ }
+
+
+def update_user_kdocs_settings(user_id, *, kdocs_unit=None, kdocs_auto_upload=None) -> bool:
+ """更新用户的金山文档配置"""
+ updates = []
+ params = []
+ if kdocs_unit is not None:
+ updates.append("kdocs_unit = ?")
+ params.append(kdocs_unit)
+ if kdocs_auto_upload is not None:
+ updates.append("kdocs_auto_upload = ?")
+ params.append(kdocs_auto_upload)
+
+ if not updates:
+ return False
+
+ params.append(user_id)
+ with db_pool.get_db() as conn:
+ cursor = conn.cursor()
+ cursor.execute(f"UPDATE users SET {', '.join(updates)} WHERE id = ?", params)
+ conn.commit()
+ return cursor.rowcount > 0
+
+
def get_user_by_username(username):
"""根据用户名获取用户"""
with db_pool.get_db() as conn:
diff --git a/docker-compose.yml b/docker-compose.yml
index 8b8b550..60ef3c8 100644
--- a/docker-compose.yml
+++ b/docker-compose.yml
@@ -7,60 +7,48 @@ services:
ports:
- "51232:51233"
volumes:
- - ./data:/app/data # 数据库持久化
- - ./logs:/app/logs # 日志持久化
- - ./截图:/app/截图 # 截图持久化
- - ./playwright:/ms-playwright # Playwright浏览器持久化(避免重复下载)
- - /etc/localtime:/etc/localtime:ro # 时区同步
- - ./static:/app/static # 静态文件(实时更新)
- - ./templates:/app/templates # 模板文件(实时更新)
- - ./app.py:/app/app.py # 主程序(实时更新)
- - ./database.py:/app/database.py # 数据库模块(实时更新)
+ - ./data:/app/data
+ - ./logs:/app/logs
+ - ./截图:/app/截图
+ - ./playwright:/ms-playwright
+ - /etc/localtime:/etc/localtime:ro
+ - ./static:/app/static
+ - ./templates:/app/templates
+ - ./app.py:/app/app.py
+ - ./database.py:/app/database.py
+ # 代码热更新
+ - ./services:/app/services
+ - ./routes:/app/routes
+ - ./db:/app/db
+ - ./security:/app/security
+ - ./realtime:/app/realtime
+ - ./api_browser.py:/app/api_browser.py
+ - ./app_config.py:/app/app_config.py
+ - ./app_logger.py:/app/app_logger.py
+ - ./app_security.py:/app/app_security.py
+ - ./browser_pool_worker.py:/app/browser_pool_worker.py
+ - ./crypto_utils.py:/app/crypto_utils.py
+ - ./db_pool.py:/app/db_pool.py
+ - ./email_service.py:/app/email_service.py
+ - ./password_utils.py:/app/password_utils.py
+ - ./playwright_automation.py:/app/playwright_automation.py
+ - ./task_checkpoint.py:/app/task_checkpoint.py
dns:
- 223.5.5.5
- 114.114.114.114
- - 119.29.29.29
environment:
- TZ=Asia/Shanghai
- PYTHONUNBUFFERED=1
- PLAYWRIGHT_BROWSERS_PATH=/ms-playwright
- - PLAYWRIGHT_DOWNLOAD_HOST=https://npmmirror.com/mirrors/playwright
- # Flask 配置
- FLASK_ENV=production
- - FLASK_DEBUG=false
- # 服务器配置
- SERVER_HOST=0.0.0.0
- SERVER_PORT=51233
- # 数据库配置
- - DB_FILE=data/app_data.db
- - DB_POOL_SIZE=5
- # 并发控制配置
- - MAX_CONCURRENT_GLOBAL=2
- - MAX_CONCURRENT_PER_ACCOUNT=1
- - MAX_CONCURRENT_CONTEXTS=100
- # 安全配置
- - SESSION_LIFETIME_HOURS=24
- - SESSION_COOKIE_SECURE=false
- - MAX_CAPTCHA_ATTEMPTS=5
- - MAX_IP_ATTEMPTS_PER_HOUR=10
- # 日志配置
- LOG_LEVEL=INFO
- - LOG_FILE=logs/app.log
- - API_DIAGNOSTIC_LOG=0
- - API_DIAGNOSTIC_SLOW_MS=0
- # 知识管理平台配置
- - ZSGL_LOGIN_URL=https://postoa.aidunsoft.com/admin/login.aspx
- - ZSGL_INDEX_URL_PATTERN=index.aspx
- - PAGE_LOAD_TIMEOUT=60000
restart: unless-stopped
- shm_size: 2gb # 为Chromium分配共享内存
-
- # 内存和CPU资源限制
- mem_limit: 4g # 硬限制:最大4GB内存
- mem_reservation: 2g # 软限制:预留2GB
- cpus: '2.0' # 限制使用2个CPU核心
-
- # 健康检查(可选)
+ shm_size: 2gb
+ mem_limit: 4g
+ mem_reservation: 2g
+ cpus: '2.0'
healthcheck:
test: ["CMD-SHELL", "curl -f http://localhost:51233 || exit 1"]
interval: 5m
diff --git a/email_service.py b/email_service.py
index e28f1d3..c02629c 100644
--- a/email_service.py
+++ b/email_service.py
@@ -154,6 +154,7 @@ def init_email_tables():
enabled INTEGER DEFAULT 0,
failover_enabled INTEGER DEFAULT 1,
register_verify_enabled INTEGER DEFAULT 0,
+ login_alert_enabled INTEGER DEFAULT 1,
task_notify_enabled INTEGER DEFAULT 0,
base_url TEXT DEFAULT '',
updated_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP
@@ -244,8 +245,8 @@ def get_email_settings() -> Dict[str, Any]:
with db_pool.get_db() as conn:
cursor = conn.cursor()
cursor.execute("""
- SELECT enabled, failover_enabled, register_verify_enabled, base_url,
- task_notify_enabled, updated_at
+ SELECT enabled, failover_enabled, register_verify_enabled, login_alert_enabled,
+ base_url, task_notify_enabled, updated_at
FROM email_settings WHERE id = 1
""")
row = cursor.fetchone()
@@ -254,14 +255,16 @@ def get_email_settings() -> Dict[str, Any]:
'enabled': bool(row[0]),
'failover_enabled': bool(row[1]),
'register_verify_enabled': bool(row[2]) if row[2] is not None else False,
- 'base_url': row[3] or '',
- 'task_notify_enabled': bool(row[4]) if row[4] is not None else False,
- 'updated_at': row[5]
+ 'login_alert_enabled': bool(row[3]) if row[3] is not None else True,
+ 'base_url': row[4] or '',
+ 'task_notify_enabled': bool(row[5]) if row[5] is not None else False,
+ 'updated_at': row[6]
}
return {
'enabled': False,
'failover_enabled': True,
'register_verify_enabled': False,
+ 'login_alert_enabled': True,
'base_url': '',
'task_notify_enabled': False,
'updated_at': None
@@ -272,6 +275,7 @@ def update_email_settings(
enabled: bool,
failover_enabled: bool,
register_verify_enabled: bool = None,
+ login_alert_enabled: bool = None,
base_url: str = None,
task_notify_enabled: bool = None
) -> bool:
@@ -287,6 +291,10 @@ def update_email_settings(
updates.append('register_verify_enabled = ?')
params.append(int(register_verify_enabled))
+ if login_alert_enabled is not None:
+ updates.append('login_alert_enabled = ?')
+ params.append(int(login_alert_enabled))
+
if base_url is not None:
updates.append('base_url = ?')
params.append(base_url)
diff --git a/playwright_automation.py b/playwright_automation.py
index 7c2d9d5..cefcbaf 100755
--- a/playwright_automation.py
+++ b/playwright_automation.py
@@ -424,7 +424,7 @@ class PlaywrightAutomation:
# 等待跳转
# self.log("等待登录处理...") # 精简日志
- self.page.wait_for_load_state('networkidle', timeout=10000) # 优化为10秒
+ self.page.wait_for_load_state('networkidle', timeout=30000) # 增加到30秒
# 检查登录结果
current_url = self.page.url
@@ -823,7 +823,7 @@ class PlaywrightAutomation:
self.log(f"导航到 '{browse_type}' 页面...")
try:
# 等待页面完全加载
- time.sleep(0.5)
+ time.sleep(2)
self.log(f"当前URL: {self.main_page.url}")
except Exception as e:
self.log(f"获取URL失败: {str(e)}")
@@ -835,7 +835,7 @@ class PlaywrightAutomation:
# 如果只是导航(用于截图),切换完成后直接返回
if navigate_only:
- time.sleep(0.3) # 等待页面稳定
+ time.sleep(1) # 等待页面稳定
result.success = True
return result
@@ -867,21 +867,27 @@ class PlaywrightAutomation:
except Exception: # Bug fix: 明确捕获Exception
self.log("等待表格超时,继续尝试...")
- # 等待页面网络空闲,确保AJAX加载完成
- try:
- self.page.wait_for_load_state('networkidle', timeout=5000)
- except Exception:
- pass # 超时继续,不阻塞
+ # 额外等待,确保AJAX内容加载完成
+ # 第一页等待更长时间,因为是首次加载(并发时尤其���要)
+ if current_page == 1 and total_items == 0:
+ time.sleep(3.0)
+ else:
+ time.sleep(1.0)
- # 获取内容行数量(简化重试:2次快速检测)
+ # 获取内容行数量(带重试机制,避免AJAX加载慢导致误判)
+ # 第一页使用更多重试次数(8次×3秒=24秒),处理高并发时的慢加载
+ # 后续页使用3次×1.5秒=4.5秒
+ max_retries = 8 if (current_page == 1 and total_items == 0) else 3
+ retry_wait = 3.0 if (current_page == 1 and total_items == 0) else 1.5
rows_count = 0
- for retry in range(2):
+ for retry in range(max_retries):
rows_locator = self.page.locator("//table[@class='ltable']/tbody/tr[position()>1 and count(td)>=5]")
rows_count = rows_locator.count()
if rows_count > 0:
break
- if retry == 0:
- time.sleep(0.5) # 仅重试一次,等待0.5秒
+ if retry < max_retries - 1:
+ self.log(f"未检测到内容,等待后重试... ({retry+1}/{max_retries})")
+ time.sleep(retry_wait)
if rows_count == 0:
self.log("当前页面没有内容")
diff --git a/requirements.txt b/requirements.txt
index 16d0382..31546bf 100644
--- a/requirements.txt
+++ b/requirements.txt
@@ -2,7 +2,6 @@ flask==3.0.0
flask-socketio==5.3.5
flask-login==0.6.3
python-socketio==5.10.0
-playwright==1.40.0
schedule==1.2.0
psutil==5.9.6
pytz==2024.1
@@ -10,6 +9,6 @@ bcrypt==4.0.1
requests==2.31.0
python-dotenv==1.0.0
beautifulsoup4==4.12.2
-nest_asyncio
cryptography>=41.0.0
Pillow>=10.0.0
+playwright==1.42.0
diff --git a/routes/__init__.py b/routes/__init__.py
index 1327246..439daa9 100644
--- a/routes/__init__.py
+++ b/routes/__init__.py
@@ -5,6 +5,7 @@ from __future__ import annotations
def register_blueprints(app) -> None:
from routes.admin_api import admin_api_bp
+ from routes.admin_api import security_bp as admin_security_bp
from routes.api_accounts import api_accounts_bp
from routes.api_auth import api_auth_bp
from routes.api_schedules import api_schedules_bp
@@ -21,3 +22,6 @@ def register_blueprints(app) -> None:
app.register_blueprint(api_screenshots_bp)
app.register_blueprint(api_schedules_bp)
app.register_blueprint(admin_api_bp)
+ # Security admin APIs (support both /api/admin/* and /yuyx/api/admin/*)
+ app.register_blueprint(admin_security_bp)
+ app.register_blueprint(admin_security_bp, url_prefix="/yuyx", name="admin_security_yuyx")
diff --git a/routes/admin_api/__init__.py b/routes/admin_api/__init__.py
index c5e56f8..ea19e72 100644
--- a/routes/admin_api/__init__.py
+++ b/routes/admin_api/__init__.py
@@ -8,4 +8,6 @@ admin_api_bp = Blueprint("admin_api", __name__, url_prefix="/yuyx/api")
# Import side effects: register routes on blueprint
from routes.admin_api import core as _core # noqa: F401
-from routes.admin_api import update as _update # noqa: F401
+
+# Export security blueprint for app registration
+from routes.admin_api.security import security_bp # noqa: F401
diff --git a/routes/admin_api/core.py b/routes/admin_api/core.py
index 2024345..83a8434 100644
--- a/routes/admin_api/core.py
+++ b/routes/admin_api/core.py
@@ -3,6 +3,8 @@
from __future__ import annotations
import os
+import posixpath
+import secrets
import threading
import time
from datetime import datetime
@@ -15,7 +17,9 @@ from app_logger import get_logger
from app_security import (
get_rate_limit_ip,
is_safe_outbound_url,
+ is_safe_path,
require_ip_not_locked,
+ sanitize_filename,
validate_email,
validate_password,
)
@@ -48,6 +52,36 @@ from services.time_utils import BEIJING_TZ, get_beijing_now
logger = get_logger("app")
config = get_config()
+_server_cpu_percent_lock = threading.Lock()
+_server_cpu_percent_last: float | None = None
+_server_cpu_percent_last_ts = 0.0
+
+
+def _get_server_cpu_percent() -> float:
+ import psutil
+
+ global _server_cpu_percent_last, _server_cpu_percent_last_ts
+
+ now = time.time()
+ with _server_cpu_percent_lock:
+ if _server_cpu_percent_last is not None and (now - _server_cpu_percent_last_ts) < 0.5:
+ return _server_cpu_percent_last
+
+ try:
+ if _server_cpu_percent_last is None:
+ cpu_percent = float(psutil.cpu_percent(interval=0.1))
+ else:
+ cpu_percent = float(psutil.cpu_percent(interval=None))
+ except Exception:
+ cpu_percent = float(_server_cpu_percent_last or 0.0)
+
+ if cpu_percent < 0:
+ cpu_percent = 0.0
+
+ _server_cpu_percent_last = cpu_percent
+ _server_cpu_percent_last_ts = now
+ return cpu_percent
+
def _admin_reauth_required() -> bool:
try:
@@ -61,6 +95,24 @@ def _require_admin_reauth():
return jsonify({"error": "需要二次确认", "code": "reauth_required"}), 401
return None
+def _get_upload_dir():
+ rel_dir = getattr(config, "ANNOUNCEMENT_IMAGE_DIR", "static/announcements")
+ if not is_safe_path(current_app.root_path, rel_dir):
+ rel_dir = "static/announcements"
+ abs_dir = os.path.join(current_app.root_path, rel_dir)
+ os.makedirs(abs_dir, exist_ok=True)
+ return abs_dir, rel_dir
+
+
+def _get_file_size(file_storage):
+ try:
+ file_storage.stream.seek(0, os.SEEK_END)
+ size = file_storage.stream.tell()
+ file_storage.stream.seek(0)
+ return size
+ except Exception:
+ return None
+
@admin_api_bp.route("/debug-config", methods=["GET"])
@admin_required
@@ -199,6 +251,42 @@ def admin_reauth():
# ==================== 公告管理API(管理员) ====================
+@admin_api_bp.route("/announcements/upload_image", methods=["POST"])
+@admin_required
+def admin_upload_announcement_image():
+ """上传公告图片(返回可访问URL)"""
+ file = request.files.get("file")
+ if not file or not file.filename:
+ return jsonify({"error": "请选择图片"}), 400
+
+ filename = sanitize_filename(file.filename)
+ ext = os.path.splitext(filename)[1].lower()
+ allowed_exts = getattr(config, "ALLOWED_ANNOUNCEMENT_IMAGE_EXTENSIONS", {".png", ".jpg", ".jpeg"})
+ if not ext or ext not in allowed_exts:
+ return jsonify({"error": "不支持的图片格式"}), 400
+ if file.mimetype and not str(file.mimetype).startswith("image/"):
+ return jsonify({"error": "文件类型无效"}), 400
+
+ size = _get_file_size(file)
+ max_size = int(getattr(config, "MAX_ANNOUNCEMENT_IMAGE_SIZE", 5 * 1024 * 1024))
+ if size is not None and size > max_size:
+ max_mb = max_size // 1024 // 1024
+ return jsonify({"error": f"图片大小不能超过{max_mb}MB"}), 400
+
+ abs_dir, rel_dir = _get_upload_dir()
+ token = secrets.token_hex(6)
+ name = f"announcement_{int(time.time())}_{token}{ext}"
+ save_path = os.path.join(abs_dir, name)
+ file.save(save_path)
+
+ static_root = os.path.join(current_app.root_path, "static")
+ rel_to_static = os.path.relpath(abs_dir, static_root)
+ if rel_to_static.startswith(".."):
+ rel_to_static = "announcements"
+ url_path = posixpath.join(rel_to_static.replace(os.sep, "/"), name)
+ return jsonify({"success": True, "url": url_for("serve_static", filename=url_path)})
+
+
@admin_api_bp.route("/announcements", methods=["GET"])
@admin_required
def admin_get_announcements():
@@ -221,9 +309,13 @@ def admin_create_announcement():
data = request.json or {}
title = (data.get("title") or "").strip()
content = (data.get("content") or "").strip()
+ image_url = (data.get("image_url") or "").strip()
is_active = bool(data.get("is_active", True))
- announcement_id = database.create_announcement(title, content, is_active=is_active)
+ if image_url and len(image_url) > 1000:
+ return jsonify({"error": "图片地址过长"}), 400
+
+ announcement_id = database.create_announcement(title, content, image_url=image_url, is_active=is_active)
if not announcement_id:
return jsonify({"error": "标题和内容不能为空"}), 400
@@ -317,6 +409,71 @@ def get_system_stats():
return jsonify(stats)
+@admin_api_bp.route("/browser_pool/stats", methods=["GET"])
+@admin_required
+def get_browser_pool_stats():
+ """获取截图线程池状态"""
+ try:
+ from browser_pool_worker import get_browser_worker_pool
+
+ pool = get_browser_worker_pool()
+ stats = pool.get_stats() or {}
+
+ worker_details = []
+ for w in stats.get("workers") or []:
+ last_ts = float(w.get("last_active_ts") or 0)
+ last_active_at = None
+ if last_ts > 0:
+ try:
+ last_active_at = datetime.fromtimestamp(last_ts, tz=BEIJING_TZ).strftime("%Y-%m-%d %H:%M:%S")
+ except Exception:
+ last_active_at = None
+
+ created_ts = w.get("browser_created_at")
+ created_at = None
+ if created_ts:
+ try:
+ created_at = datetime.fromtimestamp(float(created_ts), tz=BEIJING_TZ).strftime("%Y-%m-%d %H:%M:%S")
+ except Exception:
+ created_at = None
+
+ worker_details.append(
+ {
+ "worker_id": w.get("worker_id"),
+ "idle": bool(w.get("idle")),
+ "has_browser": bool(w.get("has_browser")),
+ "total_tasks": int(w.get("total_tasks") or 0),
+ "failed_tasks": int(w.get("failed_tasks") or 0),
+ "browser_use_count": int(w.get("browser_use_count") or 0),
+ "browser_created_at": created_at,
+ "browser_created_ts": created_ts,
+ "last_active_at": last_active_at,
+ "last_active_ts": last_ts,
+ "thread_alive": bool(w.get("thread_alive")),
+ }
+ )
+
+ total_workers = len(worker_details) if worker_details else int(stats.get("pool_size") or 0)
+ return jsonify(
+ {
+ "total_workers": total_workers,
+ "active_workers": int(stats.get("busy_workers") or 0),
+ "idle_workers": int(stats.get("idle_workers") or 0),
+ "queue_size": int(stats.get("queue_size") or 0),
+ "workers": worker_details,
+ "summary": {
+ "total_tasks": int(stats.get("total_tasks") or 0),
+ "failed_tasks": int(stats.get("failed_tasks") or 0),
+ "success_rate": stats.get("success_rate"),
+ },
+ "server_time_cst": get_beijing_now().strftime("%Y-%m-%d %H:%M:%S"),
+ }
+ )
+ except Exception as e:
+ logger.exception(f"[AdminAPI] 获取截图线程池状态失败: {e}")
+ return jsonify({"error": "获取截图线程池状态失败"}), 500
+
+
@admin_api_bp.route("/docker_stats", methods=["GET"])
@admin_required
def get_docker_stats():
@@ -510,9 +667,21 @@ def update_system_config_api():
schedule_weekdays = data.get("schedule_weekdays")
new_max_concurrent_per_account = data.get("max_concurrent_per_account")
new_max_screenshot_concurrent = data.get("max_screenshot_concurrent")
+ enable_screenshot = data.get("enable_screenshot")
auto_approve_enabled = data.get("auto_approve_enabled")
auto_approve_hourly_limit = data.get("auto_approve_hourly_limit")
auto_approve_vip_days = data.get("auto_approve_vip_days")
+ kdocs_enabled = data.get("kdocs_enabled")
+ kdocs_doc_url = data.get("kdocs_doc_url")
+ kdocs_default_unit = data.get("kdocs_default_unit")
+ kdocs_sheet_name = data.get("kdocs_sheet_name")
+ kdocs_sheet_index = data.get("kdocs_sheet_index")
+ kdocs_unit_column = data.get("kdocs_unit_column")
+ kdocs_image_column = data.get("kdocs_image_column")
+ kdocs_admin_notify_enabled = data.get("kdocs_admin_notify_enabled")
+ kdocs_admin_notify_email = data.get("kdocs_admin_notify_email")
+ kdocs_row_start = data.get("kdocs_row_start")
+ kdocs_row_end = data.get("kdocs_row_end")
if max_concurrent is not None:
if not isinstance(max_concurrent, int) or max_concurrent < 1:
@@ -524,7 +693,13 @@ def update_system_config_api():
if new_max_screenshot_concurrent is not None:
if not isinstance(new_max_screenshot_concurrent, int) or new_max_screenshot_concurrent < 1:
- return jsonify({"error": "截图并发数必须大于0(建议根据服务器配置设置,每个浏览器约占用200MB内存)"}), 400
+ return jsonify({"error": "截图并发数必须大于0(建议根据服务器配置设置,wkhtmltoimage 资源占用较低)"}), 400
+
+ if enable_screenshot is not None:
+ if isinstance(enable_screenshot, bool):
+ enable_screenshot = 1 if enable_screenshot else 0
+ if enable_screenshot not in (0, 1):
+ return jsonify({"error": "截图开关必须是0或1"}), 400
if schedule_time is not None:
import re
@@ -554,6 +729,82 @@ def update_system_config_api():
if not isinstance(auto_approve_vip_days, int) or auto_approve_vip_days < 0:
return jsonify({"error": "注册赠送VIP天数不能为负数"}), 400
+ if kdocs_enabled is not None:
+ if isinstance(kdocs_enabled, bool):
+ kdocs_enabled = 1 if kdocs_enabled else 0
+ if kdocs_enabled not in (0, 1):
+ return jsonify({"error": "表格上传开关必须是0或1"}), 400
+
+ if kdocs_doc_url is not None:
+ kdocs_doc_url = str(kdocs_doc_url or "").strip()
+ if kdocs_doc_url and not is_safe_outbound_url(kdocs_doc_url):
+ return jsonify({"error": "文档链接格式不正确"}), 400
+
+ if kdocs_default_unit is not None:
+ kdocs_default_unit = str(kdocs_default_unit or "").strip()
+ if len(kdocs_default_unit) > 50:
+ return jsonify({"error": "默认县区长度不能超过50"}), 400
+
+ if kdocs_sheet_name is not None:
+ kdocs_sheet_name = str(kdocs_sheet_name or "").strip()
+ if len(kdocs_sheet_name) > 50:
+ return jsonify({"error": "Sheet名称长度不能超过50"}), 400
+
+ if kdocs_sheet_index is not None:
+ try:
+ kdocs_sheet_index = int(kdocs_sheet_index)
+ except Exception:
+ return jsonify({"error": "Sheet序号必须是数字"}), 400
+ if kdocs_sheet_index < 0:
+ return jsonify({"error": "Sheet序号不能为负数"}), 400
+
+ if kdocs_unit_column is not None:
+ kdocs_unit_column = str(kdocs_unit_column or "").strip().upper()
+ if not kdocs_unit_column:
+ return jsonify({"error": "县区列不能为空"}), 400
+ import re
+
+ if not re.match(r"^[A-Z]{1,3}$", kdocs_unit_column):
+ return jsonify({"error": "县区列格式错误"}), 400
+
+ if kdocs_image_column is not None:
+ kdocs_image_column = str(kdocs_image_column or "").strip().upper()
+ if not kdocs_image_column:
+ return jsonify({"error": "图片列不能为空"}), 400
+ import re
+
+ if not re.match(r"^[A-Z]{1,3}$", kdocs_image_column):
+ return jsonify({"error": "图片列格式错误"}), 400
+
+ if kdocs_admin_notify_enabled is not None:
+ if isinstance(kdocs_admin_notify_enabled, bool):
+ kdocs_admin_notify_enabled = 1 if kdocs_admin_notify_enabled else 0
+ if kdocs_admin_notify_enabled not in (0, 1):
+ return jsonify({"error": "管理员通知开关必须是0或1"}), 400
+
+ if kdocs_admin_notify_email is not None:
+ kdocs_admin_notify_email = str(kdocs_admin_notify_email or "").strip()
+ if kdocs_admin_notify_email:
+ is_valid, error_msg = validate_email(kdocs_admin_notify_email)
+ if not is_valid:
+ return jsonify({"error": error_msg}), 400
+
+ if kdocs_row_start is not None:
+ try:
+ kdocs_row_start = int(kdocs_row_start)
+ except (ValueError, TypeError):
+ return jsonify({"error": "起始行必须是数字"}), 400
+ if kdocs_row_start < 0:
+ return jsonify({"error": "起始行不能为负数"}), 400
+
+ if kdocs_row_end is not None:
+ try:
+ kdocs_row_end = int(kdocs_row_end)
+ except (ValueError, TypeError):
+ return jsonify({"error": "结束行必须是数字"}), 400
+ if kdocs_row_end < 0:
+ return jsonify({"error": "结束行不能为负数"}), 400
+
old_config = database.get_system_config() or {}
if not database.update_system_config(
@@ -564,9 +815,21 @@ def update_system_config_api():
schedule_weekdays=schedule_weekdays,
max_concurrent_per_account=new_max_concurrent_per_account,
max_screenshot_concurrent=new_max_screenshot_concurrent,
+ enable_screenshot=enable_screenshot,
auto_approve_enabled=auto_approve_enabled,
auto_approve_hourly_limit=auto_approve_hourly_limit,
auto_approve_vip_days=auto_approve_vip_days,
+ kdocs_enabled=kdocs_enabled,
+ kdocs_doc_url=kdocs_doc_url,
+ kdocs_default_unit=kdocs_default_unit,
+ kdocs_sheet_name=kdocs_sheet_name,
+ kdocs_sheet_index=kdocs_sheet_index,
+ kdocs_unit_column=kdocs_unit_column,
+ kdocs_image_column=kdocs_image_column,
+ kdocs_admin_notify_enabled=kdocs_admin_notify_enabled,
+ kdocs_admin_notify_email=kdocs_admin_notify_email,
+ kdocs_row_start=kdocs_row_start,
+ kdocs_row_end=kdocs_row_end,
):
return jsonify({"error": "更新失败"}), 400
@@ -577,6 +840,14 @@ def update_system_config_api():
max_global=int(new_config.get("max_concurrent_global", old_config.get("max_concurrent_global", 2))),
max_per_user=int(new_config.get("max_concurrent_per_account", old_config.get("max_concurrent_per_account", 1))),
)
+ if new_max_screenshot_concurrent is not None:
+ try:
+ from browser_pool_worker import resize_browser_worker_pool
+
+ if resize_browser_worker_pool(int(new_config.get("max_screenshot_concurrent", new_max_screenshot_concurrent))):
+ logger.info(f"截图线程池并发已更新为: {new_config.get('max_screenshot_concurrent')}")
+ except Exception as pool_error:
+ logger.warning(f"截图线程池并发更新失败: {pool_error}")
except Exception:
pass
@@ -590,6 +861,70 @@ def update_system_config_api():
return jsonify({"message": "系统配置已更新"})
+@admin_api_bp.route("/kdocs/status", methods=["GET"])
+@admin_required
+def get_kdocs_status_api():
+ """获取金山文档上传状态"""
+ try:
+ from services.kdocs_uploader import get_kdocs_uploader
+
+ uploader = get_kdocs_uploader()
+ status = uploader.get_status()
+ live = str(request.args.get("live", "")).lower() in ("1", "true", "yes")
+ if live:
+ live_status = uploader.refresh_login_status()
+ if live_status.get("success"):
+ logged_in = bool(live_status.get("logged_in"))
+ status["logged_in"] = logged_in
+ status["last_login_ok"] = logged_in
+ status["login_required"] = not logged_in
+ if live_status.get("error"):
+ status["last_error"] = live_status.get("error")
+ else:
+ status["logged_in"] = True if status.get("last_login_ok") else False if status.get("last_login_ok") is False else None
+ if status.get("last_login_ok") is True and status.get("last_error") == "操作超时":
+ status["last_error"] = None
+ return jsonify(status)
+ except Exception as e:
+ return jsonify({"error": f"获取状态失败: {e}"}), 500
+
+
+@admin_api_bp.route("/kdocs/qr", methods=["POST"])
+@admin_required
+def get_kdocs_qr_api():
+ """获取金山文档登录二维码"""
+ try:
+ from services.kdocs_uploader import get_kdocs_uploader
+
+ uploader = get_kdocs_uploader()
+ data = request.get_json(silent=True) or {}
+ force = bool(data.get("force"))
+ if not force:
+ force = str(request.args.get("force", "")).lower() in ("1", "true", "yes")
+ result = uploader.request_qr(force=force)
+ if not result.get("success"):
+ return jsonify({"error": result.get("error", "获取二维码失败")}), 400
+ return jsonify(result)
+ except Exception as e:
+ return jsonify({"error": f"获取二维码失败: {e}"}), 500
+
+
+@admin_api_bp.route("/kdocs/clear-login", methods=["POST"])
+@admin_required
+def clear_kdocs_login_api():
+ """清除金山文档登录态"""
+ try:
+ from services.kdocs_uploader import get_kdocs_uploader
+
+ uploader = get_kdocs_uploader()
+ result = uploader.clear_login()
+ if not result.get("success"):
+ return jsonify({"error": result.get("error", "清除失败")}), 400
+ return jsonify({"success": True})
+ except Exception as e:
+ return jsonify({"error": f"清除失败: {e}"}), 500
+
+
@admin_api_bp.route("/schedule/execute", methods=["POST"])
@admin_required
def execute_schedule_now():
@@ -673,7 +1008,7 @@ def get_server_info_api():
"""获取服务器信息"""
import psutil
- cpu_percent = psutil.cpu_percent(interval=1)
+ cpu_percent = _get_server_cpu_percent()
memory = psutil.virtual_memory()
memory_total = f"{memory.total / (1024**3):.1f}GB"
@@ -776,30 +1111,44 @@ def get_running_tasks_api():
@admin_required
def get_task_logs_api():
"""获取任务日志列表(支持分页和多种筛选)"""
- limit = int(request.args.get("limit", 20))
- offset = int(request.args.get("offset", 0))
+ try:
+ limit = int(request.args.get("limit", 20))
+ limit = max(1, min(limit, 200)) # 限制 1-200 条
+ except (ValueError, TypeError):
+ limit = 20
+
+ try:
+ offset = int(request.args.get("offset", 0))
+ offset = max(0, offset)
+ except (ValueError, TypeError):
+ offset = 0
+
date_filter = request.args.get("date")
status_filter = request.args.get("status")
source_filter = request.args.get("source")
user_id_filter = request.args.get("user_id")
- account_filter = request.args.get("account")
+ account_filter = (request.args.get("account") or "").strip()
if user_id_filter:
try:
user_id_filter = int(user_id_filter)
- except ValueError:
+ except (ValueError, TypeError):
user_id_filter = None
- result = database.get_task_logs(
- limit=limit,
- offset=offset,
- date_filter=date_filter,
- status_filter=status_filter,
- source_filter=source_filter,
- user_id_filter=user_id_filter,
- account_filter=account_filter,
- )
- return jsonify(result)
+ try:
+ result = database.get_task_logs(
+ limit=limit,
+ offset=offset,
+ date_filter=date_filter,
+ status_filter=status_filter,
+ source_filter=source_filter,
+ user_id_filter=user_id_filter,
+ account_filter=account_filter if account_filter else None,
+ )
+ return jsonify(result)
+ except Exception as e:
+ logger.error(f"获取任务日志失败: {e}")
+ return jsonify({"logs": [], "total": 0, "error": "查询失败"})
@admin_api_bp.route("/task/logs/clear", methods=["POST"])
@@ -910,32 +1259,6 @@ def admin_reset_password_route(user_id):
return jsonify({"error": "重置失败,用户不存在"}), 400
-@admin_api_bp.route("/password_resets", methods=["GET"])
-@admin_required
-def get_password_resets_route():
- """获取所有待审核的密码重置申请"""
- resets = database.get_pending_password_resets()
- return jsonify(resets)
-
-
-@admin_api_bp.route("/password_resets//approve", methods=["POST"])
-@admin_required
-def approve_password_reset_route(request_id):
- """批准密码重置申请"""
- if database.approve_password_reset(request_id):
- return jsonify({"message": "密码重置申请已批准"})
- return jsonify({"error": "批准失败"}), 400
-
-
-@admin_api_bp.route("/password_resets//reject", methods=["POST"])
-@admin_required
-def reject_password_reset_route(request_id):
- """拒绝密码重置申请"""
- if database.reject_password_reset(request_id):
- return jsonify({"message": "密码重置申请已拒绝"})
- return jsonify({"error": "拒绝失败"}), 400
-
-
@admin_api_bp.route("/feedbacks", methods=["GET"])
@admin_required
def get_all_feedbacks():
@@ -1067,6 +1390,7 @@ def update_email_settings_api():
enabled = data.get("enabled", False)
failover_enabled = data.get("failover_enabled", True)
register_verify_enabled = data.get("register_verify_enabled")
+ login_alert_enabled = data.get("login_alert_enabled")
base_url = data.get("base_url")
task_notify_enabled = data.get("task_notify_enabled")
@@ -1074,6 +1398,7 @@ def update_email_settings_api():
enabled=enabled,
failover_enabled=failover_enabled,
register_verify_enabled=register_verify_enabled,
+ login_alert_enabled=login_alert_enabled,
base_url=base_url,
task_notify_enabled=task_notify_enabled,
)
diff --git a/routes/admin_api/security.py b/routes/admin_api/security.py
new file mode 100644
index 0000000..f6a4caa
--- /dev/null
+++ b/routes/admin_api/security.py
@@ -0,0 +1,348 @@
+#!/usr/bin/env python3
+# -*- coding: utf-8 -*-
+from __future__ import annotations
+
+from typing import Any
+
+from flask import Blueprint, jsonify, request
+
+import db_pool
+from db import security as security_db
+from routes.decorators import admin_required
+from security import BlacklistManager, RiskScorer
+
+security_bp = Blueprint("admin_security", __name__)
+blacklist = BlacklistManager()
+scorer = RiskScorer(blacklist_manager=blacklist)
+
+
+def _truncate(value: Any, max_len: int = 200) -> str:
+ text = str(value or "")
+ if max_len <= 0:
+ return ""
+ if len(text) <= max_len:
+ return text
+ return text[: max(0, max_len - 3)] + "..."
+
+
+def _parse_int_arg(name: str, default: int, *, min_value: int | None = None, max_value: int | None = None) -> int:
+ raw = request.args.get(name, None)
+ if raw is None or str(raw).strip() == "":
+ value = int(default)
+ else:
+ try:
+ value = int(str(raw).strip())
+ except Exception:
+ value = int(default)
+
+ if min_value is not None:
+ value = max(int(min_value), value)
+ if max_value is not None:
+ value = min(int(max_value), value)
+ return value
+
+
+def _parse_json() -> dict:
+ if request.is_json:
+ data = request.get_json(silent=True) or {}
+ return data if isinstance(data, dict) else {}
+ # 兼容 form-data
+ try:
+ return dict(request.form or {})
+ except Exception:
+ return {}
+
+
+def _parse_bool(value: Any) -> bool:
+ if isinstance(value, bool):
+ return value
+ if isinstance(value, int):
+ return value != 0
+ text = str(value or "").strip().lower()
+ return text in {"1", "true", "yes", "y", "on"}
+
+
+def _sanitize_threat_event(event: dict) -> dict:
+ return {
+ "id": event.get("id"),
+ "threat_type": event.get("threat_type") or "unknown",
+ "score": int(event.get("score") or 0),
+ "ip": _truncate(event.get("ip"), 64),
+ "user_id": event.get("user_id"),
+ "request_method": _truncate(event.get("request_method"), 16),
+ "request_path": _truncate(event.get("request_path"), 256),
+ "field_name": _truncate(event.get("field_name"), 80),
+ "rule": _truncate(event.get("rule"), 120),
+ "matched": _truncate(event.get("matched"), 120),
+ "value_preview": _truncate(event.get("value_preview"), 200),
+ "created_at": event.get("created_at"),
+ }
+
+
+def _sanitize_ban_entry(entry: dict, *, kind: str) -> dict:
+ if kind == "ip":
+ return {
+ "ip": _truncate(entry.get("ip"), 64),
+ "reason": _truncate(entry.get("reason"), 200),
+ "added_at": entry.get("added_at"),
+ "expires_at": entry.get("expires_at"),
+ "is_active": int(entry.get("is_active") or 0),
+ }
+ if kind == "user":
+ return {
+ "user_id": entry.get("user_id"),
+ "reason": _truncate(entry.get("reason"), 200),
+ "added_at": entry.get("added_at"),
+ "expires_at": entry.get("expires_at"),
+ "is_active": int(entry.get("is_active") or 0),
+ }
+ return {}
+
+
+@security_bp.route("/api/admin/security/dashboard", methods=["GET"])
+@admin_required
+def get_security_dashboard():
+ """
+ 获取安全仪表板数据
+ 返回:
+ - 最近24小时威胁事件数
+ - 当前封禁IP数
+ - 当前封禁用户数
+ - 最近10条威胁事件
+ """
+ try:
+ threat_24h = security_db.get_threat_events_count(hours=24)
+ except Exception:
+ threat_24h = 0
+
+ try:
+ banned_ips = blacklist.get_banned_ips()
+ except Exception:
+ banned_ips = []
+
+ try:
+ banned_users = blacklist.get_banned_users()
+ except Exception:
+ banned_users = []
+
+ try:
+ recent = security_db.get_threat_events_list(page=1, per_page=10, filters={}).get("items", [])
+ recent_items = [_sanitize_threat_event(e) for e in recent if isinstance(e, dict)]
+ except Exception:
+ recent_items = []
+
+ return jsonify(
+ {
+ "threat_events_24h": int(threat_24h or 0),
+ "banned_ip_count": len(banned_ips),
+ "banned_user_count": len(banned_users),
+ "recent_threat_events": recent_items,
+ }
+ )
+
+
+@security_bp.route("/api/admin/security/threats", methods=["GET"])
+@admin_required
+def get_threat_events():
+ """
+ 获取威胁事件列表(分页)
+ 参数: page, per_page, severity, event_type
+ """
+ page = _parse_int_arg("page", 1, min_value=1, max_value=100000)
+ per_page = _parse_int_arg("per_page", 20, min_value=1, max_value=200)
+ severity = (request.args.get("severity") or "").strip()
+ event_type = (request.args.get("event_type") or "").strip()
+
+ filters: dict[str, Any] = {}
+ if severity:
+ filters["severity"] = severity
+ if event_type:
+ filters["event_type"] = event_type
+
+ data = security_db.get_threat_events_list(page, per_page, filters)
+ items = data.get("items") or []
+ data["items"] = [_sanitize_threat_event(e) for e in items if isinstance(e, dict)]
+ return jsonify(data)
+
+
+@security_bp.route("/api/admin/security/banned-ips", methods=["GET"])
+@admin_required
+def get_banned_ips():
+ """获取封禁IP列表"""
+ items = blacklist.get_banned_ips()
+ return jsonify({"count": len(items), "items": [_sanitize_ban_entry(x, kind="ip") for x in items]})
+
+
+@security_bp.route("/api/admin/security/banned-users", methods=["GET"])
+@admin_required
+def get_banned_users():
+ """获取封禁用户列表"""
+ items = blacklist.get_banned_users()
+ return jsonify({"count": len(items), "items": [_sanitize_ban_entry(x, kind="user") for x in items]})
+
+
+@security_bp.route("/api/admin/security/ban-ip", methods=["POST"])
+@admin_required
+def ban_ip():
+ """
+ 手动封禁IP
+ 参数: ip, reason, duration_hours(可选), permanent(可选)
+ """
+ data = _parse_json()
+ ip = str(data.get("ip") or "").strip()
+ reason = str(data.get("reason") or "").strip()
+ duration_hours_raw = data.get("duration_hours", 24)
+ permanent = _parse_bool(data.get("permanent", False))
+
+ if not ip:
+ return jsonify({"error": "ip不能为空"}), 400
+ if not reason:
+ return jsonify({"error": "reason不能为空"}), 400
+
+ try:
+ duration_hours = max(1, int(duration_hours_raw))
+ except Exception:
+ duration_hours = 24
+
+ ok = blacklist.ban_ip(ip, reason, duration_hours=duration_hours, permanent=permanent)
+ if not ok:
+ return jsonify({"error": "封禁失败"}), 400
+ return jsonify({"success": True})
+
+
+@security_bp.route("/api/admin/security/unban-ip", methods=["POST"])
+@admin_required
+def unban_ip():
+ """解除IP封禁"""
+ data = _parse_json()
+ ip = str(data.get("ip") or "").strip()
+ if not ip:
+ return jsonify({"error": "ip不能为空"}), 400
+
+ ok = blacklist.unban_ip(ip)
+ if not ok:
+ return jsonify({"error": "未找到封禁记录"}), 404
+ return jsonify({"success": True})
+
+
+@security_bp.route("/api/admin/security/ban-user", methods=["POST"])
+@admin_required
+def ban_user():
+ """手动封禁用户"""
+ data = _parse_json()
+ user_id_raw = data.get("user_id")
+ reason = str(data.get("reason") or "").strip()
+ duration_hours_raw = data.get("duration_hours", 24)
+ permanent = _parse_bool(data.get("permanent", False))
+
+ try:
+ user_id = int(user_id_raw)
+ except Exception:
+ user_id = None
+
+ if user_id is None:
+ return jsonify({"error": "user_id不能为空"}), 400
+ if not reason:
+ return jsonify({"error": "reason不能为空"}), 400
+
+ try:
+ duration_hours = max(1, int(duration_hours_raw))
+ except Exception:
+ duration_hours = 24
+
+ ok = blacklist._ban_user_internal(user_id, reason=reason, duration_hours=duration_hours, permanent=permanent)
+ if not ok:
+ return jsonify({"error": "封禁失败"}), 400
+ return jsonify({"success": True})
+
+
+@security_bp.route("/api/admin/security/unban-user", methods=["POST"])
+@admin_required
+def unban_user():
+ """解除用户封禁"""
+ data = _parse_json()
+ user_id_raw = data.get("user_id")
+ try:
+ user_id = int(user_id_raw)
+ except Exception:
+ user_id = None
+
+ if user_id is None:
+ return jsonify({"error": "user_id不能为空"}), 400
+
+ ok = blacklist.unban_user(user_id)
+ if not ok:
+ return jsonify({"error": "未找到封禁记录"}), 404
+ return jsonify({"success": True})
+
+
+@security_bp.route("/api/admin/security/ip-risk/", methods=["GET"])
+@admin_required
+def get_ip_risk(ip):
+ """获取指定IP的风险评分和历史事件"""
+ ip_text = str(ip or "").strip()
+ if not ip_text:
+ return jsonify({"error": "ip不能为空"}), 400
+
+ history = security_db.get_ip_threat_history(ip_text)
+ return jsonify(
+ {
+ "ip": _truncate(ip_text, 64),
+ "risk_score": int(scorer.get_ip_score(ip_text) or 0),
+ "is_banned": bool(blacklist.is_ip_banned(ip_text)),
+ "threat_history": [_sanitize_threat_event(e) for e in history if isinstance(e, dict)],
+ }
+ )
+
+
+@security_bp.route("/api/admin/security/ip-risk/clear", methods=["POST"])
+@admin_required
+def clear_ip_risk():
+ """清除指定IP的风险分"""
+ data = _parse_json()
+ ip_text = str(data.get("ip") or "").strip()
+ if not ip_text:
+ return jsonify({"error": "ip不能为空"}), 400
+
+ if not scorer.reset_ip_score(ip_text):
+ return jsonify({"error": "清理失败"}), 400
+
+ return jsonify({"success": True, "ip": _truncate(ip_text, 64), "risk_score": 0})
+
+
+@security_bp.route("/api/admin/security/user-risk/", methods=["GET"])
+@admin_required
+def get_user_risk(user_id):
+ """获取指定用户的风险评分和历史事件"""
+ history = security_db.get_user_threat_history(user_id)
+ return jsonify(
+ {
+ "user_id": int(user_id),
+ "risk_score": int(scorer.get_user_score(user_id) or 0),
+ "is_banned": bool(blacklist.is_user_banned(user_id)),
+ "threat_history": [_sanitize_threat_event(e) for e in history if isinstance(e, dict)],
+ }
+ )
+
+
+@security_bp.route("/api/admin/security/cleanup", methods=["POST"])
+@admin_required
+def cleanup_expired():
+ """清理过期的封禁记录和衰减风险分"""
+ try:
+ blacklist.cleanup_expired()
+ except Exception:
+ pass
+ try:
+ scorer.decay_scores()
+ except Exception:
+ pass
+
+ # 可选:返回当前连接池统计信息,便于排查后台运行状态
+ pool_stats = None
+ try:
+ pool_stats = db_pool.get_pool_stats()
+ except Exception:
+ pool_stats = None
+
+ return jsonify({"success": True, "pool_stats": pool_stats})
diff --git a/routes/admin_api/update.py b/routes/admin_api/update.py
index 735d2e9..cd2e0cd 100644
--- a/routes/admin_api/update.py
+++ b/routes/admin_api/update.py
@@ -3,7 +3,6 @@
from __future__ import annotations
import os
-import time
import uuid
from flask import jsonify, request, session
@@ -66,13 +65,6 @@ def _parse_bool_field(data: dict, key: str) -> bool | None:
raise ValueError(f"{key} 必须是 0/1 或 true/false")
-def _admin_reauth_required() -> bool:
- try:
- return time.time() > float(session.get("admin_reauth_until", 0) or 0)
- except Exception:
- return True
-
-
@admin_api_bp.route("/update/status", methods=["GET"])
@admin_required
def get_update_status_api():
@@ -154,8 +146,6 @@ def request_update_check_api():
def request_update_run_api():
"""请求宿主机 Update-Agent 执行一键更新并重启服务。"""
ensure_update_dirs()
- if _admin_reauth_required():
- return jsonify({"error": "需要二次确认", "code": "reauth_required"}), 401
if _has_pending_request():
return jsonify({"error": "已有更新请求正在处理中,请稍后再试"}), 409
diff --git a/routes/api_accounts.py b/routes/api_accounts.py
index bc8ce6a..2bb5402 100644
--- a/routes/api_accounts.py
+++ b/routes/api_accounts.py
@@ -11,7 +11,6 @@ from crypto_utils import encrypt_password as encrypt_account_password
from flask import Blueprint, jsonify, request
from flask_login import current_user, login_required
from services.accounts_service import load_user_accounts
-from services.browser_manager import init_browser_manager_async
from services.browse_types import BROWSE_TYPE_SHOULD_READ, normalize_browse_type, validate_browse_type
from services.client_log import log_to_client
from services.models import Account
@@ -230,10 +229,6 @@ def start_account(account_id):
if not browse_type:
return jsonify({"error": "浏览类型无效"}), 400
enable_screenshot = data.get("enable_screenshot", True)
- if enable_screenshot:
- # 异步初始化浏览器环境,避免首次下载/安装 Chromium 阻塞请求导致“网页无响应”
- init_browser_manager_async()
-
ok, message = submit_account_task(
user_id=user_id,
account_id=account_id,
@@ -308,9 +303,6 @@ def manual_screenshot(account_id):
account.last_browse_type = browse_type
- # 异步初始化浏览器环境,避免首次下载/安装 Chromium 阻塞请求
- init_browser_manager_async()
-
threading.Thread(
target=take_screenshot_for_account,
args=(user_id, account_id, browse_type, "manual_screenshot"),
@@ -336,10 +328,6 @@ def batch_start_accounts():
if not account_ids:
return jsonify({"error": "请选择要启动的账号"}), 400
- if enable_screenshot:
- # 异步初始化浏览器环境,避免首次下载/安装 Chromium 阻塞请求
- init_browser_manager_async()
-
started = []
failed = []
diff --git a/routes/api_auth.py b/routes/api_auth.py
index 6080314..3621892 100644
--- a/routes/api_auth.py
+++ b/routes/api_auth.py
@@ -237,23 +237,31 @@ def forgot_password():
"""发送密码重置邮件"""
data = request.json or {}
email = data.get("email", "").strip().lower()
+ username = data.get("username", "").strip()
captcha_session = data.get("captcha_session", "")
captcha_code = data.get("captcha", "").strip()
- if not email:
- return jsonify({"error": "请输入邮箱"}), 400
+ if not email and not username:
+ return jsonify({"error": "请输入邮箱或用户名"}), 400
- is_valid, error_msg = validate_email(email)
- if not is_valid:
- return jsonify({"error": error_msg}), 400
+ if username:
+ is_valid, error_msg = validate_username(username)
+ if not is_valid:
+ return jsonify({"error": error_msg}), 400
+
+ if email:
+ is_valid, error_msg = validate_email(email)
+ if not is_valid:
+ return jsonify({"error": error_msg}), 400
client_ip = get_rate_limit_ip()
allowed, error_msg = check_ip_request_rate(client_ip, "email")
if not allowed:
return jsonify({"error": error_msg}), 429
- allowed, error_msg = check_email_rate_limit(email, "forgot_password")
- if not allowed:
- return jsonify({"error": error_msg}), 429
+ if email:
+ allowed, error_msg = check_email_rate_limit(email, "forgot_password")
+ if not allowed:
+ return jsonify({"error": error_msg}), 429
success, message = safe_verify_and_consume_captcha(captcha_session, captcha_code)
if not success:
@@ -266,6 +274,34 @@ def forgot_password():
if not email_settings.get("enabled", False):
return jsonify({"error": "邮件功能未启用,请联系管理员"}), 400
+ if username:
+ user = database.get_user_by_username(username)
+ if user and user.get("status") == "approved":
+ bound_email = (user.get("email") or "").strip()
+ if not bound_email:
+ return (
+ jsonify(
+ {
+ "error": "您尚未绑定邮箱,无法通过邮箱找回密码。请联系管理员重置密码。",
+ "code": "email_not_bound",
+ }
+ ),
+ 400,
+ )
+
+ allowed, error_msg = check_email_rate_limit(bound_email, "forgot_password")
+ if not allowed:
+ return jsonify({"error": error_msg}), 429
+
+ result = email_service.send_password_reset_email(
+ email=bound_email,
+ username=user["username"],
+ user_id=user["id"],
+ )
+ if not result["success"]:
+ logger.error(f"密码重置邮件发送失败: {result['error']}")
+ return jsonify({"success": True, "message": "如果该账号已绑定邮箱,您将收到密码重置邮件"})
+
user = database.get_user_by_email(email)
if user and user.get("status") == "approved":
result = email_service.send_password_reset_email(email=email, username=user["username"], user_id=user["id"])
@@ -317,46 +353,6 @@ def reset_password_confirm():
return jsonify({"error": "密码重置失败"}), 500
-@api_auth_bp.route("/api/reset_password_request", methods=["POST"])
-def request_password_reset():
- """用户申请重置密码(需要审核)"""
- data = request.json or {}
- username = data.get("username", "").strip()
- email = data.get("email", "").strip().lower()
- new_password = data.get("new_password", "").strip()
-
- if not username or not new_password:
- return jsonify({"error": "用户名和新密码不能为空"}), 400
-
- is_valid, error_msg = validate_password(new_password)
- if not is_valid:
- return jsonify({"error": error_msg}), 400
-
- if email:
- is_valid, error_msg = validate_email(email)
- if not is_valid:
- return jsonify({"error": error_msg}), 400
-
- client_ip = get_rate_limit_ip()
- allowed, error_msg = check_ip_request_rate(client_ip, "email")
- if not allowed:
- return jsonify({"error": error_msg}), 429
- if email:
- allowed, error_msg = check_email_rate_limit(email, "reset_request")
- if not allowed:
- return jsonify({"error": error_msg}), 429
-
- user = database.get_user_by_username(username)
-
- if user:
- if email and user.get("email") != email:
- pass
- else:
- database.create_password_reset_request(user["id"], new_password)
-
- return jsonify({"message": "如果账号存在,密码重置申请已提交,请等待管理员审核"})
-
-
@api_auth_bp.route("/api/generate_captcha", methods=["POST"])
def generate_captcha():
"""生成4位数字验证码图片"""
@@ -481,15 +477,19 @@ def login():
load_user_accounts(user["id"])
try:
- user_agent = request.headers.get("User-Agent", "")
- context = database.record_login_context(user["id"], client_ip, user_agent)
- if context and (context.get("new_ip") or context.get("new_device")):
- if config.LOGIN_ALERT_ENABLED and should_send_login_alert(user["id"], client_ip):
- user_info = database.get_user_by_id(user["id"]) or {}
- if user_info.get("email") and user_info.get("email_verified"):
- if database.get_user_email_notify(user["id"]):
- email_service.send_security_alert_email(
- email=user_info.get("email"),
+ user_agent = request.headers.get("User-Agent", "")
+ context = database.record_login_context(user["id"], client_ip, user_agent)
+ if context and (context.get("new_ip") or context.get("new_device")):
+ if (
+ config.LOGIN_ALERT_ENABLED
+ and should_send_login_alert(user["id"], client_ip)
+ and email_service.get_email_settings().get("login_alert_enabled", True)
+ ):
+ user_info = database.get_user_by_id(user["id"]) or {}
+ if user_info.get("email") and user_info.get("email_verified"):
+ if database.get_user_email_notify(user["id"]):
+ email_service.send_security_alert_email(
+ email=user_info.get("email"),
username=user_info.get("username") or username,
ip_address=client_ip,
user_agent=user_agent,
diff --git a/routes/api_user.py b/routes/api_user.py
index 1c1ba8b..cb40502 100644
--- a/routes/api_user.py
+++ b/routes/api_user.py
@@ -35,6 +35,7 @@ def get_active_announcement():
"id": announcement.get("id"),
"title": announcement.get("title", ""),
"content": announcement.get("content", ""),
+ "image_url": announcement.get("image_url") or "",
"created_at": announcement.get("created_at"),
}
}
@@ -147,6 +148,50 @@ def get_user_email():
return jsonify({"email": user.get("email", ""), "email_verified": user.get("email_verified", False)})
+@api_user_bp.route("/api/user/kdocs", methods=["GET"])
+@login_required
+def get_user_kdocs_settings():
+ """获取当前用户的金山文档设置"""
+ settings = database.get_user_kdocs_settings(current_user.id)
+ if not settings:
+ return jsonify({"kdocs_unit": "", "kdocs_auto_upload": 0})
+ return jsonify(settings)
+
+
+@api_user_bp.route("/api/user/kdocs", methods=["POST"])
+@login_required
+def update_user_kdocs_settings():
+ """更新当前用户的金山文档设置"""
+ data = request.get_json() or {}
+ kdocs_unit = data.get("kdocs_unit")
+ kdocs_auto_upload = data.get("kdocs_auto_upload")
+
+ if kdocs_unit is not None:
+ kdocs_unit = str(kdocs_unit or "").strip()
+ if len(kdocs_unit) > 50:
+ return jsonify({"error": "县区长度不能超过50"}), 400
+
+ if kdocs_auto_upload is not None:
+ if isinstance(kdocs_auto_upload, bool):
+ kdocs_auto_upload = 1 if kdocs_auto_upload else 0
+ try:
+ kdocs_auto_upload = int(kdocs_auto_upload)
+ except Exception:
+ return jsonify({"error": "自动上传开关必须是0或1"}), 400
+ if kdocs_auto_upload not in (0, 1):
+ return jsonify({"error": "自动上传开关必须是0或1"}), 400
+
+ if not database.update_user_kdocs_settings(
+ current_user.id,
+ kdocs_unit=kdocs_unit,
+ kdocs_auto_upload=kdocs_auto_upload,
+ ):
+ return jsonify({"error": "更新失败"}), 400
+
+ settings = database.get_user_kdocs_settings(current_user.id) or {"kdocs_unit": "", "kdocs_auto_upload": 0}
+ return jsonify({"success": True, "settings": settings})
+
+
@api_user_bp.route("/api/user/bind-email", methods=["POST"])
@login_required
@require_ip_not_locked
@@ -303,3 +348,37 @@ def get_run_stats():
"today_attachments": stats.get("total_attachments", 0),
}
)
+
+
+@api_user_bp.route("/api/kdocs/status", methods=["GET"])
+@login_required
+def get_kdocs_status_for_user():
+ """获取金山文档在线状态(用户端简化版)"""
+ try:
+ # 检查系统是否启用了金山文档功能
+ cfg = database.get_system_config() or {}
+ kdocs_enabled = int(cfg.get("kdocs_enabled") or 0)
+
+ if not kdocs_enabled:
+ return jsonify({"enabled": False, "online": False, "message": "未启用"})
+
+ # 获取金山文档状态
+ from services.kdocs_uploader import get_kdocs_uploader
+
+ kdocs = get_kdocs_uploader()
+ status = kdocs.get_status()
+
+ login_required_flag = status.get("login_required", False)
+ last_login_ok = status.get("last_login_ok")
+
+ # 判断是否在线
+ is_online = not login_required_flag and last_login_ok is True
+
+ return jsonify({
+ "enabled": True,
+ "online": is_online,
+ "message": "就绪" if is_online else "离线"
+ })
+ except Exception as e:
+ logger.error(f"获取金山文档状态失败: {e}")
+ return jsonify({"enabled": False, "online": False, "message": "获取失败"})
diff --git a/routes/decorators.py b/routes/decorators.py
index 798babf..f99f9c3 100644
--- a/routes/decorators.py
+++ b/routes/decorators.py
@@ -14,11 +14,20 @@ def admin_required(f):
@wraps(f)
def decorated_function(*args, **kwargs):
- logger = get_logger()
+ try:
+ logger = get_logger()
+ except Exception:
+ import logging
+
+ logger = logging.getLogger("app")
logger.debug(f"[admin_required] 检查会话,admin_id存在: {'admin_id' in session}")
if "admin_id" not in session:
logger.warning(f"[admin_required] 拒绝访问 {request.path} - session中无admin_id")
- is_api = request.blueprint == "admin_api" or request.path.startswith("/yuyx/api")
+ is_api = (
+ request.blueprint in {"admin_api", "admin_security", "admin_security_yuyx"}
+ or request.path.startswith("/yuyx/api")
+ or request.path.startswith("/api/admin")
+ )
if is_api:
return jsonify({"error": "需要管理员权限"}), 403
return redirect(url_for("pages.admin_login_page"))
diff --git a/routes/pages.py b/routes/pages.py
index b0656d8..66e6c77 100644
--- a/routes/pages.py
+++ b/routes/pages.py
@@ -6,7 +6,7 @@ import json
import os
from typing import Optional
-from flask import Blueprint, current_app, redirect, render_template, session, url_for
+from flask import Blueprint, current_app, redirect, render_template, request, session, url_for
from flask_login import current_user, login_required
from routes.decorators import admin_required
@@ -36,10 +36,18 @@ def render_app_spa_or_legacy(
logger.warning(f"[app_spa] manifest缺少入口文件: {manifest_path}")
return render_template(legacy_template_name, **legacy_context)
+ app_spa_js_file = f"app/{js_file}"
+ app_spa_css_files = [f"app/{p}" for p in css_files]
+ app_spa_build_id = _get_asset_build_id(
+ os.path.join(current_app.root_path, "static"),
+ [app_spa_js_file, *app_spa_css_files],
+ )
+
return render_template(
"app.html",
- app_spa_js_file=f"app/{js_file}",
- app_spa_css_files=[f"app/{p}" for p in css_files],
+ app_spa_js_file=app_spa_js_file,
+ app_spa_css_files=app_spa_css_files,
+ app_spa_build_id=app_spa_build_id,
app_spa_initial_state=spa_initial_state,
)
except FileNotFoundError:
@@ -50,6 +58,27 @@ def render_app_spa_or_legacy(
return render_template(legacy_template_name, **legacy_context)
+def _get_asset_build_id(static_root: str, rel_paths: list[str]) -> Optional[str]:
+ mtimes = []
+ for rel_path in rel_paths:
+ if not rel_path:
+ continue
+ try:
+ mtimes.append(os.path.getmtime(os.path.join(static_root, rel_path)))
+ except OSError:
+ continue
+ if not mtimes:
+ return None
+ return str(int(max(mtimes)))
+
+
+def _is_legacy_admin_user_agent(user_agent: str) -> bool:
+ if not user_agent:
+ return False
+ ua = user_agent.lower()
+ return "msie" in ua or "trident/" in ua
+
+
@pages_bp.route("/")
def index():
"""主页 - 重定向到登录或应用"""
@@ -96,6 +125,8 @@ def admin_login_page():
@admin_required
def admin_page():
"""后台管理页面"""
+ if request.args.get("legacy") == "1" or _is_legacy_admin_user_agent(request.headers.get("User-Agent", "")):
+ return render_template("admin_legacy.html")
logger = get_logger()
manifest_path = os.path.join(current_app.root_path, "static", "admin", ".vite", "manifest.json")
try:
@@ -110,10 +141,18 @@ def admin_page():
logger.warning(f"[admin_spa] manifest缺少入口文件: {manifest_path}")
return render_template("admin_legacy.html")
+ admin_spa_js_file = f"admin/{js_file}"
+ admin_spa_css_files = [f"admin/{p}" for p in css_files]
+ admin_spa_build_id = _get_asset_build_id(
+ os.path.join(current_app.root_path, "static"),
+ [admin_spa_js_file, *admin_spa_css_files],
+ )
+
return render_template(
"admin.html",
- admin_spa_js_file=f"admin/{js_file}",
- admin_spa_css_files=[f"admin/{p}" for p in css_files],
+ admin_spa_js_file=admin_spa_js_file,
+ admin_spa_css_files=admin_spa_css_files,
+ admin_spa_build_id=admin_spa_build_id,
)
except FileNotFoundError:
logger.warning(f"[admin_spa] 未找到manifest: {manifest_path},回退旧版后台模板")
diff --git a/security/__init__.py b/security/__init__.py
new file mode 100644
index 0000000..a82f31c
--- /dev/null
+++ b/security/__init__.py
@@ -0,0 +1,22 @@
+#!/usr/bin/env python3
+# -*- coding: utf-8 -*-
+from __future__ import annotations
+
+from security.blacklist import BlacklistManager
+from security.honeypot import HoneypotResponder
+from security.middleware import init_security_middleware
+from security.response_handler import ResponseAction, ResponseHandler, ResponseStrategy
+from security.risk_scorer import RiskScorer
+from security.threat_detector import ThreatDetector, ThreatResult
+
+__all__ = [
+ "BlacklistManager",
+ "HoneypotResponder",
+ "init_security_middleware",
+ "ResponseAction",
+ "ResponseHandler",
+ "ResponseStrategy",
+ "RiskScorer",
+ "ThreatDetector",
+ "ThreatResult",
+]
diff --git a/security/blacklist.py b/security/blacklist.py
new file mode 100644
index 0000000..56b6f12
--- /dev/null
+++ b/security/blacklist.py
@@ -0,0 +1,255 @@
+#!/usr/bin/env python3
+# -*- coding: utf-8 -*-
+from __future__ import annotations
+
+import threading
+from datetime import timedelta
+from typing import List, Optional
+
+import db_pool
+from db.utils import get_cst_now, get_cst_now_str
+
+
+class BlacklistManager:
+ """黑名单管理器"""
+
+ def __init__(self) -> None:
+ self._schema_ready = False
+ self._schema_lock = threading.Lock()
+
+ def is_ip_banned(self, ip: str) -> bool:
+ """检查IP是否被封禁"""
+ ip_text = str(ip or "").strip()[:64]
+ if not ip_text:
+ return False
+ now_str = get_cst_now_str()
+
+ with db_pool.get_db() as conn:
+ cursor = conn.cursor()
+ cursor.execute(
+ """
+ SELECT 1
+ FROM ip_blacklist
+ WHERE ip = ?
+ AND is_active = 1
+ AND (expires_at IS NULL OR expires_at > ?)
+ LIMIT 1
+ """,
+ (ip_text, now_str),
+ )
+ return cursor.fetchone() is not None
+
+ def is_user_banned(self, user_id: int) -> bool:
+ """检查用户是否被封禁"""
+ if user_id is None:
+ return False
+ self._ensure_schema()
+ user_id_int = int(user_id)
+ now_str = get_cst_now_str()
+
+ with db_pool.get_db() as conn:
+ cursor = conn.cursor()
+ cursor.execute(
+ """
+ SELECT 1
+ FROM user_blacklist
+ WHERE user_id = ?
+ AND is_active = 1
+ AND (expires_at IS NULL OR expires_at > ?)
+ LIMIT 1
+ """,
+ (user_id_int, now_str),
+ )
+ return cursor.fetchone() is not None
+
+ def ban_ip(self, ip: str, reason: str, duration_hours: int = 24, permanent: bool = False):
+ """封禁IP"""
+ ip_text = str(ip or "").strip()[:64]
+ if not ip_text:
+ return False
+ reason_text = str(reason or "").strip()[:512]
+ now_str = get_cst_now_str()
+
+ expires_at: Optional[str]
+ if permanent:
+ expires_at = None
+ else:
+ hours = max(1, int(duration_hours))
+ expires_at = (get_cst_now() + timedelta(hours=hours)).strftime("%Y-%m-%d %H:%M:%S")
+
+ with db_pool.get_db() as conn:
+ cursor = conn.cursor()
+ cursor.execute(
+ """
+ INSERT INTO ip_blacklist (ip, reason, is_active, added_at, expires_at)
+ VALUES (?, ?, 1, ?, ?)
+ ON CONFLICT(ip) DO UPDATE SET
+ reason = excluded.reason,
+ is_active = 1,
+ added_at = excluded.added_at,
+ expires_at = excluded.expires_at
+ """,
+ (ip_text, reason_text, now_str, expires_at),
+ )
+ conn.commit()
+ return True
+
+ def ban_user(self, user_id: int, reason: str):
+ """封禁用户"""
+ return self._ban_user_internal(user_id, reason=reason, duration_hours=24, permanent=False)
+
+ def unban_ip(self, ip: str):
+ """解除IP封禁"""
+ ip_text = str(ip or "").strip()[:64]
+ if not ip_text:
+ return False
+ with db_pool.get_db() as conn:
+ cursor = conn.cursor()
+ cursor.execute("UPDATE ip_blacklist SET is_active = 0 WHERE ip = ?", (ip_text,))
+ conn.commit()
+ return cursor.rowcount > 0
+
+ def unban_user(self, user_id: int):
+ """解除用户封禁"""
+ if user_id is None:
+ return False
+ self._ensure_schema()
+ user_id_int = int(user_id)
+ with db_pool.get_db() as conn:
+ cursor = conn.cursor()
+ cursor.execute("UPDATE user_blacklist SET is_active = 0 WHERE user_id = ?", (user_id_int,))
+ conn.commit()
+ return cursor.rowcount > 0
+
+ def get_banned_ips(self) -> List[dict]:
+ """获取所有被封禁的IP"""
+ now_str = get_cst_now_str()
+ with db_pool.get_db() as conn:
+ cursor = conn.cursor()
+ cursor.execute(
+ """
+ SELECT ip, reason, is_active, added_at, expires_at
+ FROM ip_blacklist
+ WHERE is_active = 1
+ AND (expires_at IS NULL OR expires_at > ?)
+ ORDER BY added_at DESC
+ """,
+ (now_str,),
+ )
+ return [dict(row) for row in cursor.fetchall()]
+
+ def get_banned_users(self) -> List[dict]:
+ """获取所有被封禁的用户"""
+ self._ensure_schema()
+ now_str = get_cst_now_str()
+ with db_pool.get_db() as conn:
+ cursor = conn.cursor()
+ cursor.execute(
+ """
+ SELECT user_id, reason, is_active, added_at, expires_at
+ FROM user_blacklist
+ WHERE is_active = 1
+ AND (expires_at IS NULL OR expires_at > ?)
+ ORDER BY added_at DESC
+ """,
+ (now_str,),
+ )
+ return [dict(row) for row in cursor.fetchall()]
+
+ def cleanup_expired(self):
+ """清理过期的封禁记录"""
+ now_str = get_cst_now_str()
+ with db_pool.get_db() as conn:
+ cursor = conn.cursor()
+ cursor.execute(
+ """
+ UPDATE ip_blacklist
+ SET is_active = 0
+ WHERE is_active = 1
+ AND expires_at IS NOT NULL
+ AND expires_at <= ?
+ """,
+ (now_str,),
+ )
+ conn.commit()
+
+ self._ensure_schema()
+ with db_pool.get_db() as conn:
+ cursor = conn.cursor()
+ cursor.execute(
+ """
+ UPDATE user_blacklist
+ SET is_active = 0
+ WHERE is_active = 1
+ AND expires_at IS NOT NULL
+ AND expires_at <= ?
+ """,
+ (now_str,),
+ )
+ conn.commit()
+
+ # ==================== Internal ====================
+
+ def _ensure_schema(self) -> None:
+ if self._schema_ready:
+ return
+ with self._schema_lock:
+ if self._schema_ready:
+ return
+ with db_pool.get_db() as conn:
+ cursor = conn.cursor()
+ cursor.execute(
+ """
+ CREATE TABLE IF NOT EXISTS user_blacklist (
+ user_id INTEGER PRIMARY KEY,
+ reason TEXT,
+ is_active INTEGER DEFAULT 1,
+ added_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
+ expires_at TIMESTAMP
+ )
+ """
+ )
+ cursor.execute("CREATE INDEX IF NOT EXISTS idx_user_blacklist_active ON user_blacklist(is_active)")
+ cursor.execute("CREATE INDEX IF NOT EXISTS idx_user_blacklist_expires ON user_blacklist(expires_at)")
+ conn.commit()
+ self._schema_ready = True
+
+ def _ban_user_internal(
+ self,
+ user_id: int,
+ *,
+ reason: str,
+ duration_hours: int = 24,
+ permanent: bool = False,
+ ) -> bool:
+ if user_id is None:
+ return False
+ self._ensure_schema()
+ user_id_int = int(user_id)
+ reason_text = str(reason or "").strip()[:512]
+ now_str = get_cst_now_str()
+
+ expires_at: Optional[str]
+ if permanent:
+ expires_at = None
+ else:
+ hours = max(1, int(duration_hours))
+ expires_at = (get_cst_now() + timedelta(hours=hours)).strftime("%Y-%m-%d %H:%M:%S")
+
+ with db_pool.get_db() as conn:
+ cursor = conn.cursor()
+ cursor.execute(
+ """
+ INSERT INTO user_blacklist (user_id, reason, is_active, added_at, expires_at)
+ VALUES (?, ?, 1, ?, ?)
+ ON CONFLICT(user_id) DO UPDATE SET
+ reason = excluded.reason,
+ is_active = 1,
+ added_at = excluded.added_at,
+ expires_at = excluded.expires_at
+ """,
+ (user_id_int, reason_text, now_str, expires_at),
+ )
+ conn.commit()
+ return True
+
diff --git a/security/constants.py b/security/constants.py
new file mode 100644
index 0000000..923e50a
--- /dev/null
+++ b/security/constants.py
@@ -0,0 +1,146 @@
+#!/usr/bin/env python3
+# -*- coding: utf-8 -*-
+from __future__ import annotations
+
+import re
+
+# ==================== Threat Types ====================
+
+THREAT_TYPE_JNDI_INJECTION = "jndi_injection"
+THREAT_TYPE_NESTED_EXPRESSION = "nested_expression"
+THREAT_TYPE_SQL_INJECTION = "sql_injection"
+THREAT_TYPE_XSS = "xss"
+THREAT_TYPE_PATH_TRAVERSAL = "path_traversal"
+THREAT_TYPE_COMMAND_INJECTION = "command_injection"
+THREAT_TYPE_SSRF = "ssrf"
+THREAT_TYPE_XXE = "xxe"
+THREAT_TYPE_TEMPLATE_INJECTION = "template_injection"
+THREAT_TYPE_SENSITIVE_PATH_PROBE = "sensitive_path_probe"
+
+
+# ==================== Scores ====================
+
+SCORE_JNDI_DIRECT = 100
+SCORE_JNDI_OBFUSCATED = 100
+SCORE_NESTED_EXPRESSION = 80
+SCORE_SQL_INJECTION = 90
+SCORE_XSS = 70
+SCORE_PATH_TRAVERSAL = 60
+SCORE_COMMAND_INJECTION = 85
+SCORE_SSRF = 75
+SCORE_XXE = 85
+SCORE_TEMPLATE_INJECTION = 70
+SCORE_SENSITIVE_PATH_PROBE = 40
+
+
+# ==================== JNDI (Log4j) ====================
+#
+# - Direct: ${jndi:ldap://...} / ${jndi:rmi://...} => 100
+# - Obfuscated: ${${xxx:-j}${xxx:-n}...:ldap://...} => detect
+# - Nested expression: ${${...}} => 80
+
+JNDI_DIRECT_PATTERN = r"\$\{\s*jndi\s*:\s*(?:ldap|rmi)\s*://"
+
+# Common Log4j "default value" obfuscation variants:
+# ${${::-j}${::-n}${::-d}${::-i}:ldap://...}
+# ${${foo:-j}${bar:-n}${baz:-d}${qux:-i}:rmi://...}
+JNDI_OBFUSCATED_PATTERN = (
+ r"\$\{\s*"
+ r"(?:\$\{[^{}]{0,50}:-j\}|\$\{::-[jJ]\})\s*"
+ r"(?:\$\{[^{}]{0,50}:-n\}|\$\{::-[nN]\})\s*"
+ r"(?:\$\{[^{}]{0,50}:-d\}|\$\{::-[dD]\})\s*"
+ r"(?:\$\{[^{}]{0,50}:-i\}|\$\{::-[iI]\})\s*"
+ r":\s*(?:ldap|rmi)\s*://"
+)
+
+NESTED_EXPRESSION_PATTERN = r"\$\{\s*\$\{"
+
+
+# ==================== SQL Injection ====================
+
+SQLI_UNION_SELECT_PATTERN = r"\bunion\b\s+(?:all\s+)?\bselect\b"
+SQLI_OR_1_EQ_1_PATTERN = r"\bor\b\s+1\s*=\s*1\b"
+
+
+# ==================== XSS ====================
+
+XSS_SCRIPT_TAG_PATTERN = r"<\s*script\b"
+XSS_JS_PROTOCOL_PATTERN = r"javascript\s*:"
+XSS_INLINE_EVENT_HANDLER_PATTERN = r"\bon\w+\s*="
+
+
+# ==================== Path Traversal ====================
+
+PATH_TRAVERSAL_PATTERN = r"(?:\.\./|\.\.\\)+"
+
+
+# ==================== Command Injection ====================
+
+CMD_INJECTION_OPERATOR_WITH_CMD_PATTERN = (
+ r"(?:;|&&|\|\||\|)\s*"
+ r"(?:bash|sh|zsh|cmd|powershell|pwsh|curl|wget|nc|netcat|python|perl|ruby|php|node|cat|ls|id|whoami|uname|rm)\b"
+)
+CMD_INJECTION_SUBSHELL_PATTERN = r"(?:`[^`]{1,200}`|\$\([^)]{1,200}\))"
+
+
+# ==================== SSRF ====================
+
+SSRF_LOCALHOST_URL_PATTERN = r"\bhttps?\s*:\s*//\s*(?:127\.0\.0\.1\b|localhost\b|0\.0\.0\.0\b)"
+SSRF_INTERNAL_IP_URL_PATTERN = r"\bhttps?\s*:\s*//\s*(?:10\.|192\.168\.|172\.(?:1[6-9]|2[0-9]|3[0-1])\.)"
+SSRF_DANGEROUS_PROTOCOL_PATTERN = r"\b(?:file|gopher|dict)\s*:\s*//"
+
+
+# ==================== XXE ====================
+
+XXE_DOCTYPE_PATTERN = r" None:
+ self._rng = rng or random.SystemRandom()
+ self._logger = get_logger("app")
+
+ def generate_fake_response(self, endpoint: str, original_data: dict = None) -> dict:
+ """
+ 根据端点生成假的成功响应
+
+ 策略:
+ - 邮件发送类: {"success": True, "message": "邮件已发送"}
+ - 注册类: {"success": True, "user_id": fake_uuid}
+ - 登录类: {"success": True} 但不设置session
+ - 通用: {"success": True, "message": "操作成功"}
+ """
+ endpoint_text = str(endpoint or "").strip()
+ endpoint_lc = endpoint_text.lower()
+
+ category = self._classify_endpoint(endpoint_lc)
+ response: dict[str, Any] = {"success": True}
+
+ if category == "email":
+ response["message"] = "邮件已发送"
+ elif category == "register":
+ response["user_id"] = str(uuid.uuid4())
+ elif category == "login":
+ # 登录类:保持正常成功响应,但不进行任何 session / token 设置(调用方负责不写 session)
+ pass
+ else:
+ response["message"] = "操作成功"
+
+ response = self._merge_safe_fields(response, original_data)
+
+ self._logger.warning(
+ "蜜罐响应已生成: endpoint=%s, category=%s, keys=%s",
+ endpoint_text[:256],
+ category,
+ sorted(response.keys()),
+ )
+ return response
+
+ def should_use_honeypot(self, risk_score: int) -> bool:
+ """风险分>=80使用蜜罐响应"""
+ score = self._normalize_risk_score(risk_score)
+ use = score >= 80
+ self._logger.debug("蜜罐判定: risk_score=%s => %s", score, use)
+ return use
+
+ def delay_response(self, risk_score: int) -> float:
+ """
+ 根据风险分计算延迟时间
+ 0-20: 0秒
+ 21-50: 随机0.5-1秒
+ 51-80: 随机1-3秒
+ 81-100: 随机3-8秒(蜜罐模式额外延迟消耗攻击者时间)
+ """
+ score = self._normalize_risk_score(risk_score)
+
+ delay = 0.0
+ if score <= 20:
+ delay = 0.0
+ elif score <= 50:
+ delay = float(self._rng.uniform(0.5, 1.0))
+ elif score <= 80:
+ delay = float(self._rng.uniform(1.0, 3.0))
+ else:
+ delay = float(self._rng.uniform(3.0, 8.0))
+
+ self._logger.debug("蜜罐延迟计算: risk_score=%s => delay_seconds=%.3f", score, delay)
+ return delay
+
+ # ==================== Internal ====================
+
+ def _normalize_risk_score(self, risk_score: Any) -> int:
+ try:
+ score = int(risk_score)
+ except Exception:
+ score = 0
+ return max(0, min(100, score))
+
+ def _classify_endpoint(self, endpoint_lc: str) -> str:
+ if not endpoint_lc:
+ return "generic"
+
+ # 先匹配更具体的:注册 / 登录
+ if any(k in endpoint_lc for k in ["/register", "register", "signup", "sign-up"]):
+ return "register"
+ if any(k in endpoint_lc for k in ["/login", "login", "signin", "sign-in"]):
+ return "login"
+
+ # 邮件相关:发送验证码 / 重置密码 / 重发验证等
+ if any(k in endpoint_lc for k in ["email", "mail", "forgot-password", "reset-password", "resend-verify"]):
+ return "email"
+
+ return "generic"
+
+ def _merge_safe_fields(self, base: dict, original_data: Optional[dict]) -> dict:
+ if not isinstance(original_data, dict) or not original_data:
+ return base
+
+ # 避免把攻击者输入或真实业务结果回显得太明显;仅合并少量“形状字段”
+ safe_bool_keys = {"need_verify", "need_captcha"}
+
+ merged = dict(base)
+ for key in safe_bool_keys:
+ if key in original_data and key not in merged:
+ try:
+ merged[key] = bool(original_data.get(key))
+ except Exception:
+ continue
+
+ return merged
+
diff --git a/security/middleware.py b/security/middleware.py
new file mode 100644
index 0000000..025f679
--- /dev/null
+++ b/security/middleware.py
@@ -0,0 +1,307 @@
+#!/usr/bin/env python3
+# -*- coding: utf-8 -*-
+from __future__ import annotations
+
+import logging
+from typing import Optional
+
+from flask import g, jsonify, request
+from flask_login import current_user
+
+from app_logger import get_logger
+from app_security import get_rate_limit_ip
+
+from .blacklist import BlacklistManager
+from .honeypot import HoneypotResponder
+from .response_handler import ResponseAction, ResponseHandler, ResponseStrategy
+from .risk_scorer import RiskScorer
+from .threat_detector import ThreatDetector, ThreatResult
+
+# 全局实例(保持单例,避免重复初始化开销)
+detector = ThreatDetector()
+blacklist = BlacklistManager()
+scorer = RiskScorer(blacklist_manager=blacklist)
+handler: Optional[ResponseHandler] = None
+honeypot: Optional[HoneypotResponder] = None
+
+
+def _get_handler() -> ResponseHandler:
+ global handler
+ if handler is None:
+ handler = ResponseHandler()
+ return handler
+
+
+def _get_honeypot() -> HoneypotResponder:
+ global honeypot
+ if honeypot is None:
+ honeypot = HoneypotResponder()
+ return honeypot
+
+
+def _get_security_log_level(app) -> int:
+ level_name = str(getattr(app, "config", {}).get("SECURITY_LOG_LEVEL", "INFO") or "INFO").upper()
+ return int(getattr(logging, level_name, logging.INFO))
+
+
+def _log(app, level: int, message: str, *args, exc_info: bool = False) -> None:
+ """按 SECURITY_LOG_LEVEL 控制安全日志输出,避免过多日志影响性能。"""
+ try:
+ logger = get_logger("app")
+ min_level = _get_security_log_level(app)
+ if int(level) >= int(min_level):
+ logger.log(int(level), message, *args, exc_info=exc_info)
+ except Exception:
+ # 安全模块日志故障不得影响正常请求
+ return
+
+
+def _is_static_request(app) -> bool:
+ """对静态文件请求跳过安全检查以提升性能。"""
+ try:
+ path = str(getattr(request, "path", "") or "")
+ except Exception:
+ path = ""
+
+ if path.startswith("/static/"):
+ return True
+
+ try:
+ static_url_path = getattr(app, "static_url_path", None) or "/static"
+ if static_url_path and path.startswith(str(static_url_path).rstrip("/") + "/"):
+ return True
+ except Exception:
+ pass
+
+ try:
+ endpoint = getattr(request, "endpoint", None)
+ if endpoint in {"static", "serve_static"}:
+ return True
+ except Exception:
+ pass
+
+ return False
+
+
+def _safe_get_user_id() -> Optional[int]:
+ try:
+ if hasattr(current_user, "is_authenticated") and current_user.is_authenticated:
+ return getattr(current_user, "id", None)
+ except Exception:
+ return None
+ return None
+
+
+def _scan_request_threats(req) -> list[ThreatResult]:
+ """仅扫描 GET query 与 POST JSON body(降低开销与误报)。"""
+ threats: list[ThreatResult] = []
+
+ try:
+ # 1) Query 参数(所有方法均可能携带 query string)
+ try:
+ args = getattr(req, "args", None)
+ if args:
+ # MultiDict -> dict(list) 以保留多值
+ args_dict = args.to_dict(flat=False) if hasattr(args, "to_dict") else dict(args)
+ threats.extend(detector.scan_input(args_dict, "args"))
+ except Exception:
+ pass
+
+ # 2) JSON body(主要针对 POST;其他方法保持兼容)
+ try:
+ method = str(getattr(req, "method", "") or "").upper()
+ except Exception:
+ method = ""
+
+ if method in {"POST", "PUT", "PATCH", "DELETE"}:
+ try:
+ data = req.get_json(silent=True) if hasattr(req, "get_json") else None
+ except Exception:
+ data = None
+ if data is not None:
+ threats.extend(detector.scan_input(data, "json"))
+ except Exception:
+ # 扫描失败不应阻断业务
+ return []
+
+ threats.sort(key=lambda t: int(getattr(t, "score", 0) or 0), reverse=True)
+ return threats
+
+
+def init_security_middleware(app):
+ """初始化安全中间件到 Flask 应用。"""
+ try:
+ scorer.auto_ban_enabled = bool(app.config.get("AUTO_BAN_ENABLED", True))
+ except Exception:
+ pass
+
+ @app.before_request
+ def security_check():
+ if not bool(app.config.get("SECURITY_ENABLED", True)):
+ return None
+ if _is_static_request(app):
+ return None
+
+ try:
+ ip = get_rate_limit_ip()
+ except Exception:
+ ip = getattr(request, "remote_addr", "") or ""
+
+ user_id = _safe_get_user_id()
+
+ # 默认值,确保后续逻辑可用
+ g.risk_score = 0
+ g.response_strategy = ResponseStrategy(action=ResponseAction.ALLOW)
+ g.honeypot_mode = False
+ g.honeypot_endpoint = None
+ g.honeypot_generated = False
+
+ try:
+ # 1) 检查黑名单(静默拒绝,返回通用错误)
+ try:
+ if blacklist.is_ip_banned(ip):
+ _log(app, logging.WARNING, "安全拦截: IP封禁命中 ip=%s path=%s", ip, request.path[:256])
+ return jsonify({"error": "服务暂时繁忙,请稍后重试"}), 503
+ except Exception:
+ _log(app, logging.ERROR, "黑名单检查失败(ip) ip=%s", ip, exc_info=True)
+
+ try:
+ if user_id is not None and blacklist.is_user_banned(user_id):
+ _log(app, logging.WARNING, "安全拦截: 用户封禁命中 user_id=%s path=%s", user_id, request.path[:256])
+ return jsonify({"error": "服务暂时繁忙,请稍后重试"}), 503
+ except Exception:
+ _log(app, logging.ERROR, "黑名单检查失败(user) user_id=%s", user_id, exc_info=True)
+
+ # 2) 扫描威胁(GET query / POST JSON)
+ threats = _scan_request_threats(request)
+
+ if threats:
+ max_threat = threats[0]
+ _log(
+ app,
+ logging.WARNING,
+ "威胁检测: ip=%s user_id=%s type=%s score=%s field=%s rule=%s",
+ ip,
+ user_id,
+ getattr(max_threat, "threat_type", "unknown"),
+ getattr(max_threat, "score", 0),
+ getattr(max_threat, "field_name", ""),
+ getattr(max_threat, "rule", ""),
+ )
+
+ # 记录威胁事件(异常不应阻断业务)
+ try:
+ payload = getattr(max_threat, "value_preview", "") or getattr(max_threat, "matched", "") or ""
+ scorer.record_threat(
+ ip=ip,
+ user_id=user_id,
+ threat_type=getattr(max_threat, "threat_type", "unknown"),
+ score=int(getattr(max_threat, "score", 0) or 0),
+ request_path=getattr(request, "path", None),
+ payload=str(payload)[:500] if payload else None,
+ )
+ except Exception:
+ _log(app, logging.ERROR, "威胁事件记录失败 ip=%s user_id=%s", ip, user_id, exc_info=True)
+
+ # 高危威胁启用蜜罐模式
+ if bool(app.config.get("HONEYPOT_ENABLED", True)):
+ try:
+ if int(getattr(max_threat, "score", 0) or 0) >= 80:
+ g.honeypot_mode = True
+ g.honeypot_endpoint = getattr(request, "endpoint", None)
+ except Exception:
+ pass
+
+ # 3) 综合风险分与响应策略
+ try:
+ risk_score = scorer.get_combined_score(ip, user_id)
+ except Exception:
+ _log(app, logging.ERROR, "风险分计算失败 ip=%s user_id=%s", ip, user_id, exc_info=True)
+ risk_score = 0
+
+ try:
+ strategy = _get_handler().get_strategy(risk_score)
+ except Exception:
+ _log(app, logging.ERROR, "响应策略计算失败 risk_score=%s", risk_score, exc_info=True)
+ strategy = ResponseStrategy(action=ResponseAction.ALLOW)
+
+ g.risk_score = int(risk_score or 0)
+ g.response_strategy = strategy
+
+ # 风险分触发蜜罐模式(兼容 ResponseHandler 的 HONEYPOT 策略)
+ if bool(app.config.get("HONEYPOT_ENABLED", True)):
+ try:
+ if getattr(strategy, "action", None) == ResponseAction.HONEYPOT:
+ g.honeypot_mode = True
+ except Exception:
+ pass
+
+ # 4) 应用延迟
+ try:
+ if float(getattr(strategy, "delay_seconds", 0) or 0) > 0:
+ _get_handler().apply_delay(strategy)
+ except Exception:
+ _log(app, logging.ERROR, "延迟应用失败", exc_info=True)
+
+ # 优先短路:避免业务 side effects(例如发送邮件/修改状态)
+ if getattr(g, "honeypot_mode", False) and bool(app.config.get("HONEYPOT_ENABLED", True)):
+ try:
+ fake_payload = None
+ try:
+ fake_payload = request.get_json(silent=True)
+ except Exception:
+ fake_payload = None
+ fake_response = _get_honeypot().generate_fake_response(
+ getattr(g, "honeypot_endpoint", "default"),
+ fake_payload if isinstance(fake_payload, dict) else None,
+ )
+ g.honeypot_generated = True
+ return jsonify(fake_response), 200
+ except Exception:
+ _log(app, logging.ERROR, "蜜罐响应生成失败", exc_info=True)
+ return None
+ except Exception:
+ # 全局兜底:安全模块任何异常都不能阻断正常请求
+ _log(app, logging.ERROR, "安全中间件发生异常", exc_info=True)
+ return None
+
+ return None # 继续正常处理
+
+ @app.after_request
+ def security_response(response):
+ """请求后处理 - 兜底应用蜜罐响应。"""
+ if not bool(app.config.get("SECURITY_ENABLED", True)):
+ return response
+ if not bool(app.config.get("HONEYPOT_ENABLED", True)):
+ return response
+
+ try:
+ if _is_static_request(app):
+ return response
+ except Exception:
+ pass
+
+ # 如果在 before_request 已经生成过蜜罐响应,则不再覆盖,避免丢失其他 after_request 的改动
+ try:
+ if getattr(g, "honeypot_generated", False):
+ return response
+ except Exception:
+ pass
+
+ try:
+ if getattr(g, "honeypot_mode", False):
+ fake_payload = None
+ try:
+ fake_payload = request.get_json(silent=True)
+ except Exception:
+ fake_payload = None
+ fake_response = _get_honeypot().generate_fake_response(
+ getattr(g, "honeypot_endpoint", "default"),
+ fake_payload if isinstance(fake_payload, dict) else None,
+ )
+ return jsonify(fake_response), 200
+ except Exception:
+ _log(app, logging.ERROR, "请求后蜜罐覆盖失败", exc_info=True)
+ return response
+
+ return response
diff --git a/security/response_handler.py b/security/response_handler.py
new file mode 100644
index 0000000..6d781c9
--- /dev/null
+++ b/security/response_handler.py
@@ -0,0 +1,131 @@
+#!/usr/bin/env python3
+# -*- coding: utf-8 -*-
+from __future__ import annotations
+
+import random
+import time
+from dataclasses import dataclass
+from enum import Enum
+from typing import Any, Optional
+
+from app_logger import get_logger
+
+
+class ResponseAction(Enum):
+ ALLOW = "allow" # 正常放行
+ ENHANCE_CAPTCHA = "enhance_captcha" # 增强验证码
+ DELAY = "delay" # 静默延迟
+ HONEYPOT = "honeypot" # 蜜罐响应
+ BLOCK = "block" # 直接拒绝
+
+
+@dataclass
+class ResponseStrategy:
+ action: ResponseAction
+ delay_seconds: float = 0
+ captcha_level: int = 1 # 1=普通4位, 2=6位, 3=滑块
+ message: str | None = None
+
+
+class ResponseHandler:
+ """响应策略处理器"""
+
+ def __init__(self, *, rng: Optional[random.Random] = None) -> None:
+ self._rng = rng or random.SystemRandom()
+ self._logger = get_logger("app")
+
+ def get_strategy(self, risk_score: int, is_banned: bool = False) -> ResponseStrategy:
+ """
+ 根据风险分获取响应策略
+
+ 0-20分: ALLOW, 无延迟, 普通验证码
+ 21-40分: ALLOW, 无延迟, 6位验证码
+ 41-60分: DELAY, 1-2秒延迟
+ 61-80分: DELAY, 2-5秒延迟
+ 81-100分: HONEYPOT, 3-8秒延迟
+ 已封禁: BLOCK
+ """
+ score = self._normalize_risk_score(risk_score)
+
+ if is_banned:
+ strategy = ResponseStrategy(action=ResponseAction.BLOCK, message="访问被拒绝")
+ self._logger.warning("响应策略: BLOCK (banned=%s, risk_score=%s)", is_banned, score)
+ return strategy
+
+ if score <= 20:
+ strategy = ResponseStrategy(action=ResponseAction.ALLOW, delay_seconds=0, captcha_level=1)
+ elif score <= 40:
+ strategy = ResponseStrategy(action=ResponseAction.ALLOW, delay_seconds=0, captcha_level=2)
+ elif score <= 60:
+ strategy = ResponseStrategy(action=ResponseAction.DELAY, delay_seconds=float(self._rng.uniform(1.0, 2.0)))
+ elif score <= 80:
+ strategy = ResponseStrategy(action=ResponseAction.DELAY, delay_seconds=float(self._rng.uniform(2.0, 5.0)))
+ else:
+ strategy = ResponseStrategy(action=ResponseAction.HONEYPOT, delay_seconds=float(self._rng.uniform(3.0, 8.0)))
+
+ strategy.captcha_level = self._normalize_captcha_level(strategy.captcha_level)
+
+ self._logger.info(
+ "响应策略: action=%s risk_score=%s delay=%.3f captcha_level=%s",
+ strategy.action.value,
+ score,
+ float(strategy.delay_seconds or 0),
+ int(strategy.captcha_level),
+ )
+ return strategy
+
+ def apply_delay(self, strategy: ResponseStrategy):
+ """应用延迟(使用time.sleep)"""
+ if strategy is None:
+ return
+ delay = 0.0
+ try:
+ delay = float(getattr(strategy, "delay_seconds", 0) or 0)
+ except Exception:
+ delay = 0.0
+
+ if delay <= 0:
+ return
+
+ self._logger.debug("应用延迟: action=%s delay=%.3f", getattr(strategy.action, "value", strategy.action), delay)
+ time.sleep(delay)
+
+ def get_captcha_requirement(self, strategy: ResponseStrategy) -> dict:
+ """返回验证码要求 {"required": True, "level": 2}"""
+ level = 1
+ try:
+ level = int(getattr(strategy, "captcha_level", 1) or 1)
+ except Exception:
+ level = 1
+ level = self._normalize_captcha_level(level)
+
+ required = True
+ try:
+ required = getattr(strategy, "action", None) != ResponseAction.BLOCK
+ except Exception:
+ required = True
+
+ payload = {"required": bool(required), "level": level}
+ self._logger.debug("验证码要求: %s", payload)
+ return payload
+
+ # ==================== Internal ====================
+
+ def _normalize_risk_score(self, risk_score: Any) -> int:
+ try:
+ score = int(risk_score)
+ except Exception:
+ score = 0
+ return max(0, min(100, score))
+
+ def _normalize_captcha_level(self, level: Any) -> int:
+ try:
+ i = int(level)
+ except Exception:
+ i = 1
+ if i <= 1:
+ return 1
+ if i == 2:
+ return 2
+ return 3
+
diff --git a/security/risk_scorer.py b/security/risk_scorer.py
new file mode 100644
index 0000000..f48fb29
--- /dev/null
+++ b/security/risk_scorer.py
@@ -0,0 +1,389 @@
+#!/usr/bin/env python3
+# -*- coding: utf-8 -*-
+from __future__ import annotations
+
+import math
+from dataclasses import dataclass
+from datetime import timedelta
+from typing import Optional
+
+import db_pool
+from db.utils import get_cst_now, get_cst_now_str, parse_cst_datetime
+
+from . import constants as C
+from .blacklist import BlacklistManager
+
+
+@dataclass(frozen=True)
+class _ScoreUpdateResult:
+ ip_score: int
+ user_score: int
+
+
+@dataclass(frozen=True)
+class _BanAction:
+ reason: str
+ duration_hours: Optional[int] = None
+ permanent: bool = False
+
+
+class RiskScorer:
+ """风险评分引擎 - 计算IP和用户的风险分数"""
+
+ def __init__(
+ self,
+ *,
+ auto_ban_enabled: bool = True,
+ auto_ban_duration_hours: int = 24,
+ high_risk_threshold: int = 80,
+ high_risk_window_hours: int = 1,
+ high_risk_permanent_ban_count: int = 3,
+ blacklist_manager: Optional[BlacklistManager] = None,
+ ) -> None:
+ self.auto_ban_enabled = bool(auto_ban_enabled)
+ self.auto_ban_duration_hours = max(1, int(auto_ban_duration_hours))
+ self.high_risk_threshold = max(0, int(high_risk_threshold))
+ self.high_risk_window_hours = max(1, int(high_risk_window_hours))
+ self.high_risk_permanent_ban_count = max(1, int(high_risk_permanent_ban_count))
+ self.blacklist = blacklist_manager or BlacklistManager()
+
+ def get_ip_score(self, ip_address: str) -> int:
+ """获取IP风险分(0-100),从数据库读取"""
+ ip_text = str(ip_address or "").strip()[:64]
+ if not ip_text:
+ return 0
+
+ with db_pool.get_db() as conn:
+ cursor = conn.cursor()
+ cursor.execute("SELECT risk_score FROM ip_risk_scores WHERE ip = ?", (ip_text,))
+ row = cursor.fetchone()
+ if not row:
+ return 0
+ try:
+ return max(0, min(100, int(row["risk_score"])))
+ except Exception:
+ return 0
+
+ def get_user_score(self, user_id: int) -> int:
+ """获取用户风险分(0-100)"""
+ if user_id is None:
+ return 0
+ user_id_int = int(user_id)
+
+ with db_pool.get_db() as conn:
+ cursor = conn.cursor()
+ cursor.execute("SELECT risk_score FROM user_risk_scores WHERE user_id = ?", (user_id_int,))
+ row = cursor.fetchone()
+ if not row:
+ return 0
+ try:
+ return max(0, min(100, int(row["risk_score"])))
+ except Exception:
+ return 0
+
+ def get_combined_score(self, ip: str, user_id: int = None) -> int:
+ """综合风险分 = max(IP分, 用户分) + 行为加成"""
+ base = max(self.get_ip_score(ip), self.get_user_score(user_id) if user_id is not None else 0)
+ bonus = self._get_behavior_bonus(ip, user_id)
+ return max(0, min(100, int(base + bonus)))
+
+ def record_threat(
+ self,
+ ip: str,
+ user_id: int,
+ threat_type: str,
+ score: int,
+ request_path: str = None,
+ payload: str = None,
+ ):
+ """记录威胁事件到数据库,并更新IP/用户风险分"""
+ ip_text = str(ip or "").strip()[:64]
+ user_id_int = int(user_id) if user_id is not None else None
+ threat_type_text = str(threat_type or "").strip()[:64] or "unknown"
+ score_int = max(0, int(score))
+ path_text = str(request_path or "").strip()[:512] if request_path else None
+ payload_text = str(payload or "").strip() if payload else None
+ if payload_text and len(payload_text) > 2048:
+ payload_text = payload_text[:2048]
+
+ now_str = get_cst_now_str()
+
+ ip_ban_action: Optional[_BanAction] = None
+ user_ban_action: Optional[_BanAction] = None
+
+ with db_pool.get_db() as conn:
+ cursor = conn.cursor()
+
+ cursor.execute(
+ """
+ INSERT INTO threat_events (
+ threat_type, score, ip, user_id, request_path, value_preview, created_at
+ ) VALUES (?, ?, ?, ?, ?, ?, ?)
+ """,
+ (
+ threat_type_text,
+ score_int,
+ ip_text or None,
+ user_id_int,
+ path_text,
+ payload_text,
+ now_str,
+ ),
+ )
+
+ update = self._update_scores(cursor, ip_text, user_id_int, score_int, now_str)
+
+ if self.auto_ban_enabled:
+ ip_ban_action, user_ban_action = self._get_auto_ban_actions(
+ cursor,
+ ip_text,
+ user_id_int,
+ threat_type_text,
+ score_int,
+ update,
+ )
+
+ conn.commit()
+
+ if not self.auto_ban_enabled:
+ return
+
+ if ip_ban_action and ip_text:
+ self.blacklist.ban_ip(
+ ip_text,
+ reason=ip_ban_action.reason,
+ duration_hours=ip_ban_action.duration_hours or self.auto_ban_duration_hours,
+ permanent=ip_ban_action.permanent,
+ )
+ if user_ban_action and user_id_int is not None:
+ self.blacklist._ban_user_internal(
+ user_id_int,
+ reason=user_ban_action.reason,
+ duration_hours=user_ban_action.duration_hours or self.auto_ban_duration_hours,
+ permanent=user_ban_action.permanent,
+ )
+
+ def decay_scores(self):
+ """风险分衰减 - 定期调用,降低历史风险分"""
+ now = get_cst_now()
+ now_str = now.strftime("%Y-%m-%d %H:%M:%S")
+
+ with db_pool.get_db() as conn:
+ cursor = conn.cursor()
+
+ cursor.execute("SELECT ip, risk_score, updated_at, created_at FROM ip_risk_scores")
+ for row in cursor.fetchall():
+ ip = row["ip"]
+ current_score = int(row["risk_score"] or 0)
+ updated_at = row["updated_at"] or row["created_at"]
+ hours = self._hours_since(updated_at, now)
+ if hours <= 0:
+ continue
+ new_score = self._apply_hourly_decay(current_score, hours)
+ if new_score == current_score:
+ continue
+ cursor.execute(
+ "UPDATE ip_risk_scores SET risk_score = ?, updated_at = ? WHERE ip = ?",
+ (new_score, now_str, ip),
+ )
+
+ cursor.execute("SELECT user_id, risk_score, updated_at, created_at FROM user_risk_scores")
+ for row in cursor.fetchall():
+ user_id = int(row["user_id"])
+ current_score = int(row["risk_score"] or 0)
+ updated_at = row["updated_at"] or row["created_at"]
+ hours = self._hours_since(updated_at, now)
+ if hours <= 0:
+ continue
+ new_score = self._apply_hourly_decay(current_score, hours)
+ if new_score == current_score:
+ continue
+ cursor.execute(
+ "UPDATE user_risk_scores SET risk_score = ?, updated_at = ? WHERE user_id = ?",
+ (new_score, now_str, user_id),
+ )
+
+ conn.commit()
+
+ def _update_ip_score(self, ip: str, score_delta: int):
+ """更新IP风险分"""
+ ip_text = str(ip or "").strip()[:64]
+ if not ip_text:
+ return
+ delta = int(score_delta)
+ now_str = get_cst_now_str()
+ with db_pool.get_db() as conn:
+ cursor = conn.cursor()
+ self._update_scores(cursor, ip_text, None, delta, now_str)
+ conn.commit()
+
+ def _update_user_score(self, user_id: int, score_delta: int):
+ """更新用户风险分"""
+ if user_id is None:
+ return
+ user_id_int = int(user_id)
+ delta = int(score_delta)
+ now_str = get_cst_now_str()
+ with db_pool.get_db() as conn:
+ cursor = conn.cursor()
+ self._update_scores(cursor, "", user_id_int, delta, now_str)
+ conn.commit()
+
+ def reset_ip_score(self, ip: str) -> bool:
+ """清零指定IP的风险分"""
+ ip_text = str(ip or "").strip()[:64]
+ if not ip_text:
+ return False
+
+ now_str = get_cst_now_str()
+ with db_pool.get_db() as conn:
+ cursor = conn.cursor()
+ cursor.execute("SELECT ip FROM ip_risk_scores WHERE ip = ?", (ip_text,))
+ row = cursor.fetchone()
+ if row:
+ cursor.execute(
+ "UPDATE ip_risk_scores SET risk_score = 0, last_seen = ?, updated_at = ? WHERE ip = ?",
+ (now_str, now_str, ip_text),
+ )
+ else:
+ cursor.execute(
+ """
+ INSERT INTO ip_risk_scores (ip, risk_score, last_seen, created_at, updated_at)
+ VALUES (?, 0, ?, ?, ?)
+ """,
+ (ip_text, now_str, now_str, now_str),
+ )
+ conn.commit()
+ return True
+
+ def _update_scores(
+ self,
+ cursor,
+ ip: str,
+ user_id: Optional[int],
+ score_delta: int,
+ now_str: str,
+ ) -> _ScoreUpdateResult:
+ ip_score = 0
+ user_score = 0
+
+ if ip:
+ cursor.execute("SELECT risk_score FROM ip_risk_scores WHERE ip = ?", (ip,))
+ row = cursor.fetchone()
+ current = int(row["risk_score"]) if row else 0
+ ip_score = max(0, min(100, current + int(score_delta)))
+ if row:
+ cursor.execute(
+ "UPDATE ip_risk_scores SET risk_score = ?, last_seen = ?, updated_at = ? WHERE ip = ?",
+ (ip_score, now_str, now_str, ip),
+ )
+ else:
+ cursor.execute(
+ """
+ INSERT INTO ip_risk_scores (ip, risk_score, last_seen, created_at, updated_at)
+ VALUES (?, ?, ?, ?, ?)
+ """,
+ (ip, ip_score, now_str, now_str, now_str),
+ )
+
+ if user_id is not None:
+ cursor.execute("SELECT risk_score FROM user_risk_scores WHERE user_id = ?", (int(user_id),))
+ row = cursor.fetchone()
+ current = int(row["risk_score"]) if row else 0
+ user_score = max(0, min(100, current + int(score_delta)))
+ if row:
+ cursor.execute(
+ "UPDATE user_risk_scores SET risk_score = ?, last_seen = ?, updated_at = ? WHERE user_id = ?",
+ (user_score, now_str, now_str, int(user_id)),
+ )
+ else:
+ cursor.execute(
+ """
+ INSERT INTO user_risk_scores (user_id, risk_score, last_seen, created_at, updated_at)
+ VALUES (?, ?, ?, ?, ?)
+ """,
+ (int(user_id), user_score, now_str, now_str, now_str),
+ )
+
+ return _ScoreUpdateResult(ip_score=ip_score, user_score=user_score)
+
+ def _get_auto_ban_actions(
+ self,
+ cursor,
+ ip: str,
+ user_id: Optional[int],
+ threat_type: str,
+ score: int,
+ update: _ScoreUpdateResult,
+ ) -> tuple[Optional["_BanAction"], Optional["_BanAction"]]:
+ ip_action: Optional[_BanAction] = None
+ user_action: Optional[_BanAction] = None
+
+ if threat_type == C.THREAT_TYPE_JNDI_INJECTION:
+ if ip:
+ ip_action = _BanAction(reason="JNDI injection detected", permanent=True)
+ if user_id is not None:
+ user_action = _BanAction(reason="JNDI injection detected", permanent=True)
+ return ip_action, user_action
+
+ if ip and update.ip_score >= 100:
+ ip_action = _BanAction(reason="Risk score reached 100", duration_hours=self.auto_ban_duration_hours)
+ if user_id is not None and update.user_score >= 100:
+ user_action = _BanAction(reason="Risk score reached 100", duration_hours=self.auto_ban_duration_hours)
+
+ if score < self.high_risk_threshold:
+ return ip_action, user_action
+
+ window_start = (get_cst_now() - timedelta(hours=self.high_risk_window_hours)).strftime("%Y-%m-%d %H:%M:%S")
+
+ if ip:
+ cursor.execute(
+ """
+ SELECT COUNT(*) AS cnt
+ FROM threat_events
+ WHERE ip = ? AND score >= ? AND created_at >= ?
+ """,
+ (ip, int(self.high_risk_threshold), window_start),
+ )
+ row = cursor.fetchone()
+ cnt = int(row["cnt"]) if row else 0
+ if cnt >= self.high_risk_permanent_ban_count:
+ ip_action = _BanAction(reason="High-risk threats threshold reached", permanent=True)
+
+ if user_id is not None:
+ cursor.execute(
+ """
+ SELECT COUNT(*) AS cnt
+ FROM threat_events
+ WHERE user_id = ? AND score >= ? AND created_at >= ?
+ """,
+ (int(user_id), int(self.high_risk_threshold), window_start),
+ )
+ row = cursor.fetchone()
+ cnt = int(row["cnt"]) if row else 0
+ if cnt >= self.high_risk_permanent_ban_count:
+ user_action = _BanAction(reason="High-risk threats threshold reached", permanent=True)
+
+ return ip_action, user_action
+
+ def _get_behavior_bonus(self, ip: str, user_id: Optional[int]) -> int:
+ return 0
+
+ def _hours_since(self, dt_str: Optional[str], now) -> int:
+ if not dt_str:
+ return 0
+ try:
+ dt = parse_cst_datetime(str(dt_str))
+ except Exception:
+ return 0
+ seconds = (now - dt).total_seconds()
+ if seconds <= 0:
+ return 0
+ return int(seconds // 3600)
+
+ def _apply_hourly_decay(self, score: int, hours: int) -> int:
+ score_int = max(0, int(score))
+ if score_int <= 0 or hours <= 0:
+ return score_int
+ decayed = int(math.floor(score_int * (0.9**int(hours))))
+ return max(0, min(100, decayed))
diff --git a/security/threat_detector.py b/security/threat_detector.py
new file mode 100644
index 0000000..c38c375
--- /dev/null
+++ b/security/threat_detector.py
@@ -0,0 +1,410 @@
+#!/usr/bin/env python3
+# -*- coding: utf-8 -*-
+from __future__ import annotations
+
+from dataclasses import dataclass
+from typing import Any, Iterable, List, Optional, Tuple
+from urllib.parse import unquote_plus
+
+from . import constants as C
+
+
+@dataclass
+class ThreatResult:
+ threat_type: str
+ score: int
+ field_name: str
+ rule: str = ""
+ matched: str = ""
+ value_preview: str = ""
+
+ def to_dict(self) -> dict:
+ return {
+ "threat_type": self.threat_type,
+ "score": int(self.score),
+ "field_name": self.field_name,
+ "rule": self.rule,
+ "matched": self.matched,
+ "value_preview": self.value_preview,
+ }
+
+
+class ThreatDetector:
+ def __init__(
+ self,
+ *,
+ max_value_length: int = 4096,
+ max_decode_rounds: int = 2,
+ ) -> None:
+ self.max_value_length = max(64, int(max_value_length))
+ self.max_decode_rounds = max(0, int(max_decode_rounds))
+
+ def scan_input(self, value: Any, field_name: str = "value") -> List[ThreatResult]:
+ """扫描单个输入值(支持 dict/list 等嵌套结构)。"""
+ results: List[ThreatResult] = []
+ for sub_field, leaf in self._flatten_value(value, field_name):
+ text = self._stringify(leaf)
+ if not text:
+ continue
+ if len(text) > self.max_value_length:
+ text = text[: self.max_value_length]
+ results.extend(self._scan_text(text, sub_field))
+ results.sort(key=lambda r: int(r.score), reverse=True)
+ return results
+
+ def scan_request(self, request: Any) -> List[ThreatResult]:
+ """扫描整个请求对象(兼容 Flask Request / dict 风格对象)。"""
+ results: List[ThreatResult] = []
+ for field_name, value in self._extract_request_fields(request):
+ results.extend(self.scan_input(value, field_name))
+ results.sort(key=lambda r: int(r.score), reverse=True)
+ return results
+
+ # ==================== Internal scanning ====================
+
+ def _scan_text(self, text: str, field_name: str) -> List[ThreatResult]:
+ hits: List[ThreatResult] = []
+
+ for check in [
+ self._check_jndi_injection,
+ self._check_sql_injection,
+ self._check_xss,
+ self._check_path_traversal,
+ self._check_command_injection,
+ self._check_ssrf,
+ self._check_xxe,
+ self._check_template_injection,
+ self._check_sensitive_path_probe,
+ ]:
+ result = check(text)
+ if result:
+ threat_type, score, rule, matched = result
+ hits.append(
+ ThreatResult(
+ threat_type=threat_type,
+ score=int(score),
+ field_name=field_name,
+ rule=rule,
+ matched=matched,
+ value_preview=self._preview(text),
+ )
+ )
+
+ return hits
+
+ def _check_jndi_injection(self, text: str) -> Optional[Tuple[str, int, str, str]]:
+ # 1) Direct match
+ m = C.JNDI_DIRECT_RE.search(text)
+ if m:
+ return (C.THREAT_TYPE_JNDI_INJECTION, C.SCORE_JNDI_DIRECT, "JNDI_DIRECT", m.group(0))
+
+ # 2) URL-decoded
+ decoded = self._multi_unquote(text)
+ if decoded != text:
+ m2 = C.JNDI_DIRECT_RE.search(decoded)
+ if m2:
+ return (C.THREAT_TYPE_JNDI_INJECTION, C.SCORE_JNDI_DIRECT, "JNDI_DIRECT_URL_DECODED", m2.group(0))
+
+ # 3) Obfuscation patterns (raw/decoded)
+ for candidate, rule in [(text, "JNDI_OBFUSCATED"), (decoded, "JNDI_OBFUSCATED_URL_DECODED")]:
+ m3 = C.JNDI_OBFUSCATED_RE.search(candidate)
+ if m3:
+ return (C.THREAT_TYPE_JNDI_INJECTION, C.SCORE_JNDI_OBFUSCATED, rule, m3.group(0))
+
+ # 4) Try limited de-obfuscation to reveal ${jndi:...}
+ deobf = self._deobfuscate_log4j(decoded)
+ if deobf and deobf != decoded:
+ m4 = C.JNDI_DIRECT_RE.search(deobf)
+ if m4:
+ return (C.THREAT_TYPE_JNDI_INJECTION, C.SCORE_JNDI_OBFUSCATED, "JNDI_DEOBFUSCATED", m4.group(0))
+
+ # 5) Nested expression heuristic
+ for candidate in [text, decoded]:
+ m5 = C.NESTED_EXPRESSION_RE.search(candidate)
+ if m5:
+ return (C.THREAT_TYPE_NESTED_EXPRESSION, C.SCORE_NESTED_EXPRESSION, "NESTED_EXPRESSION", m5.group(0))
+
+ return None
+
+ def _check_sql_injection(self, text: str) -> Optional[Tuple[str, int, str, str]]:
+ candidates = [text, self._multi_unquote(text)]
+ for candidate in candidates:
+ m = C.SQLI_UNION_SELECT_RE.search(candidate)
+ if m:
+ return (C.THREAT_TYPE_SQL_INJECTION, C.SCORE_SQL_INJECTION, "SQLI_UNION_SELECT", m.group(0))
+ m = C.SQLI_OR_1_EQ_1_RE.search(candidate)
+ if m:
+ return (C.THREAT_TYPE_SQL_INJECTION, C.SCORE_SQL_INJECTION, "SQLI_OR_1_EQ_1", m.group(0))
+ return None
+
+ def _check_xss(self, text: str) -> Optional[Tuple[str, int, str, str]]:
+ candidates = [text, self._multi_unquote(text)]
+ for candidate in candidates:
+ m = C.XSS_SCRIPT_TAG_RE.search(candidate)
+ if m:
+ return (C.THREAT_TYPE_XSS, C.SCORE_XSS, "XSS_SCRIPT_TAG", m.group(0))
+ m = C.XSS_JS_PROTOCOL_RE.search(candidate)
+ if m:
+ return (C.THREAT_TYPE_XSS, C.SCORE_XSS, "XSS_JS_PROTOCOL", m.group(0))
+ m = C.XSS_INLINE_EVENT_HANDLER_RE.search(candidate)
+ if m:
+ return (C.THREAT_TYPE_XSS, C.SCORE_XSS, "XSS_INLINE_EVENT_HANDLER", m.group(0))
+ return None
+
+ def _check_path_traversal(self, text: str) -> Optional[Tuple[str, int, str, str]]:
+ decoded = self._multi_unquote(text)
+ candidates = [text, decoded]
+ for candidate in candidates:
+ m = C.PATH_TRAVERSAL_RE.search(candidate)
+ if m:
+ return (C.THREAT_TYPE_PATH_TRAVERSAL, C.SCORE_PATH_TRAVERSAL, "PATH_TRAVERSAL", m.group(0))
+ return None
+
+ def _check_command_injection(self, text: str) -> Optional[Tuple[str, int, str, str]]:
+ decoded = self._multi_unquote(text)
+ candidates = [text, decoded]
+ for candidate in candidates:
+ m = C.CMD_INJECTION_SUBSHELL_RE.search(candidate)
+ if m:
+ return (C.THREAT_TYPE_COMMAND_INJECTION, C.SCORE_COMMAND_INJECTION, "CMD_SUBSHELL", m.group(0))
+ m = C.CMD_INJECTION_OPERATOR_WITH_CMD_RE.search(candidate)
+ if m:
+ return (C.THREAT_TYPE_COMMAND_INJECTION, C.SCORE_COMMAND_INJECTION, "CMD_OPERATOR_WITH_CMD", m.group(0))
+ return None
+
+ def _check_ssrf(self, text: str) -> Optional[Tuple[str, int, str, str]]:
+ decoded = self._multi_unquote(text)
+ candidates: List[Tuple[str, str]] = [(text, "")]
+ if decoded != text:
+ candidates.append((decoded, "_URL_DECODED"))
+
+ for candidate, suffix in candidates:
+ m = C.SSRF_LOCALHOST_URL_RE.search(candidate)
+ if m:
+ return (C.THREAT_TYPE_SSRF, C.SCORE_SSRF, f"SSRF_LOCALHOST{suffix}", m.group(0))
+ m = C.SSRF_INTERNAL_IP_URL_RE.search(candidate)
+ if m:
+ return (C.THREAT_TYPE_SSRF, C.SCORE_SSRF, f"SSRF_INTERNAL_IP{suffix}", m.group(0))
+ m = C.SSRF_DANGEROUS_PROTOCOL_RE.search(candidate)
+ if m:
+ return (C.THREAT_TYPE_SSRF, C.SCORE_SSRF, f"SSRF_DANGEROUS_PROTOCOL{suffix}", m.group(0))
+
+ return None
+
+ def _check_xxe(self, text: str) -> Optional[Tuple[str, int, str, str]]:
+ decoded = self._multi_unquote(text)
+ candidates: List[Tuple[str, str]] = [(text, "")]
+ if decoded != text:
+ candidates.append((decoded, "_URL_DECODED"))
+
+ for candidate, suffix in candidates:
+ m_doctype = C.XXE_DOCTYPE_RE.search(candidate)
+ if not m_doctype:
+ continue
+ m_entity = C.XXE_ENTITY_RE.search(candidate)
+ if not m_entity:
+ continue
+ m_sys_pub = C.XXE_SYSTEM_PUBLIC_RE.search(candidate)
+ if not m_sys_pub:
+ continue
+ matched = f"{m_doctype.group(0)} {m_entity.group(0)} {m_sys_pub.group(0)}"
+ return (C.THREAT_TYPE_XXE, C.SCORE_XXE, f"XXE_KEYWORD_COMBO{suffix}", matched)
+
+ return None
+
+ def _check_template_injection(self, text: str) -> Optional[Tuple[str, int, str, str]]:
+ decoded = self._multi_unquote(text)
+ candidates: List[Tuple[str, str]] = [(text, "")]
+ if decoded != text:
+ candidates.append((decoded, "_URL_DECODED"))
+
+ for candidate, suffix in candidates:
+ m = C.TEMPLATE_JINJA_EXPR_RE.search(candidate)
+ if m:
+ return (C.THREAT_TYPE_TEMPLATE_INJECTION, C.SCORE_TEMPLATE_INJECTION, f"TEMPLATE_JINJA_EXPR{suffix}", m.group(0))
+ m = C.TEMPLATE_JINJA_STMT_RE.search(candidate)
+ if m:
+ return (C.THREAT_TYPE_TEMPLATE_INJECTION, C.SCORE_TEMPLATE_INJECTION, f"TEMPLATE_JINJA_STMT{suffix}", m.group(0))
+ m = C.TEMPLATE_VELOCITY_DIRECTIVE_RE.search(candidate)
+ if m:
+ return (
+ C.THREAT_TYPE_TEMPLATE_INJECTION,
+ C.SCORE_TEMPLATE_INJECTION,
+ f"TEMPLATE_VELOCITY_DIRECTIVE{suffix}",
+ m.group(0),
+ )
+
+ return None
+
+ def _check_sensitive_path_probe(self, text: str) -> Optional[Tuple[str, int, str, str]]:
+ decoded = self._multi_unquote(text)
+ candidates: List[Tuple[str, str]] = [(text, "")]
+ if decoded != text:
+ candidates.append((decoded, "_URL_DECODED"))
+
+ for candidate, suffix in candidates:
+ m = C.SENSITIVE_PATH_DOTFILES_RE.search(candidate)
+ if m:
+ return (
+ C.THREAT_TYPE_SENSITIVE_PATH_PROBE,
+ C.SCORE_SENSITIVE_PATH_PROBE,
+ f"SENSITIVE_PATH_DOTFILES{suffix}",
+ m.group(0),
+ )
+ m = C.SENSITIVE_PATH_PROBE_RE.search(candidate)
+ if m:
+ return (
+ C.THREAT_TYPE_SENSITIVE_PATH_PROBE,
+ C.SCORE_SENSITIVE_PATH_PROBE,
+ f"SENSITIVE_PATH_PROBE{suffix}",
+ m.group(0),
+ )
+
+ return None
+
+ # ==================== Helpers ====================
+
+ def _preview(self, text: str, limit: int = 160) -> str:
+ s = text.replace("\n", "\\n").replace("\r", "\\r").replace("\t", "\\t")
+ if len(s) <= limit:
+ return s
+ return s[: limit - 3] + "..."
+
+ def _stringify(self, value: Any) -> str:
+ if value is None:
+ return ""
+ if isinstance(value, bytes):
+ try:
+ return value.decode("utf-8", errors="ignore")
+ except Exception:
+ return ""
+ try:
+ return str(value)
+ except Exception:
+ return ""
+
+ def _multi_unquote(self, text: str) -> str:
+ s = text
+ for _ in range(self.max_decode_rounds):
+ try:
+ nxt = unquote_plus(s)
+ except Exception:
+ break
+ if nxt == s:
+ break
+ s = nxt
+ return s
+
+ def _deobfuscate_log4j(self, text: str) -> str:
+ # Replace ${...:-x} with x (including ${::-x}).
+ # This is intentionally conservative to reduce false positives.
+ import re
+
+ s = text
+ pattern = re.compile(r"\$\{[^{}]{0,50}:-([a-zA-Z])\}")
+ for _ in range(3):
+ nxt = pattern.sub(lambda m: m.group(1), s)
+ if nxt == s:
+ break
+ s = nxt
+ return s
+
+ def _flatten_value(self, value: Any, field_name: str) -> Iterable[Tuple[str, Any]]:
+ if isinstance(value, dict):
+ for k, v in value.items():
+ key = self._stringify(k) or "key"
+ sub_name = f"{field_name}.{key}" if field_name else key
+ yield from self._flatten_value(v, sub_name)
+ return
+ if isinstance(value, (list, tuple, set)):
+ for i, v in enumerate(value):
+ sub_name = f"{field_name}[{i}]"
+ yield from self._flatten_value(v, sub_name)
+ return
+ yield (field_name, value)
+
+ def _extract_request_fields(self, request: Any) -> List[Tuple[str, Any]]:
+ # dict-like input (useful for unit tests / non-Flask callers)
+ if isinstance(request, dict):
+ out: List[Tuple[str, Any]] = []
+ for k, v in request.items():
+ out.append((self._stringify(k) or "request", v))
+ return out
+
+ out: List[Tuple[str, Any]] = []
+
+ # path / method
+ for attr_name in ["method", "path", "full_path", "url", "remote_addr"]:
+ try:
+ v = getattr(request, attr_name, None)
+ except Exception:
+ v = None
+ if v:
+ out.append((attr_name, v))
+
+ # args / form (Flask MultiDict)
+ out.extend(self._extract_multidict(getattr(request, "args", None), "args"))
+ out.extend(self._extract_multidict(getattr(request, "form", None), "form"))
+
+ # headers
+ try:
+ headers = getattr(request, "headers", None)
+ if headers is not None:
+ try:
+ items = headers.items()
+ except Exception:
+ items = []
+ for k, v in items:
+ out.append((f"headers.{self._stringify(k)}", v))
+ except Exception:
+ pass
+
+ # cookies
+ try:
+ cookies = getattr(request, "cookies", None)
+ if isinstance(cookies, dict):
+ for k, v in cookies.items():
+ out.append((f"cookies.{self._stringify(k)}", v))
+ except Exception:
+ pass
+
+ # json body
+ data = None
+ try:
+ get_json = getattr(request, "get_json", None)
+ if callable(get_json):
+ data = get_json(silent=True)
+ except Exception:
+ data = None
+
+ if data is not None:
+ for name, v in self._flatten_value(data, "json"):
+ out.append((name, v))
+ return out
+
+ # raw body (as a fallback)
+ try:
+ get_data = getattr(request, "get_data", None)
+ if callable(get_data):
+ raw = get_data(cache=True, as_text=True)
+ if raw:
+ out.append(("body", raw))
+ except Exception:
+ pass
+
+ return out
+
+ def _extract_multidict(self, md: Any, prefix: str) -> List[Tuple[str, Any]]:
+ out: List[Tuple[str, Any]] = []
+ if md is None:
+ return out
+ try:
+ items = md.items(multi=True)
+ except Exception:
+ try:
+ items = md.items()
+ except Exception:
+ return out
+ for k, v in items:
+ out.append((f"{prefix}.{self._stringify(k)}", v))
+ return out
diff --git a/services/kdocs_uploader.py b/services/kdocs_uploader.py
new file mode 100644
index 0000000..8789feb
--- /dev/null
+++ b/services/kdocs_uploader.py
@@ -0,0 +1,1494 @@
+#!/usr/bin/env python3
+# -*- coding: utf-8 -*-
+from __future__ import annotations
+
+import base64
+import os
+import queue
+import re
+import threading
+import time
+from io import BytesIO
+from typing import Any, Dict, Optional
+from urllib.parse import urlparse
+
+import database
+import email_service
+from app_config import get_config
+from services.client_log import log_to_client
+from services.runtime import get_logger, get_socketio
+from services.state import safe_get_account
+
+try:
+ from playwright.sync_api import sync_playwright, TimeoutError as PlaywrightTimeoutError
+except Exception: # pragma: no cover - 运行环境缺少 playwright 时降级
+ sync_playwright = None
+
+ class PlaywrightTimeoutError(Exception):
+ pass
+
+
+logger = get_logger()
+config = get_config()
+
+
+class KDocsUploader:
+ def __init__(self) -> None:
+ self._queue: queue.Queue = queue.Queue(maxsize=int(os.environ.get("KDOCS_QUEUE_MAXSIZE", "200")))
+ self._thread = threading.Thread(target=self._run, name="kdocs-uploader", daemon=True)
+ self._running = False
+ self._last_error: Optional[str] = None
+ self._last_success_at: Optional[float] = None
+ self._login_required = False
+ self._playwright = None
+ self._browser = None
+ self._context = None
+ self._page = None
+ self._last_qr_image: Optional[bytes] = None
+ self._last_login_check: float = 0.0
+ self._last_login_ok: Optional[bool] = None
+ self._doc_url: Optional[str] = None
+
+ def start(self) -> None:
+ if self._running:
+ return
+ self._running = True
+ self._thread.start()
+
+ def stop(self) -> None:
+ if not self._running:
+ return
+ self._running = False
+ self._queue.put({"action": "shutdown"})
+
+ def get_status(self) -> Dict[str, Any]:
+ return {
+ "queue_size": self._queue.qsize(),
+ "login_required": self._login_required,
+ "last_error": self._last_error,
+ "last_success_at": self._last_success_at,
+ "last_login_ok": self._last_login_ok,
+ }
+
+ def enqueue_upload(
+ self,
+ *,
+ user_id: int,
+ account_id: str,
+ unit: str,
+ name: str,
+ image_path: str,
+ ) -> bool:
+ if not self._running:
+ self.start()
+
+ payload = {
+ "user_id": int(user_id),
+ "account_id": str(account_id),
+ "unit": unit,
+ "name": name,
+ "image_path": image_path,
+ }
+ try:
+ self._queue.put({"action": "upload", "payload": payload}, timeout=1)
+ return True
+ except queue.Full:
+ self._last_error = "上传队列已满"
+ return False
+
+ def request_qr(self, timeout: int = 60, *, force: bool = False) -> Dict[str, Any]:
+ return self._submit_command("qr", timeout=timeout, payload={"force": force})
+
+ def clear_login(self, timeout: int = 20) -> Dict[str, Any]:
+ return self._submit_command("clear_login", timeout=timeout)
+
+ def refresh_login_status(self, timeout: int = 20) -> Dict[str, Any]:
+ return self._submit_command("status", timeout=timeout)
+
+ def _submit_command(
+ self,
+ action: str,
+ timeout: int = 30,
+ payload: Optional[Dict[str, Any]] = None,
+ ) -> Dict[str, Any]:
+ if not self._running:
+ self.start()
+ resp_queue: queue.Queue = queue.Queue(maxsize=1)
+ self._queue.put({"action": action, "response": resp_queue, "payload": payload or {}})
+ try:
+ return resp_queue.get(timeout=timeout)
+ except queue.Empty:
+ return {"success": False, "error": "操作超时"}
+
+ def _run(self) -> None:
+ while True:
+ task = self._queue.get()
+ if not task:
+ continue
+ action = task.get("action")
+ if action == "shutdown":
+ break
+ try:
+ if action == "upload":
+ self._handle_upload(task.get("payload") or {})
+ elif action == "qr":
+ result = self._handle_qr(task.get("payload") or {})
+ task.get("response").put(result)
+ elif action == "clear_login":
+ result = self._handle_clear_login()
+ task.get("response").put(result)
+ elif action == "status":
+ result = self._handle_status_check()
+ task.get("response").put(result)
+ except Exception as e:
+ logger.warning(f"[KDocs] 处理任务失败: {e}")
+
+ self._cleanup_browser()
+
+ def _load_system_config(self) -> Dict[str, Any]:
+ return database.get_system_config() or {}
+
+ def _ensure_playwright(self, *, use_storage_state: bool = True) -> bool:
+ if sync_playwright is None:
+ self._last_error = "playwright 未安装"
+ return False
+ try:
+ if self._playwright is None:
+ self._playwright = sync_playwright().start()
+ if self._browser is None:
+ headless = os.environ.get("KDOCS_HEADLESS", "true").lower() != "false"
+ self._browser = self._playwright.chromium.launch(headless=headless)
+ if self._context is None:
+ storage_state = getattr(config, "KDOCS_LOGIN_STATE_FILE", "data/kdocs_login_state.json")
+ if use_storage_state and os.path.exists(storage_state):
+ self._context = self._browser.new_context(storage_state=storage_state)
+ else:
+ self._context = self._browser.new_context()
+ if self._page is None or self._page.is_closed():
+ self._page = self._context.new_page()
+ self._page.set_default_timeout(int(getattr(config, "DEFAULT_TIMEOUT", 60000)))
+ return True
+ except Exception as e:
+ self._last_error = f"浏览器启动失败: {e}"
+ self._cleanup_browser()
+ return False
+
+ def _cleanup_browser(self) -> None:
+ try:
+ if self._page:
+ self._page.close()
+ except Exception:
+ pass
+ self._page = None
+ try:
+ if self._context:
+ self._context.close()
+ except Exception:
+ pass
+ self._context = None
+ try:
+ if self._browser:
+ self._browser.close()
+ except Exception:
+ pass
+ self._browser = None
+ try:
+ if self._playwright:
+ self._playwright.stop()
+ except Exception:
+ pass
+ self._playwright = None
+
+ def _open_document(self, doc_url: str, *, fast: bool = False) -> bool:
+ try:
+ self._doc_url = doc_url
+ self._ensure_clipboard_permissions(doc_url)
+ if fast:
+ doc_pages = self._find_doc_pages(doc_url)
+ if doc_pages:
+ self._page = doc_pages[0]
+ return True
+ login_pages = []
+ for page in self._list_pages():
+ try:
+ url = getattr(page, "url", "") or ""
+ if self._is_login_url(url):
+ login_pages.append(page)
+ except Exception:
+ continue
+ if login_pages:
+ self._page = login_pages[0]
+ return True
+ goto_kwargs = {}
+ if fast:
+ fast_timeout = int(os.environ.get("KDOCS_FAST_GOTO_TIMEOUT_MS", "15000"))
+ goto_kwargs = {"wait_until": "domcontentloaded", "timeout": fast_timeout}
+ self._page.goto(doc_url, **goto_kwargs)
+ time.sleep(0.6)
+ doc_pages = self._find_doc_pages(doc_url)
+ if doc_pages and doc_pages[0] is not self._page:
+ self._page = doc_pages[0]
+ return True
+ except Exception as e:
+ self._last_error = f"打开文档失败: {e}"
+ return False
+
+ def _ensure_clipboard_permissions(self, doc_url: str) -> None:
+ if not self._context or not doc_url:
+ return
+ try:
+ parsed = urlparse(doc_url)
+ if not parsed.scheme or not parsed.netloc:
+ return
+ host = parsed.netloc
+ origins = {f"{parsed.scheme}://{host}"}
+ if host.startswith("www."):
+ origins.add(f"{parsed.scheme}://{host[4:]}")
+ else:
+ origins.add(f"{parsed.scheme}://www.{host}")
+ for origin in origins:
+ try:
+ self._context.grant_permissions(["clipboard-read", "clipboard-write"], origin=origin)
+ except Exception:
+ continue
+ except Exception:
+ return
+
+ def _normalize_doc_url(self, url: str) -> str:
+ if not url:
+ return ""
+ return url.split("#", 1)[0].split("?", 1)[0].rstrip("/")
+
+ def _list_pages(self) -> list:
+ pages = []
+ if self._context:
+ pages.extend(self._context.pages)
+ if self._page and self._page not in pages:
+ pages.insert(0, self._page)
+ return pages
+
+ def _is_login_url(self, url: str) -> bool:
+ if not url:
+ return False
+ lower = url.lower()
+ try:
+ host = urlparse(lower).netloc
+ except Exception:
+ host = ""
+ if "account.wps.cn" in host:
+ return True
+ if "passport" in lower:
+ return True
+ if "login" in lower and "kdocs.cn" not in host:
+ return True
+ return False
+
+ def _find_doc_pages(self, doc_url: Optional[str]) -> list:
+ doc_key = self._normalize_doc_url(doc_url or "")
+ pages = self._list_pages()
+ matches = []
+ for page in pages:
+ url = getattr(page, "url", "") or ""
+ if not url:
+ continue
+ if self._is_login_url(url):
+ continue
+ norm_url = self._normalize_doc_url(url)
+ if doc_key and doc_key in norm_url:
+ matches.append(page)
+ continue
+ try:
+ host = urlparse(url).netloc.lower()
+ except Exception:
+ host = ""
+ if "kdocs.cn" in host:
+ matches.append(page)
+ return matches
+
+ def _page_has_login_gate(self, page) -> bool:
+ url = getattr(page, "url", "") or ""
+ if self._is_login_url(url):
+ return True
+ login_texts = [
+ "登录并加入编辑",
+ "立即登录",
+ "微信登录",
+ "扫码登录",
+ "确认登录",
+ "确认登陆",
+ "账号登录",
+ "登录",
+ ]
+ for text in login_texts:
+ try:
+ if page.get_by_role("button", name=text).is_visible(timeout=800):
+ return True
+ except Exception:
+ pass
+ try:
+ if page.get_by_role("link", name=text).is_visible(timeout=800):
+ return True
+ except Exception:
+ pass
+ try:
+ if page.get_by_text(text, exact=True).is_visible(timeout=800):
+ return True
+ except Exception:
+ pass
+ try:
+ if page.locator("text=登录并加入编辑").first.is_visible(timeout=800):
+ return True
+ except Exception:
+ pass
+ return False
+
+ def _is_logged_in(self) -> bool:
+ doc_pages = self._find_doc_pages(self._doc_url)
+ if not doc_pages:
+ if self._page and not self._page.is_closed() and not self._page_has_login_gate(self._page):
+ return False
+ return False
+ page = doc_pages[0]
+ if self._page is None or self._page.is_closed() or self._page.url != page.url:
+ self._page = page
+ return not self._page_has_login_gate(page)
+
+ def _has_saved_login_state(self) -> bool:
+ storage_state = getattr(config, "KDOCS_LOGIN_STATE_FILE", "data/kdocs_login_state.json")
+ return os.path.exists(storage_state)
+
+ def _ensure_login_dialog(
+ self,
+ *,
+ timeout_ms: int = 1200,
+ frame_timeout_ms: int = 800,
+ quick: bool = False,
+ ) -> None:
+ login_names = ["登录并加入编辑", "立即登录", "登录"]
+ wechat_names = ["微信登录", "微信扫码登录", "微信扫码", "扫码登录"]
+ pages = self._iter_pages()
+ clicked = False
+ for page in pages:
+ if self._try_click_names(
+ page,
+ login_names,
+ timeout_ms=timeout_ms,
+ frame_timeout_ms=frame_timeout_ms,
+ quick=quick,
+ ):
+ clicked = True
+ break
+ if clicked:
+ time.sleep(1.5)
+ pages = self._iter_pages()
+ for page in pages:
+ if self._try_click_names(
+ page,
+ wechat_names,
+ timeout_ms=timeout_ms,
+ frame_timeout_ms=frame_timeout_ms,
+ quick=quick,
+ ):
+ return
+ self._try_confirm_login(timeout_ms=timeout_ms, frame_timeout_ms=frame_timeout_ms, quick=quick)
+
+ def _capture_qr_image(self) -> Optional[bytes]:
+ pages = self._iter_pages()
+ for page in pages:
+ for frame in page.frames:
+ target = self._find_qr_element_in_frame(frame)
+ if not target:
+ continue
+ try:
+ return target.screenshot()
+ except Exception:
+ continue
+ for page in pages:
+ dialog_image = self._capture_dialog_image(page)
+ if dialog_image:
+ return dialog_image
+ return None
+
+ def _iter_pages(self) -> list:
+ pages = []
+ if self._context:
+ pages.extend(self._context.pages)
+ if self._page and self._page not in pages:
+ pages.insert(0, self._page)
+ def rank(p) -> int:
+ url = (getattr(p, "url", "") or "").lower()
+ keywords = ("login", "account", "passport", "wechat", "qr")
+ return 0 if any(k in url for k in keywords) else 1
+ pages.sort(key=rank)
+ return pages
+
+ def _find_qr_element_in_frame(self, frame) -> Optional[Any]:
+ selectors = [
+ "canvas",
+ "img[alt*='二维码']",
+ "img[src*='qr']",
+ "img[src*='qrcode']",
+ "img[class*='qr']",
+ "canvas[class*='qr']",
+ "svg",
+ "div[role='img']",
+ "div[class*='qr']",
+ "div[id*='qr']",
+ "div[class*='qrcode']",
+ "div[id*='qrcode']",
+ "div[class*='wechat']",
+ "div[class*='weixin']",
+ "div[class*='wx']",
+ "img[src*='wx']",
+ ]
+ best = None
+ best_score = None
+ for selector in selectors:
+ try:
+ locator = frame.locator(selector)
+ count = min(locator.count(), 20)
+ except Exception:
+ continue
+ for i in range(count):
+ el = locator.nth(i)
+ try:
+ if not el.is_visible(timeout=800):
+ continue
+ box = el.bounding_box()
+ if not box:
+ continue
+ width = box.get("width", 0)
+ height = box.get("height", 0)
+ if width < 80 or height < 80 or width > 520 or height > 520:
+ continue
+ aspect_diff = abs(width - height)
+ if aspect_diff > 80:
+ continue
+ score = aspect_diff + abs(width - 260) + abs(height - 260)
+ if best_score is None or score < best_score:
+ best_score = score
+ best = el
+ except Exception:
+ continue
+ if best:
+ return best
+
+ handle = None
+ try:
+ handle = frame.evaluate_handle(
+ """() => {
+ const patterns = [/qr/i, /qrcode/i, /weixin/i, /wechat/i, /wx/i, /data:image/i];
+ const elements = Array.from(document.querySelectorAll('*'));
+ for (const el of elements) {
+ const style = window.getComputedStyle(el);
+ const bg = style.backgroundImage || '';
+ if (!bg || bg === 'none') continue;
+ if (!patterns.some((re) => re.test(bg))) continue;
+ const rect = el.getBoundingClientRect();
+ if (!rect.width || !rect.height) continue;
+ if (rect.width < 80 || rect.height < 80 || rect.width > 520 || rect.height > 520) continue;
+ const diff = Math.abs(rect.width - rect.height);
+ if (diff > 80) continue;
+ return el;
+ }
+ return null;
+ }"""
+ )
+ element = handle.as_element() if handle else None
+ if element:
+ return element
+ except Exception:
+ pass
+ finally:
+ try:
+ if handle:
+ handle.dispose()
+ except Exception:
+ pass
+ return best
+
+ def _try_click_role(self, page, role: str, name: str, timeout: int = 1500) -> bool:
+ try:
+ el = page.get_by_role(role, name=name)
+ if el.is_visible(timeout=timeout):
+ el.click()
+ time.sleep(1)
+ return True
+ except Exception:
+ return False
+ return False
+
+ def _try_click_names(
+ self,
+ page,
+ names: list,
+ *,
+ timeout_ms: int = 1200,
+ frame_timeout_ms: int = 800,
+ quick: bool = False,
+ ) -> bool:
+ for name in names:
+ if self._try_click_role(page, "button", name, timeout=timeout_ms):
+ return True
+ if not quick:
+ if self._try_click_role(page, "link", name, timeout=timeout_ms):
+ return True
+ try:
+ el = page.get_by_text(name, exact=True)
+ if el.is_visible(timeout=timeout_ms):
+ el.click()
+ time.sleep(1)
+ return True
+ except Exception:
+ pass
+ if not quick:
+ try:
+ el = page.get_by_text(name, exact=False)
+ if el.is_visible(timeout=timeout_ms):
+ el.click()
+ time.sleep(1)
+ return True
+ except Exception:
+ pass
+ try:
+ for frame in page.frames:
+ for name in names:
+ try:
+ el = frame.get_by_role("button", name=name)
+ if el.is_visible(timeout=frame_timeout_ms):
+ el.click()
+ time.sleep(1)
+ return True
+ except Exception:
+ pass
+ try:
+ el = frame.get_by_text(name, exact=True)
+ if el.is_visible(timeout=frame_timeout_ms):
+ el.click()
+ time.sleep(1)
+ return True
+ except Exception:
+ pass
+ if not quick:
+ try:
+ el = frame.get_by_text(name, exact=False)
+ if el.is_visible(timeout=frame_timeout_ms):
+ el.click()
+ time.sleep(1)
+ return True
+ except Exception:
+ pass
+ except Exception:
+ return False
+ return False
+
+ def _try_confirm_login(
+ self,
+ *,
+ timeout_ms: int = 1200,
+ frame_timeout_ms: int = 800,
+ quick: bool = False,
+ ) -> bool:
+ confirm_names = ["确认登录", "确认登陆"]
+ pages = self._iter_pages()
+ for page in pages:
+ if self._try_click_names(
+ page,
+ confirm_names,
+ timeout_ms=timeout_ms,
+ frame_timeout_ms=frame_timeout_ms,
+ quick=quick,
+ ):
+ return True
+ return False
+
+ def _capture_dialog_image(self, page) -> Optional[bytes]:
+ selectors = (
+ "[role='dialog'], .dialog, .modal, .popup, "
+ "div[class*='dialog'], div[class*='modal'], div[class*='popup'], "
+ "div[class*='login'], div[class*='passport'], div[class*='auth']"
+ )
+ try:
+ dialogs = page.locator(selectors)
+ count = min(dialogs.count(), 6)
+ except Exception:
+ return None
+ best = None
+ best_area = 0
+ viewport = page.viewport_size or {}
+ vp_width = viewport.get("width", 0)
+ vp_height = viewport.get("height", 0)
+ for i in range(count):
+ el = dialogs.nth(i)
+ try:
+ if not el.is_visible(timeout=800):
+ continue
+ box = el.bounding_box()
+ if not box:
+ continue
+ width = box.get("width", 0)
+ height = box.get("height", 0)
+ if width < 160 or height < 160:
+ continue
+ if vp_width and vp_height:
+ if width > vp_width * 0.92 and height > vp_height * 0.92:
+ continue
+ area = width * height
+ if area > best_area:
+ best_area = area
+ best = el
+ except Exception:
+ continue
+ if not best:
+ return None
+ try:
+ return best.screenshot()
+ except Exception:
+ return None
+
+ def _is_valid_qr_image(self, data: Optional[bytes]) -> bool:
+ if not data or len(data) < 1024:
+ return False
+ try:
+ from PIL import Image, ImageStat
+
+ img = Image.open(BytesIO(data))
+ width, height = img.size
+ if width < 120 or height < 120:
+ return False
+ ratio = width / float(height or 1)
+ if ratio < 0.6 or ratio > 1.4:
+ return False
+ gray = img.convert("L")
+ stat = ImageStat.Stat(gray)
+ if stat.stddev[0] < 5:
+ return False
+ return True
+ except Exception:
+ return len(data) >= 2048
+
+ def _save_login_state(self) -> None:
+ try:
+ storage_state = getattr(config, "KDOCS_LOGIN_STATE_FILE", "data/kdocs_login_state.json")
+ os.makedirs(os.path.dirname(storage_state), exist_ok=True)
+ self._context.storage_state(path=storage_state)
+ except Exception as e:
+ logger.warning(f"[KDocs] 保存登录态失败: {e}")
+
+ def _handle_qr(self, payload: Dict[str, Any]) -> Dict[str, Any]:
+ cfg = self._load_system_config()
+ doc_url = (cfg.get("kdocs_doc_url") or "").strip()
+ if not doc_url:
+ return {"success": False, "error": "未配置金山文档链接"}
+ force = bool(payload.get("force"))
+ if force:
+ self._handle_clear_login()
+ if not self._ensure_playwright(use_storage_state=not force):
+ return {"success": False, "error": self._last_error or "浏览器不可用"}
+ if not self._open_document(doc_url, fast=True):
+ return {"success": False, "error": self._last_error or "打开文档失败"}
+
+ if not force and self._has_saved_login_state() and self._is_logged_in():
+ self._login_required = False
+ self._last_login_ok = True
+ self._save_login_state()
+ return {"success": True, "logged_in": True, "qr_image": ""}
+
+ fast_login_timeout = int(os.environ.get("KDOCS_FAST_LOGIN_TIMEOUT_MS", "300"))
+ self._ensure_login_dialog(
+ timeout_ms=fast_login_timeout,
+ frame_timeout_ms=fast_login_timeout,
+ quick=True,
+ )
+ qr_image = None
+ invalid_qr = None
+ for attempt in range(10):
+ if attempt in (3, 7):
+ self._ensure_login_dialog(
+ timeout_ms=fast_login_timeout,
+ frame_timeout_ms=fast_login_timeout,
+ quick=True,
+ )
+ candidate = self._capture_qr_image()
+ if candidate and self._is_valid_qr_image(candidate):
+ qr_image = candidate
+ break
+ if candidate:
+ invalid_qr = candidate
+ time.sleep(1)
+ if not qr_image:
+ self._last_error = "二维码识别异常" if invalid_qr else "二维码获取失败"
+ try:
+ pages = self._iter_pages()
+ page_urls = [getattr(p, "url", "") for p in pages]
+ logger.warning(f"[KDocs] 二维码未捕获,页面: {page_urls}")
+ ts = int(time.time())
+ saved = []
+ for idx, page in enumerate(pages[:3]):
+ try:
+ path = f"data/kdocs_debug_{ts}_{idx}.png"
+ page.screenshot(path=path, full_page=True)
+ saved.append(path)
+ except Exception:
+ continue
+ if saved:
+ logger.warning(f"[KDocs] 已保存调试截图: {saved}")
+ if invalid_qr:
+ try:
+ path = f"data/kdocs_invalid_qr_{ts}.png"
+ with open(path, "wb") as handle:
+ handle.write(invalid_qr)
+ logger.warning(f"[KDocs] 已保存无效二维码截图: {path}")
+ except Exception:
+ pass
+ except Exception:
+ pass
+ return {"success": False, "error": self._last_error}
+
+ try:
+ ts = int(time.time())
+ path = f"data/kdocs_last_qr_{ts}.png"
+ with open(path, "wb") as handle:
+ handle.write(qr_image)
+ logger.info(f"[KDocs] 已保存二维码截图: {path} ({len(qr_image)} bytes)")
+ except Exception as e:
+ logger.warning(f"[KDocs] 保存二维码截图失败: {e}")
+
+ self._last_qr_image = qr_image
+ self._login_required = True
+ return {
+ "success": True,
+ "logged_in": False,
+ "qr_image": base64.b64encode(qr_image).decode("ascii"),
+ }
+
+ def _handle_clear_login(self) -> Dict[str, Any]:
+ storage_state = getattr(config, "KDOCS_LOGIN_STATE_FILE", "data/kdocs_login_state.json")
+ try:
+ if os.path.exists(storage_state):
+ os.remove(storage_state)
+ except Exception as e:
+ return {"success": False, "error": f"清除登录态失败: {e}"}
+
+ self._login_required = False
+ self._last_login_ok = None
+ self._cleanup_browser()
+ return {"success": True}
+
+ def _handle_status_check(self) -> Dict[str, Any]:
+ cfg = self._load_system_config()
+ doc_url = (cfg.get("kdocs_doc_url") or "").strip()
+ if not doc_url:
+ return {"success": True, "logged_in": False, "error": "未配置文档链接"}
+ if not self._ensure_playwright():
+ return {"success": False, "logged_in": False, "error": self._last_error or "浏览器不可用"}
+ if not self._open_document(doc_url, fast=True):
+ return {"success": False, "logged_in": False, "error": self._last_error or "打开文档失败"}
+ fast_login_timeout = int(os.environ.get("KDOCS_FAST_LOGIN_TIMEOUT_MS", "300"))
+ self._ensure_login_dialog(
+ timeout_ms=fast_login_timeout,
+ frame_timeout_ms=fast_login_timeout,
+ quick=True,
+ )
+ self._try_confirm_login(
+ timeout_ms=fast_login_timeout,
+ frame_timeout_ms=fast_login_timeout,
+ quick=True,
+ )
+ logged_in = self._is_logged_in()
+ self._last_login_ok = logged_in
+ self._login_required = not logged_in
+ if logged_in:
+ self._save_login_state()
+ return {"success": True, "logged_in": logged_in}
+
+ def _handle_upload(self, payload: Dict[str, Any]) -> None:
+ cfg = self._load_system_config()
+ if int(cfg.get("kdocs_enabled", 0) or 0) != 1:
+ return
+ doc_url = (cfg.get("kdocs_doc_url") or "").strip()
+ if not doc_url:
+ return
+
+ unit = (payload.get("unit") or "").strip()
+ name = (payload.get("name") or "").strip()
+ image_path = payload.get("image_path")
+ user_id = payload.get("user_id")
+ account_id = payload.get("account_id")
+
+ if not unit or not name:
+ return
+ if not image_path or not os.path.exists(image_path):
+ return
+
+ account = None
+ prev_status = None
+ status_tracked = False
+
+ try:
+ try:
+ account = safe_get_account(user_id, account_id)
+ if account and self._should_mark_upload(account):
+ prev_status = getattr(account, "status", None)
+ account.status = "上传截图"
+ self._emit_account_update(user_id, account)
+ status_tracked = True
+ except Exception:
+ prev_status = None
+
+ if not self._ensure_playwright():
+ self._notify_admin(unit, name, image_path, self._last_error or "浏览器不可用")
+ return
+
+ if not self._open_document(doc_url):
+ self._notify_admin(unit, name, image_path, self._last_error or "打开文档失败")
+ return
+
+ if not self._is_logged_in():
+ self._login_required = True
+ self._last_login_ok = False
+ self._notify_admin(unit, name, image_path, "登录已失效,请管理员重新扫码登录")
+ try:
+ log_to_client("表格上传失败: 登录已失效,请管理员重新扫码登录", user_id, account_id)
+ except Exception:
+ pass
+ return
+ self._login_required = False
+ self._last_login_ok = True
+
+ sheet_name = (cfg.get("kdocs_sheet_name") or "").strip()
+ sheet_index = int(cfg.get("kdocs_sheet_index") or 0)
+ unit_col = (cfg.get("kdocs_unit_column") or "A").strip().upper()
+ image_col = (cfg.get("kdocs_image_column") or "D").strip().upper()
+ row_start = int(cfg.get("kdocs_row_start") or 0)
+ row_end = int(cfg.get("kdocs_row_end") or 0)
+
+ success = False
+ error_msg = ""
+ for attempt in range(2):
+ try:
+ if sheet_name or sheet_index:
+ self._select_sheet(sheet_name, sheet_index)
+ row_num = self._find_person_with_unit(unit, name, unit_col, row_start=row_start, row_end=row_end)
+ if row_num < 0:
+ error_msg = f"未找到人员: {unit}-{name}"
+ break
+ success = self._upload_image_to_cell(row_num, image_path, image_col)
+ if success:
+ break
+ except Exception as e:
+ error_msg = str(e)
+
+ if success:
+ self._last_success_at = time.time()
+ self._last_error = None
+ try:
+ log_to_client(f"已上传表格截图: {unit}-{name}", user_id, account_id)
+ except Exception:
+ pass
+ return
+
+ if not error_msg:
+ error_msg = "上传失败"
+ self._last_error = error_msg
+ self._notify_admin(unit, name, image_path, error_msg)
+ try:
+ log_to_client(f"表格上传失败: {error_msg}", user_id, account_id)
+ except Exception:
+ pass
+ finally:
+ if status_tracked:
+ self._restore_account_status(user_id, account, prev_status)
+
+ def _notify_admin(self, unit: str, name: str, image_path: str, error: str) -> None:
+ cfg = self._load_system_config()
+ if int(cfg.get("kdocs_admin_notify_enabled", 0) or 0) != 1:
+ return
+ to_email = (cfg.get("kdocs_admin_notify_email") or "").strip()
+ if not to_email:
+ return
+ settings = email_service.get_email_settings()
+ if not settings.get("enabled", False):
+ return
+ subject = "金山文档上传失败提醒"
+ body = (
+ f"上传失败\n\n人员: {unit}-{name}\n图片: {image_path}\n错误: {error}\n\n"
+ "请检查登录状态或表格配置。"
+ )
+ try:
+ email_service.send_email_async(
+ to_email=to_email,
+ subject=subject,
+ body=body,
+ email_type="kdocs_upload_failed",
+ )
+ except Exception as e:
+ logger.warning(f"[KDocs] 发送管理员邮件失败: {e}")
+
+ def _emit_account_update(self, user_id: int, account: Any) -> None:
+ try:
+ socketio = get_socketio()
+ socketio.emit("account_update", account.to_dict(), room=f"user_{user_id}")
+ except Exception:
+ pass
+
+ def _restore_account_status(self, user_id: int, account: Any, prev_status: Optional[str]) -> None:
+ if not account or not user_id:
+ return
+ if getattr(account, "is_running", False):
+ return
+ if getattr(account, "status", "") != "上传截图":
+ return
+ account.status = prev_status or "未开始"
+ self._emit_account_update(user_id, account)
+
+ def _select_sheet(self, sheet_name: str, sheet_index: int) -> None:
+ if sheet_name:
+ candidates = [
+ self._page.locator("[role='tab']").filter(has_text=sheet_name),
+ self._page.locator(".sheet-tab").filter(has_text=sheet_name),
+ self._page.locator(".sheet-tab-name").filter(has_text=sheet_name),
+ ]
+ for locator in candidates:
+ try:
+ if locator.count() < 1:
+ continue
+ locator.first.click()
+ time.sleep(0.5)
+ return
+ except Exception:
+ continue
+
+ if sheet_index > 0:
+ idx = sheet_index - 1
+ candidates = [
+ self._page.locator("[role='tab']"),
+ self._page.locator(".sheet-tab"),
+ self._page.locator(".sheet-tab-name"),
+ ]
+ for locator in candidates:
+ try:
+ if locator.count() <= idx:
+ continue
+ locator.nth(idx).click()
+ time.sleep(0.5)
+ return
+ except Exception:
+ continue
+
+ def _get_current_cell_address(self) -> str:
+ """获取当前选中的单元格地址(如 A1, C66 等)"""
+ import re
+ # 等待一小段时间让名称框稳定
+ time.sleep(0.1)
+
+ for attempt in range(3):
+ try:
+ name_box = self._page.locator("input.edit-box").first
+ value = name_box.input_value()
+ # 验证是否为有效的单元格地址格式(如 A1, C66, AA100 等)
+ if value and re.match(r"^[A-Z]+\d+$", value.upper()):
+ return value.upper()
+ except Exception:
+ pass
+
+ try:
+ name_box = self._page.locator('#root input[type="text"]').first
+ value = name_box.input_value()
+ if value and re.match(r"^[A-Z]+\d+$", value.upper()):
+ return value.upper()
+ except Exception:
+ pass
+
+ # 等待一下再重试
+ time.sleep(0.2)
+
+ # 如果无法获取有效地址,返回空字符串
+ logger.warning("[KDocs调试] 无法获取有效的单元格地址")
+ return ""
+
+ def _navigate_to_cell(self, cell_address: str) -> None:
+ try:
+ name_box = self._page.locator("input.edit-box").first
+ name_box.click()
+ name_box.fill(cell_address)
+ name_box.press("Enter")
+ except Exception:
+ name_box = self._page.locator('#root input[type="text"]').first
+ name_box.click()
+ name_box.fill(cell_address)
+ name_box.press("Enter")
+ time.sleep(0.3)
+
+ def _focus_grid(self) -> None:
+ try:
+ info = self._page.evaluate(
+ """() => {
+ const canvases = Array.from(document.querySelectorAll("canvas"));
+ let best = null;
+ for (const c of canvases) {
+ const rect = c.getBoundingClientRect();
+ if (!rect.width || !rect.height) continue;
+ if (rect.width < 200 || rect.height < 200) continue;
+ const area = rect.width * rect.height;
+ if (!best || area > best.area) {
+ best = {x: rect.left + rect.width / 2, y: rect.top + rect.height / 2, area};
+ }
+ }
+ return best;
+ }"""
+ )
+ if info and info.get("x") and info.get("y"):
+ self._page.mouse.click(info["x"], info["y"])
+ time.sleep(0.1)
+ except Exception:
+ pass
+
+ def _read_clipboard_text(self) -> str:
+ try:
+ return self._page.evaluate("() => navigator.clipboard.readText()") or ""
+ except Exception:
+ return ""
+
+ def _get_cell_value(self, cell_address: str) -> str:
+ self._navigate_to_cell(cell_address)
+ time.sleep(0.3)
+ try:
+ self._page.evaluate("() => navigator.clipboard.writeText('')")
+ except Exception:
+ pass
+ self._focus_grid()
+
+ # 尝试方法1: 读取金山文档编辑栏/公式栏的内容
+ try:
+ # 金山文档的编辑栏选择器(可能需要调整)
+ formula_bar_selectors = [
+ ".formula-bar-input",
+ ".cell-editor-input",
+ "[class*='formulaBar'] input",
+ "[class*='formula'] textarea",
+ ".formula-editor",
+ "#formulaInput",
+ ]
+ for selector in formula_bar_selectors:
+ try:
+ el = self._page.query_selector(selector)
+ if el:
+ value = el.input_value() if hasattr(el, 'input_value') else el.inner_text()
+ if value and not value.startswith("=DISPIMG"):
+ logger.info(f"[KDocs调试] 从编辑栏读取到: '{value[:50]}...' (selector={selector})")
+ return value.strip()
+ except Exception:
+ pass
+ except Exception:
+ pass
+
+ # 尝试方法2: F2进入编辑模式,全选复制
+ try:
+ self._page.keyboard.press("F2")
+ time.sleep(0.2)
+ self._page.keyboard.press("Control+a")
+ time.sleep(0.1)
+ self._page.keyboard.press("Control+c")
+ time.sleep(0.2)
+ self._page.keyboard.press("Escape")
+ time.sleep(0.1)
+ value = self._read_clipboard_text()
+ if value and not value.startswith("=DISPIMG"):
+ return value.strip()
+ except Exception:
+ pass
+
+ # 尝试方法3: 直接复制单元格(备选)
+ try:
+ self._page.keyboard.press("Control+c")
+ time.sleep(0.2)
+ value = self._read_clipboard_text()
+ if value:
+ return value.strip()
+ except Exception:
+ pass
+ return ""
+
+ def _normalize_text(self, value: str) -> str:
+ if value is None:
+ return ""
+ cleaned = str(value)
+ cleaned = cleaned.replace("\u00a0", "").replace("\u3000", "")
+ cleaned = re.sub(r"\s+", "", cleaned)
+ return cleaned.strip()
+
+ def _unit_matches(self, cell_value: str, expected_unit: str) -> bool:
+ if not cell_value or not expected_unit:
+ return False
+ norm_cell = self._normalize_text(cell_value)
+ norm_expected = self._normalize_text(expected_unit)
+ if not norm_cell or not norm_expected:
+ return False
+ if norm_cell == norm_expected:
+ return True
+ if norm_expected in norm_cell or norm_cell in norm_expected:
+ return True
+ return False
+
+ def _should_mark_upload(self, account: Any) -> bool:
+ if not account:
+ return False
+ status_text = str(getattr(account, "status", "") or "")
+ if status_text:
+ if "运行" in status_text or "排队" in status_text:
+ return False
+ return True
+
+ def _search_person(self, name: str) -> None:
+ self._focus_grid()
+ self._page.keyboard.press("Control+f")
+ time.sleep(0.3)
+ search_input = None
+ selectors = [
+ "input[placeholder*='查找']",
+ "input[placeholder*='搜索']",
+ "input[aria-label*='查找']",
+ "input[aria-label*='搜索']",
+ "input[type='search']",
+ ]
+ try:
+ search_input = self._page.get_by_role("textbox").nth(3)
+ if search_input.is_visible(timeout=800):
+ search_input.fill(name)
+ except Exception:
+ search_input = None
+ if not search_input:
+ for selector in selectors:
+ try:
+ candidate = self._page.locator(selector).first
+ if candidate.is_visible(timeout=800):
+ search_input = candidate
+ search_input.fill(name)
+ break
+ except Exception:
+ continue
+ if not search_input:
+ try:
+ self._page.keyboard.type(name)
+ except Exception:
+ pass
+ time.sleep(0.2)
+ try:
+ find_btn = self._page.get_by_role("button", name="查找").nth(2)
+ find_btn.click()
+ except Exception:
+ try:
+ self._page.get_by_role("button", name="查找").first.click()
+ except Exception:
+ try:
+ self._page.keyboard.press("Enter")
+ except Exception:
+ pass
+ time.sleep(0.3)
+
+ def _find_next(self) -> None:
+ try:
+ find_btn = self._page.get_by_role("button", name="查找").nth(2)
+ find_btn.click()
+ except Exception:
+ try:
+ self._page.get_by_role("button", name="查找").first.click()
+ except Exception:
+ try:
+ self._page.keyboard.press("Enter")
+ except Exception:
+ pass
+ time.sleep(0.3)
+
+ def _close_search(self) -> None:
+ self._page.keyboard.press("Escape")
+ time.sleep(0.2)
+
+ def _extract_row_number(self, cell_address: str) -> int:
+ import re
+
+ match = re.search(r"(\d+)$", cell_address)
+ if match:
+ return int(match.group(1))
+ return -1
+
+ def _verify_unit_by_navigation(self, row_num: int, unit: str, unit_col: str) -> bool:
+ """验证县区 - 从目标行开始搜索县区"""
+ logger.info(f"[KDocs调试] 验证县区: 期望行={row_num}, 期望值='{unit}'")
+
+ # 方法: 先导航到目标行的A列,然后从那里搜索县区
+ try:
+ # 1. 先导航到目标行的 A 列
+ start_cell = f"{unit_col}{row_num}"
+ self._navigate_to_cell(start_cell)
+ time.sleep(0.3)
+ logger.info(f"[KDocs调试] 已导航到 {start_cell}")
+
+ # 2. 从当前位置搜索县区
+ self._page.keyboard.press("Control+f")
+ time.sleep(0.3)
+
+ # 找到搜索框并输入
+ try:
+ search_input = self._page.locator("input[placeholder*='查找'], input[placeholder*='搜索'], input[type='text']").first
+ search_input.fill(unit)
+ time.sleep(0.2)
+ self._page.keyboard.press("Enter")
+ time.sleep(0.5)
+ except Exception as e:
+ logger.warning(f"[KDocs调试] 填写搜索框失败: {e}")
+ self._page.keyboard.press("Escape")
+ return False
+
+ # 3. 关闭搜索框,检查当前位置
+ self._page.keyboard.press("Escape")
+ time.sleep(0.3)
+
+ current_address = self._get_current_cell_address()
+ found_row = self._extract_row_number(current_address)
+ logger.info(f"[KDocs调试] 搜索'{unit}'后: 当前单元格={current_address}, 行号={found_row}")
+
+ # 4. 检查是否在同一行(允许在目标行或之后的几行内,因为搜索可能从当前位置向下)
+ if found_row == row_num:
+ logger.info(f"[KDocs调试] ✓ 验证成功! 县区'{unit}'在第{row_num}行")
+ return True
+ else:
+ logger.info(f"[KDocs调试] 验证失败: 期望行{row_num}, 实际找到行{found_row}")
+ return False
+
+ except Exception as e:
+ logger.warning(f"[KDocs调试] 验证异常: {e}")
+ return False
+
+ def _debug_dump_page_elements(self) -> None:
+ """调试: 输出页面上可能包含单元格值的元素"""
+ logger.info("[KDocs调试] ========== 页面元素分析 ==========")
+ try:
+ # 查找可能的编辑栏元素
+ selectors_to_check = [
+ "input", "textarea",
+ "[class*='formula']", "[class*='Formula']",
+ "[class*='editor']", "[class*='Editor']",
+ "[class*='cell']", "[class*='Cell']",
+ "[class*='input']", "[class*='Input']",
+ ]
+ for selector in selectors_to_check:
+ try:
+ elements = self._page.query_selector_all(selector)
+ for i, el in enumerate(elements[:3]): # 只看前3个
+ try:
+ class_name = el.get_attribute("class") or ""
+ value = ""
+ try:
+ value = el.input_value()
+ except:
+ try:
+ value = el.inner_text()
+ except:
+ pass
+ if value:
+ logger.info(f"[KDocs调试] 元素 {selector}[{i}] class='{class_name[:50]}' value='{value[:30]}'")
+ except:
+ pass
+ except:
+ pass
+ except Exception as e:
+ logger.warning(f"[KDocs调试] 页面元素分析失败: {e}")
+ logger.info("[KDocs调试] ====================================")
+
+ def _debug_dump_table_structure(self, target_row: int = 66) -> None:
+ """调试: 输出表格结构"""
+ self._debug_dump_page_elements() # 先分析页面元素
+ logger.info("[KDocs调试] ========== 表格结构分析 ==========")
+ cols = ['A', 'B', 'C', 'D', 'E']
+ for row in [1, 2, 3, target_row]:
+ row_data = []
+ for col in cols:
+ val = self._get_cell_value(f"{col}{row}")
+ # 截断太长的值
+ if len(val) > 30:
+ val = val[:30] + "..."
+ row_data.append(f"{col}{row}='{val}'")
+ logger.info(f"[KDocs调试] 第{row}行: {' | '.join(row_data)}")
+ logger.info("[KDocs调试] ====================================")
+
+ def _find_person_with_unit(self, unit: str, name: str, unit_col: str, max_attempts: int = 50,
+ row_start: int = 0, row_end: int = 0) -> int:
+ """
+ 查找人员所在行号。
+ 策略:只搜索姓名,找到姓名列(C列)的匹配项
+ 注意:组合搜索会匹配到图片列的错误位置,已放弃该方案
+
+ :param row_start: 有效行范围起始(0表示不限制)
+ :param row_end: 有效行范围结束(0表示不限制)
+ """
+ logger.info(f"[KDocs调试] 开始搜索人员: name='{name}', unit='{unit}'")
+ if row_start > 0 or row_end > 0:
+ logger.info(f"[KDocs调试] 有效行范围: {row_start}-{row_end}")
+
+ # 只搜索姓名 - 这是目前唯一可靠的方式
+ logger.info(f"[KDocs调试] 搜索姓名: '{name}'")
+ row_num = self._search_and_get_row(name, max_attempts=max_attempts, expected_col='C',
+ row_start=row_start, row_end=row_end)
+ if row_num > 0:
+ logger.info(f"[KDocs调试] ✓ 姓名搜索成功! 找到行号={row_num}")
+ return row_num
+
+ logger.warning(f"[KDocs调试] 搜索失败,未找到人员 '{name}'")
+ return -1
+
+ def _search_and_get_row(self, search_text: str, max_attempts: int = 10, expected_col: str = None,
+ row_start: int = 0, row_end: int = 0) -> int:
+ """
+ 执行搜索并获取找到的行号
+ :param search_text: 要搜索的文本
+ :param max_attempts: 最大尝试次数
+ :param expected_col: 期望的列(如 'C'),如果指定则只接受该列的结果
+ :param row_start: 有效行范围起始(0表示不限制)
+ :param row_end: 有效行范围结束(0表示不限制)
+ """
+ self._focus_grid()
+ self._search_person(search_text)
+ found_positions = set() # 记录已找到的位置(列+行)
+
+ for attempt in range(max_attempts):
+ self._close_search()
+ time.sleep(0.3) # 等待名称框更新
+
+ current_address = self._get_current_cell_address()
+ if not current_address:
+ logger.warning(f"[KDocs调试] 第{attempt+1}次: 无法获取单元格地址")
+ # 继续尝试下一个
+ self._page.keyboard.press("Control+f")
+ time.sleep(0.2)
+ self._find_next()
+ continue
+
+ row_num = self._extract_row_number(current_address)
+ # 提取列字母(A, B, C, D 等)
+ col_letter = ''.join(c for c in current_address if c.isalpha()).upper()
+
+ logger.info(f"[KDocs调试] 第{attempt+1}次搜索'{search_text}': 单元格={current_address}, 列={col_letter}, 行号={row_num}")
+
+ if row_num <= 0:
+ logger.warning(f"[KDocs调试] 无法提取行号,搜索可能没有结果")
+ return -1
+
+ # 检查是否已经访问过这个位置
+ position_key = f"{col_letter}{row_num}"
+ if position_key in found_positions:
+ logger.info(f"[KDocs调试] 位置{position_key}已搜索过,循环结束")
+ # 检查是否有任何有效结果
+ valid_results = [pos for pos in found_positions
+ if (not expected_col or pos.startswith(expected_col))
+ and self._extract_row_number(pos) > 2]
+ if valid_results:
+ # 返回第一个有效结果的行号
+ return self._extract_row_number(valid_results[0])
+ return -1
+
+ found_positions.add(position_key)
+
+ # 跳过标题行和表头行(通常是第1-2行)
+ if row_num <= 2:
+ logger.info(f"[KDocs调试] 跳过标题/表头行: {row_num}")
+ self._page.keyboard.press("Control+f")
+ time.sleep(0.2)
+ self._find_next()
+ continue
+
+ # 如果指定了期望的列,检查是否匹配
+ if expected_col and col_letter != expected_col.upper():
+ logger.info(f"[KDocs调试] 列不匹配: 期望={expected_col}, 实际={col_letter},继续搜索下一个")
+ self._page.keyboard.press("Control+f")
+ time.sleep(0.2)
+ self._find_next()
+ continue
+
+ # 检查行号是否在有效范围内
+ if row_start > 0 and row_num < row_start:
+ logger.info(f"[KDocs调试] 行号{row_num}小于起始行{row_start},继续搜索下一个")
+ self._page.keyboard.press("Control+f")
+ time.sleep(0.2)
+ self._find_next()
+ continue
+
+ if row_end > 0 and row_num > row_end:
+ logger.info(f"[KDocs调试] 行号{row_num}大于结束行{row_end},继续搜索下一个")
+ self._page.keyboard.press("Control+f")
+ time.sleep(0.2)
+ self._find_next()
+ continue
+
+ # 找到有效的数据行,列匹配且在行范围内
+ logger.info(f"[KDocs调试] ✓ 找到有效位置: {current_address} (在有效范围内)")
+ return row_num
+
+ self._close_search()
+ logger.warning(f"[KDocs调试] 达到最大尝试次数{max_attempts},未找到有效结果")
+ return -1
+
+ def _upload_image_to_cell(self, row_num: int, image_path: str, image_col: str) -> bool:
+ cell_address = f"{image_col}{row_num}"
+ self._navigate_to_cell(cell_address)
+ time.sleep(0.3)
+
+ # 清除单元格现有内容
+ try:
+ # 1. 导航到单元格(名称框输入地址+Enter,会跳转并可能进入编辑模式)
+ self._navigate_to_cell(cell_address)
+ time.sleep(0.3)
+
+ # 2. 按 Escape 退出可能的编辑模式,回到选中状态
+ self._page.keyboard.press("Escape")
+ time.sleep(0.3)
+
+ # 3. 按 Delete 删除选中单元格的内容
+ self._page.keyboard.press("Delete")
+ time.sleep(0.5)
+ logger.info(f"[KDocs] 已删除 {cell_address} 的内容")
+ except Exception as e:
+ logger.warning(f"[KDocs] 清除单元格内容时出错: {e}")
+
+ logger.info(f"[KDocs] 准备上传图片到 {cell_address},已清除旧内容")
+
+ try:
+ insert_btn = self._page.get_by_role("button", name="插入")
+ insert_btn.click()
+ time.sleep(0.3)
+ except Exception as e:
+ raise RuntimeError(f"打开插入菜单失败: {e}")
+
+ try:
+ image_btn = self._page.get_by_role("button", name="图片")
+ image_btn.click()
+ time.sleep(0.3)
+ cell_image_option = self._page.get_by_role("option", name="单元格图片")
+ cell_image_option.click()
+ time.sleep(0.2)
+ except Exception as e:
+ raise RuntimeError(f"选择单元格图片失败: {e}")
+
+ try:
+ local_option = self._page.get_by_role("option", name="本地")
+ with self._page.expect_file_chooser() as fc_info:
+ local_option.click()
+ file_chooser = fc_info.value
+ file_chooser.set_files(image_path)
+ except Exception as e:
+ raise RuntimeError(f"上传文件失败: {e}")
+
+ time.sleep(2)
+ return True
+
+
+_kdocs_uploader: Optional[KDocsUploader] = None
+
+
+def get_kdocs_uploader() -> KDocsUploader:
+ global _kdocs_uploader
+ if _kdocs_uploader is None:
+ _kdocs_uploader = KDocsUploader()
+ _kdocs_uploader.start()
+ return _kdocs_uploader
diff --git a/services/maintenance.py b/services/maintenance.py
index 0503623..7dbfb90 100644
--- a/services/maintenance.py
+++ b/services/maintenance.py
@@ -4,6 +4,7 @@ from __future__ import annotations
import threading
import time
+from datetime import datetime
from app_config import get_config
from app_logger import get_logger
@@ -26,6 +27,9 @@ USER_ACCOUNTS_EXPIRE_SECONDS = int(getattr(config, "USER_ACCOUNTS_EXPIRE_SECONDS
BATCH_TASK_EXPIRE_SECONDS = int(getattr(config, "BATCH_TASK_EXPIRE_SECONDS", 21600))
PENDING_RANDOM_EXPIRE_SECONDS = int(getattr(config, "PENDING_RANDOM_EXPIRE_SECONDS", 7200))
+# 金山文档离线通知状态:每次掉线只通知一次,恢复在线后重置
+_kdocs_offline_notified: bool = False
+
def cleanup_expired_data() -> None:
"""定期清理过期数据,防止内存泄漏(逻辑保持不变)。"""
@@ -91,6 +95,87 @@ def cleanup_expired_data() -> None:
logger.debug(f"已清理 {deleted_random} 个过期随机延迟任务")
+def check_kdocs_online_status() -> None:
+ """检测金山文档登录状态,如果离线则发送邮件通知管理员(每次掉线只通知一次)"""
+ global _kdocs_offline_notified
+
+ try:
+ import database
+ from services.kdocs_uploader import get_kdocs_uploader
+
+ # 获取系统配置
+ cfg = database.get_system_config()
+ if not cfg:
+ return
+
+ # 检查是否启用了金山文档功能
+ kdocs_enabled = int(cfg.get("kdocs_enabled") or 0)
+ if not kdocs_enabled:
+ return
+
+ # 检查是否启用了管理员通知
+ admin_notify_enabled = int(cfg.get("kdocs_admin_notify_enabled") or 0)
+ admin_notify_email = (cfg.get("kdocs_admin_notify_email") or "").strip()
+ if not admin_notify_enabled or not admin_notify_email:
+ return
+
+ # 获取金山文档状态
+ kdocs = get_kdocs_uploader()
+ status = kdocs.get_status()
+ login_required = status.get("login_required", False)
+ last_login_ok = status.get("last_login_ok")
+
+ # 如果需要登录或最后登录状态不是成功
+ is_offline = login_required or (last_login_ok is False)
+
+ if is_offline:
+ # 已经通知过了,不再重复通知
+ if _kdocs_offline_notified:
+ logger.debug("[KDocs监控] 金山文档离线,已通知过,跳过重复通知")
+ return
+
+ # 发送邮件通知
+ try:
+ import email_service
+
+ now_str = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
+ subject = "【金山文档离线告警】需要重新登录"
+ body = f"""
+您好,
+
+系统检测到金山文档上传功能已离线,需要重新扫码登录。
+
+检测时间:{now_str}
+状态详情:
+- 需要登录:{login_required}
+- 上次登录状态:{last_login_ok}
+
+请尽快登录后台,在"系统配置"→"金山文档上传"中点击"获取登录二维码"重新登录。
+
+---
+此邮件由系统自动发送,请勿直接回复。
+"""
+ email_service.send_email_async(
+ to_email=admin_notify_email,
+ subject=subject,
+ body=body,
+ email_type="kdocs_offline_alert",
+ )
+ _kdocs_offline_notified = True # 标记为已通知
+ logger.warning(f"[KDocs监控] 金山文档离线,已发送通知邮件到 {admin_notify_email}")
+ except Exception as e:
+ logger.error(f"[KDocs监控] 发送离线通知邮件失败: {e}")
+ else:
+ # 恢复在线,重置通知状态
+ if _kdocs_offline_notified:
+ logger.info("[KDocs监控] 金山文档已恢复在线,重置通知状态")
+ _kdocs_offline_notified = False
+ logger.debug("[KDocs监控] 金山文档状态正常")
+
+ except Exception as e:
+ logger.error(f"[KDocs监控] 检测失败: {e}")
+
+
def start_cleanup_scheduler() -> None:
"""启动定期清理调度器"""
@@ -106,3 +191,22 @@ def start_cleanup_scheduler() -> None:
cleanup_thread.start()
logger.info("内存清理调度器已启动")
+
+def start_kdocs_monitor() -> None:
+ """启动金山文档状态监控"""
+
+ def monitor_loop():
+ # 启动后等待 60 秒再开始检测(给系统初始化的时间)
+ time.sleep(60)
+ while True:
+ try:
+ check_kdocs_online_status()
+ time.sleep(300) # 每5分钟检测一次
+ except Exception as e:
+ logger.error(f"[KDocs监控] 监控任务执行失败: {e}")
+ time.sleep(60)
+
+ monitor_thread = threading.Thread(target=monitor_loop, daemon=True, name="kdocs-monitor")
+ monitor_thread.start()
+ logger.info("[KDocs监控] 金山文档状态监控已启动(每5分钟检测一次)")
+
diff --git a/services/scheduler.py b/services/scheduler.py
index 8c915f3..c562bb1 100644
--- a/services/scheduler.py
+++ b/services/scheduler.py
@@ -87,19 +87,32 @@ def run_scheduled_task(skip_weekday_check: bool = False) -> None:
cfg = database.get_system_config()
enable_screenshot_scheduled = cfg.get("enable_screenshot", 0) == 1
+ user_accounts = {}
+ account_ids = []
for user in approved_users:
user_id = user["id"]
accounts = safe_get_user_accounts_snapshot(user_id)
if not accounts:
load_user_accounts(user_id)
accounts = safe_get_user_accounts_snapshot(user_id)
+ if accounts:
+ user_accounts[user_id] = accounts
+ account_ids.extend(list(accounts.keys()))
+
+ account_statuses = database.get_account_status_batch(account_ids)
+
+ for user in approved_users:
+ user_id = user["id"]
+ accounts = user_accounts.get(user_id, {})
+ if not accounts:
+ continue
for account_id, account in accounts.items():
total_accounts += 1
if account.is_running:
continue
- account_status_info = database.get_account_status(account_id)
+ account_status_info = account_statuses.get(str(account_id))
if account_status_info:
status = account_status_info["status"] if "status" in account_status_info.keys() else "active"
if status == "suspended":
@@ -150,6 +163,16 @@ def scheduled_task_worker() -> None:
"""定时任务工作线程"""
import schedule
+ def decay_risk_scores():
+ """风险分衰减:每天定时执行一次"""
+ try:
+ from security.risk_scorer import RiskScorer
+
+ RiskScorer().decay_scores()
+ logger.info("[定时任务] 风险分衰减已执行")
+ except Exception as e:
+ logger.exception(f"[定时任务] 风险分衰减执行失败: {e}")
+
def cleanup_expired_captcha():
try:
deleted_count = safe_cleanup_expired_captcha()
@@ -362,7 +385,12 @@ def scheduled_task_worker() -> None:
if schedule_time_cst != str(schedule_time_raw or "").strip():
logger.warning(f"[定时任务] 系统定时时间格式无效,已回退到 {schedule_time_cst} (原值: {schedule_time_raw!r})")
- signature = (schedule_enabled, schedule_time_cst)
+ risk_decay_time_raw = os.environ.get("RISK_SCORE_DECAY_TIME_CST", "04:00")
+ risk_decay_time_cst = _normalize_hhmm(risk_decay_time_raw, default="04:00")
+ if risk_decay_time_cst != str(risk_decay_time_raw or "").strip():
+ logger.warning(f"[定时任务] 风险分衰减时间格式无效,已回退到 {risk_decay_time_cst} (原值: {risk_decay_time_raw!r})")
+
+ signature = (schedule_enabled, schedule_time_cst, risk_decay_time_cst)
config_changed = schedule_state.get("signature") != signature
is_first_run = schedule_state.get("signature") is None
if (not force) and (not config_changed):
@@ -374,6 +402,8 @@ def scheduled_task_worker() -> None:
cleanup_time_cst = "03:00"
schedule.every().day.at(cleanup_time_cst).do(cleanup_old_data)
+ schedule.every().day.at(risk_decay_time_cst).do(decay_risk_scores)
+
schedule.every().hour.do(cleanup_expired_captcha)
quota_reset_time_cst = "00:00"
@@ -381,6 +411,7 @@ def scheduled_task_worker() -> None:
if is_first_run:
logger.info(f"[定时任务] 已设置数据清理任务: 每天 CST {cleanup_time_cst}")
+ logger.info(f"[定时任务] 已设置风险分衰减: 每天 CST {risk_decay_time_cst}")
logger.info(f"[定时任务] 已设置验证码清理任务: 每小时执行一次")
logger.info(f"[定时任务] 已设置SMTP配额重置: 每天 CST {quota_reset_time_cst}")
diff --git a/services/screenshots.py b/services/screenshots.py
index a66cefd..8be6afb 100644
--- a/services/screenshots.py
+++ b/services/screenshots.py
@@ -3,15 +3,16 @@
from __future__ import annotations
import os
+import shutil
+import subprocess
import time
import database
import email_service
+from api_browser import APIBrowser, get_cookie_jar_path, is_cookie_jar_fresh
from app_config import get_config
from app_logger import get_logger
from browser_pool_worker import get_browser_worker_pool
-from playwright_automation import PlaywrightAutomation
-from services.browser_manager import get_browser_manager
from services.client_log import log_to_client
from services.runtime import get_socketio
from services.state import safe_get_account, safe_remove_task_status, safe_update_task_status
@@ -24,6 +25,165 @@ config = get_config()
SCREENSHOTS_DIR = config.SCREENSHOTS_DIR
os.makedirs(SCREENSHOTS_DIR, exist_ok=True)
+_WKHTMLTOIMAGE_TIMEOUT_SECONDS = int(os.environ.get("WKHTMLTOIMAGE_TIMEOUT_SECONDS", "60"))
+_WKHTMLTOIMAGE_JS_DELAY_MS = int(os.environ.get("WKHTMLTOIMAGE_JS_DELAY_MS", "3000"))
+_WKHTMLTOIMAGE_WIDTH = int(os.environ.get("WKHTMLTOIMAGE_WIDTH", "1920"))
+_WKHTMLTOIMAGE_HEIGHT = int(os.environ.get("WKHTMLTOIMAGE_HEIGHT", "1080"))
+_WKHTMLTOIMAGE_QUALITY = int(os.environ.get("WKHTMLTOIMAGE_QUALITY", "95"))
+_WKHTMLTOIMAGE_ZOOM = float(os.environ.get("WKHTMLTOIMAGE_ZOOM", "1.0"))
+_WKHTMLTOIMAGE_FULL_PAGE = str(os.environ.get("WKHTMLTOIMAGE_FULL_PAGE", "")).strip().lower() in (
+ "1",
+ "true",
+ "yes",
+ "on",
+)
+_env_crop_w = os.environ.get("WKHTMLTOIMAGE_CROP_WIDTH")
+_env_crop_h = os.environ.get("WKHTMLTOIMAGE_CROP_HEIGHT")
+_WKHTMLTOIMAGE_CROP_WIDTH = int(_env_crop_w) if _env_crop_w is not None else _WKHTMLTOIMAGE_WIDTH
+_WKHTMLTOIMAGE_CROP_HEIGHT = (
+ int(_env_crop_h) if _env_crop_h is not None else (_WKHTMLTOIMAGE_HEIGHT if _WKHTMLTOIMAGE_HEIGHT > 0 else 0)
+)
+_WKHTMLTOIMAGE_CROP_X = int(os.environ.get("WKHTMLTOIMAGE_CROP_X", "0"))
+_WKHTMLTOIMAGE_CROP_Y = int(os.environ.get("WKHTMLTOIMAGE_CROP_Y", "0"))
+_WKHTMLTOIMAGE_UA = os.environ.get(
+ "WKHTMLTOIMAGE_USER_AGENT",
+ "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/120.0.0.0 Safari/537.36",
+)
+
+
+def _resolve_wkhtmltoimage_path() -> str | None:
+ return os.environ.get("WKHTMLTOIMAGE_PATH") or shutil.which("wkhtmltoimage")
+
+
+def _read_cookie_pairs(cookies_path: str) -> list[tuple[str, str]]:
+ if not cookies_path or not os.path.exists(cookies_path):
+ return []
+ pairs = []
+ try:
+ with open(cookies_path, "r", encoding="utf-8", errors="ignore") as f:
+ for line in f:
+ line = line.strip()
+ if not line or line.startswith("#"):
+ continue
+ parts = line.split("\t")
+ if len(parts) < 7:
+ continue
+ name = parts[5].strip()
+ value = parts[6].strip()
+ if name:
+ pairs.append((name, value))
+ except Exception:
+ return []
+ return pairs
+
+
+def _select_cookie_pairs(pairs: list[tuple[str, str]]) -> list[tuple[str, str]]:
+ preferred_names = {"ASP.NET_SessionId", ".ASPXAUTH"}
+ preferred = [(name, value) for name, value in pairs if name in preferred_names and value]
+ if preferred:
+ return preferred
+ return [(name, value) for name, value in pairs if name and value and name.isascii() and value.isascii()]
+
+
+def _ensure_login_cookies(account, proxy_config, log_callback) -> bool:
+ """确保有可用的登录 cookies(通过 API 登录刷新)"""
+ try:
+ with APIBrowser(log_callback=log_callback, proxy_config=proxy_config) as api_browser:
+ if not api_browser.login(account.username, account.password):
+ return False
+ return api_browser.save_cookies_for_screenshot(account.username)
+ except Exception:
+ return False
+
+
+def take_screenshot_wkhtmltoimage(
+ url: str,
+ output_path: str,
+ cookies_path: str | None = None,
+ proxy_server: str | None = None,
+ run_script: str | None = None,
+ window_status: str | None = None,
+ log_callback=None,
+) -> bool:
+ wkhtmltoimage_path = _resolve_wkhtmltoimage_path()
+ if not wkhtmltoimage_path:
+ if log_callback:
+ log_callback("wkhtmltoimage 未安装或不在 PATH 中")
+ return False
+
+ ext = os.path.splitext(output_path)[1].lower()
+ image_format = "jpg" if ext in (".jpg", ".jpeg") else "png"
+
+ cmd = [
+ wkhtmltoimage_path,
+ "--format",
+ image_format,
+ "--width",
+ str(_WKHTMLTOIMAGE_WIDTH),
+ "--disable-smart-width",
+ "--javascript-delay",
+ str(_WKHTMLTOIMAGE_JS_DELAY_MS),
+ "--load-error-handling",
+ "ignore",
+ "--enable-local-file-access",
+ "--encoding",
+ "utf-8",
+ ]
+
+ if _WKHTMLTOIMAGE_UA:
+ cmd.extend(["--custom-header", "User-Agent", _WKHTMLTOIMAGE_UA, "--custom-header-propagation"])
+
+ if image_format in ("jpg", "jpeg"):
+ cmd.extend(["--quality", str(_WKHTMLTOIMAGE_QUALITY)])
+
+ if _WKHTMLTOIMAGE_HEIGHT > 0 and not _WKHTMLTOIMAGE_FULL_PAGE:
+ cmd.extend(["--height", str(_WKHTMLTOIMAGE_HEIGHT)])
+
+ if abs(_WKHTMLTOIMAGE_ZOOM - 1.0) > 1e-6:
+ cmd.extend(["--zoom", str(_WKHTMLTOIMAGE_ZOOM)])
+
+ if not _WKHTMLTOIMAGE_FULL_PAGE and (_WKHTMLTOIMAGE_CROP_WIDTH > 0 or _WKHTMLTOIMAGE_CROP_HEIGHT > 0):
+ cmd.extend(["--crop-x", str(_WKHTMLTOIMAGE_CROP_X), "--crop-y", str(_WKHTMLTOIMAGE_CROP_Y)])
+ if _WKHTMLTOIMAGE_CROP_WIDTH > 0:
+ cmd.extend(["--crop-w", str(_WKHTMLTOIMAGE_CROP_WIDTH)])
+ if _WKHTMLTOIMAGE_CROP_HEIGHT > 0:
+ cmd.extend(["--crop-h", str(_WKHTMLTOIMAGE_CROP_HEIGHT)])
+
+ if run_script:
+ cmd.extend(["--run-script", run_script])
+ if window_status:
+ cmd.extend(["--window-status", window_status])
+
+ if cookies_path:
+ cookie_pairs = _select_cookie_pairs(_read_cookie_pairs(cookies_path))
+ if cookie_pairs:
+ for name, value in cookie_pairs:
+ cmd.extend(["--cookie", name, value])
+ else:
+ cmd.extend(["--cookie-jar", cookies_path])
+
+ if proxy_server:
+ cmd.extend(["--proxy", proxy_server])
+
+ cmd.extend([url, output_path])
+
+ try:
+ result = subprocess.run(cmd, capture_output=True, text=True, timeout=_WKHTMLTOIMAGE_TIMEOUT_SECONDS)
+ if result.returncode != 0:
+ if log_callback:
+ err_msg = (result.stderr or result.stdout or "").strip()
+ log_callback(f"wkhtmltoimage 截图失败: {err_msg[:200]}")
+ return False
+ return True
+ except subprocess.TimeoutExpired:
+ if log_callback:
+ log_callback("wkhtmltoimage 截图超时")
+ return False
+ except Exception as e:
+ if log_callback:
+ log_callback(f"wkhtmltoimage 截图异常: {e}")
+ return False
+
def _emit(event: str, data: object, *, room: str | None = None) -> None:
try:
@@ -42,7 +202,7 @@ def take_screenshot_for_account(
task_start_time=None,
browse_result=None,
):
- """为账号任务完成后截图(使用工作线程池,真正的浏览器复用)"""
+ """为账号任务完成后截图(使用截图线程池并发执行)"""
account = safe_get_account(user_id, account_id)
if not account:
return
@@ -63,9 +223,11 @@ def take_screenshot_for_account(
_emit("account_update", acc.to_dict(), room=f"user_{user_id}")
max_retries = 3
+ proxy_config = account.proxy_config if hasattr(account, "proxy_config") else None
+ proxy_server = proxy_config.get("server") if proxy_config else None
+ cookie_path = get_cookie_jar_path(account.username)
for attempt in range(1, max_retries + 1):
- automation = None
try:
safe_update_task_status(
account_id,
@@ -75,100 +237,70 @@ def take_screenshot_for_account(
if attempt > 1:
log_to_client(f"🔄 第 {attempt} 次截图尝试...", user_id, account_id)
+ worker_id = browser_instance.get("worker_id", "?") if isinstance(browser_instance, dict) else "?"
+ use_count = browser_instance.get("use_count", 0) if isinstance(browser_instance, dict) else 0
log_to_client(
- f"使用Worker-{browser_instance['worker_id']}的浏览器(已使用{browser_instance['use_count']}次)",
+ f"使用Worker-{worker_id}执行截图(已执行{use_count}次)",
user_id,
account_id,
)
- proxy_config = account.proxy_config if hasattr(account, "proxy_config") else None
- automation = PlaywrightAutomation(get_browser_manager(), account_id, proxy_config=proxy_config)
- automation.playwright = browser_instance["playwright"]
- automation.browser = browser_instance["browser"]
-
def custom_log(message: str):
log_to_client(message, user_id, account_id)
- automation.log = custom_log
-
- log_to_client("登录中...", user_id, account_id)
- login_result = automation.quick_login(account.username, account.password, account.remember)
- if not login_result["success"]:
- error_message = login_result.get("message", "截图登录失败")
- log_to_client(f"截图登录失败: {error_message}", user_id, account_id)
- if attempt < max_retries:
- log_to_client("将重试...", user_id, account_id)
- time.sleep(2)
- continue
- log_to_client("❌ 截图失败: 登录失败", user_id, account_id)
- return {"success": False, "error": "登录失败"}
+ if not is_cookie_jar_fresh(cookie_path) or attempt > 1:
+ log_to_client("正在刷新登录态...", user_id, account_id)
+ if not _ensure_login_cookies(account, proxy_config, custom_log):
+ log_to_client("截图登录失败", user_id, account_id)
+ if attempt < max_retries:
+ log_to_client("将重试...", user_id, account_id)
+ time.sleep(2)
+ continue
+ log_to_client("❌ 截图失败: 登录失败", user_id, account_id)
+ return {"success": False, "error": "登录失败"}
log_to_client(f"导航到 '{browse_type}' 页面...", user_id, account_id)
- # 截图场景:优先用 bz 参数直达页面(更稳定,避免页面按钮点击失败导致截图跑偏)
- navigated = False
- try:
- from urllib.parse import urlsplit
+ from urllib.parse import urlsplit
- parsed = urlsplit(config.ZSGL_LOGIN_URL)
- base = f"{parsed.scheme}://{parsed.netloc}"
- if "注册前" in str(browse_type):
- bz = 0
- else:
- bz = 2 # 应读
- target_url = f"{base}/admin/center.aspx?bz={bz}"
- # 目标:保留外层框架(左侧菜单/顶部栏),仅在 mainframe 内部导航到目标内容页
- iframe = None
- try:
- iframe = automation.get_iframe_safe(retry=True, max_retries=5)
- except Exception:
- iframe = None
-
- if iframe:
- iframe.goto(target_url, timeout=60000)
- current_url = getattr(iframe, "url", "") or ""
- if "center.aspx" not in current_url:
- raise RuntimeError(f"unexpected_iframe_url:{current_url}")
- try:
- iframe.wait_for_load_state("networkidle", timeout=10000)
- except Exception:
- pass
- try:
- iframe.wait_for_selector("table.ltable", timeout=5000)
- except Exception:
- pass
- else:
- # 兜底:若获取不到 iframe,则退回到主页面直达
- automation.main_page.goto(target_url, timeout=60000)
- current_url = getattr(automation.main_page, "url", "") or ""
- if "center.aspx" not in current_url:
- raise RuntimeError(f"unexpected_url:{current_url}")
- try:
- automation.main_page.wait_for_load_state("networkidle", timeout=10000)
- except Exception:
- pass
- try:
- automation.main_page.wait_for_selector("table.ltable", timeout=5000)
- except Exception:
- pass
- navigated = True
- except Exception as nav_error:
- log_to_client(f"直达页面失败,将尝试按钮切换: {str(nav_error)[:120]}", user_id, account_id)
-
- # 兼容兜底:若直达失败,则回退到原有按钮切换方式
- if not navigated:
- result = automation.browse_content(
- navigate_only=True,
- browse_type=browse_type,
- auto_next_page=False,
- auto_view_attachments=False,
- interval=0,
- should_stop_callback=None,
- )
- if not result.success and result.error_message:
- log_to_client(f"导航警告: {result.error_message}", user_id, account_id)
-
- time.sleep(2)
+ parsed = urlsplit(config.ZSGL_LOGIN_URL)
+ base = f"{parsed.scheme}://{parsed.netloc}"
+ if "注册前" in str(browse_type):
+ bz = 0
+ else:
+ bz = 2 # 应读
+ target_url = f"{base}/admin/center.aspx?bz={bz}"
+ index_url = config.ZSGL_INDEX_URL or f"{base}/admin/index.aspx"
+ run_script = (
+ "(function(){"
+ "function done(){window.status='ready';}"
+ "function ensureNav(){try{if(typeof loadMenuTree==='function'){loadMenuTree(true);}}catch(e){}}"
+ "function expandMenu(){"
+ "try{var body=document.body;if(body&&body.classList.contains('lay-mini')){body.classList.remove('lay-mini');}}catch(e){}"
+ "try{if(typeof mainPageResize==='function'){mainPageResize();}}catch(e){}"
+ "try{if(typeof toggleMainMenu==='function' && document.body && document.body.classList.contains('lay-mini')){toggleMainMenu();}}catch(e){}"
+ "try{var navRight=document.querySelector('.nav-right');if(navRight){navRight.style.display='block';}}catch(e){}"
+ "try{var mainNav=document.getElementById('main-nav');if(mainNav){mainNav.style.display='block';}}catch(e){}"
+ "}"
+ "function navReady(){"
+ "try{var nav=document.getElementById('sidebar-nav');return nav && nav.querySelectorAll('a').length>0;}catch(e){return false;}"
+ "}"
+ "function frameReady(){"
+ "try{var f=document.getElementById('mainframe');return f && f.contentDocument && f.contentDocument.readyState==='complete';}catch(e){return false;}"
+ "}"
+ "function check(){"
+ "if(navReady() && frameReady()){done();return;}"
+ "setTimeout(check,300);"
+ "}"
+ "var f=document.getElementById('mainframe');"
+ "ensureNav();"
+ "expandMenu();"
+ "if(!f){done();return;}"
+ f"f.src='{target_url}';"
+ "f.onload=function(){ensureNav();expandMenu();setTimeout(check,300);};"
+ "setTimeout(check,5000);"
+ "})();"
+ )
timestamp = get_beijing_now().strftime("%Y%m%d_%H%M%S")
@@ -178,7 +310,22 @@ def take_screenshot_for_account(
screenshot_filename = f"{username_prefix}_{login_account}_{browse_type}_{timestamp}.jpg"
screenshot_path = os.path.join(SCREENSHOTS_DIR, screenshot_filename)
- if automation.take_screenshot(screenshot_path):
+ cookies_for_shot = cookie_path if is_cookie_jar_fresh(cookie_path) else None
+ if take_screenshot_wkhtmltoimage(
+ index_url,
+ screenshot_path,
+ cookies_path=cookies_for_shot,
+ proxy_server=proxy_server,
+ run_script=run_script,
+ window_status="ready",
+ log_callback=custom_log,
+ ) or take_screenshot_wkhtmltoimage(
+ target_url,
+ screenshot_path,
+ cookies_path=cookies_for_shot,
+ proxy_server=proxy_server,
+ log_callback=custom_log,
+ ):
if os.path.exists(screenshot_path) and os.path.getsize(screenshot_path) > 1000:
log_to_client(f"✓ 截图成功: {screenshot_filename}", user_id, account_id)
return {"success": True, "filename": screenshot_filename}
@@ -197,15 +344,6 @@ def take_screenshot_for_account(
if attempt < max_retries:
log_to_client("将重试...", user_id, account_id)
time.sleep(2)
- finally:
- if automation:
- try:
- if automation.context:
- automation.context.close()
- automation.context = None
- automation.page = None
- except Exception as e:
- logger.debug(f"关闭context时出错: {e}")
return {"success": False, "error": "截图失败,已重试3次"}
@@ -250,6 +388,35 @@ def take_screenshot_for_account(
account_name = account.remark if account.remark else account.username
+ try:
+ if screenshot_path and result and result.get("success"):
+ cfg = database.get_system_config() or {}
+ if int(cfg.get("kdocs_enabled", 0) or 0) == 1:
+ doc_url = (cfg.get("kdocs_doc_url") or "").strip()
+ if doc_url:
+ user_cfg = database.get_user_kdocs_settings(user_id) or {}
+ if int(user_cfg.get("kdocs_auto_upload", 0) or 0) == 1:
+ unit = (user_cfg.get("kdocs_unit") or cfg.get("kdocs_default_unit") or "").strip()
+ name = (account.remark or "").strip()
+ if unit and name:
+ from services.kdocs_uploader import get_kdocs_uploader
+ ok = get_kdocs_uploader().enqueue_upload(
+ user_id=user_id,
+ account_id=account_id,
+ unit=unit,
+ name=name,
+ image_path=screenshot_path,
+ )
+ if not ok:
+ log_to_client("表格上传排队失败: 队列已满", user_id, account_id)
+ else:
+ if not unit:
+ log_to_client("表格上传跳过: 未配置县区", user_id, account_id)
+ if not name:
+ log_to_client("表格上传跳过: 账号备注为空", user_id, account_id)
+ except Exception as kdocs_error:
+ logger.warning(f"表格上传任务提交失败: {kdocs_error}")
+
if batch_id:
_batch_task_record_result(
batch_id=batch_id,
diff --git a/services/tasks.py b/services/tasks.py
index 717e54a..aa5397f 100644
--- a/services/tasks.py
+++ b/services/tasks.py
@@ -573,8 +573,16 @@ def run_task(user_id, account_id, browse_type, enable_screenshot=True, source="m
with APIBrowser(log_callback=custom_log, proxy_config=proxy_config) as api_browser:
if api_browser.login(account.username, account.password):
- log_to_client("✓ 登录成功!", user_id, account_id)
- api_browser.save_cookies_for_playwright(account.username)
+ log_to_client("✓ 首次登录成功,刷新登录时间...", user_id, account_id)
+
+ # 二次登录:让"上次登录时间"变成刚才首次登录的时间
+ # 这样截图时显示的"上次登录时间"就是几秒前而不是昨天
+ if api_browser.login(account.username, account.password):
+ log_to_client("✓ 二次登录成功!", user_id, account_id)
+ else:
+ log_to_client("⚠ 二次登录失败,继续使用首次登录状态", user_id, account_id)
+
+ api_browser.save_cookies_for_screenshot(account.username)
database.reset_account_login_status(account_id)
if not account.remark:
diff --git a/static/admin/index.html b/static/admin/index.html
index 79703fd..30a0b13 100644
--- a/static/admin/index.html
+++ b/static/admin/index.html
@@ -5,8 +5,8 @@
后台管理 - 知识管理平台
-
-
+
+
diff --git a/static/app/index.html b/static/app/index.html
index df1f75d..7a9c54d 100644
--- a/static/app/index.html
+++ b/static/app/index.html
@@ -4,8 +4,8 @@
知识管理平台
-
-
+
+
diff --git a/templates/admin.html b/templates/admin.html
index e85f6ad..3566d20 100644
--- a/templates/admin.html
+++ b/templates/admin.html
@@ -5,13 +5,48 @@
后台管理 - 知识管理平台
{% for css_file in admin_spa_css_files %}
+ {% if admin_spa_build_id %}
+
+ {% else %}
+ {% endif %}
{% endfor %}
+
+ {% if admin_spa_build_id %}
+
+ {% else %}
+ {% endif %}