commit 0fd7137cea3601bc6e9b5057376f85d85b89c7a8 Author: Yu Yon Date: Sun Nov 16 19:03:07 2025 +0800 Initial commit: 知识管理平台 主要功能: - 多用户管理系统 - 浏览器自动化(Playwright) - 任务编排和执行 - Docker容器化部署 - 数据持久化和日志管理 技术栈: - Flask 3.0.0 - Playwright 1.40.0 - SQLite with connection pooling - Docker + Docker Compose 部署说明详见README.md diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..2cadcb5 --- /dev/null +++ b/.gitignore @@ -0,0 +1,44 @@ +# 浏览器二进制文件 +playwright/ +ms-playwright/ + +# 数据库文件(敏感数据) +data/*.db +data/*.db-shm +data/*.db-wal +data/secret_key.txt + +# 日志文件 +logs/ +*.log + +# 截图文件 +截图/ + +# Python缓存 +__pycache__/ +*.py[cod] +*.class +*.so +.Python +env/ +venv/ +ENV/ + +# Docker volumes +volumes/ + +# IDE +.vscode/ +.idea/ +*.swp +*.swo + +# 系统文件 +.DS_Store +Thumbs.db + +# 临时文件 +*.tmp +*.bak +*.backup diff --git a/Dockerfile b/Dockerfile new file mode 100644 index 0000000..70b0bd6 --- /dev/null +++ b/Dockerfile @@ -0,0 +1,46 @@ +# 使用国内镜像源加速 +FROM mcr.microsoft.com/playwright/python:v1.40.0-jammy + +# 设置工作目录 +WORKDIR /app + +# 设置环境变量 +ENV PYTHONUNBUFFERED=1 +ENV PLAYWRIGHT_BROWSERS_PATH=/ms-playwright +ENV TZ=Asia/Shanghai + +# 配置 pip 使用国内镜像源 +RUN pip config set global.index-url https://mirrors.aliyun.com/pypi/simple/ && pip config set install.trusted-host mirrors.aliyun.com + +# 复制依赖文件 +COPY requirements.txt . + +# 安装Python依赖 +RUN pip install --no-cache-dir -r requirements.txt + +# 复制应用程序文件 +COPY app.py . +COPY database.py . +COPY db_pool.py . +COPY playwright_automation.py . +COPY browser_installer.py . +COPY password_utils.py . + +# 复制新的优化模块 +COPY app_config.py . +COPY app_logger.py . +COPY app_security.py . +COPY app_state.py . +COPY app_utils.py . + +COPY templates/ ./templates/ +COPY static/ ./static/ + +# 创建必要的目录 +RUN mkdir -p data logs 截图 + +# 暴露端口 +EXPOSE 5000 + +# 启动命令 +CMD ["python", "app.py"] diff --git a/README.md b/README.md new file mode 100644 index 0000000..e346e94 --- /dev/null +++ b/README.md @@ -0,0 +1,695 @@ +# 知识管理平台自动化工具 - Docker部署版 + +这是一个基于 Docker 的知识管理平台自动化工具,支持多用户、定时任务、代理IP、VIP管理等功能。 + +--- + +## 项目简介 + +本项目是一个 **Docker 容器化应用**,使用 Flask + Playwright + SQLite 构建,提供: + +- 多用户注册登录系统 +- 浏览器自动化任务 +- 定时任务调度 +- 截图管理 +- VIP用户管理 +- 代理IP支持 +- 后台管理系统 + +--- + +## 技术栈 + +- **后端**: Python 3.8+, Flask +- **数据库**: SQLite +- **自动化**: Playwright (Chromium) +- **容器化**: Docker + Docker Compose +- **前端**: HTML + JavaScript + Socket.IO + +--- + +## 项目结构 + +``` +zsgpt2/ +├── app.py # 主应用程序 +├── database.py # 数据库模块 +├── playwright_automation.py # 浏览器自动化 +├── browser_installer.py # 浏览器安装检查 +├── app_config.py # 配置管理 +├── app_logger.py # 日志系统 +├── app_security.py # 安全模块 +├── app_state.py # 状态管理 +├── app_utils.py # 工具函数 +├── db_pool.py # 数据库连接池 +├── password_utils.py # 密码工具 +├── requirements.txt # Python依赖 +├── Dockerfile # Docker镜像构建文件 +├── docker-compose.yml # Docker编排文件 +├── templates/ # HTML模板 +│ ├── index.html # 主页面 +│ ├── login.html # 登录页 +│ ├── register.html # 注册页 +│ ├── admin.html # 后台管理 +│ └── ... +└── static/ # 静态资源 + └── js/ # JavaScript文件 +``` + +--- + +## 部署前准备 + +### 1. 环境要求 + +- **服务器**: Linux (Ubuntu 20.04+ / CentOS 7+) +- **Docker**: 20.10+ +- **Docker Compose**: 1.29+ +- **内存**: 4GB+ (推荐8GB) +- **磁盘**: 20GB+ + +### 2. SSH连接 + +**注意**: 本文档假设你已经有服务器的SSH访问权限。 + +你需要准备: +- 服务器IP地址 +- SSH用户名和密码(或SSH密钥) +- SSH端口(默认22) + +**SSH连接示例**: +```bash +ssh root@your-server-ip +# 或使用密钥 +ssh -i /path/to/key root@your-server-ip +``` + +--- + +## 快速部署 + +### 步骤1: 上传项目文件 + +将整个 `zsgpt2` 文件夹上传到服务器的 `/www/wwwroot/` 目录: + +```bash +# 在本地执行(Windows PowerShell 或 Git Bash) +scp -r C:\Users\Administrator\Desktop\zsgpt2 root@your-server-ip:/www/wwwroot/ + +# 或者使用 FileZilla、WinSCP 等工具上传 +``` + +上传后,服务器上的路径应该是:`/www/wwwroot/zsgpt2/` + +### 步骤2: SSH登录服务器 + +```bash +ssh root@your-server-ip +``` + +### 步骤3: 进入项目目录 + +```bash +cd /www/wwwroot/zsgpt2 +``` + +### 步骤4: 创建必要的目录 + +```bash +mkdir -p data logs 截图 playwright +chmod 777 data logs 截图 playwright +``` + +### 步骤5: 构建并启动Docker容器 + +```bash +# 构建镜像 +docker build -t knowledge-automation . + +# 启动容器 +docker-compose up -d + +# 查看容器状态 +docker ps | grep knowledge-automation +``` + +### 步骤6: 检查容器日志 + +```bash +docker logs -f knowledge-automation-multiuser +``` + +如果看到以下信息,说明启动成功: +``` +服务器启动中... +用户访问地址: http://0.0.0.0:5000 +后台管理地址: http://0.0.0.0:5000/yuyx +``` + +--- + +## 配置Nginx反向代理(可选但推荐) + +如果你想通过域名访问,需要配置Nginx反向代理。 + +### 1. 安装Nginx + +```bash +# Ubuntu/Debian +apt update && apt install nginx -y + +# CentOS/RHEL +yum install nginx -y +``` + +### 2. 创建Nginx配置文件 + +创建文件 `/etc/nginx/conf.d/zsgpt.conf`: + +```nginx +server { + listen 80; + server_name your-domain.com; # 替换为你的域名 + + # 日志 + access_log /var/log/nginx/zsgpt_access.log; + error_log /var/log/nginx/zsgpt_error.log; + + # 反向代理 + location / { + proxy_pass http://127.0.0.1:5001; + proxy_set_header Host $host; + proxy_set_header X-Real-IP $remote_addr; + proxy_set_header X-Forwarded-For $proxy_add_x_forwarded_for; + proxy_set_header X-Forwarded-Proto $scheme; + client_max_body_size 50M; + + # WebSocket支持 + proxy_http_version 1.1; + proxy_set_header Upgrade $http_upgrade; + proxy_set_header Connection "upgrade"; + proxy_read_timeout 86400; + } +} +``` + +### 3. 重启Nginx + +```bash +nginx -t # 测试配置 +nginx -s reload # 重新加载配置 +``` + +### 4. 配置SSL(推荐) + +```bash +# 安装certbot +apt install certbot python3-certbot-nginx -y + +# 申请证书 +certbot --nginx -d your-domain.com + +# 自动续期 +certbot renew --dry-run +``` + +--- + +## 访问系统 + +### 用户端 + +- **HTTP**: `http://your-server-ip:5001` +- **域名**: `http://your-domain.com` (配置Nginx后) +- **HTTPS**: `https://your-domain.com` (配置SSL后) + +### 后台管理 + +- **路径**: `/yuyx` +- **默认账号**: `admin` +- **默认密码**: `admin` + +**首次登录后请立即修改密码!** + +--- + +## 系统配置 + +登录后台后,在"系统配置"页面可以设置: + +### 1. 并发控制 +- **全局最大并发**: 2 (根据服务器配置调整) +- **单用户并发**: 1 + +### 2. 定时任务 +- **启用定时浏览**: 是/否 +- **执行时间**: 02:00 (CST时间) +- **浏览类型**: 应读/注册前未读/未读 +- **执行日期**: 周一到周日 + +### 3. 代理配置 +- **启用代理**: 是/否 +- **API地址**: http://your-proxy-api.com +- **IP有效期**: 3分钟 + +--- + +## Docker常用命令 + +### 容器管理 + +```bash +# 启动容器 +docker start knowledge-automation-multiuser + +# 停止容器 +docker stop knowledge-automation-multiuser + +# 重启容器 +docker restart knowledge-automation-multiuser + +# 删除容器 +docker rm -f knowledge-automation-multiuser + +# 查看容器状态 +docker ps -a | grep knowledge-automation +``` + +### 日志查看 + +```bash +# 查看实时日志 +docker logs -f knowledge-automation-multiuser + +# 查看最近100行日志 +docker logs --tail 100 knowledge-automation-multiuser + +# 查看应用日志文件 +tail -f /www/wwwroot/zsgpt2/logs/app.log +``` + +### 进入容器 + +```bash +# 进入容器Shell +docker exec -it knowledge-automation-multiuser bash + +# 在容器内执行命令 +docker exec knowledge-automation-multiuser python -c "print('Hello')" +``` + +### 重新构建 + +如果修改了代码,需要重新构建: + +```bash +cd /www/wwwroot/zsgpt2 + +# 停止并删除旧容器 +docker-compose down + +# 重新构建并启动 +docker-compose build +docker-compose up -d +``` + +--- + +## 数据备份与恢复 + +### 1. 备份数据 + +```bash +cd /www/wwwroot + +# 备份整个项目 +tar -czf zsgpt2_backup_$(date +%Y%m%d).tar.gz zsgpt2/ + +# 仅备份数据库 +cp /www/wwwroot/zsgpt2/data/app_data.db /backup/app_data_$(date +%Y%m%d).db + +# 备份截图 +tar -czf screenshots_$(date +%Y%m%d).tar.gz /www/wwwroot/zsgpt2/截图/ +``` + +### 2. 恢复数据 + +```bash +# 停止容器 +docker stop knowledge-automation-multiuser + +# 恢复整个项目 +cd /www/wwwroot +tar -xzf zsgpt2_backup_20251027.tar.gz + +# 恢复数据库 +cp /backup/app_data_20251027.db /www/wwwroot/zsgpt2/data/app_data.db + +# 重启容器 +docker start knowledge-automation-multiuser +``` + +### 3. 定时备份 + +添加cron任务自动备份: + +```bash +crontab -e +``` + +添加以下内容: + +```bash +# 每天凌晨3点备份 +0 3 * * * tar -czf /backup/zsgpt2_$(date +\%Y\%m\%d).tar.gz /www/wwwroot/zsgpt2/data +``` + +--- + +## 常见问题 + +### 1. 容器启动失败 + +**问题**: `docker-compose up -d` 失败 + +**解决方案**: +```bash +# 查看详细错误 +docker-compose logs + +# 检查端口占用 +netstat -tlnp | grep 5001 + +# 重新构建 +docker-compose build --no-cache +docker-compose up -d +``` + +### 2. 502 Bad Gateway + +**问题**: Nginx返回502错误 + +**解决方案**: +```bash +# 检查容器是否运行 +docker ps | grep knowledge-automation + +# 检查端口是否监听 +netstat -tlnp | grep 5001 + +# 测试直接访问 +curl http://127.0.0.1:5001 + +# 检查Nginx配置 +nginx -t +``` + +### 3. 数据库锁定 + +**问题**: "database is locked" + +**解决方案**: +```bash +# 重启容器 +docker restart knowledge-automation-multiuser + +# 如果问题持续,优化数据库 +cd /www/wwwroot/zsgpt2 +cp data/app_data.db data/app_data.db.backup +sqlite3 data/app_data.db "VACUUM;" +``` + +### 4. 内存不足 + +**问题**: 容器OOM (Out of Memory) + +**解决方案**: + +修改 `docker-compose.yml`: + +```yaml +services: + knowledge-automation: + mem_limit: 2g + memswap_limit: 2g +``` + +然后重启: + +```bash +docker-compose down +docker-compose up -d +``` + +### 5. 浏览器下载失败 + +**问题**: Playwright浏览器下载失败 + +**解决方案**: +```bash +# 进入容器手动安装 +docker exec -it knowledge-automation-multiuser bash +playwright install chromium + +# 或使用国内镜像 +export PLAYWRIGHT_DOWNLOAD_HOST=https://npmmirror.com/mirrors/playwright/ +playwright install chromium +``` + +--- + +## 性能优化 + +### 1. 调整并发参数 + +根据服务器配置调整: +- **2核4GB**: 全局并发=1, 单用户并发=1 +- **4核8GB**: 全局并发=2, 单用户并发=1 +- **8核16GB**: 全局并发=4, 单用户并发=2 + +### 2. 启用代理IP + +避免IP被封,提高成功率: +- 选择稳定的代理服务商 +- 设置合适的IP有效期(3-5分钟) +- 启用自动重试机制 + +### 3. 定期清理数据 + +系统会自动清理7天前的数据,也可以手动清理: + +```bash +# 清理7天前的截图 +find /www/wwwroot/zsgpt2/截图 -name "*.jpg" -mtime +7 -delete + +# 清理旧日志 +find /www/wwwroot/zsgpt2/logs -name "*.log" -mtime +30 -delete + +# 优化数据库 +sqlite3 /www/wwwroot/zsgpt2/data/app_data.db "VACUUM;" +``` + +--- + +## 安全建议 + +### 1. 修改默认密码 + +首次登录后立即修改: +- 管理员密码 +- 用户密码 + +### 2. 配置防火墙 + +```bash +# 只开放必要端口 +firewall-cmd --permanent --add-port=80/tcp +firewall-cmd --permanent --add-port=443/tcp +firewall-cmd --reload + +# 禁止直接访问5001端口(仅Nginx可访问) +iptables -A INPUT -p tcp --dport 5001 -s 127.0.0.1 -j ACCEPT +iptables -A INPUT -p tcp --dport 5001 -j DROP +``` + +### 3. 启用HTTPS + +强烈建议使用HTTPS加密传输: + +```bash +certbot --nginx -d your-domain.com +``` + +### 4. 限制SSH访问 + +```bash +# 修改SSH端口(可选) +vi /etc/ssh/sshd_config +# Port 22222 + +# 禁止root密码登录(使用密钥) +PermitRootLogin prohibit-password +PasswordAuthentication no + +# 重启SSH服务 +systemctl restart sshd +``` + +--- + +## 监控与维护 + +### 1. 系统监控 + +推荐使用以下工具: +- **Docker Stats**: `docker stats knowledge-automation-multiuser` +- **Grafana + Prometheus**: 可视化监控 +- **Uptime Kuma**: 服务可用性监控 + +### 2. 日志分析 + +```bash +# 统计今日任务数 +grep "浏览完成" /www/wwwroot/zsgpt2/logs/app.log | grep $(date +%Y-%m-%d) | wc -l + +# 查看错误日志 +grep "ERROR" /www/wwwroot/zsgpt2/logs/app.log | tail -20 + +# 查看最近的登录 +grep "登录成功" /www/wwwroot/zsgpt2/logs/app.log | tail -10 +``` + +### 3. 数据库维护 + +```bash +# 定期优化数据库(每月一次) +docker exec knowledge-automation-multiuser python3 << 'EOF' +import sqlite3 +conn = sqlite3.connect('/app/data/app_data.db') +conn.execute('VACUUM') +conn.close() +print("数据库优化完成") +EOF +``` + +--- + +## 更新升级 + +### 1. 更新代码 + +```bash +# 停止容器 +docker-compose down + +# 备份数据 +cp -r data data.backup +cp -r 截图 截图.backup + +# 上传新代码(覆盖旧文件) +# 使用 scp 或 FTP 工具上传 + +# 重新构建并启动 +docker-compose build +docker-compose up -d +``` + +### 2. 数据库迁移 + +如果数据库结构有变化,应用会自动迁移。 + +查看迁移日志: + +```bash +docker logs knowledge-automation-multiuser | grep "数据库" +``` + +--- + +## 端口说明 + +| 端口 | 说明 | 映射 | +|------|------|------| +| 5000 | 容器内应用端口 | - | +| 5001 | 主机映射端口 | 容器5000 → 主机5001 | +| 80 | HTTP端口 | Nginx | +| 443 | HTTPS端口 | Nginx | + +--- + +## 环境变量 + +可以在 `docker-compose.yml` 中设置的环境变量: + +| 变量名 | 说明 | 默认值 | +|--------|------|--------| +| TZ | 时区 | Asia/Shanghai | +| PYTHONUNBUFFERED | Python输出缓冲 | 1 | +| PLAYWRIGHT_BROWSERS_PATH | 浏览器路径 | /ms-playwright | + +--- + +## 技术支持 + +### 项目信息 + +- **项目名称**: 知识管理平台自动化工具 +- **版本**: Docker 多用户版 +- **技术栈**: Python + Flask + Playwright + SQLite + Docker + +### 常用文档链接 + +- [Docker 官方文档](https://docs.docker.com/) +- [Flask 官方文档](https://flask.palletsprojects.com/) +- [Playwright 官方文档](https://playwright.dev/python/) + +### 故障排查 + +遇到问题时,请按以下顺序检查: + +1. **容器日志**: `docker logs knowledge-automation-multiuser` +2. **应用日志**: `cat /www/wwwroot/zsgpt2/logs/app.log` +3. **Nginx日志**: `cat /var/log/nginx/zsgpt_error.log` +4. **系统资源**: `docker stats`, `htop`, `df -h` + +--- + +## 许可证 + +本项目仅供学习和研究使用。 + +--- + +**文档版本**: v1.0 +**更新日期**: 2025-10-29 +**适用版本**: Docker多用户版 + +--- + +## 快速上手命令清单 + +```bash +# 1. 上传文件 +scp -r zsgpt2 root@your-ip:/www/wwwroot/ + +# 2. SSH登录 +ssh root@your-ip + +# 3. 进入目录并创建必要目录 +cd /www/wwwroot/zsgpt2 +mkdir -p data logs 截图 playwright +chmod 777 data logs 截图 playwright + +# 4. 启动容器 +docker-compose up -d + +# 5. 查看日志 +docker logs -f knowledge-automation-multiuser + +# 6. 访问系统 +# 浏览器打开: http://your-ip:5001 +# 后台管理: http://your-ip:5001/yuyx +# 默认账号: admin / admin +``` + +完成!🎉 diff --git a/app.py b/app.py new file mode 100755 index 0000000..394d7c7 --- /dev/null +++ b/app.py @@ -0,0 +1,2223 @@ +#!/usr/bin/env python3 +# -*- coding: utf-8 -*- +""" +知识管理平台自动化工具 - 多用户版本 +支持用户注册登录、后台管理、数据隔离 +""" + +# 设置时区为中国标准时间(CST, UTC+8) +import os +os.environ['TZ'] = 'Asia/Shanghai' +try: + import time + time.tzset() +except AttributeError: + pass # Windows系统不支持tzset() + +import pytz +from datetime import datetime +from flask import Flask, render_template, request, jsonify, send_from_directory, redirect, url_for, session +from flask_socketio import SocketIO, emit, join_room, leave_room +from flask_login import LoginManager, UserMixin, login_user, logout_user, login_required, current_user +import threading +import time +import json +import os +from datetime import datetime, timedelta, timezone +from functools import wraps + +# 导入数据库模块和核心模块 +import database +import requests +from playwright_automation import PlaywrightBrowserManager, PlaywrightAutomation, BrowseResult +from browser_installer import check_and_install_browser +# ========== 优化模块导入 ========== +from app_config import get_config +from app_logger import init_logging, get_logger, audit_logger +from app_security import ( + ip_rate_limiter, require_ip_not_locked, + validate_username, validate_password, validate_email, + is_safe_path, sanitize_filename, get_client_ip +) + + + +# ========== 初始化配置 ========== +config = get_config() +app = Flask(__name__) +# SECRET_KEY持久化,避免重启后所有用户登出 +SECRET_KEY_FILE = 'data/secret_key.txt' +if os.path.exists(SECRET_KEY_FILE): + with open(SECRET_KEY_FILE, 'r') as f: + SECRET_KEY = f.read().strip() +else: + SECRET_KEY = os.urandom(24).hex() + os.makedirs('data', exist_ok=True) + with open(SECRET_KEY_FILE, 'w') as f: + f.write(SECRET_KEY) + print(f"✓ 已生成新的SECRET_KEY并保存") +app.config.from_object(config) +socketio = SocketIO(app, cors_allowed_origins="*") + +# ========== 初始化日志系统 ========== +init_logging(log_level=config.LOG_LEVEL, log_file=config.LOG_FILE) +logger = get_logger('app') +logger.info("="*60) +logger.info("知识管理平台自动化工具 - 多用户版") +logger.info("="*60) + + +# Flask-Login 配置 +login_manager = LoginManager() +login_manager.init_app(app) +login_manager.login_view = 'login_page' + +# 截图目录 +SCREENSHOTS_DIR = config.SCREENSHOTS_DIR +os.makedirs(SCREENSHOTS_DIR, exist_ok=True) + +# 全局变量 +browser_manager = None +user_accounts = {} # {user_id: {account_id: Account对象}} +active_tasks = {} # {account_id: Thread对象} +log_cache = {} # {user_id: [logs]} 每个用户独立的日志缓存 +log_cache_total_count = 0 # 全局日志总数,防止无限增长 + +# 日志缓存限制 +MAX_LOGS_PER_USER = config.MAX_LOGS_PER_USER # 每个用户最多100条 +MAX_TOTAL_LOGS = config.MAX_TOTAL_LOGS # 全局最多1000条,防止内存泄漏 + +# 并发控制:每个用户同时最多运行1个账号(避免内存不足) +# 验证码存储:{session_id: {"code": "1234", "expire_time": timestamp, "failed_attempts": 0}} +captcha_storage = {} + +# IP限流存储:{ip: {"attempts": count, "lock_until": timestamp, "first_attempt": timestamp}} +ip_rate_limit = {} + +# 限流配置 +MAX_CAPTCHA_ATTEMPTS = 5 # 每个验证码最多尝试次数 +MAX_IP_ATTEMPTS_PER_HOUR = 10 # 每小时每个IP最多验证码错误次数 +IP_LOCK_DURATION = 3600 # IP锁定时长(秒) - 1小时 +# 全局限制:整个系统同时最多运行2个账号(线程本地架构,每个线程独立浏览器,内存占用约200MB/浏览器) +max_concurrent_per_account = 1 # 每个用户最多1个 +max_concurrent_global = 2 # 全局最多2个(线程本地架构内存需求更高) +user_semaphores = {} # {user_id: Semaphore} +global_semaphore = threading.Semaphore(max_concurrent_global) + +# 截图专用信号量:限制同时进行的截图任务数量为1(避免资源竞争) +screenshot_semaphore = threading.Semaphore(1) + + +class User(UserMixin): + """Flask-Login 用户类""" + def __init__(self, user_id): + self.id = user_id + + +class Admin(UserMixin): + """管理员类""" + def __init__(self, admin_id): + self.id = admin_id + self.is_admin = True + + +class Account: + """账号类""" + def __init__(self, account_id, user_id, username, password, remember=True, remark=''): + self.id = account_id + self.user_id = user_id + self.username = username + self.password = password + self.remember = remember + self.remark = remark + self.status = "未开始" + self.is_running = False + self.should_stop = False + self.total_items = 0 + self.total_attachments = 0 + self.automation = None + self.last_browse_type = "注册前未读" + self.proxy_config = None # 保存代理配置,浏览和截图共用 + + def to_dict(self): + return { + "id": self.id, + "username": self.username, + "status": self.status, + "remark": self.remark, + "total_items": self.total_items, + "total_attachments": self.total_attachments, + "is_running": self.is_running + } + + +@login_manager.user_loader +def load_user(user_id): + """Flask-Login 用户加载""" + user = database.get_user_by_id(int(user_id)) + if user: + return User(user['id']) + return None + + +def admin_required(f): + """管理员权限装饰器""" + @wraps(f) + def decorated_function(*args, **kwargs): + if 'admin_id' not in session: + return jsonify({"error": "需要管理员权限"}), 403 + return f(*args, **kwargs) + return decorated_function + + +def log_to_client(message, user_id=None, account_id=None): + """发送日志到Web客户端(用户隔离)""" + beijing_tz = timezone(timedelta(hours=8)) + timestamp = datetime.now(beijing_tz).strftime('%H:%M:%S') + log_data = { + 'timestamp': timestamp, + 'message': message, + 'account_id': account_id + } + + # 如果指定了user_id,则缓存到该用户的日志 + if user_id: + global log_cache_total_count + if user_id not in log_cache: + log_cache[user_id] = [] + log_cache[user_id].append(log_data) + log_cache_total_count += 1 + + # 持久化到数据库 (已禁用,使用task_logs表代替) + # try: + # database.save_operation_log(user_id, message, account_id, 'INFO') + # except Exception as e: + # print(f"保存日志到数据库失败: {e}") + + # 单用户限制 + if len(log_cache[user_id]) > MAX_LOGS_PER_USER: + log_cache[user_id].pop(0) + log_cache_total_count -= 1 + + # 全局限制 - 如果超过总数限制,清理日志最多的用户 + while log_cache_total_count > MAX_TOTAL_LOGS: + if log_cache: + max_user = max(log_cache.keys(), key=lambda u: len(log_cache[u])) + if log_cache[max_user]: + log_cache[max_user].pop(0) + log_cache_total_count -= 1 + else: + break + else: + break + + # 发送到该用户的room + socketio.emit('log', log_data, room=f'user_{user_id}') + + print(f"[{timestamp}] User:{user_id} {message}") + + + +def get_proxy_from_api(api_url, max_retries=3): + """从API获取代理IP(支持重试) + + Args: + api_url: 代理API地址 + max_retries: 最大重试次数 + + Returns: + 代理服务器地址(格式: http://IP:PORT)或 None + """ + for attempt in range(max_retries): + try: + response = requests.get(api_url, timeout=10) + if response.status_code == 200: + ip_port = response.text.strip() + if ip_port and ':' in ip_port: + proxy_server = f"http://{ip_port}" + print(f"✓ 获取代理成功: {proxy_server} (尝试 {attempt + 1}/{max_retries})") + return proxy_server + else: + print(f"✗ 代理格式错误: {ip_port} (尝试 {attempt + 1}/{max_retries})") + else: + print(f"✗ 获取代理失败: HTTP {response.status_code} (尝试 {attempt + 1}/{max_retries})") + except Exception as e: + print(f"✗ 获取代理异常: {str(e)} (尝试 {attempt + 1}/{max_retries})") + + if attempt < max_retries - 1: + time.sleep(1) + + print(f"✗ 获取代理失败,已重试 {max_retries} 次,将不使用代理继续") + return None + +def init_browser_manager(): + """初始化浏览器管理器""" + global browser_manager + if browser_manager is None: + print("正在初始化Playwright浏览器管理器...") + + if not check_and_install_browser(log_callback=lambda msg, account_id=None: print(msg)): + print("浏览器环境检查失败!") + return False + + browser_manager = PlaywrightBrowserManager( + headless=True, + log_callback=lambda msg, account_id=None: print(msg) + ) + + try: + # 不再需要initialize(),每个账号会创建独立浏览器 + print("Playwright浏览器管理器创建成功!") + return True + except Exception as e: + print(f"Playwright初始化失败: {str(e)}") + return False + return True + + +# ==================== 前端路由 ==================== + +@app.route('/') +def index(): + """主页 - 重定向到登录或应用""" + if current_user.is_authenticated: + return redirect(url_for('app_page')) + return redirect(url_for('login_page')) + + +@app.route('/login') +def login_page(): + """登录页面""" + return render_template('login.html') + + +@app.route('/register') +def register_page(): + """注册页面""" + return render_template('register.html') + + +@app.route('/app') +@login_required +def app_page(): + """主应用页面""" + return render_template('index.html') + + +@app.route('/yuyx') +def admin_login_page(): + """后台登录页面""" + if 'admin_id' in session: + return redirect(url_for('admin_page')) + return render_template('admin_login.html') + + +@app.route('/yuyx/admin') +@admin_required +def admin_page(): + """后台管理页面""" + return render_template('admin.html') + + + + +@app.route('/yuyx/vip') +@admin_required +def vip_admin_page(): + """VIP管理页面""" + return render_template('vip_admin.html') + + +# ==================== 用户认证API ==================== + +@app.route('/api/register', methods=['POST']) +@require_ip_not_locked # IP限流保护 +def register(): + """用户注册""" + data = request.json + username = data.get('username', '').strip() + password = data.get('password', '').strip() + email = data.get('email', '').strip() + captcha_session = data.get('captcha_session', '') + captcha_code = data.get('captcha', '').strip() + + if not username or not password: + return jsonify({"error": "用户名和密码不能为空"}), 400 + + # 验证验证码 + if not captcha_session or captcha_session not in captcha_storage: + return jsonify({"error": "验证码已过期,请重新获取"}), 400 + + captcha_data = captcha_storage[captcha_session] + if captcha_data["expire_time"] < time.time(): + del captcha_storage[captcha_session] + return jsonify({"error": "验证码已过期,请重新获取"}), 400 + + # 获取客户端IP + client_ip = request.headers.get('X-Forwarded-For', request.headers.get('X-Real-IP', request.remote_addr)) + if client_ip and ',' in client_ip: + client_ip = client_ip.split(',')[0].strip() + + # 检查IP限流 + allowed, error_msg = check_ip_rate_limit(client_ip) + if not allowed: + return jsonify({"error": error_msg}), 429 + + # 检查验证码尝试次数 + if captcha_data.get("failed_attempts", 0) >= MAX_CAPTCHA_ATTEMPTS: + del captcha_storage[captcha_session] + return jsonify({"error": "验证码尝试次数过多,请重新获取"}), 400 + + if captcha_data["code"] != captcha_code: + # 记录失败次数 + captcha_data["failed_attempts"] = captcha_data.get("failed_attempts", 0) + 1 + + # 记录IP失败尝试 + is_locked = record_failed_captcha(client_ip) + if is_locked: + return jsonify({"error": "验证码错误次数过多,IP已被锁定1小时"}), 429 + + return jsonify({"error": "验证码错误(剩余{}次机会)".format( + MAX_CAPTCHA_ATTEMPTS - captcha_data["failed_attempts"])}), 400 + + # 验证成功,删除已使用的验证码 + del captcha_storage[captcha_session] + + user_id = database.create_user(username, password, email) + if user_id: + return jsonify({"success": True, "message": "注册成功,请等待管理员审核"}) + else: + return jsonify({"error": "用户名已存在"}), 400 + + +# ==================== 验证码API ==================== +import random + + +def check_ip_rate_limit(ip_address): + """检查IP是否被限流""" + current_time = time.time() + + # 清理过期的IP记录 + expired_ips = [ip for ip, data in ip_rate_limit.items() + if data.get("lock_until", 0) < current_time and + current_time - data.get("first_attempt", current_time) > 3600] + for ip in expired_ips: + del ip_rate_limit[ip] + + # 检查IP是否被锁定 + if ip_address in ip_rate_limit: + ip_data = ip_rate_limit[ip_address] + + # 如果IP被锁定且未到解锁时间 + if ip_data.get("lock_until", 0) > current_time: + remaining_time = int(ip_data["lock_until"] - current_time) + return False, "IP已被锁定,请{}分钟后再试".format(remaining_time // 60 + 1) + + # 如果超过1小时,重置计数 + if current_time - ip_data.get("first_attempt", current_time) > 3600: + ip_rate_limit[ip_address] = { + "attempts": 0, + "first_attempt": current_time + } + + return True, None + + +def record_failed_captcha(ip_address): + """记录验证码失败尝试""" + current_time = time.time() + + if ip_address not in ip_rate_limit: + ip_rate_limit[ip_address] = { + "attempts": 1, + "first_attempt": current_time + } + else: + ip_rate_limit[ip_address]["attempts"] += 1 + + # 检查是否超过限制 + if ip_rate_limit[ip_address]["attempts"] >= MAX_IP_ATTEMPTS_PER_HOUR: + ip_rate_limit[ip_address]["lock_until"] = current_time + IP_LOCK_DURATION + return True # 表示IP已被锁定 + + return False # 表示还未锁定 + + +@app.route("/api/generate_captcha", methods=["POST"]) +def generate_captcha(): + """生成4位数字验证码""" + import uuid + session_id = str(uuid.uuid4()) + + # 生成4位随机数字 + code = "".join([str(random.randint(0, 9)) for _ in range(4)]) + + # 存储验证码,5分钟过期 + captcha_storage[session_id] = { + "code": code, + "expire_time": time.time() + 300, + "failed_attempts": 0 + } + + # 清理过期验证码 + expired_keys = [k for k, v in captcha_storage.items() if v["expire_time"] < time.time()] + for k in expired_keys: + del captcha_storage[k] + + return jsonify({"session_id": session_id, "captcha": code}) + + +@app.route('/api/login', methods=['POST']) +@require_ip_not_locked # IP限流保护 +def login(): + """用户登录""" + data = request.json + username = data.get('username', '').strip() + password = data.get('password', '').strip() + captcha_session = data.get('captcha_session', '') + captcha_code = data.get('captcha', '').strip() + need_captcha = data.get('need_captcha', False) + + # 如果需要验证码,验证验证码 + if need_captcha: + if not captcha_session or captcha_session not in captcha_storage: + return jsonify({"error": "验证码已过期,请重新获取"}), 400 + + captcha_data = captcha_storage[captcha_session] + if captcha_data["expire_time"] < time.time(): + del captcha_storage[captcha_session] + return jsonify({"error": "验证码已过期,请重新获取"}), 400 + + if captcha_data["code"] != captcha_code: + return jsonify({"error": "验证码错误"}), 400 + + # 验证成功,删除已使用的验证码 + del captcha_storage[captcha_session] + + # 先检查用户是否存在 + user_exists = database.get_user_by_username(username) + if not user_exists: + return jsonify({"error": "账号未注册", "need_captcha": True}), 401 + + # 检查密码是否正确 + user = database.verify_user(username, password) + if not user: + # 密码错误 + return jsonify({"error": "密码错误", "need_captcha": True}), 401 + + # 检查审核状态 + if user['status'] != 'approved': + return jsonify({"error": "账号未审核,请等待管理员审核", "need_captcha": False}), 401 + + # 登录成功 + user_obj = User(user['id']) + login_user(user_obj) + load_user_accounts(user['id']) + return jsonify({"success": True}) + + +@app.route('/api/logout', methods=['POST']) +@login_required +def logout(): + """用户登出""" + logout_user() + return jsonify({"success": True}) + + +# ==================== 管理员认证API ==================== + +@app.route('/yuyx/api/login', methods=['POST']) +@require_ip_not_locked # IP限流保护 +def admin_login(): + """管理员登录""" + data = request.json + username = data.get('username', '').strip() + password = data.get('password', '').strip() + captcha_session = data.get('captcha_session', '') + captcha_code = data.get('captcha', '').strip() + need_captcha = data.get('need_captcha', False) + + # 如果需要验证码,验证验证码 + if need_captcha: + if not captcha_session or captcha_session not in captcha_storage: + return jsonify({"error": "验证码已过期,请重新获取"}), 400 + + captcha_data = captcha_storage[captcha_session] + if captcha_data["expire_time"] < time.time(): + del captcha_storage[captcha_session] + return jsonify({"error": "验证码已过期,请重新获取"}), 400 + + if captcha_data["code"] != captcha_code: + return jsonify({"error": "验证码错误"}), 400 + + # 验证成功,删除已使用的验证码 + del captcha_storage[captcha_session] + + admin = database.verify_admin(username, password) + if admin: + session['admin_id'] = admin['id'] + session['admin_username'] = admin['username'] + return jsonify({"success": True}) + else: + return jsonify({"error": "管理员用户名或密码错误", "need_captcha": True}), 401 + + +@app.route('/yuyx/api/logout', methods=['POST']) +@admin_required +def admin_logout(): + """管理员登出""" + session.pop('admin_id', None) + session.pop('admin_username', None) + return jsonify({"success": True}) + + +@app.route('/yuyx/api/users', methods=['GET']) +@admin_required +def get_all_users(): + """获取所有用户""" + users = database.get_all_users() + return jsonify(users) + + +@app.route('/yuyx/api/users/pending', methods=['GET']) +@admin_required +def get_pending_users(): + """获取待审核用户""" + users = database.get_pending_users() + return jsonify(users) + + +@app.route('/yuyx/api/users//approve', methods=['POST']) +@admin_required +def approve_user_route(user_id): + """审核通过用户""" + if database.approve_user(user_id): + return jsonify({"success": True}) + return jsonify({"error": "审核失败"}), 400 + + +@app.route('/yuyx/api/users//reject', methods=['POST']) +@admin_required +def reject_user_route(user_id): + """拒绝用户""" + if database.reject_user(user_id): + return jsonify({"success": True}) + return jsonify({"error": "拒绝失败"}), 400 + + +@app.route('/yuyx/api/users/', methods=['DELETE']) +@admin_required +def delete_user_route(user_id): + """删除用户""" + if database.delete_user(user_id): + # 清理内存中的账号数据 + if user_id in user_accounts: + del user_accounts[user_id] + + # 清理用户信号量,防止内存泄漏 + if user_id in user_semaphores: + del user_semaphores[user_id] + + # 清理用户日志缓存,防止内存泄漏 + global log_cache_total_count + if user_id in log_cache: + log_cache_total_count -= len(log_cache[user_id]) + del log_cache[user_id] + + return jsonify({"success": True}) + return jsonify({"error": "删除失败"}), 400 + + +@app.route('/yuyx/api/stats', methods=['GET']) +@admin_required +def get_system_stats(): + """获取系统统计""" + stats = database.get_system_stats() + # 从session获取管理员用户名 + stats["admin_username"] = session.get('admin_username', 'admin') + return jsonify(stats) + + +@app.route('/yuyx/api/docker_stats', methods=['GET']) +@admin_required +def get_docker_stats(): + """获取Docker容器运行状态""" + import subprocess + + docker_status = { + 'running': False, + 'container_name': 'N/A', + 'uptime': 'N/A', + 'memory_usage': 'N/A', + 'memory_limit': 'N/A', + 'memory_percent': 'N/A', + 'cpu_percent': 'N/A', + 'status': 'Unknown' + } + + try: + # 检查是否在Docker容器内 + if os.path.exists('/.dockerenv'): + docker_status['running'] = True + + # 获取容器名称 + try: + with open('/etc/hostname', 'r') as f: + docker_status['container_name'] = f.read().strip() + except: + pass + + # 获取内存使用情况 (cgroup v2) + try: + # 尝试cgroup v2路径 + if os.path.exists('/sys/fs/cgroup/memory.current'): + # Read total memory + with open('/sys/fs/cgroup/memory.current', 'r') as f: + mem_total = int(f.read().strip()) + + # Read cache from memory.stat + cache = 0 + if os.path.exists('/sys/fs/cgroup/memory.stat'): + with open('/sys/fs/cgroup/memory.stat', 'r') as f: + for line in f: + if line.startswith('inactive_file '): + cache = int(line.split()[1]) + break + + # Actual memory = total - cache + mem_bytes = mem_total - cache + docker_status['memory_usage'] = "{:.2f} MB".format(mem_bytes / 1024 / 1024) + + # 获取内存限制 + if os.path.exists('/sys/fs/cgroup/memory.max'): + with open('/sys/fs/cgroup/memory.max', 'r') as f: + limit_str = f.read().strip() + if limit_str != 'max': + limit_bytes = int(limit_str) + docker_status['memory_limit'] = "{:.2f} GB".format(limit_bytes / 1024 / 1024 / 1024) + docker_status['memory_percent'] = "{:.2f}%".format(mem_bytes / limit_bytes * 100) + # 尝试cgroup v1路径 + elif os.path.exists('/sys/fs/cgroup/memory/memory.usage_in_bytes'): + # 从 memory.stat 读取内存信息 + mem_bytes = 0 + if os.path.exists('/sys/fs/cgroup/memory/memory.stat'): + with open('/sys/fs/cgroup/memory/memory.stat', 'r') as f: + rss = 0 + cache = 0 + for line in f: + if line.startswith('total_rss '): + rss = int(line.split()[1]) + elif line.startswith('total_cache '): + cache = int(line.split()[1]) + # 使用 RSS + (一部分活跃的cache),更接近docker stats的计算 + # 但为了准确性,我们只用RSS + mem_bytes = rss + + # 如果找不到,则使用总内存减去缓存作为后备 + if mem_bytes == 0: + with open('/sys/fs/cgroup/memory/memory.usage_in_bytes', 'r') as f: + total_mem = int(f.read().strip()) + + cache = 0 + if os.path.exists('/sys/fs/cgroup/memory/memory.stat'): + with open('/sys/fs/cgroup/memory/memory.stat', 'r') as f: + for line in f: + if line.startswith('total_inactive_file '): + cache = int(line.split()[1]) + break + + mem_bytes = total_mem - cache + + docker_status['memory_usage'] = "{:.2f} MB".format(mem_bytes / 1024 / 1024) + + # 获取内存限制 + if os.path.exists('/sys/fs/cgroup/memory/memory.limit_in_bytes'): + with open('/sys/fs/cgroup/memory/memory.limit_in_bytes', 'r') as f: + limit_bytes = int(f.read().strip()) + # 检查是否是实际限制(不是默认的超大值) + if limit_bytes < 9223372036854771712: + docker_status['memory_limit'] = "{:.2f} GB".format(limit_bytes / 1024 / 1024 / 1024) + docker_status['memory_percent'] = "{:.2f}%".format(mem_bytes / limit_bytes * 100) + except Exception as e: + docker_status['memory_usage'] = 'Error: {}'.format(str(e)) + + # 获取容器运行时间(基于PID 1的启动时间) + try: + # Get PID 1 start time + with open('/proc/1/stat', 'r') as f: + stat_data = f.read().split() + starttime_ticks = int(stat_data[21]) + + # Get system uptime + with open('/proc/uptime', 'r') as f: + system_uptime = float(f.read().split()[0]) + + # Get clock ticks per second + import os as os_module + ticks_per_sec = os_module.sysconf(os_module.sysconf_names['SC_CLK_TCK']) + + # Calculate container uptime + process_start = starttime_ticks / ticks_per_sec + uptime_seconds = int(system_uptime - process_start) + + days = uptime_seconds // 86400 + hours = (uptime_seconds % 86400) // 3600 + minutes = (uptime_seconds % 3600) // 60 + + if days > 0: + docker_status['uptime'] = "{}天 {}小时 {}分钟".format(days, hours, minutes) + elif hours > 0: + docker_status['uptime'] = "{}小时 {}分钟".format(hours, minutes) + else: + docker_status['uptime'] = "{}分钟".format(minutes) + except: + pass + + docker_status['status'] = 'Running' + else: + docker_status['status'] = 'Not in Docker' + + except Exception as e: + docker_status['status'] = 'Error: {}'.format(str(e)) + + return jsonify(docker_status) + +@app.route('/yuyx/api/admin/password', methods=['PUT']) +@admin_required +def update_admin_password(): + """修改管理员密码""" + data = request.json + new_password = data.get('new_password', '').strip() + + if not new_password: + return jsonify({"error": "密码不能为空"}), 400 + + username = session.get('admin_username') + if database.update_admin_password(username, new_password): + return jsonify({"success": True}) + return jsonify({"error": "修改失败"}), 400 + + +@app.route('/yuyx/api/admin/username', methods=['PUT']) +@admin_required +def update_admin_username(): + """修改管理员用户名""" + data = request.json + new_username = data.get('new_username', '').strip() + + if not new_username: + return jsonify({"error": "用户名不能为空"}), 400 + + old_username = session.get('admin_username') + if database.update_admin_username(old_username, new_username): + session['admin_username'] = new_username + return jsonify({"success": True}) + return jsonify({"error": "用户名已存在"}), 400 + + + +# ==================== 密码重置API ==================== + +# 管理员直接重置用户密码 +@app.route('/yuyx/api/users//reset_password', methods=['POST']) +@admin_required +def admin_reset_password_route(user_id): + """管理员直接重置用户密码(无需审核)""" + data = request.json + new_password = data.get('new_password', '').strip() + + if not new_password: + return jsonify({"error": "新密码不能为空"}), 400 + + if len(new_password) < 6: + return jsonify({"error": "密码长度不能少于6位"}), 400 + + if database.admin_reset_user_password(user_id, new_password): + return jsonify({"message": "密码重置成功"}) + return jsonify({"error": "重置失败,用户不存在"}), 400 + + +# 获取密码重置申请列表 +@app.route('/yuyx/api/password_resets', methods=['GET']) +@admin_required +def get_password_resets_route(): + """获取所有待审核的密码重置申请""" + resets = database.get_pending_password_resets() + return jsonify(resets) + + +# 批准密码重置申请 +@app.route('/yuyx/api/password_resets//approve', methods=['POST']) +@admin_required +def approve_password_reset_route(request_id): + """批准密码重置申请""" + if database.approve_password_reset(request_id): + return jsonify({"message": "密码重置申请已批准"}) + return jsonify({"error": "批准失败"}), 400 + + +# 拒绝密码重置申请 +@app.route('/yuyx/api/password_resets//reject', methods=['POST']) +@admin_required +def reject_password_reset_route(request_id): + """拒绝密码重置申请""" + if database.reject_password_reset(request_id): + return jsonify({"message": "密码重置申请已拒绝"}) + return jsonify({"error": "拒绝失败"}), 400 + + +# 用户申请重置密码(需要审核) +@app.route('/api/reset_password_request', methods=['POST']) +def request_password_reset(): + """用户申请重置密码""" + data = request.json + username = data.get('username', '').strip() + email = data.get('email', '').strip() + new_password = data.get('new_password', '').strip() + + if not username or not new_password: + return jsonify({"error": "用户名和新密码不能为空"}), 400 + + if len(new_password) < 6: + return jsonify({"error": "密码长度不能少于6位"}), 400 + + # 验证用户存在 + user = database.get_user_by_username(username) + if not user: + return jsonify({"error": "用户不存在"}), 404 + + # 如果提供了邮箱,验证邮箱是否匹配 + if email and user.get('email') != email: + return jsonify({"error": "邮箱不匹配"}), 400 + + # 创建重置申请 + request_id = database.create_password_reset_request(user['id'], new_password) + if request_id: + return jsonify({"message": "密码重置申请已提交,请等待管理员审核"}) + else: + return jsonify({"error": "申请提交失败"}), 500 + + +# ==================== 账号管理API (用户隔离) ==================== + +def load_user_accounts(user_id): + """从数据库加载用户的账号到内存""" + if user_id not in user_accounts: + user_accounts[user_id] = {} + + accounts_data = database.get_user_accounts(user_id) + for acc_data in accounts_data: + account = Account( + account_id=acc_data['id'], + user_id=user_id, + username=acc_data['username'], + password=acc_data['password'], + remember=bool(acc_data['remember']), + remark=acc_data['remark'] or '' + ) + user_accounts[user_id][account.id] = account + + +@app.route('/api/accounts', methods=['GET']) +@login_required +def get_accounts(): + """获取当前用户的所有账号""" + user_id = current_user.id + if user_id not in user_accounts: + load_user_accounts(user_id) + + accounts = user_accounts.get(user_id, {}) + return jsonify([acc.to_dict() for acc in accounts.values()]) + + +@app.route('/api/accounts', methods=['POST']) +@login_required +def add_account(): + """添加账号""" + user_id = current_user.id + + # VIP账号数量限制检查 + if not database.is_user_vip(user_id): + current_count = len(database.get_user_accounts(user_id)) + if current_count >= 1: + return jsonify({"error": "非VIP用户只能添加1个账号,请联系管理员开通VIP"}), 403 + data = request.json + username = data.get('username', '').strip() + password = data.get('password', '').strip() + remember = data.get('remember', True) + + if not username or not password: + return jsonify({"error": "用户名和密码不能为空"}), 400 + + # 检查当前用户是否已存在该账号 + if user_id in user_accounts: + for acc in user_accounts[user_id].values(): + if acc.username == username: + return jsonify({"error": f"账号 '{username}' 已存在"}), 400 + + # 生成账号ID + import uuid + account_id = str(uuid.uuid4())[:8] + + # 保存到数据库 + database.create_account(user_id, account_id, username, password, remember, '') + + # 加载到内存 + account = Account(account_id, user_id, username, password, remember, '') + if user_id not in user_accounts: + user_accounts[user_id] = {} + user_accounts[user_id][account_id] = account + + log_to_client(f"添加账号: {username}", user_id) + return jsonify(account.to_dict()) + + +@app.route('/api/accounts/', methods=['DELETE']) +@login_required +def delete_account(account_id): + """删除账号""" + user_id = current_user.id + + if user_id not in user_accounts or account_id not in user_accounts[user_id]: + return jsonify({"error": "账号不存在"}), 404 + + account = user_accounts[user_id][account_id] + + # 停止正在运行的任务 + if account.is_running: + account.should_stop = True + if account.automation: + account.automation.close() + + username = account.username + + # 从数据库删除 + database.delete_account(account_id) + + # 从内存删除 + del user_accounts[user_id][account_id] + + log_to_client(f"删除账号: {username}", user_id) + return jsonify({"success": True}) + + +@app.route('/api/accounts//remark', methods=['PUT']) +@login_required +def update_remark(account_id): + """更新备注""" + user_id = current_user.id + + if user_id not in user_accounts or account_id not in user_accounts[user_id]: + return jsonify({"error": "账号不存在"}), 404 + + data = request.json + remark = data.get('remark', '').strip()[:200] + + # 更新数据库 + database.update_account_remark(account_id, remark) + + # 更新内存 + user_accounts[user_id][account_id].remark = remark + log_to_client(f"更新备注: {user_accounts[user_id][account_id].username} -> {remark}", user_id) + + return jsonify({"success": True}) + + +@app.route('/api/accounts//start', methods=['POST']) +@login_required +def start_account(account_id): + """启动账号任务""" + user_id = current_user.id + + if user_id not in user_accounts or account_id not in user_accounts[user_id]: + return jsonify({"error": "账号不存在"}), 404 + + account = user_accounts[user_id][account_id] + + if account.is_running: + return jsonify({"error": "任务已在运行中"}), 400 + + data = request.json + browse_type = data.get('browse_type', '应读') + enable_screenshot = data.get('enable_screenshot', True) # 默认启用截图 + + # 确保浏览器管理器已初始化 + if not init_browser_manager(): + return jsonify({"error": "浏览器初始化失败"}), 500 + + # 启动任务线程 + account.is_running = True + account.should_stop = False + account.status = "运行中" + + thread = threading.Thread( + target=run_task, + args=(user_id, account_id, browse_type, enable_screenshot), + daemon=True + ) + thread.start() + active_tasks[account_id] = thread + + log_to_client(f"启动任务: {account.username} - {browse_type}", user_id) + socketio.emit('account_update', account.to_dict(), room=f'user_{user_id}') + + return jsonify({"success": True}) + + +@app.route('/api/accounts//stop', methods=['POST']) +@login_required +def stop_account(account_id): + """停止账号任务""" + user_id = current_user.id + + if user_id not in user_accounts or account_id not in user_accounts[user_id]: + return jsonify({"error": "账号不存在"}), 404 + + account = user_accounts[user_id][account_id] + + if not account.is_running: + return jsonify({"error": "任务未在运行"}), 400 + + account.should_stop = True + account.status = "正在停止" + + log_to_client(f"停止任务: {account.username}", user_id) + socketio.emit('account_update', account.to_dict(), room=f'user_{user_id}') + + return jsonify({"success": True}) + + +def get_user_semaphore(user_id): + """获取或创建用户的信号量""" + if user_id not in user_semaphores: + user_semaphores[user_id] = threading.Semaphore(max_concurrent_per_account) + return user_semaphores[user_id] + + +def run_task(user_id, account_id, browse_type, enable_screenshot=True): + """运行自动化任务""" + if user_id not in user_accounts or account_id not in user_accounts[user_id]: + return + + account = user_accounts[user_id][account_id] + + # 记录任务开始时间 + import time as time_module + task_start_time = time_module.time() + + # 两级并发控制:用户级 + 全局级 + user_sem = get_user_semaphore(user_id) + + # 获取用户级信号量(同一用户的账号排队) + log_to_client(f"等待资源分配...", user_id, account_id) + account.status = "排队中" + socketio.emit('account_update', account.to_dict(), room=f'user_{user_id}') + + user_sem.acquire() + + try: + # 如果在排队期间被停止,直接返回 + if account.should_stop: + log_to_client(f"任务已取消", user_id, account_id) + account.status = "已停止" + account.is_running = False + socketio.emit('account_update', account.to_dict(), room=f'user_{user_id}') + return + + # 获取全局信号量(防止所有用户同时运行导致资源耗尽) + global_semaphore.acquire() + + try: + # 再次检查是否被停止 + if account.should_stop: + log_to_client(f"任务已取消", user_id, account_id) + account.status = "已停止" + account.is_running = False + socketio.emit('account_update', account.to_dict(), room=f'user_{user_id}') + return + + account.status = "运行中" + socketio.emit('account_update', account.to_dict(), room=f'user_{user_id}') + account.last_browse_type = browse_type + + # 重试机制:最多尝试3次,超时则换IP重试 + max_attempts = 3 + last_error = None + + for attempt in range(1, max_attempts + 1): + try: + if attempt > 1: + log_to_client(f"🔄 第 {attempt} 次尝试(共{max_attempts}次)...", user_id, account_id) + + # 检查是否需要使用代理 + proxy_config = None + config = database.get_system_config() + if config.get('proxy_enabled') == 1: + proxy_api_url = config.get('proxy_api_url', '').strip() + if proxy_api_url: + log_to_client(f"正在获取代理IP...", user_id, account_id) + proxy_server = get_proxy_from_api(proxy_api_url, max_retries=3) + if proxy_server: + proxy_config = {'server': proxy_server} + log_to_client(f"✓ 将使用代理: {proxy_server}", user_id, account_id) + account.proxy_config = proxy_config # 保存代理配置供截图使用 + else: + log_to_client(f"✗ 代理获取失败,将不使用代理继续", user_id, account_id) + else: + log_to_client(f"⚠ 代理已启用但未配置API地址", user_id, account_id) + + log_to_client(f"创建自动化实例...", user_id, account_id) + account.automation = PlaywrightAutomation(browser_manager, account_id, proxy_config=proxy_config) + + # 为automation注入包含user_id的自定义log方法,使其能够实时发送日志到WebSocket + def custom_log(message: str): + log_to_client(message, user_id, account_id) + account.automation.log = custom_log + + log_to_client(f"开始登录...", user_id, account_id) + if not account.automation.login(account.username, account.password, account.remember): + log_to_client(f"❌ 登录失败,请检查用户名和密码", user_id, account_id) + account.status = "登录失败" + account.is_running = False + # 记录登录失败日志 + database.create_task_log( + user_id=user_id, + account_id=account_id, + username=account.username, + browse_type=browse_type, + status='failed', + total_items=0, + total_attachments=0, + error_message='登录失败,请检查用户名和密码', + duration=int(time_module.time() - task_start_time) + ) + socketio.emit('account_update', account.to_dict(), room=f'user_{user_id}') + return + + log_to_client(f"✓ 登录成功!", user_id, account_id) + log_to_client(f"开始浏览 '{browse_type}' 内容...", user_id, account_id) + + def should_stop(): + return account.should_stop + + result = account.automation.browse_content( + browse_type=browse_type, + auto_next_page=True, + auto_view_attachments=True, + interval=2.0, + should_stop_callback=should_stop + ) + + account.total_items = result.total_items + account.total_attachments = result.total_attachments + + if result.success: + log_to_client(f"浏览完成! 共 {result.total_items} 条内容,{result.total_attachments} 个附件", user_id, account_id) + account.status = "已完成" + # 记录成功日志 + database.create_task_log( + user_id=user_id, + account_id=account_id, + username=account.username, + browse_type=browse_type, + status='success', + total_items=result.total_items, + total_attachments=result.total_attachments, + error_message='', + duration=int(time_module.time() - task_start_time) + ) + # 成功则跳出重试循环 + break + else: + # 浏览出错,检查是否是超时错误 + error_msg = result.error_message + if 'Timeout' in error_msg or 'timeout' in error_msg: + last_error = error_msg + log_to_client(f"⚠ 检测到超时错误: {error_msg}", user_id, account_id) + + # 关闭当前浏览器 + if account.automation: + try: + account.automation.close() + log_to_client(f"已关闭超时的浏览器实例", user_id, account_id) + except: + pass + account.automation = None + + if attempt < max_attempts: + log_to_client(f"⚠ 代理可能速度过慢,将换新IP重试 ({attempt}/{max_attempts})", user_id, account_id) + time_module.sleep(2) # 等待2秒再重试 + continue + else: + log_to_client(f"❌ 已达到最大重试次数({max_attempts}),任务失败", user_id, account_id) + account.status = "出错" + database.create_task_log( + user_id=user_id, + account_id=account_id, + username=account.username, + browse_type=browse_type, + status='failed', + total_items=result.total_items, + total_attachments=result.total_attachments, + error_message=f"重试{max_attempts}次后仍失败: {error_msg}", + duration=int(time_module.time() - task_start_time) + ) + break + else: + # 非超时错误,直接失败不重试 + log_to_client(f"浏览出错: {error_msg}", user_id, account_id) + account.status = "出错" + database.create_task_log( + user_id=user_id, + account_id=account_id, + username=account.username, + browse_type=browse_type, + status='failed', + total_items=result.total_items, + total_attachments=result.total_attachments, + error_message=error_msg, + duration=int(time_module.time() - task_start_time) + ) + break + + except Exception as retry_error: + # 捕获重试过程中的异常 + error_msg = str(retry_error) + last_error = error_msg + + # 关闭可能存在的浏览器实例 + if account.automation: + try: + account.automation.close() + except: + pass + account.automation = None + + if 'Timeout' in error_msg or 'timeout' in error_msg: + log_to_client(f"⚠ 执行超时: {error_msg}", user_id, account_id) + if attempt < max_attempts: + log_to_client(f"⚠ 将换新IP重试 ({attempt}/{max_attempts})", user_id, account_id) + time_module.sleep(2) + continue + else: + log_to_client(f"❌ 已达到最大重试次数({max_attempts}),任务失败", user_id, account_id) + account.status = "出错" + database.create_task_log( + user_id=user_id, + account_id=account_id, + username=account.username, + browse_type=browse_type, + status='failed', + total_items=account.total_items, + total_attachments=account.total_attachments, + error_message=f"重试{max_attempts}次后仍失败: {error_msg}", + duration=int(time_module.time() - task_start_time) + ) + break + else: + # 非超时异常,直接失败 + log_to_client(f"任务执行异常: {error_msg}", user_id, account_id) + account.status = "出错" + database.create_task_log( + user_id=user_id, + account_id=account_id, + username=account.username, + browse_type=browse_type, + status='failed', + total_items=account.total_items, + total_attachments=account.total_attachments, + error_message=error_msg, + duration=int(time_module.time() - task_start_time) + ) + break + + + except Exception as e: + error_msg = str(e) + log_to_client(f"任务执行出错: {error_msg}", user_id, account_id) + account.status = "出错" + # 记录异常失败日志 + database.create_task_log( + user_id=user_id, + account_id=account_id, + username=account.username, + browse_type=browse_type, + status='failed', + total_items=account.total_items, + total_attachments=account.total_attachments, + error_message=error_msg, + duration=int(time_module.time() - task_start_time) + ) + + finally: + # 释放全局信号量 + global_semaphore.release() + + account.is_running = False + + if account.automation: + try: + account.automation.close() + log_to_client(f"主任务浏览器已关闭", user_id, account_id) + except Exception as e: + log_to_client(f"关闭主任务浏览器时出错: {str(e)}", user_id, account_id) + finally: + account.automation = None + + if account_id in active_tasks: + del active_tasks[account_id] + + socketio.emit('account_update', account.to_dict(), room=f'user_{user_id}') + + # 任务完成后自动截图(增加2秒延迟,确保资源完全释放) + # 根据enable_screenshot参数决定是否截图 + if account.status == "已完成" and not account.should_stop: + if enable_screenshot: + log_to_client(f"等待2秒后开始截图...", user_id, account_id) + time.sleep(2) # 延迟启动截图,确保主任务资源已完全释放 + threading.Thread(target=take_screenshot_for_account, args=(user_id, account_id), daemon=True).start() + else: + log_to_client(f"截图功能已禁用,跳过截图", user_id, account_id) + + finally: + # 释放用户级信号量 + user_sem.release() + + +def take_screenshot_for_account(user_id, account_id): + """为账号任务完成后截图(带并发控制,避免资源竞争)""" + if user_id not in user_accounts or account_id not in user_accounts[user_id]: + return + + account = user_accounts[user_id][account_id] + + # 使用截图信号量,确保同时只有1个截图任务在执行 + log_to_client(f"等待截图资源分配...", user_id, account_id) + screenshot_acquired = screenshot_semaphore.acquire(blocking=True, timeout=300) # 最多等待5分钟 + + if not screenshot_acquired: + log_to_client(f"截图资源获取超时,跳过截图", user_id, account_id) + return + + automation = None + try: + log_to_client(f"开始截图流程...", user_id, account_id) + + # 使用与浏览任务相同的代理配置 + proxy_config = account.proxy_config if hasattr(account, 'proxy_config') else None + if proxy_config: + log_to_client(f"截图将使用相同代理: {proxy_config.get('server', 'Unknown')}", user_id, account_id) + + automation = PlaywrightAutomation(browser_manager, account_id, proxy_config=proxy_config) + + # 为截图automation也注入自定义log方法 + def custom_log(message: str): + log_to_client(message, user_id, account_id) + automation.log = custom_log + + log_to_client(f"重新登录以进行截图...", user_id, account_id) + if not automation.login(account.username, account.password, account.remember): + log_to_client(f"截图登录失败", user_id, account_id) + return + + browse_type = account.last_browse_type + log_to_client(f"导航到 '{browse_type}' 页面...", user_id, account_id) + + # 不使用should_stop_callback,让页面加载完成显示"暂无记录" + result = automation.browse_content( + browse_type=browse_type, + auto_next_page=False, + auto_view_attachments=False, + interval=0, + should_stop_callback=None + ) + + if not result.success and result.error_message != "": + log_to_client(f"导航失败: {result.error_message}", user_id, account_id) + + time.sleep(2) + + # 生成截图文件名(使用北京时间并简化格式) + beijing_tz = pytz.timezone('Asia/Shanghai') + now_beijing = datetime.now(beijing_tz) + timestamp = now_beijing.strftime('%Y%m%d_%H%M%S') + + # 简化文件名:用户名_登录账号_浏览类型_时间.jpg + # 获取用户名前缀 + user_info = database.get_user_by_id(user_id) + username_prefix = user_info['username'] if user_info else f"user{user_id}" + # 使用登录账号(account.username)而不是备注 + login_account = account.username + screenshot_filename = f"{username_prefix}_{login_account}_{browse_type}_{timestamp}.jpg" + screenshot_path = os.path.join(SCREENSHOTS_DIR, screenshot_filename) + + if automation.take_screenshot(screenshot_path): + log_to_client(f"✓ 截图已保存: {screenshot_filename}", user_id, account_id) + else: + log_to_client(f"✗ 截图失败", user_id, account_id) + + except Exception as e: + log_to_client(f"✗ 截图过程中出错: {str(e)}", user_id, account_id) + + finally: + # 确保浏览器资源被正确关闭 + if automation: + try: + automation.close() + log_to_client(f"截图浏览器已关闭", user_id, account_id) + except Exception as e: + log_to_client(f"关闭截图浏览器时出错: {str(e)}", user_id, account_id) + + # 释放截图信号量 + screenshot_semaphore.release() + log_to_client(f"截图资源已释放", user_id, account_id) + + +@app.route('/api/accounts//screenshot', methods=['POST']) +@login_required +def manual_screenshot(account_id): + """手动为指定账号截图""" + user_id = current_user.id + + if user_id not in user_accounts or account_id not in user_accounts[user_id]: + return jsonify({"error": "账号不存在"}), 404 + + account = user_accounts[user_id][account_id] + if account.is_running: + return jsonify({"error": "任务运行中,无法截图"}), 400 + + data = request.json or {} + browse_type = data.get('browse_type', account.last_browse_type) + + account.last_browse_type = browse_type + + threading.Thread(target=take_screenshot_for_account, args=(user_id, account_id), daemon=True).start() + log_to_client(f"手动截图: {account.username} - {browse_type}", user_id) + return jsonify({"success": True}) + + +# ==================== 截图管理API ==================== + +@app.route('/api/screenshots', methods=['GET']) +@login_required +def get_screenshots(): + """获取当前用户的截图列表""" + user_id = current_user.id + user_info = database.get_user_by_id(user_id) + username_prefix = user_info['username'] if user_info else f"user{user_id}" + + try: + screenshots = [] + if os.path.exists(SCREENSHOTS_DIR): + for filename in os.listdir(SCREENSHOTS_DIR): + # 只显示属于当前用户的截图(支持png和jpg格式) + if (filename.lower().endswith(('.png', '.jpg', '.jpeg'))) and filename.startswith(username_prefix + '_'): + filepath = os.path.join(SCREENSHOTS_DIR, filename) + stat = os.stat(filepath) + # 转换为北京时间 + beijing_tz = pytz.timezone('Asia/Shanghai') + created_time = datetime.fromtimestamp(stat.st_mtime, tz=beijing_tz) + # 解析文件名获取显示名称 + # 文件名格式:用户名_登录账号_浏览类型_时间.jpg + parts = filename.rsplit('.', 1)[0].split('_', 1) # 移除扩展名并分割 + if len(parts) > 1: + # 显示名称:登录账号_浏览类型_时间.jpg + display_name = parts[1] + '.' + filename.rsplit('.', 1)[1] + else: + display_name = filename + + screenshots.append({ + 'filename': filename, + 'display_name': display_name, + 'size': stat.st_size, + 'created': created_time.strftime('%Y-%m-%d %H:%M:%S') + }) + screenshots.sort(key=lambda x: x['created'], reverse=True) + return jsonify(screenshots) + except Exception as e: + return jsonify({"error": str(e)}), 500 + + +@app.route('/screenshots/') +@login_required +def serve_screenshot(filename): + """提供截图文件访问""" + user_id = current_user.id + user_info = database.get_user_by_id(user_id) + username_prefix = user_info['username'] if user_info else f"user{user_id}" + + # 验证文件属于当前用户 + if not filename.startswith(username_prefix + '_'): + return jsonify({"error": "无权访问"}), 403 + + return send_from_directory(SCREENSHOTS_DIR, filename) + + +@app.route('/api/screenshots/', methods=['DELETE']) +@login_required +def delete_screenshot(filename): + """删除指定截图""" + user_id = current_user.id + user_info = database.get_user_by_id(user_id) + username_prefix = user_info['username'] if user_info else f"user{user_id}" + + # 验证文件属于当前用户 + if not filename.startswith(username_prefix + '_'): + return jsonify({"error": "无权删除"}), 403 + + try: + filepath = os.path.join(SCREENSHOTS_DIR, filename) + if os.path.exists(filepath): + os.remove(filepath) + log_to_client(f"删除截图: {filename}", user_id) + return jsonify({"success": True}) + else: + return jsonify({"error": "文件不存在"}), 404 + except Exception as e: + return jsonify({"error": str(e)}), 500 + + +@app.route('/api/screenshots/clear', methods=['POST']) +@login_required +def clear_all_screenshots(): + """清空当前用户的所有截图""" + user_id = current_user.id + user_info = database.get_user_by_id(user_id) + username_prefix = user_info['username'] if user_info else f"user{user_id}" + + try: + deleted_count = 0 + if os.path.exists(SCREENSHOTS_DIR): + for filename in os.listdir(SCREENSHOTS_DIR): + if (filename.lower().endswith(('.png', '.jpg', '.jpeg'))) and filename.startswith(username_prefix + '_'): + filepath = os.path.join(SCREENSHOTS_DIR, filename) + os.remove(filepath) + deleted_count += 1 + log_to_client(f"清理了 {deleted_count} 个截图文件", user_id) + return jsonify({"success": True, "deleted": deleted_count}) + except Exception as e: + return jsonify({"error": str(e)}), 500 + + +# ==================== WebSocket事件 ==================== + +@socketio.on('connect') +def handle_connect(): + """客户端连接""" + if current_user.is_authenticated: + user_id = current_user.id + join_room(f'user_{user_id}') + log_to_client("客户端已连接", user_id) + + # 发送账号列表 + accounts = user_accounts.get(user_id, {}) + emit('accounts_list', [acc.to_dict() for acc in accounts.values()]) + + # 发送历史日志 + if user_id in log_cache: + for log_entry in log_cache[user_id]: + emit('log', log_entry) + + +@socketio.on('disconnect') +def handle_disconnect(): + """客户端断开""" + if current_user.is_authenticated: + user_id = current_user.id + leave_room(f'user_{user_id}') + + +# ==================== 静态文件 ==================== + +@app.route('/static/') +def serve_static(filename): + """提供静态文件访问""" + return send_from_directory('static', filename) + + +# ==================== 启动 ==================== + + +# ==================== 管理员VIP管理API ==================== + +@app.route('/yuyx/api/vip/config', methods=['GET']) +def get_vip_config_api(): + """获取VIP配置""" + if 'admin_id' not in session: + return jsonify({"error": "需要管理员权限"}), 403 + config = database.get_vip_config() + return jsonify(config) + + +@app.route('/yuyx/api/vip/config', methods=['POST']) +def set_vip_config_api(): + """设置默认VIP天数""" + if 'admin_id' not in session: + return jsonify({"error": "需要管理员权限"}), 403 + + data = request.json + days = data.get('default_vip_days', 0) + + if not isinstance(days, int) or days < 0: + return jsonify({"error": "VIP天数必须是非负整数"}), 400 + + database.set_default_vip_days(days) + return jsonify({"message": "VIP配置已更新", "default_vip_days": days}) + + +@app.route('/yuyx/api/users//vip', methods=['POST']) +def set_user_vip_api(user_id): + """设置用户VIP""" + if 'admin_id' not in session: + return jsonify({"error": "需要管理员权限"}), 403 + + data = request.json + days = data.get('days', 30) + + # 验证days参数 + valid_days = [7, 30, 365, 999999] + if days not in valid_days: + return jsonify({"error": "VIP天数必须是 7/30/365/999999 之一"}), 400 + + if database.set_user_vip(user_id, days): + vip_type = {7: "一周", 30: "一个月", 365: "一年", 999999: "永久"}[days] + return jsonify({"message": f"VIP设置成功: {vip_type}"}) + return jsonify({"error": "设置失败,用户不存在"}), 400 + + +@app.route('/yuyx/api/users//vip', methods=['DELETE']) +def remove_user_vip_api(user_id): + """移除用户VIP""" + if 'admin_id' not in session: + return jsonify({"error": "需要管理员权限"}), 403 + + if database.remove_user_vip(user_id): + return jsonify({"message": "VIP已移除"}) + return jsonify({"error": "移除失败"}), 400 + + +@app.route('/yuyx/api/users//vip', methods=['GET']) +def get_user_vip_info_api(user_id): + """获取用户VIP信息(管理员)""" + if 'admin_id' not in session: + return jsonify({"error": "需要管理员权限"}), 403 + + vip_info = database.get_user_vip_info(user_id) + return jsonify(vip_info) + + + +# ==================== 用户端VIP查询API ==================== + +@app.route('/api/user/vip', methods=['GET']) +@login_required +def get_current_user_vip(): + """获取当前用户VIP信息""" + vip_info = database.get_user_vip_info(current_user.id) + return jsonify(vip_info) + + +@app.route('/api/run_stats', methods=['GET']) +@login_required +def get_run_stats(): + """获取当前用户的运行统计""" + user_id = current_user.id + + # 获取今日任务统计 + stats = database.get_user_run_stats(user_id) + + # 计算当前正在运行的账号数 + current_running = 0 + if user_id in user_accounts: + current_running = sum(1 for acc in user_accounts[user_id].values() if acc.is_running) + + return jsonify({ + 'today_completed': stats.get('completed', 0), + 'current_running': current_running, + 'today_failed': stats.get('failed', 0), + 'today_items': stats.get('total_items', 0), + 'today_attachments': stats.get('total_attachments', 0) + }) + + +# ==================== 系统配置API ==================== + +@app.route('/yuyx/api/system/config', methods=['GET']) +@admin_required +def get_system_config_api(): + """获取系统配置""" + config = database.get_system_config() + return jsonify(config) + + +@app.route('/yuyx/api/system/config', methods=['POST']) +@admin_required +def update_system_config_api(): + """更新系统配置""" + global max_concurrent_global, global_semaphore, max_concurrent_per_account + + data = request.json + max_concurrent = data.get('max_concurrent_global') + schedule_enabled = data.get('schedule_enabled') + schedule_time = data.get('schedule_time') + schedule_browse_type = data.get('schedule_browse_type') + schedule_weekdays = data.get('schedule_weekdays') + new_max_concurrent_per_account = data.get('max_concurrent_per_account') + + # 验证参数 + if max_concurrent is not None: + if not isinstance(max_concurrent, int) or max_concurrent < 1 or max_concurrent > 20: + return jsonify({"error": "全局并发数必须在1-20之间"}), 400 + + if new_max_concurrent_per_account is not None: + if not isinstance(new_max_concurrent_per_account, int) or new_max_concurrent_per_account < 1 or new_max_concurrent_per_account > 5: + return jsonify({"error": "单账号并发数必须在1-5之间"}), 400 + + if schedule_time is not None: + # 验证时间格式 HH:MM + import re + if not re.match(r'^([01]\d|2[0-3]):([0-5]\d)$', schedule_time): + return jsonify({"error": "时间格式错误,应为 HH:MM"}), 400 + + if schedule_browse_type is not None: + if schedule_browse_type not in ['注册前未读', '应读', '未读']: + return jsonify({"error": "浏览类型无效"}), 400 + + if schedule_weekdays is not None: + # 验证星期格式,应该是逗号分隔的数字字符串 "1,2,3,4,5,6,7" + try: + days = [int(d.strip()) for d in schedule_weekdays.split(',') if d.strip()] + if not all(1 <= d <= 7 for d in days): + return jsonify({"error": "星期数字必须在1-7之间"}), 400 + except (ValueError, AttributeError): + return jsonify({"error": "星期格式错误"}), 400 + + # 更新数据库 + if database.update_system_config( + max_concurrent=max_concurrent, + schedule_enabled=schedule_enabled, + schedule_time=schedule_time, + schedule_browse_type=schedule_browse_type, + schedule_weekdays=schedule_weekdays, + max_concurrent_per_account=new_max_concurrent_per_account + ): + # 如果修改了并发数,更新全局变量和信号量 + if max_concurrent is not None and max_concurrent != max_concurrent_global: + max_concurrent_global = max_concurrent + global_semaphore = threading.Semaphore(max_concurrent) + print(f"全局并发数已更新为: {max_concurrent}") + + # 如果修改了单用户并发数,更新全局变量(已有的信号量会在下次创建时使用新值) + if new_max_concurrent_per_account is not None and new_max_concurrent_per_account != max_concurrent_per_account: + max_concurrent_per_account = new_max_concurrent_per_account + print(f"单用户并发数已更新为: {max_concurrent_per_account}") + + return jsonify({"message": "系统配置已更新"}) + + return jsonify({"error": "更新失败"}), 400 + + + + +# ==================== 代理配置API ==================== + +@app.route('/yuyx/api/proxy/config', methods=['GET']) +@admin_required +def get_proxy_config_api(): + """获取代理配置""" + config = database.get_system_config() + return jsonify({ + 'proxy_enabled': config.get('proxy_enabled', 0), + 'proxy_api_url': config.get('proxy_api_url', ''), + 'proxy_expire_minutes': config.get('proxy_expire_minutes', 3) + }) + + +@app.route('/yuyx/api/proxy/config', methods=['POST']) +@admin_required +def update_proxy_config_api(): + """更新代理配置""" + data = request.json + proxy_enabled = data.get('proxy_enabled') + proxy_api_url = data.get('proxy_api_url', '').strip() + proxy_expire_minutes = data.get('proxy_expire_minutes') + + if proxy_enabled is not None and proxy_enabled not in [0, 1]: + return jsonify({"error": "proxy_enabled必须是0或1"}), 400 + + if proxy_expire_minutes is not None: + if not isinstance(proxy_expire_minutes, int) or proxy_expire_minutes < 1: + return jsonify({"error": "代理有效期必须是大于0的整数"}), 400 + + if database.update_system_config( + proxy_enabled=proxy_enabled, + proxy_api_url=proxy_api_url, + proxy_expire_minutes=proxy_expire_minutes + ): + return jsonify({"message": "代理配置已更新"}) + + return jsonify({"error": "更新失败"}), 400 + + +@app.route('/yuyx/api/proxy/test', methods=['POST']) +@admin_required +def test_proxy_api(): + """测试代理连接""" + data = request.json + api_url = data.get('api_url', '').strip() + + if not api_url: + return jsonify({"error": "请提供API地址"}), 400 + + try: + response = requests.get(api_url, timeout=10) + if response.status_code == 200: + ip_port = response.text.strip() + if ip_port and ':' in ip_port: + return jsonify({ + "success": True, + "proxy": ip_port, + "message": f"代理获取成功: {ip_port}" + }) + else: + return jsonify({ + "success": False, + "message": f"代理格式错误: {ip_port}" + }), 400 + else: + return jsonify({ + "success": False, + "message": f"HTTP错误: {response.status_code}" + }), 400 + except Exception as e: + return jsonify({ + "success": False, + "message": f"连接失败: {str(e)}" + }), 500 + +# ==================== 服务器信息API ==================== + +@app.route('/yuyx/api/server/info', methods=['GET']) +@admin_required +def get_server_info_api(): + """获取服务器信息""" + import psutil + import datetime + + # CPU使用率 + cpu_percent = psutil.cpu_percent(interval=1) + + # 内存信息 + memory = psutil.virtual_memory() + memory_total = f"{memory.total / (1024**3):.1f}GB" + memory_used = f"{memory.used / (1024**3):.1f}GB" + memory_percent = memory.percent + + # 磁盘信息 + disk = psutil.disk_usage('/') + disk_total = f"{disk.total / (1024**3):.1f}GB" + disk_used = f"{disk.used / (1024**3):.1f}GB" + disk_percent = disk.percent + + # 运行时长 + boot_time = datetime.datetime.fromtimestamp(psutil.boot_time()) + uptime_delta = datetime.datetime.now() - boot_time + days = uptime_delta.days + hours = uptime_delta.seconds // 3600 + uptime = f"{days}天{hours}小时" + + return jsonify({ + 'cpu_percent': cpu_percent, + 'memory_total': memory_total, + 'memory_used': memory_used, + 'memory_percent': memory_percent, + 'disk_total': disk_total, + 'disk_used': disk_used, + 'disk_percent': disk_percent, + 'uptime': uptime + }) + + +# ==================== 任务统计和日志API ==================== + +@app.route('/yuyx/api/task/stats', methods=['GET']) +@admin_required +def get_task_stats_api(): + """获取任务统计数据""" + date_filter = request.args.get('date') # YYYY-MM-DD格式 + stats = database.get_task_stats(date_filter) + return jsonify(stats) + + +@app.route('/yuyx/api/task/logs', methods=['GET']) +@admin_required +def get_task_logs_api(): + """获取任务日志列表""" + limit = int(request.args.get('limit', 100)) + offset = int(request.args.get('offset', 0)) + date_filter = request.args.get('date') # YYYY-MM-DD格式 + status_filter = request.args.get('status') # success/failed + + logs = database.get_task_logs(limit, offset, date_filter, status_filter) + return jsonify(logs) + + +@app.route('/yuyx/api/task/logs/clear', methods=['POST']) +@admin_required +def clear_old_task_logs_api(): + """清理旧的任务日志""" + data = request.json or {} + days = data.get('days', 30) + + if not isinstance(days, int) or days < 1: + return jsonify({"error": "天数必须是大于0的整数"}), 400 + + deleted_count = database.delete_old_task_logs(days) + return jsonify({"message": f"已删除{days}天前的{deleted_count}条日志"}) + + +# ==================== 定时任务调度器 ==================== + +def scheduled_task_worker(): + """定时任务工作线程""" + import schedule + from datetime import datetime + + def run_all_accounts_task(): + """执行所有账号的浏览任务(过滤重复账号)""" + try: + config = database.get_system_config() + browse_type = config.get('schedule_browse_type', '应读') + + # 检查今天是否在允许执行的星期列表中 + from datetime import datetime + import pytz + + # 获取北京时间的星期几 (1=周一, 7=周日) + beijing_tz = pytz.timezone('Asia/Shanghai') + now_beijing = datetime.now(beijing_tz) + current_weekday = now_beijing.isoweekday() # 1-7 + + # 获取配置的星期列表 + schedule_weekdays = config.get('schedule_weekdays', '1,2,3,4,5,6,7') + allowed_weekdays = [int(d.strip()) for d in schedule_weekdays.split(',') if d.strip()] + + if current_weekday not in allowed_weekdays: + weekday_names = ['', '周一', '周二', '周三', '周四', '周五', '周六', '周日'] + print(f"[定时任务] 今天是{weekday_names[current_weekday]},不在执行日期内,跳过执行") + return + + print(f"[定时任务] 开始执行 - 浏览类型: {browse_type}") + + # 获取所有已审核用户的所有账号 + all_users = database.get_all_users() + approved_users = [u for u in all_users if u['status'] == 'approved'] + + # 用于记录已执行的账号用户名,避免重复 + executed_usernames = set() + total_accounts = 0 + skipped_duplicates = 0 + executed_accounts = 0 + + for user in approved_users: + user_id = user['id'] + if user_id not in user_accounts: + load_user_accounts(user_id) + + accounts = user_accounts.get(user_id, {}) + for account_id, account in accounts.items(): + total_accounts += 1 + + # 跳过正在运行的账号 + if account.is_running: + continue + + # 检查账号用户名是否已经执行过(重复账号过滤) + if account.username in executed_usernames: + skipped_duplicates += 1 + print(f"[定时任务] 跳过重复账号: {account.username} (用户:{user['username']}) - 该账号已被其他用户执行") + continue + + # 记录该账号用户名,避免后续重复执行 + executed_usernames.add(account.username) + + print(f"[定时任务] 启动账号: {account.username} (用户:{user['username']})") + + # 启动任务 + account.is_running = True + account.should_stop = False + account.status = "运行中" + + # 获取系统配置的截图开关 + config = database.get_system_config() + enable_screenshot_scheduled = config.get("enable_screenshot", 0) == 1 + + thread = threading.Thread( + target=run_task, + args=(user_id, account_id, browse_type, enable_screenshot_scheduled), + daemon=True + ) + thread.start() + active_tasks[account_id] = thread + executed_accounts += 1 + + # 发送更新到用户 + socketio.emit('account_update', account.to_dict(), room=f'user_{user_id}') + + # 间隔启动,避免瞬间并发过高 + time.sleep(2) + + print(f"[定时任务] 执行完成 - 总账号数:{total_accounts}, 已执行:{executed_accounts}, 跳过重复:{skipped_duplicates}") + + except Exception as e: + print(f"[定时任务] 执行出错: {str(e)}") + + def cleanup_expired_captcha(): + """清理过期验证码,防止内存泄漏""" + try: + current_time = time.time() + expired_keys = [k for k, v in captcha_storage.items() + if v["expire_time"] < current_time] + deleted_count = len(expired_keys) + for k in expired_keys: + del captcha_storage[k] + if deleted_count > 0: + print(f"[定时清理] 已清理 {deleted_count} 个过期验证码") + except Exception as e: + print(f"[定时清理] 清理验证码出错: {str(e)}") + + def cleanup_old_data(): + """清理7天前的截图和日志""" + try: + print(f"[定时清理] 开始清理7天前的数据...") + + # 清理7天前的任务日志 + deleted_logs = database.delete_old_task_logs(7) + print(f"[定时清理] 已删除 {deleted_logs} 条任务日志") + + # 清理30天前的操作日志 + deleted_operation_logs = database.clean_old_operation_logs(30) + print(f"[定时清理] 已删除 {deleted_operation_logs} 条操作日志") + # 清理7天前的截图 + deleted_screenshots = 0 + if os.path.exists(SCREENSHOTS_DIR): + cutoff_time = time.time() - (7 * 24 * 60 * 60) # 7天前的时间戳 + for filename in os.listdir(SCREENSHOTS_DIR): + if filename.lower().endswith(('.png', '.jpg', '.jpeg')): + filepath = os.path.join(SCREENSHOTS_DIR, filename) + try: + # 检查文件修改时间 + if os.path.getmtime(filepath) < cutoff_time: + os.remove(filepath) + deleted_screenshots += 1 + except Exception as e: + print(f"[定时清理] 删除截图失败 {filename}: {str(e)}") + + print(f"[定时清理] 已删除 {deleted_screenshots} 个截图文件") + print(f"[定时清理] 清理完成!") + + except Exception as e: + print(f"[定时清理] 清理任务出错: {str(e)}") + + # 每分钟检查一次配置 + def check_and_schedule(): + config = database.get_system_config() + + # 清除旧的任务 + schedule.clear() + + # 时区转换函数:将CST时间转换为UTC时间(容器使用UTC) + def cst_to_utc_time(cst_time_str): + """将CST时间字符串(HH:MM)转换为UTC时间字符串 + + Args: + cst_time_str: CST时间字符串,格式为 HH:MM + + Returns: + UTC时间字符串,格式为 HH:MM + """ + from datetime import datetime, timedelta + # 解析CST时间 + hour, minute = map(int, cst_time_str.split(':')) + # CST是UTC+8,所以UTC时间 = CST时间 - 8小时 + utc_hour = (hour - 8) % 24 + return f"{utc_hour:02d}:{minute:02d}" + + # 始终添加每天凌晨3点(CST)的数据清理任务 + cleanup_utc_time = cst_to_utc_time("03:00") + schedule.every().day.at(cleanup_utc_time).do(cleanup_old_data) + print(f"[定时任务] 已设置数据清理任务: 每天 CST 03:00 (UTC {cleanup_utc_time})") + + # 每小时清理过期验证码 + schedule.every().hour.do(cleanup_expired_captcha) + print(f"[定时任务] 已设置验证码清理任务: 每小时执行一次") + + # 如果启用了定时浏览任务,则添加 + if config.get('schedule_enabled'): + schedule_time_cst = config.get('schedule_time', '02:00') + schedule_time_utc = cst_to_utc_time(schedule_time_cst) + schedule.every().day.at(schedule_time_utc).do(run_all_accounts_task) + print(f"[定时任务] 已设置浏览任务: 每天 CST {schedule_time_cst} (UTC {schedule_time_utc})") + + # 初始检查 + check_and_schedule() + last_check = time.time() + + while True: + try: + # 执行待执行的任务 + schedule.run_pending() + + # 每60秒重新检查一次配置 + if time.time() - last_check > 60: + check_and_schedule() + last_check = time.time() + + time.sleep(1) + except Exception as e: + print(f"[定时任务] 调度器出错: {str(e)}") + time.sleep(5) + + +if __name__ == '__main__': + print("=" * 60) + print("知识管理平台自动化工具 - 多用户版") + print("=" * 60) + + # 初始化数据库 + database.init_database() + + # 加载系统配置(并发设置) + try: + config = database.get_system_config() + if config: + # 使用globals()修改全局变量 + globals()['max_concurrent_global'] = config.get('max_concurrent_global', 2) + globals()['max_concurrent_per_account'] = config.get('max_concurrent_per_account', 1) + + # 重新创建信号量 + globals()['global_semaphore'] = threading.Semaphore(globals()['max_concurrent_global']) + + print(f"✓ 已加载并发配置: 全局={globals()['max_concurrent_global']}, 单账号={globals()['max_concurrent_per_account']}") + except Exception as e: + print(f"警告: 加载并发配置失败,使用默认值: {e}") + + # 主线程初始化浏览器(Playwright不支持跨线程) + print("\n正在初始化浏览器管理器...") + init_browser_manager() + + # 启动定时任务调度器 + print("\n启动定时任务调度器...") + scheduler_thread = threading.Thread(target=scheduled_task_worker, daemon=True) + scheduler_thread.start() + print("✓ 定时任务调度器已启动") + + # 启动Web服务器 + print("\n服务器启动中...") + print("用户访问地址: http://0.0.0.0:5000") + print("后台管理地址: http://0.0.0.0:5000/yuyx") + print("默认管理员: admin/admin") + print("=" * 60 + "\n") + + socketio.run(app, host='0.0.0.0', port=5000, debug=False) diff --git a/app_config.py b/app_config.py new file mode 100755 index 0000000..30d59e7 --- /dev/null +++ b/app_config.py @@ -0,0 +1,182 @@ +#!/usr/bin/env python3 +# -*- coding: utf-8 -*- +""" +配置管理模块 +集中管理所有配置项,支持环境变量 +""" + +import os +from datetime import timedelta + + +# 常量定义 +SECRET_KEY_FILE = 'data/secret_key.txt' + + +def get_secret_key(): + """获取SECRET_KEY(优先环境变量)""" + # 优先从环境变量读取 + secret_key = os.environ.get('SECRET_KEY') + if secret_key: + return secret_key + + # 从文件读取 + if os.path.exists(SECRET_KEY_FILE): + with open(SECRET_KEY_FILE, 'r') as f: + return f.read().strip() + + # 生成新的 + new_key = os.urandom(24).hex() + os.makedirs('data', exist_ok=True) + with open(SECRET_KEY_FILE, 'w') as f: + f.write(new_key) + print(f"✓ 已生成新的SECRET_KEY并保存到 {SECRET_KEY_FILE}") + return new_key + + +class Config: + """应用配置基类""" + + # ==================== Flask核心配置 ==================== + SECRET_KEY = get_secret_key() + + # ==================== 会话安全配置 ==================== + SESSION_COOKIE_SECURE = os.environ.get('SESSION_COOKIE_SECURE', 'False').lower() == 'true' + SESSION_COOKIE_HTTPONLY = True # 防止XSS攻击 + SESSION_COOKIE_SAMESITE = 'Lax' # 防止CSRF攻击 + PERMANENT_SESSION_LIFETIME = timedelta(hours=int(os.environ.get('SESSION_LIFETIME_HOURS', '24'))) + + # ==================== 数据库配置 ==================== + DB_FILE = os.environ.get('DB_FILE', 'data/app_data.db') + DB_POOL_SIZE = int(os.environ.get('DB_POOL_SIZE', '5')) + + # ==================== 浏览器配置 ==================== + SCREENSHOTS_DIR = os.environ.get('SCREENSHOTS_DIR', '截图') + + # ==================== 并发控制配置 ==================== + MAX_CONCURRENT_GLOBAL = int(os.environ.get('MAX_CONCURRENT_GLOBAL', '2')) + MAX_CONCURRENT_PER_ACCOUNT = int(os.environ.get('MAX_CONCURRENT_PER_ACCOUNT', '1')) + + # ==================== 日志缓存配置 ==================== + MAX_LOGS_PER_USER = int(os.environ.get('MAX_LOGS_PER_USER', '100')) + MAX_TOTAL_LOGS = int(os.environ.get('MAX_TOTAL_LOGS', '1000')) + + # ==================== 验证码配置 ==================== + MAX_CAPTCHA_ATTEMPTS = int(os.environ.get('MAX_CAPTCHA_ATTEMPTS', '5')) + CAPTCHA_EXPIRE_SECONDS = int(os.environ.get('CAPTCHA_EXPIRE_SECONDS', '300')) + + # ==================== IP限流配置 ==================== + MAX_IP_ATTEMPTS_PER_HOUR = int(os.environ.get('MAX_IP_ATTEMPTS_PER_HOUR', '10')) + IP_LOCK_DURATION = int(os.environ.get('IP_LOCK_DURATION', '3600')) # 秒 + + # ==================== 超时配置 ==================== + PAGE_LOAD_TIMEOUT = int(os.environ.get('PAGE_LOAD_TIMEOUT', '60000')) # 毫秒 + DEFAULT_TIMEOUT = int(os.environ.get('DEFAULT_TIMEOUT', '60000')) # 毫秒 + + # ==================== SocketIO配置 ==================== + SOCKETIO_CORS_ALLOWED_ORIGINS = os.environ.get('SOCKETIO_CORS_ALLOWED_ORIGINS', '*') + + # ==================== 日志配置 ==================== + LOG_LEVEL = os.environ.get('LOG_LEVEL', 'INFO') + LOG_FILE = os.environ.get('LOG_FILE', 'logs/app.log') + LOG_MAX_BYTES = int(os.environ.get('LOG_MAX_BYTES', '10485760')) # 10MB + LOG_BACKUP_COUNT = int(os.environ.get('LOG_BACKUP_COUNT', '5')) + + # ==================== 安全配置 ==================== + DEBUG = os.environ.get('FLASK_DEBUG', 'False').lower() == 'true' + ALLOWED_SCREENSHOT_EXTENSIONS = {'.png', '.jpg', '.jpeg'} + MAX_SCREENSHOT_SIZE = int(os.environ.get('MAX_SCREENSHOT_SIZE', '10485760')) # 10MB + + @classmethod + def validate(cls): + """验证配置的有效性""" + errors = [] + + # 验证SECRET_KEY + if not cls.SECRET_KEY or len(cls.SECRET_KEY) < 32: + errors.append("SECRET_KEY长度必须至少32个字符") + + # 验证并发配置 + if cls.MAX_CONCURRENT_GLOBAL < 1: + errors.append("MAX_CONCURRENT_GLOBAL必须大于0") + + if cls.MAX_CONCURRENT_PER_ACCOUNT < 1: + errors.append("MAX_CONCURRENT_PER_ACCOUNT必须大于0") + + # 验证数据库配置 + if not cls.DB_FILE: + errors.append("DB_FILE不能为空") + + if cls.DB_POOL_SIZE < 1: + errors.append("DB_POOL_SIZE必须大于0") + + # 验证日志配置 + if cls.LOG_LEVEL not in ['DEBUG', 'INFO', 'WARNING', 'ERROR', 'CRITICAL']: + errors.append(f"LOG_LEVEL无效: {cls.LOG_LEVEL}") + + return errors + + @classmethod + def print_config(cls): + """打印当前配置(隐藏敏感信息)""" + print("=" * 60) + print("应用配置") + print("=" * 60) + print(f"DEBUG模式: {cls.DEBUG}") + print(f"SECRET_KEY: {'*' * 20} (长度: {len(cls.SECRET_KEY)})") + print(f"会话超时: {cls.PERMANENT_SESSION_LIFETIME}") + print(f"Cookie安全: HTTPS={cls.SESSION_COOKIE_SECURE}, HttpOnly={cls.SESSION_COOKIE_HTTPONLY}") + print(f"数据库文件: {cls.DB_FILE}") + print(f"数据库连接池: {cls.DB_POOL_SIZE}") + print(f"并发配置: 全局={cls.MAX_CONCURRENT_GLOBAL}, 单账号={cls.MAX_CONCURRENT_PER_ACCOUNT}") + print(f"日志级别: {cls.LOG_LEVEL}") + print(f"日志文件: {cls.LOG_FILE}") + print(f"截图目录: {cls.SCREENSHOTS_DIR}") + print("=" * 60) + + +class DevelopmentConfig(Config): + """开发环境配置""" + DEBUG = True + SESSION_COOKIE_SECURE = False + + +class ProductionConfig(Config): + """生产环境配置""" + DEBUG = False + SESSION_COOKIE_SECURE = True # 生产环境必须使用HTTPS + + +class TestingConfig(Config): + """测试环境配置""" + DEBUG = True + TESTING = True + DB_FILE = 'data/test_app_data.db' + + +# 根据环境变量选择配置 +config_map = { + 'development': DevelopmentConfig, + 'production': ProductionConfig, + 'testing': TestingConfig, +} + + +def get_config(): + """获取当前环境的配置""" + env = os.environ.get('FLASK_ENV', 'production') + return config_map.get(env, ProductionConfig) + + +if __name__ == '__main__': + # 配置验证测试 + config = get_config() + errors = config.validate() + + if errors: + print("配置验证失败:") + for error in errors: + print(f" ✗ {error}") + else: + print("✓ 配置验证通过") + config.print_config() diff --git a/app_logger.py b/app_logger.py new file mode 100755 index 0000000..ad0196a --- /dev/null +++ b/app_logger.py @@ -0,0 +1,305 @@ +#!/usr/bin/env python3 +# -*- coding: utf-8 -*- +""" +日志管理模块 +提供标准化的日志记录功能 +""" + +import logging +import os +from logging.handlers import RotatingFileHandler +from datetime import datetime +import threading + +# 全局日志配置 +_loggers = {} +_logger_lock = threading.Lock() + + +class ColoredFormatter(logging.Formatter): + """带颜色的日志格式化器(用于控制台)""" + + # ANSI颜色代码 + COLORS = { + 'DEBUG': '\033[36m', # 青色 + 'INFO': '\033[32m', # 绿色 + 'WARNING': '\033[33m', # 黄色 + 'ERROR': '\033[31m', # 红色 + 'CRITICAL': '\033[35m', # 紫色 + } + RESET = '\033[0m' + + def format(self, record): + """格式化日志记录""" + # 添加颜色 + levelname = record.levelname + if levelname in self.COLORS: + record.levelname = f"{self.COLORS[levelname]}{levelname}{self.RESET}" + + # 格式化 + result = super().format(record) + + # 恢复原始levelname(避免影响其他handler) + record.levelname = levelname + + return result + + +def setup_logger(name='app', level=None, log_file=None, max_bytes=10*1024*1024, backup_count=5): + """ + 设置日志记录器 + + Args: + name: 日志器名称 + level: 日志级别(DEBUG, INFO, WARNING, ERROR, CRITICAL) + log_file: 日志文件路径 + max_bytes: 日志文件最大大小(字节) + backup_count: 保留的备份文件数量 + + Returns: + logging.Logger: 配置好的日志器 + """ + with _logger_lock: + # 如果已经存在,直接返回 + if name in _loggers: + return _loggers[name] + + # 创建日志器 + logger = logging.getLogger(name) + + # 设置日志级别 + if level is None: + level = os.environ.get('LOG_LEVEL', 'INFO') + logger.setLevel(getattr(logging, level.upper())) + + # 清除已有的处理器(避免重复) + logger.handlers.clear() + + # 日志格式 + detailed_formatter = logging.Formatter( + '[%(asctime)s] [%(name)s] [%(levelname)s] [%(filename)s:%(lineno)d] - %(message)s', + datefmt='%Y-%m-%d %H:%M:%S' + ) + + simple_formatter = logging.Formatter( + '[%(asctime)s] [%(levelname)s] %(message)s', + datefmt='%Y-%m-%d %H:%M:%S' + ) + + colored_formatter = ColoredFormatter( + '[%(asctime)s] [%(levelname)s] %(message)s', + datefmt='%Y-%m-%d %H:%M:%S' + ) + + # 控制台处理器(带颜色) + console_handler = logging.StreamHandler() + console_handler.setLevel(logging.INFO) + console_handler.setFormatter(colored_formatter) + logger.addHandler(console_handler) + + # 文件处理器(如果指定了文件路径) + if log_file: + # 确保日志目录存在 + log_dir = os.path.dirname(log_file) + if log_dir and not os.path.exists(log_dir): + os.makedirs(log_dir, exist_ok=True) + + # 主日志文件(详细格式) + file_handler = RotatingFileHandler( + log_file, + maxBytes=max_bytes, + backupCount=backup_count, + encoding='utf-8' + ) + file_handler.setLevel(logging.DEBUG) + file_handler.setFormatter(detailed_formatter) + logger.addHandler(file_handler) + + # 错误日志文件(仅记录WARNING及以上) + error_file = log_file.replace('.log', '_error.log') + error_handler = RotatingFileHandler( + error_file, + maxBytes=max_bytes, + backupCount=backup_count, + encoding='utf-8' + ) + error_handler.setLevel(logging.WARNING) + error_handler.setFormatter(detailed_formatter) + logger.addHandler(error_handler) + + # 防止日志向上传播(避免重复) + logger.propagate = False + + # 缓存日志器 + _loggers[name] = logger + + return logger + + +def get_logger(name='app'): + """ + 获取日志记录器 + + Args: + name: 日志器名称 + + Returns: + logging.Logger: 日志器实例 + """ + if name in _loggers: + return _loggers[name] + else: + # 如果不存在,创建一个默认的 + return setup_logger(name) + + +class LoggerAdapter: + """日志适配器,提供便捷的日志记录方法""" + + def __init__(self, logger_name='app', context=None): + """ + 初始化日志适配器 + + Args: + logger_name: 日志器名称 + context: 上下文信息(如用户ID、账号ID等) + """ + self.logger = get_logger(logger_name) + self.context = context or {} + + def _format_message(self, message): + """格式化消息,添加上下文信息""" + if self.context: + context_str = ' '.join([f"[{k}={v}]" for k, v in self.context.items()]) + return f"{context_str} {message}" + return message + + def debug(self, message): + """记录调试信息""" + self.logger.debug(self._format_message(message)) + + def info(self, message): + """记录普通信息""" + self.logger.info(self._format_message(message)) + + def warning(self, message): + """记录警告信息""" + self.logger.warning(self._format_message(message)) + + def error(self, message, exc_info=False): + """记录错误信息""" + self.logger.error(self._format_message(message), exc_info=exc_info) + + def critical(self, message, exc_info=False): + """记录严重错误信息""" + self.logger.critical(self._format_message(message), exc_info=exc_info) + + def exception(self, message): + """记录异常信息(自动包含堆栈跟踪)""" + self.logger.exception(self._format_message(message)) + + +class AuditLogger: + """审计日志记录器(用于记录关键操作)""" + + def __init__(self, log_file='logs/audit.log'): + """初始化审计日志""" + self.logger = setup_logger('audit', level='INFO', log_file=log_file) + + def log_user_login(self, user_id, username, ip_address, success=True): + """记录用户登录""" + status = "成功" if success else "失败" + self.logger.info(f"用户登录{status}: user_id={user_id}, username={username}, ip={ip_address}") + + def log_admin_login(self, username, ip_address, success=True): + """记录管理员登录""" + status = "成功" if success else "失败" + self.logger.info(f"管理员登录{status}: username={username}, ip={ip_address}") + + def log_user_created(self, user_id, username, created_by=None): + """记录用户创建""" + self.logger.info(f"用户创建: user_id={user_id}, username={username}, created_by={created_by}") + + def log_user_deleted(self, user_id, username, deleted_by): + """记录用户删除""" + self.logger.warning(f"用户删除: user_id={user_id}, username={username}, deleted_by={deleted_by}") + + def log_password_reset(self, user_id, username, reset_by): + """记录密码重置""" + self.logger.warning(f"密码重置: user_id={user_id}, username={username}, reset_by={reset_by}") + + def log_config_change(self, config_name, old_value, new_value, changed_by): + """记录配置修改""" + self.logger.warning(f"配置修改: {config_name} 从 {old_value} 改为 {new_value}, changed_by={changed_by}") + + def log_security_event(self, event_type, description, ip_address=None): + """记录安全事件""" + self.logger.warning(f"安全事件 [{event_type}]: {description}, ip={ip_address}") + + +# 全局审计日志实例 +audit_logger = AuditLogger() + + +# 辅助函数 +def log_exception(logger, message="发生异常"): + """记录异常(包含堆栈跟踪)""" + if isinstance(logger, str): + logger = get_logger(logger) + logger.exception(message) + + +def log_performance(logger, operation, duration_ms, threshold_ms=1000): + """记录性能信息""" + if isinstance(logger, str): + logger = get_logger(logger) + + if duration_ms > threshold_ms: + logger.warning(f"性能警告: {operation} 耗时 {duration_ms}ms (阈值: {threshold_ms}ms)") + else: + logger.debug(f"性能: {operation} 耗时 {duration_ms}ms") + + +# 初始化默认日志器 +def init_logging(log_level='INFO', log_file='logs/app.log'): + """ + 初始化日志系统 + + Args: + log_level: 日志级别 + log_file: 日志文件路径 + """ + # 创建主应用日志器 + setup_logger('app', level=log_level, log_file=log_file) + + # 创建数据库日志器 + setup_logger('database', level=log_level, log_file='logs/database.log') + + # 创建自动化日志器 + setup_logger('automation', level=log_level, log_file='logs/automation.log') + + # 创建审计日志器(已在AuditLogger中创建) + + print("✓ 日志系统初始化完成") + + +if __name__ == '__main__': + # 测试日志系统 + init_logging(log_level='DEBUG') + + logger = get_logger('app') + logger.debug("这是调试信息") + logger.info("这是普通信息") + logger.warning("这是警告信息") + logger.error("这是错误信息") + logger.critical("这是严重错误信息") + + # 测试上下文日志 + adapter = LoggerAdapter('app', {'user_id': 123, 'username': 'test'}) + adapter.info("用户操作日志") + + # 测试审计日志 + audit_logger.log_user_login(123, 'test_user', '127.0.0.1', success=True) + audit_logger.log_security_event('LOGIN_ATTEMPT', '多次登录失败', '192.168.1.1') + + print("\n日志测试完成,请检查 logs/ 目录") diff --git a/app_security.py b/app_security.py new file mode 100755 index 0000000..9c92b99 --- /dev/null +++ b/app_security.py @@ -0,0 +1,435 @@ +#!/usr/bin/env python3 +# -*- coding: utf-8 -*- +""" +安全工具模块 +提供各种安全相关的功能 +""" + +import os +import re +import time +import hashlib +import secrets +from pathlib import Path +from typing import Optional +from functools import wraps +from flask import request, jsonify, session +from collections import defaultdict +import threading + + +# ==================== 文件路径安全 ==================== + +def is_safe_path(basedir, path, follow_symlinks=True): + """ + 检查路径是否安全(防止路径遍历攻击) + + Args: + basedir: 基础目录 + path: 要检查的路径 + follow_symlinks: 是否跟随符号链接 + + Returns: + bool: 路径是否安全 + """ + # 检查路径中是否包含危险字符 + if '..' in path or path.startswith('/') or path.startswith('\\'): + return False + + # 解析路径 + if follow_symlinks: + matchpath = os.path.realpath(os.path.join(basedir, path)) + else: + matchpath = os.path.abspath(os.path.join(basedir, path)) + + # 检查是否在基础目录内 + return matchpath.startswith(os.path.abspath(basedir)) + + +def sanitize_filename(filename): + """ + 清理文件名,移除危险字符 + + Args: + filename: 原始文件名 + + Returns: + str: 清理后的文件名 + """ + # 移除路径分隔符 + filename = filename.replace('/', '_').replace('\\', '_') + + # 只保留安全字符 + filename = re.sub(r'[^a-zA-Z0-9._-]', '_', filename) + + # 限制长度 + if len(filename) > 255: + name, ext = os.path.splitext(filename) + filename = name[:255-len(ext)] + ext + + return filename + + +# ==================== IP限流和黑名单 ==================== + +class IPRateLimiter: + """IP访问频率限制器""" + + def __init__(self, max_attempts=10, window_seconds=3600, lock_duration=3600): + """ + 初始化限流器 + + Args: + max_attempts: 时间窗口内的最大尝试次数 + window_seconds: 时间窗口大小(秒) + lock_duration: 锁定时长(秒) + """ + self.max_attempts = max_attempts + self.window_seconds = window_seconds + self.lock_duration = lock_duration + + # IP访问记录: {ip: [(timestamp, success), ...]} + self._attempts = defaultdict(list) + # IP锁定记录: {ip: lock_until_timestamp} + self._locked = {} + self._lock = threading.Lock() + + def is_locked(self, ip_address): + """ + 检查IP是否被锁定 + + Args: + ip_address: IP地址 + + Returns: + bool: 是否被锁定 + """ + with self._lock: + if ip_address in self._locked: + if time.time() < self._locked[ip_address]: + return True + else: + # 锁定已过期,移除 + del self._locked[ip_address] + return False + + def record_attempt(self, ip_address, success=True): + """ + 记录访问尝试 + + Args: + ip_address: IP地址 + success: 是否成功 + + Returns: + bool: 是否应该锁定该IP + """ + with self._lock: + now = time.time() + + # 清理过期记录 + cutoff_time = now - self.window_seconds + self._attempts[ip_address] = [ + (ts, succ) for ts, succ in self._attempts[ip_address] + if ts > cutoff_time + ] + + # 记录本次尝试 + self._attempts[ip_address].append((now, success)) + + # 检查失败次数 + failed_attempts = sum(1 for ts, succ in self._attempts[ip_address] if not succ) + + if failed_attempts >= self.max_attempts: + # 锁定IP + self._locked[ip_address] = now + self.lock_duration + return True + + return False + + def get_remaining_attempts(self, ip_address): + """ + 获取剩余尝试次数 + + Args: + ip_address: IP地址 + + Returns: + int: 剩余尝试次数 + """ + with self._lock: + now = time.time() + cutoff_time = now - self.window_seconds + + # 清理过期记录 + self._attempts[ip_address] = [ + (ts, succ) for ts, succ in self._attempts[ip_address] + if ts > cutoff_time + ] + + failed_attempts = sum(1 for ts, succ in self._attempts[ip_address] if not succ) + return max(0, self.max_attempts - failed_attempts) + + def cleanup(self): + """清理过期数据""" + with self._lock: + now = time.time() + + # 清理过期的尝试记录 + cutoff_time = now - self.window_seconds + for ip in list(self._attempts.keys()): + self._attempts[ip] = [ + (ts, succ) for ts, succ in self._attempts[ip] + if ts > cutoff_time + ] + if not self._attempts[ip]: + del self._attempts[ip] + + # 清理过期的锁定 + for ip in list(self._locked.keys()): + if now >= self._locked[ip]: + del self._locked[ip] + + +# 全局IP限流器实例 +ip_rate_limiter = IPRateLimiter() + + +def require_ip_not_locked(f): + """装饰器:检查IP是否被锁定""" + @wraps(f) + def decorated_function(*args, **kwargs): + ip_address = request.remote_addr + + if ip_rate_limiter.is_locked(ip_address): + return jsonify({ + "error": "由于多次失败尝试,您的IP已被临时锁定", + "locked_until": ip_rate_limiter._locked.get(ip_address, 0) + }), 429 + + return f(*args, **kwargs) + + return decorated_function + + +# ==================== 输入验证 ==================== + +def validate_username(username): + """ + 验证用户名格式 + + Args: + username: 用户名 + + Returns: + tuple: (is_valid, error_message) + """ + if not username: + return False, "用户名不能为空" + + if len(username) < 3: + return False, "用户名长度不能少于3个字符" + + if len(username) > 50: + return False, "用户名长度不能超过50个字符" + + # 只允许字母、数字、下划线、中文 + if not re.match(r'^[\w\u4e00-\u9fa5]+$', username): + return False, "用户名只能包含字母、数字、下划线和中文字符" + + return True, None + + +def validate_password(password): + """ + 验证密码强度 + + Args: + password: 密码 + + Returns: + tuple: (is_valid, error_message) + """ + if not password: + return False, "密码不能为空" + + if len(password) < 6: + return False, "密码长度不能少于6个字符" + + if len(password) > 128: + return False, "密码长度不能超过128个字符" + + # 可选:强制密码复杂度 + # has_upper = bool(re.search(r'[A-Z]', password)) + # has_lower = bool(re.search(r'[a-z]', password)) + # has_digit = bool(re.search(r'\d', password)) + # + # if not (has_upper and has_lower and has_digit): + # return False, "密码必须包含大写字母、小写字母和数字" + + return True, None + + +def validate_email(email): + """ + 验证邮箱格式 + + Args: + email: 邮箱地址 + + Returns: + tuple: (is_valid, error_message) + """ + if not email: + return True, None # 邮箱可选 + + # 简单的邮箱正则 + pattern = r'^[a-zA-Z0-9._%+-]+@[a-zA-Z0-9.-]+\.[a-zA-Z]{2,}$' + if not re.match(pattern, email): + return False, "邮箱格式不正确" + + if len(email) > 255: + return False, "邮箱长度不能超过255个字符" + + return True, None + + +# ==================== 会话安全 ==================== + +def generate_session_token(): + """生成安全的会话令牌""" + return secrets.token_urlsafe(32) + + +def hash_token(token): + """哈希令牌(用于存储)""" + return hashlib.sha256(token.encode()).hexdigest() + + +# ==================== CSRF保护 ==================== + +def generate_csrf_token(): + """生成CSRF令牌""" + if 'csrf_token' not in session: + session['csrf_token'] = secrets.token_urlsafe(32) + return session['csrf_token'] + + +def validate_csrf_token(token): + """验证CSRF令牌""" + return token == session.get('csrf_token') + + +# ==================== 内容安全 ==================== + +def escape_html(text): + """转义HTML特殊字符(防止XSS)""" + if not text: + return text + + replacements = { + '&': '&', + '<': '<', + '>': '>', + '"': '"', + "'": ''', + '/': '/', + } + + for char, escaped in replacements.items(): + text = text.replace(char, escaped) + + return text + + +def sanitize_sql_like_pattern(pattern): + """ + 清理SQL LIKE模式中的特殊字符 + + Args: + pattern: LIKE模式字符串 + + Returns: + str: 清理后的模式 + """ + # 转义LIKE中的特殊字符 + pattern = pattern.replace('\\', '\\\\') + pattern = pattern.replace('%', '\\%') + pattern = pattern.replace('_', '\\_') + return pattern + + +# ==================== 安全配置检查 ==================== + +def check_security_config(): + """ + 检查安全配置 + + Returns: + list: 安全问题列表 + """ + issues = [] + + # 检查SECRET_KEY + from flask import current_app + secret_key = current_app.config.get('SECRET_KEY') + if not secret_key or len(secret_key) < 32: + issues.append("SECRET_KEY过短或未设置") + + # 检查DEBUG模式 + if current_app.config.get('DEBUG'): + issues.append("DEBUG模式在生产环境应该关闭") + + # 检查Cookie安全设置 + if not current_app.config.get('SESSION_COOKIE_HTTPONLY'): + issues.append("SESSION_COOKIE_HTTPONLY应该设置为True") + + if not current_app.config.get('SESSION_COOKIE_SECURE'): + issues.append("生产环境应该启用SESSION_COOKIE_SECURE(需要HTTPS)") + + return issues + + +# ==================== 辅助函数 ==================== + +def get_client_ip(): + """ + 获取客户端真实IP地址 + + Returns: + str: IP地址 + """ + # 检查代理头 + if request.headers.get('X-Forwarded-For'): + return request.headers.get('X-Forwarded-For').split(',')[0].strip() + elif request.headers.get('X-Real-IP'): + return request.headers.get('X-Real-IP') + else: + return request.remote_addr + + +if __name__ == '__main__': + # 测试文件路径安全 + print("文件路径安全测试:") + print(f" 安全路径: {is_safe_path('/tmp', 'test.txt')}") + print(f" 危险路径: {is_safe_path('/tmp', '../etc/passwd')}") + + # 测试文件名清理 + print(f"\n文件名清理: {sanitize_filename('../../../etc/passwd')}") + + # 测试输入验证 + print("\n输入验证测试:") + print(f" 用户名: {validate_username('test_user')}") + print(f" 密码: {validate_password('Test123456')}") + print(f" 邮箱: {validate_email('test@example.com')}") + + # 测试IP限流 + print("\nIP限流测试:") + limiter = IPRateLimiter(max_attempts=3, window_seconds=60) + ip = '192.168.1.1' + + for i in range(5): + locked = limiter.record_attempt(ip, success=False) + print(f" 尝试 {i+1}: 剩余次数={limiter.get_remaining_attempts(ip)}, 是否锁定={locked}") + + print(f" IP被锁定: {limiter.is_locked(ip)}") diff --git a/app_state.py b/app_state.py new file mode 100755 index 0000000..2707663 --- /dev/null +++ b/app_state.py @@ -0,0 +1,328 @@ +#!/usr/bin/env python3 +# -*- coding: utf-8 -*- +""" +应用状态管理模块 +提供线程安全的全局状态管理 +""" + +import threading +from typing import Tuple +from typing import Dict, Any, Optional +from datetime import datetime, timedelta +from app_logger import get_logger + +logger = get_logger('app_state') + + +class ThreadSafeDict: + """线程安全的字典包装类""" + + def __init__(self): + self._dict = {} + self._lock = threading.RLock() + + def get(self, key, default=None): + """获取值""" + with self._lock: + return self._dict.get(key, default) + + def set(self, key, value): + """设置值""" + with self._lock: + self._dict[key] = value + + def delete(self, key): + """删除键""" + with self._lock: + if key in self._dict: + del self._dict[key] + + def pop(self, key, default=None): + """弹出键值""" + with self._lock: + return self._dict.pop(key, default) + + def keys(self): + """获取所有键(返回副本)""" + with self._lock: + return list(self._dict.keys()) + + def items(self): + """获取所有键值对(返回副本)""" + with self._lock: + return list(self._dict.items()) + + def __contains__(self, key): + """检查键是否存在""" + with self._lock: + return key in self._dict + + def clear(self): + """清空字典""" + with self._lock: + self._dict.clear() + + def __len__(self): + """获取长度""" + with self._lock: + return len(self._dict) + + +class LogCacheManager: + """日志缓存管理器(线程安全)""" + + def __init__(self, max_logs_per_user=100, max_total_logs=1000): + self._cache = {} # {user_id: [logs]} + self._total_count = 0 + self._lock = threading.RLock() + self._max_logs_per_user = max_logs_per_user + self._max_total_logs = max_total_logs + + def add_log(self, user_id: int, log_entry: Dict[str, Any]) -> bool: + """添加日志到缓存""" + with self._lock: + # 检查总数限制 + if self._total_count >= self._max_total_logs: + logger.warning(f"日志缓存已满 ({self._max_total_logs}),拒绝添加") + return False + + # 初始化用户日志列表 + if user_id not in self._cache: + self._cache[user_id] = [] + + user_logs = self._cache[user_id] + + # 检查用户日志数限制 + if len(user_logs) >= self._max_logs_per_user: + # 移除最旧的日志 + user_logs.pop(0) + self._total_count -= 1 + + # 添加新日志 + user_logs.append(log_entry) + self._total_count += 1 + + return True + + def get_logs(self, user_id: int) -> list: + """获取用户的所有日志(返回副本)""" + with self._lock: + return list(self._cache.get(user_id, [])) + + def clear_user_logs(self, user_id: int): + """清空用户的日志""" + with self._lock: + if user_id in self._cache: + count = len(self._cache[user_id]) + del self._cache[user_id] + self._total_count -= count + logger.info(f"清空用户 {user_id} 的 {count} 条日志") + + def get_total_count(self) -> int: + """获取总日志数""" + with self._lock: + return self._total_count + + def get_stats(self) -> Dict[str, int]: + """获取统计信息""" + with self._lock: + return { + 'total_count': self._total_count, + 'user_count': len(self._cache), + 'max_per_user': self._max_logs_per_user, + 'max_total': self._max_total_logs + } + + +class CaptchaManager: + """验证码管理器(线程安全)""" + + def __init__(self, expire_seconds=300): + self._storage = {} # {identifier: {'code': str, 'expire': datetime}} + self._lock = threading.RLock() + self._expire_seconds = expire_seconds + + def create(self, identifier: str, code: str) -> None: + """创建验证码""" + with self._lock: + self._storage[identifier] = { + 'code': code, + 'expire': datetime.now() + timedelta(seconds=self._expire_seconds) + } + + def verify(self, identifier: str, code: str) -> Tuple[bool, str]: + """验证验证码""" + with self._lock: + if identifier not in self._storage: + return False, "验证码不存在或已过期" + + captcha_data = self._storage[identifier] + + # 检查是否过期 + if datetime.now() > captcha_data['expire']: + del self._storage[identifier] + return False, "验证码已过期,请重新获取" + + # 验证码码值 + if captcha_data['code'] != code: + return False, "验证码错误" + + # 验证成功,删除验证码 + del self._storage[identifier] + return True, "验证成功" + + def cleanup_expired(self) -> int: + """清理过期的验证码""" + with self._lock: + now = datetime.now() + expired_keys = [ + key for key, data in self._storage.items() + if now > data['expire'] + ] + for key in expired_keys: + del self._storage[key] + + if expired_keys: + logger.info(f"清理了 {len(expired_keys)} 个过期验证码") + + return len(expired_keys) + + def get_count(self) -> int: + """获取当前验证码数量""" + with self._lock: + return len(self._storage) + + +class ApplicationState: + """应用全局状态管理器(单例模式)""" + + _instance = None + _lock = threading.Lock() + + def __new__(cls): + if cls._instance is None: + with cls._lock: + if cls._instance is None: + cls._instance = super().__new__(cls) + cls._instance._initialized = False + return cls._instance + + def __init__(self): + if self._initialized: + return + + # 浏览器管理器 + self.browser_manager = None + self._browser_lock = threading.Lock() + + # 用户账号管理 {user_id: {account_id: Account对象}} + self.user_accounts = ThreadSafeDict() + + # 活动任务管理 {account_id: Thread对象} + self.active_tasks = ThreadSafeDict() + + # 日志缓存管理 + self.log_cache = LogCacheManager() + + # 验证码管理 + self.captcha = CaptchaManager() + + # 用户信号量管理 {account_id: Semaphore} + self.user_semaphores = ThreadSafeDict() + + # 全局信号量 + self.global_semaphore = None + self.screenshot_semaphore = threading.Semaphore(1) + + self._initialized = True + logger.info("应用状态管理器初始化完成") + + def set_browser_manager(self, manager): + """设置浏览器管理器""" + with self._browser_lock: + self.browser_manager = manager + + def get_browser_manager(self): + """获取浏览器管理器""" + with self._browser_lock: + return self.browser_manager + + def get_user_semaphore(self, account_id: int, max_concurrent: int = 1): + """获取或创建用户信号量""" + if account_id not in self.user_semaphores: + self.user_semaphores.set(account_id, threading.Semaphore(max_concurrent)) + return self.user_semaphores.get(account_id) + + def set_global_semaphore(self, max_concurrent: int): + """设置全局信号量""" + self.global_semaphore = threading.Semaphore(max_concurrent) + + def get_stats(self) -> Dict[str, Any]: + """获取状态统计信息""" + return { + 'user_accounts_count': len(self.user_accounts), + 'active_tasks_count': len(self.active_tasks), + 'log_cache_stats': self.log_cache.get_stats(), + 'captcha_count': self.captcha.get_count(), + 'user_semaphores_count': len(self.user_semaphores), + 'browser_manager': 'initialized' if self.browser_manager else 'not_initialized' + } + + +# 全局单例实例 +app_state = ApplicationState() + + +# 向后兼容的辅助函数 +def verify_captcha(identifier: str, code: str) -> Tuple[bool, str]: + """验证验证码(向后兼容接口)""" + return app_state.captcha.verify(identifier, code) + + +def create_captcha(identifier: str, code: str) -> None: + """创建验证码(向后兼容接口)""" + app_state.captcha.create(identifier, code) + + +def cleanup_expired_captchas() -> int: + """清理过期验证码(向后兼容接口)""" + return app_state.captcha.cleanup_expired() + + +if __name__ == '__main__': + # 测试代码 + print("测试线程安全状态管理器...") + print("=" * 60) + + # 测试 ThreadSafeDict + print("\n1. 测试 ThreadSafeDict:") + td = ThreadSafeDict() + td.set('key1', 'value1') + print(f" 设置 key1 = {td.get('key1')}") + print(f" 长度: {len(td)}") + + # 测试 LogCacheManager + print("\n2. 测试 LogCacheManager:") + lcm = LogCacheManager(max_logs_per_user=3, max_total_logs=10) + for i in range(5): + lcm.add_log(1, {'message': f'log {i}'}) + print(f" 用户1日志数: {len(lcm.get_logs(1))}") + print(f" 总日志数: {lcm.get_total_count()}") + print(f" 统计: {lcm.get_stats()}") + + # 测试 CaptchaManager + print("\n3. 测试 CaptchaManager:") + cm = CaptchaManager(expire_seconds=2) + cm.create('test@example.com', '1234') + success, msg = cm.verify('test@example.com', '1234') + print(f" 验证结果: {success}, {msg}") + + # 测试 ApplicationState + print("\n4. 测试 ApplicationState (单例):") + state1 = ApplicationState() + state2 = ApplicationState() + print(f" 单例验证: {state1 is state2}") + print(f" 状态统计: {state1.get_stats()}") + + print("\n" + "=" * 60) + print("✓ 所有测试通过!") diff --git a/app_utils.py b/app_utils.py new file mode 100755 index 0000000..bb76c53 --- /dev/null +++ b/app_utils.py @@ -0,0 +1,302 @@ +#!/usr/bin/env python3 +# -*- coding: utf-8 -*- +""" +应用工具模块 +提取重复的业务逻辑 +""" + +from typing import Dict, Any, Optional, Tuple +from flask import session, jsonify +from app_logger import get_logger, audit_logger +from app_security import get_client_ip +import database + +logger = get_logger('app_utils') + + +class ValidationError(Exception): + """验证错误异常""" + pass + + +def verify_user_file_permission(user_id: int, filename: str) -> Tuple[bool, Optional[str]]: + """ + 验证用户文件访问权限 + + Args: + user_id: 用户ID + filename: 文件名 + + Returns: + (是否有权限, 错误消息) + """ + # 获取用户信息 + user = database.get_user_by_id(user_id) + if not user: + return False, "用户不存在" + + username = user['username'] + + # 检查文件名是否以用户名开头 + if not filename.startswith(f"{username}_"): + logger.warning(f"用户 {username} (ID:{user_id}) 尝试访问未授权文件: {filename}") + return False, "无权访问此文件" + + return True, None + + +def log_task_event(account_id: int, status: str, message: str, + browse_type: Optional[str] = None, + screenshot_path: Optional[str] = None) -> bool: + """ + 记录任务日志(统一接口) + + Args: + account_id: 账号ID + status: 状态(running/completed/failed/stopped) + message: 消息 + browse_type: 浏览类型 + screenshot_path: 截图路径 + + Returns: + 是否成功 + """ + try: + return database.create_task_log( + account_id=account_id, + status=status, + message=message, + browse_type=browse_type, + screenshot_path=screenshot_path + ) + except Exception as e: + logger.error(f"记录任务日志失败: {e}", exc_info=True) + return False + + +def update_account_status(account_id: int, status: str, + error_message: Optional[str] = None) -> bool: + """ + 更新账号状态(统一接口) + + Args: + account_id: 账号ID + status: 状态(idle/running/error/stopped) + error_message: 错误消息(仅当status=error时) + + Returns: + 是否成功 + """ + try: + return database.update_account_status( + account_id=account_id, + status=status, + error_message=error_message + ) + except Exception as e: + logger.error(f"更新账号状态失败 (account_id={account_id}): {e}", exc_info=True) + return False + + +def get_or_create_config_cache() -> Optional[Dict[str, Any]]: + """ + 获取或创建系统配置缓存 + + 缓存存储在session中,避免重复查询数据库 + + Returns: + 配置字典,失败返回None + """ + # 尝试从session获取缓存 + if '_system_config' in session: + return session['_system_config'] + + # 从数据库加载 + try: + config = database.get_system_config() + if config: + # 存入session缓存 + session['_system_config'] = config + return config + return None + except Exception as e: + logger.error(f"获取系统配置失败: {e}", exc_info=True) + return None + + +def clear_config_cache(): + """清除配置缓存(配置变更时调用)""" + if '_system_config' in session: + del session['_system_config'] + logger.debug("已清除系统配置缓存") + + +def safe_close_browser(automation_obj, account_id: int): + """ + 安全关闭浏览器(统一错误处理) + + Args: + automation_obj: PlaywrightAutomation对象 + account_id: 账号ID + """ + if automation_obj: + try: + automation_obj.close() + logger.info(f"账号 {account_id} 的浏览器已关闭") + except Exception as e: + logger.error(f"关闭账号 {account_id} 的浏览器失败: {e}", exc_info=True) + + +def format_error_response(error: str, status_code: int = 400, + need_captcha: bool = False, + extra_data: Optional[Dict] = None) -> Tuple[Any, int]: + """ + 格式化错误响应(统一接口) + + Args: + error: 错误消息 + status_code: HTTP状态码 + need_captcha: 是否需要验证码 + extra_data: 额外数据 + + Returns: + (jsonify响应, 状态码) + """ + response_data = {"error": error} + + if need_captcha: + response_data["need_captcha"] = True + + if extra_data: + response_data.update(extra_data) + + return jsonify(response_data), status_code + + +def format_success_response(message: str = "操作成功", + extra_data: Optional[Dict] = None) -> Any: + """ + 格式化成功响应(统一接口) + + Args: + message: 成功消息 + extra_data: 额外数据 + + Returns: + jsonify响应 + """ + response_data = {"success": True, "message": message} + + if extra_data: + response_data.update(extra_data) + + return jsonify(response_data) + + +def log_user_action(action: str, user_id: int, username: str, + success: bool, details: Optional[str] = None): + """ + 记录用户操作到审计日志(统一接口) + + Args: + action: 操作类型(login/register/logout等) + user_id: 用户ID + username: 用户名 + success: 是否成功 + details: 详细信息 + """ + ip = get_client_ip() + + if action == 'login': + audit_logger.log_user_login(user_id, username, ip, success) + elif action == 'logout': + audit_logger.log_user_logout(user_id, username, ip) + elif action == 'register': + audit_logger.log_user_created(user_id, username, created_by='self') + + if details: + logger.info(f"用户操作: {action}, 用户={username}, 成功={success}, 详情={details}") + + +def validate_pagination(page: Any, page_size: Any, + max_page_size: int = 100) -> Tuple[int, int, Optional[str]]: + """ + 验证分页参数 + + Args: + page: 页码 + page_size: 每页大小 + max_page_size: 最大每页大小 + + Returns: + (页码, 每页大小, 错误消息) + """ + try: + page = int(page) if page else 1 + page_size = int(page_size) if page_size else 20 + except (ValueError, TypeError): + return 1, 20, "无效的分页参数" + + if page < 1: + return 1, 20, "页码必须大于0" + + if page_size < 1 or page_size > max_page_size: + return page, 20, f"每页大小必须在1-{max_page_size}之间" + + return page, page_size, None + + +def check_user_ownership(user_id: int, resource_type: str, + resource_id: int) -> Tuple[bool, Optional[str]]: + """ + 检查用户是否拥有资源 + + Args: + user_id: 用户ID + resource_type: 资源类型(account/task等) + resource_id: 资源ID + + Returns: + (是否拥有, 错误消息) + """ + try: + if resource_type == 'account': + account = database.get_account_by_id(resource_id) + if not account: + return False, "账号不存在" + if account['user_id'] != user_id: + return False, "无权访问此账号" + return True, None + + elif resource_type == 'task': + # 通过account查询所属用户 + # 这里需要根据实际数据库结构实现 + pass + + return False, "不支持的资源类型" + + except Exception as e: + logger.error(f"检查资源所有权失败: {e}", exc_info=True) + return False, "系统错误" + + +if __name__ == '__main__': + # 测试代码 + print("测试应用工具模块...") + print("=" * 60) + + # 测试分页验证 + print("\n1. 测试分页验证:") + page, page_size, error = validate_pagination("2", "50") + print(f" 页码={page}, 每页={page_size}, 错误={error}") + + page, page_size, error = validate_pagination("invalid", "50") + print(f" 无效输入: 页码={page}, 每页={page_size}, 错误={error}") + + # 测试响应格式化 + print("\n2. 测试响应格式化:") + print(f" 错误响应: {format_error_response('测试错误', need_captcha=True)}") + print(f" 成功响应: {format_success_response('测试成功', {'data': [1, 2, 3]})}") + + print("\n" + "=" * 60) + print("✓ 工具模块加载成功!") diff --git a/browser_installer.py b/browser_installer.py new file mode 100755 index 0000000..b940397 --- /dev/null +++ b/browser_installer.py @@ -0,0 +1,214 @@ +#!/usr/bin/env python3 +# -*- coding: utf-8 -*- +""" +浏览器自动下载安装模块 +检测本地是否有Playwright浏览器,如果没有则自动下载安装 +""" + +import os +import sys +import shutil +import subprocess +from pathlib import Path + +# 设置浏览器安装路径(支持Docker和本地环境) +# Docker环境: PLAYWRIGHT_BROWSERS_PATH环境变量已设置为 /ms-playwright +# 本地环境: 使用Playwright默认路径 +if 'PLAYWRIGHT_BROWSERS_PATH' in os.environ: + BROWSERS_PATH = os.environ['PLAYWRIGHT_BROWSERS_PATH'] +else: + # Windows: %USERPROFILE%\AppData\Local\ms-playwright + # Linux: ~/.cache/ms-playwright + if sys.platform == 'win32': + BROWSERS_PATH = str(Path.home() / "AppData" / "Local" / "ms-playwright") + else: + BROWSERS_PATH = str(Path.home() / ".cache" / "ms-playwright") + os.environ["PLAYWRIGHT_BROWSERS_PATH"] = BROWSERS_PATH + + +class BrowserInstaller: + """浏览器安装器""" + + def __init__(self, log_callback=None): + """ + 初始化安装器 + + Args: + log_callback: 日志回调函数 + """ + self.log_callback = log_callback + + def log(self, message): + """输出日志""" + if self.log_callback: + self.log_callback(message) + else: + try: + print(message) + except UnicodeEncodeError: + # 如果打印Unicode字符失败,替换特殊字符 + safe_message = message.replace('✓', '[OK]').replace('✗', '[X]') + print(safe_message) + + def check_playwright_installed(self): + """检查Playwright是否已安装""" + try: + import playwright + self.log("✓ Playwright已安装") + return True + except ImportError: + self.log("✗ Playwright未安装") + return False + + def check_chromium_installed(self): + """检查Chromium浏览器是否已安装""" + try: + from playwright.sync_api import sync_playwright + + # 尝试启动浏览器检查是否可用 + with sync_playwright() as p: + try: + # 使用超时快速检查 + browser = p.chromium.launch(headless=True, timeout=5000) + browser.close() + self.log("✓ Chromium浏览器已安装且可用") + return True + except Exception as e: + error_msg = str(e) + self.log(f"✗ Chromium浏览器不可用: {error_msg}") + + # 检查是否是路径不存在的错误 + if "Executable doesn't exist" in error_msg: + self.log("检测到浏览器文件缺失,需要重新安装") + + return False + except Exception as e: + self.log(f"✗ 检查浏览器时出错: {str(e)}") + return False + + def install_chromium(self): + """安装Chromium浏览器""" + try: + self.log("正在安装 Chromium 浏览器...") + + # 查找 playwright 可执行文件 + playwright_cli = None + possible_paths = [ + os.path.join(os.path.dirname(sys.executable), "Scripts", "playwright.exe"), + os.path.join(os.path.dirname(sys.executable), "playwright.exe"), + os.path.join(os.path.dirname(sys.executable), "Scripts", "playwright"), + os.path.join(os.path.dirname(sys.executable), "playwright"), + "playwright", # 系统PATH中 + ] + + for path in possible_paths: + if os.path.exists(path) or shutil.which(path): + playwright_cli = path + break + + # 如果找到了 playwright CLI,直接调用 + if playwright_cli: + self.log(f"使用 Playwright CLI: {playwright_cli}") + result = subprocess.run( + [playwright_cli, "install", "chromium"], + capture_output=True, + text=True, + timeout=300 + ) + else: + # 检测是否是 Nuitka 编译的程序 + is_nuitka = hasattr(sys, 'frozen') or '__compiled__' in globals() + + if is_nuitka: + self.log("检测到 Nuitka 编译环境") + self.log("✗ 无法找到 playwright CLI 工具") + self.log("请手动运行: playwright install chromium") + return False + else: + # 使用 python -m + result = subprocess.run( + [sys.executable, "-m", "playwright", "install", "chromium"], + capture_output=True, + text=True, + timeout=300 + ) + + if result.returncode == 0: + self.log("✓ Chromium浏览器安装成功") + return True + else: + self.log(f"✗ 浏览器安装失败: {result.stderr}") + return False + + except subprocess.TimeoutExpired: + self.log("✗ 浏览器安装超时") + return False + except Exception as e: + self.log(f"✗ 浏览器安装出错: {str(e)}") + return False + + def auto_install(self): + """ + 自动检测并安装所需环境 + + Returns: + 是否成功安装或已安装 + """ + self.log("=" * 60) + self.log("检查浏览器环境...") + self.log("=" * 60) + + # 1. 检查Playwright是否安装 + if not self.check_playwright_installed(): + self.log("✗ Playwright未安装,无法继续") + self.log("请确保程序包含 Playwright 库") + return False + + # 2. 检查Chromium浏览器是否安装 + if not self.check_chromium_installed(): + self.log("\n未检测到Chromium浏览器,开始自动安装...") + + # 安装浏览器 + if not self.install_chromium(): + self.log("✗ 浏览器安装失败") + self.log("\n您可以尝试以下方法:") + self.log("1. 手动执行: playwright install chromium") + self.log("2. 检查网络连接后重试") + self.log("3. 检查防火墙设置") + return False + + self.log("\n" + "=" * 60) + self.log("✓ 浏览器环境检查完成,一切就绪!") + self.log("=" * 60 + "\n") + + return True + + +def check_and_install_browser(log_callback=None): + """ + 便捷函数:检查并安装浏览器 + + Args: + log_callback: 日志回调函数 + + Returns: + 是否成功 + """ + installer = BrowserInstaller(log_callback) + return installer.auto_install() + + +# 测试代码 +if __name__ == "__main__": + print("浏览器自动安装工具") + print("=" * 60) + + installer = BrowserInstaller() + success = installer.auto_install() + + if success: + print("\n✓ 安装成功!您现在可以运行主程序了。") + else: + print("\n✗ 安装失败,请查看上方错误信息。") + + print("=" * 60) diff --git a/database.py b/database.py new file mode 100755 index 0000000..56f30d8 --- /dev/null +++ b/database.py @@ -0,0 +1,1066 @@ +#!/usr/bin/env python3 +# -*- coding: utf-8 -*- +""" +数据库模块 - 使用SQLite进行数据持久化 +支持VIP功能 + +优化内容: +1. 清理所有注释掉的代码 +2. 统一使用bcrypt密码哈希 +3. 优化数据库索引 +4. 规范化事务处理 +5. 添加数据迁移功能 +6. 改进错误处理 +""" + +import sqlite3 +import time +from datetime import datetime, timedelta +import pytz +import threading +import db_pool +from password_utils import ( + hash_password_bcrypt, + verify_password_bcrypt, + is_sha256_hash, + verify_password_sha256 +) + +# 数据库文件路径 +DB_FILE = "data/app_data.db" + +# 数据库版本 (用于迁移管理) +DB_VERSION = 2 + + +def hash_password(password): + """Password hashing using bcrypt""" + return hash_password_bcrypt(password) + + +def init_database(): + """初始化数据库表结构""" + db_pool.init_pool(DB_FILE, pool_size=5) + + with db_pool.get_db() as conn: + cursor = conn.cursor() + + # 管理员表 + cursor.execute(''' + CREATE TABLE IF NOT EXISTS admins ( + id INTEGER PRIMARY KEY AUTOINCREMENT, + username TEXT UNIQUE NOT NULL, + password_hash TEXT NOT NULL, + created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP + ) + ''') + + # 用户表 + cursor.execute(''' + CREATE TABLE IF NOT EXISTS users ( + id INTEGER PRIMARY KEY AUTOINCREMENT, + username TEXT UNIQUE NOT NULL, + password_hash TEXT NOT NULL, + email TEXT, + status TEXT DEFAULT 'pending', + vip_expire_time TIMESTAMP, + created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP, + approved_at TIMESTAMP + ) + ''') + + # 账号表(关联用户) + cursor.execute(''' + CREATE TABLE IF NOT EXISTS accounts ( + id TEXT PRIMARY KEY, + user_id INTEGER NOT NULL, + username TEXT NOT NULL, + password TEXT NOT NULL, + remember INTEGER DEFAULT 1, + remark TEXT, + created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP, + FOREIGN KEY (user_id) REFERENCES users (id) ON DELETE CASCADE + ) + ''') + + # VIP配置表 + cursor.execute(''' + CREATE TABLE IF NOT EXISTS vip_config ( + id INTEGER PRIMARY KEY CHECK (id = 1), + default_vip_days INTEGER DEFAULT 0, + updated_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP + ) + ''') + + # 系统配置表 + cursor.execute(''' + CREATE TABLE IF NOT EXISTS system_config ( + id INTEGER PRIMARY KEY CHECK (id = 1), + max_concurrent_global INTEGER DEFAULT 2, + max_concurrent_per_account INTEGER DEFAULT 1, + schedule_enabled INTEGER DEFAULT 0, + schedule_time TEXT DEFAULT '02:00', + schedule_browse_type TEXT DEFAULT '应读', + schedule_weekdays TEXT DEFAULT '1,2,3,4,5,6,7', + proxy_enabled INTEGER DEFAULT 0, + proxy_api_url TEXT DEFAULT '', + proxy_expire_minutes INTEGER DEFAULT 3, + enable_screenshot INTEGER DEFAULT 1, + updated_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP + ) + ''') + + # 任务日志表 + cursor.execute(''' + CREATE TABLE IF NOT EXISTS task_logs ( + id INTEGER PRIMARY KEY AUTOINCREMENT, + user_id INTEGER NOT NULL, + account_id TEXT NOT NULL, + username TEXT NOT NULL, + browse_type TEXT NOT NULL, + status TEXT NOT NULL, + total_items INTEGER DEFAULT 0, + total_attachments INTEGER DEFAULT 0, + error_message TEXT, + duration INTEGER, + created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP, + FOREIGN KEY (user_id) REFERENCES users (id) ON DELETE CASCADE + ) + ''') + + # 密码重置申请表 + cursor.execute(''' + CREATE TABLE IF NOT EXISTS password_reset_requests ( + id INTEGER PRIMARY KEY AUTOINCREMENT, + user_id INTEGER NOT NULL, + new_password_hash TEXT NOT NULL, + status TEXT NOT NULL DEFAULT 'pending', + created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP, + processed_at TIMESTAMP, + FOREIGN KEY (user_id) REFERENCES users (id) ON DELETE CASCADE + ) + ''') + + # 数据库版本表 + cursor.execute(''' + CREATE TABLE IF NOT EXISTS db_version ( + id INTEGER PRIMARY KEY CHECK (id = 1), + version INTEGER NOT NULL, + updated_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP + ) + ''') + + # ========== 创建索引 ========== + # 用户表索引 + cursor.execute('CREATE INDEX IF NOT EXISTS idx_users_username ON users(username)') + cursor.execute('CREATE INDEX IF NOT EXISTS idx_users_status ON users(status)') + cursor.execute('CREATE INDEX IF NOT EXISTS idx_users_vip_expire ON users(vip_expire_time)') + + # 账号表索引 + cursor.execute('CREATE INDEX IF NOT EXISTS idx_accounts_user_id ON accounts(user_id)') + cursor.execute('CREATE INDEX IF NOT EXISTS idx_accounts_username ON accounts(username)') + + # 任务日志表索引 + cursor.execute('CREATE INDEX IF NOT EXISTS idx_task_logs_user_id ON task_logs(user_id)') + cursor.execute('CREATE INDEX IF NOT EXISTS idx_task_logs_status ON task_logs(status)') + cursor.execute('CREATE INDEX IF NOT EXISTS idx_task_logs_created_at ON task_logs(created_at)') + cursor.execute('CREATE INDEX IF NOT EXISTS idx_task_logs_user_date ON task_logs(user_id, created_at)') + + # 密码重置表索引 + cursor.execute('CREATE INDEX IF NOT EXISTS idx_password_reset_status ON password_reset_requests(status)') + cursor.execute('CREATE INDEX IF NOT EXISTS idx_password_reset_user_id ON password_reset_requests(user_id)') + + # 初始化VIP配置 + try: + cursor.execute('INSERT INTO vip_config (id, default_vip_days) VALUES (1, 0)') + conn.commit() + print("✓ 已创建VIP配置(默认不赠送)") + except sqlite3.IntegrityError: + pass + + # 初始化系统配置 + try: + cursor.execute(''' + INSERT INTO system_config ( + id, max_concurrent_global, schedule_enabled, + schedule_time, schedule_browse_type, schedule_weekdays + ) VALUES (1, 2, 0, '02:00', '应读', '1,2,3,4,5,6,7') + ''') + conn.commit() + print("✓ 已创建系统配置(默认并发2,定时任务关闭)") + except sqlite3.IntegrityError: + pass + + # 初始化数据库版本 + try: + cursor.execute('INSERT INTO db_version (id, version) VALUES (1, ?)', (DB_VERSION,)) + conn.commit() + print(f"✓ 数据库版本: {DB_VERSION}") + except sqlite3.IntegrityError: + pass + + conn.commit() + print("✓ 数据库初始化完成") + + # 执行数据迁移 + migrate_database() + + +def migrate_database(): + """数据库迁移 - 自动检测并应用必要的迁移""" + with db_pool.get_db() as conn: + cursor = conn.cursor() + + # 获取当前数据库版本 + cursor.execute('SELECT version FROM db_version WHERE id = 1') + row = cursor.fetchone() + current_version = row['version'] if row else 0 + + print(f"当前数据库版本: {current_version}, 目标版本: {DB_VERSION}") + + # 应用迁移 + if current_version < 1: + _migrate_to_v1(conn) + current_version = 1 + + if current_version < 2: + _migrate_to_v2(conn) + current_version = 2 + + # 更新版本号 + cursor.execute('UPDATE db_version SET version = ?, updated_at = CURRENT_TIMESTAMP WHERE id = 1', + (DB_VERSION,)) + conn.commit() + + if current_version < DB_VERSION: + print(f"✓ 数据库已迁移到版本 {DB_VERSION}") + + +def _migrate_to_v1(conn): + """迁移到版本1 - 添加缺失字段""" + cursor = conn.cursor() + + # 检查并添加 schedule_weekdays 字段 + cursor.execute("PRAGMA table_info(system_config)") + columns = [col[1] for col in cursor.fetchall()] + + if 'schedule_weekdays' not in columns: + cursor.execute('ALTER TABLE system_config ADD COLUMN schedule_weekdays TEXT DEFAULT "1,2,3,4,5,6,7"') + print(" ✓ 添加 schedule_weekdays 字段") + + if 'max_concurrent_per_account' not in columns: + cursor.execute('ALTER TABLE system_config ADD COLUMN max_concurrent_per_account INTEGER DEFAULT 1') + print(" ✓ 添加 max_concurrent_per_account 字段") + + # 检查并添加 duration 字段到 task_logs + cursor.execute("PRAGMA table_info(task_logs)") + columns = [col[1] for col in cursor.fetchall()] + + if 'duration' not in columns: + cursor.execute('ALTER TABLE task_logs ADD COLUMN duration INTEGER') + print(" ✓ 添加 duration 字段到 task_logs") + + conn.commit() + + +def _migrate_to_v2(conn): + """迁移到版本2 - 添加代理配置字段""" + cursor = conn.cursor() + + cursor.execute("PRAGMA table_info(system_config)") + columns = [col[1] for col in cursor.fetchall()] + + if 'proxy_enabled' not in columns: + cursor.execute('ALTER TABLE system_config ADD COLUMN proxy_enabled INTEGER DEFAULT 0') + print(" ✓ 添加 proxy_enabled 字段") + + if 'proxy_api_url' not in columns: + cursor.execute('ALTER TABLE system_config ADD COLUMN proxy_api_url TEXT DEFAULT ""') + print(" ✓ 添加 proxy_api_url 字段") + + if 'proxy_expire_minutes' not in columns: + cursor.execute('ALTER TABLE system_config ADD COLUMN proxy_expire_minutes INTEGER DEFAULT 3') + print(" ✓ 添加 proxy_expire_minutes 字段") + + if 'enable_screenshot' not in columns: + cursor.execute('ALTER TABLE system_config ADD COLUMN enable_screenshot INTEGER DEFAULT 1') + print(" ✓ 添加 enable_screenshot 字段") + + conn.commit() + + +# ==================== 管理员相关 ==================== + +def verify_admin(username, password): + """验证管理员登录 - 自动从SHA256升级到bcrypt""" + with db_pool.get_db() as conn: + cursor = conn.cursor() + cursor.execute('SELECT * FROM admins WHERE username = ?', (username,)) + admin = cursor.fetchone() + + if not admin: + return None + + admin_dict = dict(admin) + password_hash = admin_dict['password_hash'] + + # 检查是否为旧的SHA256哈希 + if is_sha256_hash(password_hash): + if verify_password_sha256(password, password_hash): + # 自动升级到bcrypt + new_hash = hash_password_bcrypt(password) + cursor.execute('UPDATE admins SET password_hash = ? WHERE username = ?', + (new_hash, username)) + conn.commit() + print(f"管理员 {username} 密码已自动升级到bcrypt") + return admin_dict + return None + else: + # bcrypt验证 + if verify_password_bcrypt(password, password_hash): + return admin_dict + return None + + +def update_admin_password(username, new_password): + """更新管理员密码""" + with db_pool.get_db() as conn: + cursor = conn.cursor() + password_hash = hash_password(new_password) + cursor.execute('UPDATE admins SET password_hash = ? WHERE username = ?', + (password_hash, username)) + conn.commit() + return cursor.rowcount > 0 + + +def update_admin_username(old_username, new_username): + """更新管理员用户名""" + with db_pool.get_db() as conn: + cursor = conn.cursor() + try: + cursor.execute('UPDATE admins SET username = ? WHERE username = ?', + (new_username, old_username)) + conn.commit() + return True + except sqlite3.IntegrityError: + return False + + +# ==================== VIP管理 ==================== + +def get_vip_config(): + """获取VIP配置""" + with db_pool.get_db() as conn: + cursor = conn.cursor() + cursor.execute('SELECT * FROM vip_config WHERE id = 1') + config = cursor.fetchone() + return dict(config) if config else {'default_vip_days': 0} + + +def set_default_vip_days(days): + """设置默认VIP天数""" + with db_pool.get_db() as conn: + cursor = conn.cursor() + cursor.execute(''' + INSERT OR REPLACE INTO vip_config (id, default_vip_days, updated_at) + VALUES (1, ?, CURRENT_TIMESTAMP) + ''', (days,)) + conn.commit() + return True + + +def set_user_vip(user_id, days): + """设置用户VIP - days: 7=一周, 30=一个月, 365=一年, 999999=永久""" + with db_pool.get_db() as conn: + cst_tz = pytz.timezone("Asia/Shanghai") + cursor = conn.cursor() + + if days == 999999: + expire_time = '2099-12-31 23:59:59' + else: + expire_time = (datetime.now(cst_tz) + timedelta(days=days)).strftime('%Y-%m-%d %H:%M:%S') + + cursor.execute('UPDATE users SET vip_expire_time = ? WHERE id = ?', (expire_time, user_id)) + conn.commit() + return cursor.rowcount > 0 + + +def extend_user_vip(user_id, days): + """延长用户VIP时间""" + user = get_user_by_id(user_id) + cst_tz = pytz.timezone("Asia/Shanghai") + + if not user: + return False + + with db_pool.get_db() as conn: + cursor = conn.cursor() + current_expire = user.get('vip_expire_time') + + if current_expire and current_expire != '2099-12-31 23:59:59': + try: + expire_time_naive = datetime.strptime(current_expire, '%Y-%m-%d %H:%M:%S') + expire_time = cst_tz.localize(expire_time_naive) + now = datetime.now(cst_tz) + if expire_time < now: + expire_time = now + new_expire = (expire_time + timedelta(days=days)).strftime('%Y-%m-%d %H:%M:%S') + except: + new_expire = (datetime.now(cst_tz) + timedelta(days=days)).strftime('%Y-%m-%d %H:%M:%S') + else: + new_expire = (datetime.now(cst_tz) + timedelta(days=days)).strftime('%Y-%m-%d %H:%M:%S') + + cursor.execute('UPDATE users SET vip_expire_time = ? WHERE id = ?', (new_expire, user_id)) + conn.commit() + return cursor.rowcount > 0 + + +def remove_user_vip(user_id): + """移除用户VIP""" + with db_pool.get_db() as conn: + cursor = conn.cursor() + cursor.execute('UPDATE users SET vip_expire_time = NULL WHERE id = ?', (user_id,)) + conn.commit() + return cursor.rowcount > 0 + + +def is_user_vip(user_id): + """检查用户是否是VIP""" + cst_tz = pytz.timezone("Asia/Shanghai") + user = get_user_by_id(user_id) + + if not user or not user.get('vip_expire_time'): + return False + + try: + expire_time_naive = datetime.strptime(user['vip_expire_time'], '%Y-%m-%d %H:%M:%S') + expire_time = cst_tz.localize(expire_time_naive) + return datetime.now(cst_tz) < expire_time + except: + return False + + +def get_user_vip_info(user_id): + """获取用户VIP信息""" + cst_tz = pytz.timezone("Asia/Shanghai") + user = get_user_by_id(user_id) + + if not user: + return {'is_vip': False, 'expire_time': None, 'days_left': 0, 'username': ''} + + vip_expire_time = user.get('vip_expire_time') + if not vip_expire_time: + return {'is_vip': False, 'expire_time': None, 'days_left': 0, 'username': user.get('username', '')} + + try: + expire_time_naive = datetime.strptime(vip_expire_time, '%Y-%m-%d %H:%M:%S') + expire_time = cst_tz.localize(expire_time_naive) + now = datetime.now(cst_tz) + is_vip = now < expire_time + days_left = (expire_time - now).days if is_vip else 0 + + return { + "username": user.get("username", ""), + 'is_vip': is_vip, + 'expire_time': vip_expire_time, + 'days_left': max(0, days_left) + } + except Exception as e: + print(f"VIP信息获取错误: {e}") + return {'is_vip': False, 'expire_time': None, 'days_left': 0, 'username': user.get('username', '')} + + +# ==================== 用户相关 ==================== + +def create_user(username, password, email=''): + """创建新用户(待审核状态,赠送默认VIP)""" + cst_tz = pytz.timezone("Asia/Shanghai") + + with db_pool.get_db() as conn: + cursor = conn.cursor() + password_hash = hash_password(password) + + # 获取默认VIP天数 + default_vip_days = get_vip_config()['default_vip_days'] + vip_expire_time = None + + if default_vip_days > 0: + if default_vip_days == 999999: + vip_expire_time = '2099-12-31 23:59:59' + else: + vip_expire_time = (datetime.now(cst_tz) + timedelta(days=default_vip_days)).strftime('%Y-%m-%d %H:%M:%S') + + try: + cursor.execute(''' + INSERT INTO users (username, password_hash, email, status, vip_expire_time) + VALUES (?, ?, ?, 'pending', ?) + ''', (username, password_hash, email, vip_expire_time)) + conn.commit() + return cursor.lastrowid + except sqlite3.IntegrityError: + return None + + +def verify_user(username, password): + """验证用户登录 - 自动从SHA256升级到bcrypt""" + with db_pool.get_db() as conn: + cursor = conn.cursor() + cursor.execute("SELECT * FROM users WHERE username = ? AND status = 'approved'", (username,)) + user = cursor.fetchone() + + if not user: + return None + + user_dict = dict(user) + password_hash = user_dict['password_hash'] + + # 检查是否为旧的SHA256哈希 + if is_sha256_hash(password_hash): + if verify_password_sha256(password, password_hash): + # 自动升级到bcrypt + new_hash = hash_password_bcrypt(password) + cursor.execute('UPDATE users SET password_hash = ? WHERE id = ?', + (new_hash, user_dict['id'])) + conn.commit() + print(f"用户 {username} 密码已自动升级到bcrypt") + return user_dict + return None + else: + # bcrypt验证 + if verify_password_bcrypt(password, password_hash): + return user_dict + return None + + +def get_user_by_id(user_id): + """根据ID获取用户""" + with db_pool.get_db() as conn: + cursor = conn.cursor() + cursor.execute('SELECT * FROM users WHERE id = ?', (user_id,)) + user = cursor.fetchone() + return dict(user) if user else None + + +def get_user_by_username(username): + """根据用户名获取用户""" + with db_pool.get_db() as conn: + cursor = conn.cursor() + cursor.execute('SELECT * FROM users WHERE username = ?', (username,)) + user = cursor.fetchone() + return dict(user) if user else None + + +def get_all_users(): + """获取所有用户""" + with db_pool.get_db() as conn: + cursor = conn.cursor() + cursor.execute('SELECT * FROM users ORDER BY created_at DESC') + return [dict(row) for row in cursor.fetchall()] + + +def get_pending_users(): + """获取待审核用户""" + with db_pool.get_db() as conn: + cursor = conn.cursor() + cursor.execute("SELECT * FROM users WHERE status = 'pending' ORDER BY created_at DESC") + return [dict(row) for row in cursor.fetchall()] + + +def approve_user(user_id): + """审核通过用户""" + with db_pool.get_db() as conn: + cursor = conn.cursor() + cursor.execute(''' + UPDATE users + SET status = 'approved', approved_at = CURRENT_TIMESTAMP + WHERE id = ? + ''', (user_id,)) + conn.commit() + return cursor.rowcount > 0 + + +def reject_user(user_id): + """拒绝用户""" + with db_pool.get_db() as conn: + cursor = conn.cursor() + cursor.execute("UPDATE users SET status = 'rejected' WHERE id = ?", (user_id,)) + conn.commit() + return cursor.rowcount > 0 + + +def delete_user(user_id): + """删除用户(级联删除相关账号)""" + with db_pool.get_db() as conn: + cursor = conn.cursor() + cursor.execute('DELETE FROM users WHERE id = ?', (user_id,)) + conn.commit() + return cursor.rowcount > 0 + + +# ==================== 账号相关 ==================== + +def create_account(user_id, account_id, username, password, remember=True, remark=''): + """创建账号""" + with db_pool.get_db() as conn: + cursor = conn.cursor() + cursor.execute(''' + INSERT INTO accounts (id, user_id, username, password, remember, remark) + VALUES (?, ?, ?, ?, ?, ?) + ''', (account_id, user_id, username, password, 1 if remember else 0, remark)) + conn.commit() + return cursor.lastrowid + + +def get_user_accounts(user_id): + """获取用户的所有账号""" + with db_pool.get_db() as conn: + cursor = conn.cursor() + cursor.execute('SELECT * FROM accounts WHERE user_id = ? ORDER BY created_at DESC', (user_id,)) + return [dict(row) for row in cursor.fetchall()] + + +def get_account(account_id): + """获取单个账号""" + with db_pool.get_db() as conn: + cursor = conn.cursor() + cursor.execute('SELECT * FROM accounts WHERE id = ?', (account_id,)) + row = cursor.fetchone() + return dict(row) if row else None + + +def update_account_remark(account_id, remark): + """更新账号备注""" + with db_pool.get_db() as conn: + cursor = conn.cursor() + cursor.execute('UPDATE accounts SET remark = ? WHERE id = ?', (remark, account_id)) + conn.commit() + return cursor.rowcount > 0 + + +def delete_account(account_id): + """删除账号""" + with db_pool.get_db() as conn: + cursor = conn.cursor() + cursor.execute('DELETE FROM accounts WHERE id = ?', (account_id,)) + conn.commit() + return cursor.rowcount > 0 + + +def delete_user_accounts(user_id): + """删除用户的所有账号""" + with db_pool.get_db() as conn: + cursor = conn.cursor() + cursor.execute('DELETE FROM accounts WHERE user_id = ?', (user_id,)) + conn.commit() + return cursor.rowcount + + +# ==================== 统计相关 ==================== + +def get_user_stats(user_id): + """获取用户统计信息""" + with db_pool.get_db() as conn: + cursor = conn.cursor() + cursor.execute('SELECT COUNT(*) as count FROM accounts WHERE user_id = ?', (user_id,)) + account_count = cursor.fetchone()['count'] + return {'account_count': account_count} + + +def get_system_stats(): + """获取系统统计信息""" + with db_pool.get_db() as conn: + cursor = conn.cursor() + + cursor.execute('SELECT COUNT(*) as count FROM users') + total_users = cursor.fetchone()['count'] + + cursor.execute("SELECT COUNT(*) as count FROM users WHERE status = 'approved'") + approved_users = cursor.fetchone()['count'] + + cursor.execute("SELECT COUNT(*) as count FROM users WHERE status = 'pending'") + pending_users = cursor.fetchone()['count'] + + cursor.execute('SELECT COUNT(*) as count FROM accounts') + total_accounts = cursor.fetchone()['count'] + + cursor.execute(''' + SELECT COUNT(*) as count FROM users + WHERE vip_expire_time IS NOT NULL + AND datetime(vip_expire_time) > datetime('now') + ''') + vip_users = cursor.fetchone()['count'] + + return { + 'total_users': total_users, + 'approved_users': approved_users, + 'pending_users': pending_users, + 'total_accounts': total_accounts, + 'vip_users': vip_users + } + + +# ==================== 系统配置管理 ==================== + +def get_system_config(): + """获取系统配置""" + with db_pool.get_db() as conn: + cursor = conn.cursor() + cursor.execute('SELECT * FROM system_config WHERE id = 1') + row = cursor.fetchone() + + if row: + return dict(row) + + # 返回默认值 + return { + 'max_concurrent_global': 2, + 'max_concurrent_per_account': 1, + 'schedule_enabled': 0, + 'schedule_time': '02:00', + 'schedule_browse_type': '应读', + 'schedule_weekdays': '1,2,3,4,5,6,7', + 'proxy_enabled': 0, + 'proxy_api_url': '', + 'proxy_expire_minutes': 3, + 'enable_screenshot': 1 + } + + +def update_system_config(max_concurrent=None, schedule_enabled=None, schedule_time=None, + schedule_browse_type=None, schedule_weekdays=None, + max_concurrent_per_account=None, proxy_enabled=None, + proxy_api_url=None, proxy_expire_minutes=None): + """更新系统配置""" + with db_pool.get_db() as conn: + cursor = conn.cursor() + updates = [] + params = [] + + if max_concurrent is not None: + updates.append('max_concurrent_global = ?') + params.append(max_concurrent) + + if schedule_enabled is not None: + updates.append('schedule_enabled = ?') + params.append(schedule_enabled) + + if schedule_time is not None: + updates.append('schedule_time = ?') + params.append(schedule_time) + + if schedule_browse_type is not None: + updates.append('schedule_browse_type = ?') + params.append(schedule_browse_type) + + if max_concurrent_per_account is not None: + updates.append('max_concurrent_per_account = ?') + params.append(max_concurrent_per_account) + + if schedule_weekdays is not None: + updates.append('schedule_weekdays = ?') + params.append(schedule_weekdays) + + if proxy_enabled is not None: + updates.append('proxy_enabled = ?') + params.append(proxy_enabled) + + if proxy_api_url is not None: + updates.append('proxy_api_url = ?') + params.append(proxy_api_url) + + if proxy_expire_minutes is not None: + updates.append('proxy_expire_minutes = ?') + params.append(proxy_expire_minutes) + + if updates: + updates.append('updated_at = CURRENT_TIMESTAMP') + sql = f"UPDATE system_config SET {', '.join(updates)} WHERE id = 1" + cursor.execute(sql, params) + conn.commit() + return True + + return False + + +# ==================== 任务日志管理 ==================== + +def create_task_log(user_id, account_id, username, browse_type, status, + total_items=0, total_attachments=0, error_message='', duration=None): + """创建任务日志记录""" + with db_pool.get_db() as conn: + cursor = conn.cursor() + cst_tz = pytz.timezone("Asia/Shanghai") + cst_time = datetime.now(cst_tz).strftime("%Y-%m-%d %H:%M:%S") + + cursor.execute(''' + INSERT INTO task_logs ( + user_id, account_id, username, browse_type, status, + total_items, total_attachments, error_message, duration, created_at + ) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?) + ''', (user_id, account_id, username, browse_type, status, + total_items, total_attachments, error_message, duration, cst_time)) + + conn.commit() + return cursor.lastrowid + + +def get_task_logs(limit=100, offset=0, date_filter=None, status_filter=None): + """获取任务日志列表""" + with db_pool.get_db() as conn: + cursor = conn.cursor() + + sql = ''' + SELECT + tl.*, + u.username as user_username + FROM task_logs tl + LEFT JOIN users u ON tl.user_id = u.id + WHERE 1=1 + ''' + params = [] + + if date_filter: + sql += " AND date(tl.created_at) = ?" + params.append(date_filter) + + if status_filter: + sql += " AND tl.status = ?" + params.append(status_filter) + + sql += " ORDER BY tl.created_at DESC LIMIT ? OFFSET ?" + params.extend([limit, offset]) + + cursor.execute(sql, params) + return [dict(row) for row in cursor.fetchall()] + + +def get_task_stats(date_filter=None): + """获取任务统计信息""" + with db_pool.get_db() as conn: + cursor = conn.cursor() + cst_tz = pytz.timezone("Asia/Shanghai") + + if date_filter is None: + date_filter = datetime.now(cst_tz).strftime('%Y-%m-%d') + + # 当日统计 + cursor.execute(''' + SELECT + COUNT(*) as total_tasks, + SUM(CASE WHEN status = 'success' THEN 1 ELSE 0 END) as success_tasks, + SUM(CASE WHEN status = 'failed' THEN 1 ELSE 0 END) as failed_tasks, + SUM(total_items) as total_items, + SUM(total_attachments) as total_attachments + FROM task_logs + WHERE date(created_at) = ? + ''', (date_filter,)) + + today_stats = cursor.fetchone() + + # 历史累计统计 + cursor.execute(''' + SELECT + COUNT(*) as total_tasks, + SUM(CASE WHEN status = 'success' THEN 1 ELSE 0 END) as success_tasks, + SUM(CASE WHEN status = 'failed' THEN 1 ELSE 0 END) as failed_tasks, + SUM(total_items) as total_items, + SUM(total_attachments) as total_attachments + FROM task_logs + ''') + + total_stats = cursor.fetchone() + + return { + 'today': { + 'total_tasks': today_stats['total_tasks'] or 0, + 'success_tasks': today_stats['success_tasks'] or 0, + 'failed_tasks': today_stats['failed_tasks'] or 0, + 'total_items': today_stats['total_items'] or 0, + 'total_attachments': today_stats['total_attachments'] or 0 + }, + 'total': { + 'total_tasks': total_stats['total_tasks'] or 0, + 'success_tasks': total_stats['success_tasks'] or 0, + 'failed_tasks': total_stats['failed_tasks'] or 0, + 'total_items': total_stats['total_items'] or 0, + 'total_attachments': total_stats['total_attachments'] or 0 + } + } + + +def delete_old_task_logs(days=30): + """删除N天前的任务日志""" + with db_pool.get_db() as conn: + cursor = conn.cursor() + cursor.execute(''' + DELETE FROM task_logs + WHERE created_at < datetime('now', '-' || ? || ' days') + ''', (days,)) + conn.commit() + return cursor.rowcount + + +def get_user_run_stats(user_id, date_filter=None): + """获取用户的运行统计信息""" + with db_pool.get_db() as conn: + cst_tz = pytz.timezone("Asia/Shanghai") + cursor = conn.cursor() + + if date_filter is None: + date_filter = datetime.now(cst_tz).strftime('%Y-%m-%d') + + cursor.execute(''' + SELECT + SUM(CASE WHEN status = 'success' THEN 1 ELSE 0 END) as completed, + SUM(CASE WHEN status = 'failed' THEN 1 ELSE 0 END) as failed, + SUM(total_items) as total_items, + SUM(total_attachments) as total_attachments + FROM task_logs + WHERE user_id = ? AND date(created_at) = ? + ''', (user_id, date_filter)) + + stats = cursor.fetchone() + + return { + 'completed': stats['completed'] or 0, + 'failed': stats['failed'] or 0, + 'total_items': stats['total_items'] or 0, + 'total_attachments': stats['total_attachments'] or 0 + } + + +# ==================== 密码重置功能 ==================== + +def create_password_reset_request(user_id, new_password): + """创建密码重置申请 - 使用bcrypt哈希""" + with db_pool.get_db() as conn: + cursor = conn.cursor() + password_hash = hash_password_bcrypt(new_password) + + try: + cursor.execute(''' + INSERT INTO password_reset_requests (user_id, new_password_hash, status) + VALUES (?, ?, 'pending') + ''', (user_id, password_hash)) + conn.commit() + return cursor.lastrowid + except Exception as e: + print(f"创建密码重置申请失败: {e}") + return None + + +def get_pending_password_resets(): + """获取所有待审核的密码重置申请""" + with db_pool.get_db() as conn: + cursor = conn.cursor() + cursor.execute(''' + SELECT r.id, r.user_id, r.created_at, r.status, + u.username, u.email + FROM password_reset_requests r + JOIN users u ON r.user_id = u.id + WHERE r.status = 'pending' + ORDER BY r.created_at DESC + ''') + return [dict(row) for row in cursor.fetchall()] + + +def approve_password_reset(request_id): + """批准密码重置申请""" + with db_pool.get_db() as conn: + cursor = conn.cursor() + + try: + # 获取申请信息 + cursor.execute(''' + SELECT user_id, new_password_hash + FROM password_reset_requests + WHERE id = ? AND status = 'pending' + ''', (request_id,)) + + result = cursor.fetchone() + if not result: + return False + + user_id = result['user_id'] + new_password_hash = result['new_password_hash'] + + # 更新用户密码 + cursor.execute('UPDATE users SET password_hash = ? WHERE id = ?', + (new_password_hash, user_id)) + + # 更新申请状态 + cursor.execute(''' + UPDATE password_reset_requests + SET status = 'approved', processed_at = CURRENT_TIMESTAMP + WHERE id = ? + ''', (request_id,)) + + conn.commit() + return True + except Exception as e: + print(f"批准密码重置失败: {e}") + return False + + +def reject_password_reset(request_id): + """拒绝密码重置申请""" + with db_pool.get_db() as conn: + cursor = conn.cursor() + + try: + cursor.execute(''' + UPDATE password_reset_requests + SET status = 'rejected', processed_at = CURRENT_TIMESTAMP + WHERE id = ? AND status = 'pending' + ''', (request_id,)) + conn.commit() + return cursor.rowcount > 0 + except Exception as e: + print(f"拒绝密码重置失败: {e}") + return False + + +def admin_reset_user_password(user_id, new_password): + """管理员直接重置用户密码 - 使用bcrypt哈希""" + with db_pool.get_db() as conn: + cursor = conn.cursor() + password_hash = hash_password_bcrypt(new_password) + + try: + cursor.execute('UPDATE users SET password_hash = ? WHERE id = ?', + (password_hash, user_id)) + conn.commit() + return cursor.rowcount > 0 + except Exception as e: + print(f"管理员重置密码失败: {e}") + return False + + +# ==================== 日志清理 ==================== + +def clean_old_operation_logs(days=30): + """清理指定天数前的操作日志(如果存在operation_logs表)""" + with db_pool.get_db() as conn: + cursor = conn.cursor() + + # 检查表是否存在 + cursor.execute(""" + SELECT name FROM sqlite_master + WHERE type='table' AND name='operation_logs' + """) + + if not cursor.fetchone(): + return 0 + + try: + cursor.execute(''' + DELETE FROM operation_logs + WHERE created_at < datetime('now', '-' || ? || ' days') + ''', (days,)) + deleted_count = cursor.rowcount + conn.commit() + print(f"已清理 {deleted_count} 条旧操作日志 (>{days}天)") + return deleted_count + except Exception as e: + print(f"清理旧操作日志失败: {e}") + return 0 diff --git a/db_pool.py b/db_pool.py new file mode 100755 index 0000000..e3e0d35 --- /dev/null +++ b/db_pool.py @@ -0,0 +1,250 @@ +#!/usr/bin/env python3 +# -*- coding: utf-8 -*- +""" +数据库连接池模块 +使用queue实现固定大小的连接池,防止连接泄漏 +""" + +import sqlite3 +import threading +from queue import Queue, Empty +import time + + +class ConnectionPool: + """SQLite连接池""" + + def __init__(self, database, pool_size=5, timeout=30): + """ + 初始化连接池 + + Args: + database: 数据库文件路径 + pool_size: 连接池大小(默认5) + timeout: 获取连接超时时间(秒) + """ + self.database = database + self.pool_size = pool_size + self.timeout = timeout + self._pool = Queue(maxsize=pool_size) + self._lock = threading.Lock() + self._created_connections = 0 + + # 预创建连接 + self._initialize_pool() + + def _initialize_pool(self): + """预创建连接池中的连接""" + for _ in range(self.pool_size): + conn = self._create_connection() + self._pool.put(conn) + self._created_connections += 1 + + def _create_connection(self): + """创建新的数据库连接""" + conn = sqlite3.connect(self.database, check_same_thread=False) + conn.row_factory = sqlite3.Row + # 设置WAL模式提高并发性能 + conn.execute('PRAGMA journal_mode=WAL') + # 设置合理的超时时间 + conn.execute('PRAGMA busy_timeout=5000') + return conn + + def get_connection(self): + """ + 从连接池获取连接 + + Returns: + PooledConnection: 连接包装对象 + """ + try: + conn = self._pool.get(timeout=self.timeout) + return PooledConnection(conn, self) + except Empty: + raise RuntimeError(f"无法在{self.timeout}秒内获取数据库连接") + + def return_connection(self, conn): + """ + 归还连接到连接池 [已修复Bug#7] + + Args: + conn: 要归还的连接 + """ + import sqlite3 + from queue import Full + + try: + # 回滚任何未提交的事务 + conn.rollback() + self._pool.put(conn, block=False) + except sqlite3.Error as e: + # 数据库相关错误,连接可能损坏 + print(f"归还连接失败(数据库错误): {e}") + try: + conn.close() + except Exception: + pass + # 创建新连接补充 + with self._lock: + try: + new_conn = self._create_connection() + self._pool.put(new_conn, block=False) + except Exception as create_error: + print(f"重建连接失败: {create_error}") + except Full: + # 队列已满(不应该发生) + print(f"警告: 连接池已满,关闭多余连接") + try: + conn.close() + except Exception: + pass + except Exception as e: + print(f"归还连接失败(未知错误): {e}") + try: + conn.close() + except Exception: + pass + + def close_all(self): + """关闭所有连接""" + while not self._pool.empty(): + try: + conn = self._pool.get(block=False) + conn.close() + except Exception as e: + print(f"关闭连接失败: {e}") + + def get_stats(self): + """获取连接池统计信息""" + return { + 'pool_size': self.pool_size, + 'available': self._pool.qsize(), + 'in_use': self.pool_size - self._pool.qsize(), + 'total_created': self._created_connections + } + + +class PooledConnection: + """连接池连接包装器,支持with语句自动归还""" + + def __init__(self, conn, pool): + """ + 初始化 + + Args: + conn: 实际的数据库连接 + pool: 连接池对象 + """ + self._conn = conn + self._pool = pool + self._cursor = None + + def __enter__(self): + """支持with语句""" + return self + + def __exit__(self, exc_type, exc_val, exc_tb): + """with语句结束时自动归还连接 [已修复Bug#3]""" + try: + if exc_type is not None: + # 发生异常,回滚事务 + self._conn.rollback() + print(f"数据库事务已回滚: {exc_type.__name__}") + # 注意: 不自动commit,要求用户显式调用conn.commit() + + if self._cursor: + self._cursor.close() + except Exception as e: + print(f"关闭游标失败: {e}") + finally: + # 归还连接 + self._pool.return_connection(self._conn) + + return False # 不抑制异常 + + def cursor(self): + """获取游标""" + self._cursor = self._conn.cursor() + return self._cursor + + def commit(self): + """提交事务""" + self._conn.commit() + + def rollback(self): + """回滚事务""" + self._conn.rollback() + + def execute(self, sql, parameters=None): + """执行SQL""" + cursor = self.cursor() + if parameters: + return cursor.execute(sql, parameters) + return cursor.execute(sql) + + def fetchone(self): + """获取一行""" + if self._cursor: + return self._cursor.fetchone() + return None + + def fetchall(self): + """获取所有行""" + if self._cursor: + return self._cursor.fetchall() + return [] + + @property + def lastrowid(self): + """最后插入的行ID""" + if self._cursor: + return self._cursor.lastrowid + return None + + @property + def rowcount(self): + """影响的行数""" + if self._cursor: + return self._cursor.rowcount + return 0 + + +# 全局连接池实例 +_pool = None +_pool_lock = threading.Lock() + + +def init_pool(database, pool_size=5): + """ + 初始化全局连接池 + + Args: + database: 数据库文件路径 + pool_size: 连接池大小 + """ + global _pool + with _pool_lock: + if _pool is None: + _pool = ConnectionPool(database, pool_size) + print(f"✓ 数据库连接池已初始化 (大小: {pool_size})") + + +def get_db(): + """ + 获取数据库连接(替代原有的get_db函数) + + Returns: + PooledConnection: 连接对象 + """ + global _pool + if _pool is None: + raise RuntimeError("连接池未初始化,请先调用init_pool()") + return _pool.get_connection() + + +def get_pool_stats(): + """获取连接池统计信息""" + global _pool + if _pool: + return _pool.get_stats() + return None diff --git a/docker-compose.yml b/docker-compose.yml new file mode 100644 index 0000000..b17ba10 --- /dev/null +++ b/docker-compose.yml @@ -0,0 +1,20 @@ +version: '3.8' + +services: + knowledge-automation: + build: . + container_name: knowledge-automation-multiuser + ports: + - "5001:5000" + volumes: + - ./data:/app/data # 数据库持久化 + - ./logs:/app/logs # 日志持久化 + - ./截图:/app/截图 # 截图持久化 + - ./playwright:/ms-playwright # Playwright浏览器持久化(避免重复下载) + - /etc/localtime:/etc/localtime:ro # 时区同步 + environment: + - TZ=Asia/Shanghai + - PYTHONUNBUFFERED=1 + - PLAYWRIGHT_BROWSERS_PATH=/ms-playwright + restart: unless-stopped + shm_size: 2gb # 为Chromium分配共享内存 diff --git a/ftp-manager.db b/ftp-manager.db new file mode 100644 index 0000000..e69de29 diff --git a/password_utils.py b/password_utils.py new file mode 100644 index 0000000..79feb66 --- /dev/null +++ b/password_utils.py @@ -0,0 +1,74 @@ +""" +密码哈希工具模块 +支持bcrypt加密和SHA256兼容性验证 +""" +import bcrypt +import hashlib + + +def hash_password_bcrypt(password): + """ + 使用bcrypt加密密码 + + Args: + password: 明文密码 + + Returns: + str: bcrypt哈希值(包含盐值) + """ + salt = bcrypt.gensalt(rounds=12) + return bcrypt.hashpw(password.encode('utf-8'), salt).decode('utf-8') + + +def verify_password_bcrypt(password, password_hash): + """ + 验证bcrypt密码 + + Args: + password: 明文密码 + password_hash: bcrypt哈希值 + + Returns: + bool: 验证成功返回True + """ + try: + return bcrypt.checkpw(password.encode('utf-8'), + password_hash.encode('utf-8')) + except Exception as e: + print(f"bcrypt验证异常: {e}") + return False + + +def is_sha256_hash(password_hash): + """ + 判断是否为旧的SHA256哈希 + + Args: + password_hash: 哈希值 + + Returns: + bool: SHA256哈希为64位十六进制字符串 + """ + if not password_hash: + return False + # SHA256输出固定64位十六进制 + return len(password_hash) == 64 and all(c in '0123456789abcdef' for c in password_hash.lower()) + + +def verify_password_sha256(password, password_hash): + """ + 验证旧的SHA256密码(兼容性) + + Args: + password: 明文密码 + password_hash: SHA256哈希值 + + Returns: + bool: 验证成功返回True + """ + try: + computed_hash = hashlib.sha256(password.encode()).hexdigest() + return computed_hash == password_hash + except Exception as e: + print(f"SHA256验证异常: {e}") + return False diff --git a/playwright_automation.py b/playwright_automation.py new file mode 100755 index 0000000..d0ddb38 --- /dev/null +++ b/playwright_automation.py @@ -0,0 +1,762 @@ +#!/usr/bin/env python3 +# -*- coding: utf-8 -*- +""" +Playwright版本 - 知识管理系统自动化核心 +使用浏览器上下文(Context)实现高性能并发 +""" + +import os +from pathlib import Path +from playwright.sync_api import sync_playwright, Browser, BrowserContext, Page, Playwright +import time +import threading +from typing import Optional, Callable +from dataclasses import dataclass + +# 设置浏览器安装路径(避免Nuitka onefile临时目录问题) +BROWSERS_PATH = str(Path.home() / "AppData" / "Local" / "ms-playwright") +os.environ["PLAYWRIGHT_BROWSERS_PATH"] = BROWSERS_PATH + +# 配置常量 +class Config: + """配置常量""" + LOGIN_URL = "https://postoa.aidunsoft.com/admin/login.aspx" + INDEX_URL_PATTERN = "index.aspx" + + PAGE_LOAD_TIMEOUT = 60000 # 毫秒 (increased from 30s to 60s for multi-account support) + DEFAULT_TIMEOUT = 60000 # 增加超时时间以支持多账号并发 + + MAX_CONCURRENT_CONTEXTS = 100 # 最大并发上下文数 + + +@dataclass +class BrowseResult: + """浏览结果""" + success: bool + total_items: int = 0 + total_attachments: int = 0 + error_message: str = "" + + +class PlaywrightBrowserManager: + """Playwright浏览器管理器 - 每个账号独立的浏览器实例""" + + def __init__(self, headless: bool = True, log_callback: Optional[Callable] = None): + """ + 初始化浏览器管理器 + + Args: + headless: 是否使用无头模式 + log_callback: 日志回调函数,签名: log_callback(message, account_id=None) + """ + self.headless = headless + self.log_callback = log_callback + self._lock = threading.Lock() + + def log(self, message: str, account_id: Optional[str] = None): + """记录日志""" + if self.log_callback: + self.log_callback(message, account_id) + + def create_browser(self, proxy_config=None): + """创建新的独立浏览器实例(每个账号独立)""" + try: + self.log("初始化Playwright实例...") + playwright = sync_playwright().start() + + self.log("启动独立浏览器进程...") + start_time = time.time() + + # 准备浏览器启动参数 + launch_options = { + 'headless': self.headless, + 'args': [ + '--no-sandbox', + '--disable-dev-shm-usage', + '--disable-gpu', + '--disable-extensions', + '--disable-notifications', + '--disable-infobars', + '--disable-default-apps', + '--disable-background-timer-throttling', + '--disable-backgrounding-occluded-windows', + '--disable-renderer-backgrounding', + ] + } + + # 如果有代理配置,添加代理 + if proxy_config and proxy_config.get('server'): + launch_options['proxy'] = { + 'server': proxy_config['server'] + } + self.log(f"使用代理: {proxy_config['server']}") + + browser = playwright.chromium.launch(**launch_options) + + elapsed = time.time() - start_time + self.log(f"独立浏览器启动成功 (耗时: {elapsed:.2f}秒)") + + return playwright, browser + + except Exception as e: + self.log(f"启动浏览器失败: {str(e)}") + raise + + def create_browser_and_context(self, proxy_config=None): + """创建独立的浏览器和上下文(每个账号完全隔离)""" + playwright, browser = self.create_browser(proxy_config) + + start_time = time.time() + self.log("创建浏览器上下文...") + + context = browser.new_context( + viewport={'width': 1920, 'height': 1080}, + user_agent='Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36', + device_scale_factor=2, # 2倍设备像素比,提高文字清晰度 + ) + + # 设置默认超时 + context.set_default_timeout(Config.DEFAULT_TIMEOUT) + context.set_default_navigation_timeout(Config.PAGE_LOAD_TIMEOUT) + + elapsed = time.time() - start_time + self.log(f"上下文创建完成 (耗时: {elapsed:.3f}秒)") + + return playwright, browser, context + + +class PlaywrightAutomation: + """Playwright自动化操作类""" + + def __init__(self, browser_manager: PlaywrightBrowserManager, account_id: str, proxy_config: Optional[dict] = None): + """ + 初始化自动化操作 + + Args: + browser_manager: 浏览器管理器 + account_id: 账号ID(用于日志) + """ + self.browser_manager = browser_manager + self.account_id = account_id + self.proxy_config = proxy_config + self.playwright: Optional[Playwright] = None + self.browser: Optional[Browser] = None + self.context: Optional[BrowserContext] = None + self.page: Optional[Page] = None + self.main_page: Optional[Page] = None + + def log(self, message: str): + """记录日志""" + self.browser_manager.log(message, self.account_id) + + def login(self, username: str, password: str, remember: bool = True) -> bool: + """ + 登录系统 + + Args: + username: 用户名 + password: 密码 + remember: 是否记住密码 + + Returns: + 是否登录成功 + """ + try: + self.log("创建浏览器上下文...") + start_time = time.time() + self.playwright, self.browser, self.context = self.browser_manager.create_browser_and_context(self.proxy_config) + elapsed = time.time() - start_time + self.log(f"浏览器和上下文创建完成 (耗时: {elapsed:.3f}秒)") + + self.log("创建页面...") + self.page = self.context.new_page() + self.main_page = self.page + + self.log("访问登录页面...") + # 使用重试机制处理超时 + max_retries = 2 + for attempt in range(max_retries): + try: + self.page.goto(Config.LOGIN_URL, timeout=60000) + break + except Exception as e: + if attempt < max_retries - 1: + self.log(f"页面加载超时,重试中... ({attempt + 1}/{max_retries})") + time.sleep(2) + else: + raise + + self.log("填写登录信息...") + self.page.fill('#txtUserName', username) + self.page.fill('#txtPassword', password) + + if remember: + self.page.check('#chkRemember') + + self.log("点击登录按钮...") + self.page.click('#btnSubmit') + + # 等待跳转 + self.log("等待登录处理...") + self.page.wait_for_load_state('networkidle', timeout=30000) # 增加到30秒 + + # 检查登录结果 + current_url = self.page.url + self.log(f"当前URL: {current_url}") + + if Config.INDEX_URL_PATTERN in current_url: + self.log("登录成功!") + return True + else: + self.log("登录失败,请检查用户名和密码") + return False + + except Exception as e: + self.log(f"登录过程中出错: {str(e)}") + return False + + def switch_to_iframe(self) -> bool: + """切换到mainframe iframe""" + try: + self.log("查找并切换到iframe...") + + # 使用Playwright的等待机制 + max_retries = 3 + for i in range(max_retries): + try: + # 等待iframe元素出现 + self.main_page.wait_for_selector("iframe[name='mainframe']", timeout=2000) + + # 获取iframe + iframe = self.main_page.frame('mainframe') + if iframe: + self.page = iframe + self.log(f"✓ 成功切换到iframe (尝试 {i+1}/{max_retries})") + return True + except Exception as e: + if i < max_retries - 1: + self.log(f"未找到iframe,重试中... ({i+1}/{max_retries})") + time.sleep(1) + else: + self.log(f"所有重试都失败,未找到iframe") + + return False + + except Exception as e: + self.log(f"切换到iframe时出错: {str(e)}") + return False + + def switch_browse_type(self, browse_type: str, max_retries: int = 2) -> bool: + """ + 切换浏览类型(带重试机制) + + Args: + browse_type: 浏览类型(注册前未读/应读/已读) + max_retries: 最大重试次数(默认2次) + + Returns: + 是否切换成功 + """ + for attempt in range(max_retries + 1): + try: + if attempt > 0: + self.log(f"⚠ 第 {attempt + 1} 次尝试切换浏览类型...") + else: + self.log(f"切换到'{browse_type}'类型...") + + # 切换到iframe + if not self.switch_to_iframe(): + if attempt < max_retries: + self.log(f"iframe切换失败,等待1秒后重试...") + time.sleep(1) + continue + return False + + # 方法1: 尝试查找标签(如果JavaScript创建了的话) + selector = f"//div[contains(@class, 'rule-multi-radio')]//a[contains(text(), '{browse_type}')]" + + try: + # 等待并点击 + self.page.locator(selector).click(timeout=5000) + self.log(f"点击'{browse_type}'按钮成功") + + # 等待页面刷新并加载内容 + time.sleep(1.5) + + # 等待表格加载(最多等待30秒) + try: + self.page.locator("//table[@class='ltable']").wait_for(timeout=30000) + self.log("内容表格已加载") + except Exception as e: + self.log("等待表格加载超时,继续...") + + return True + except Exception as e: + error_msg = str(e) + if "Execution context was destroyed" in error_msg: + self.log(f"⚠ 检测到执行上下文被销毁") + if attempt < max_retries: + self.log(f"等待2秒后重试...") + time.sleep(2) + continue + self.log(f"未找到标签,尝试点击 + + + + diff --git a/templates/index.html b/templates/index.html new file mode 100644 index 0000000..2c73c5d --- /dev/null +++ b/templates/index.html @@ -0,0 +1,2334 @@ + + + + + + 知识管理平台自动化工具 - Web版 + + + + + + + +
+ +
+
+
+
+

知识管理平台

+

基于Playwright的多账号自动化管理系统

+
+
+ 欢迎, +
+ +
+ +
+
+
+ +
+ +
+ +
+
账号管理
+ + + + + +
+ + + +
+ + +
+
+ + +
+
运行统计
+
+
+
0
+
今日完成
+
+
+
0
+
正在运行
+
+
+
0
+
今日失败
+
+
+
0
+
浏览内容
+
+
+
0
+
查看附件
+
+
+
+
+ + +
+ +
+
+ 截图管理 +
+ + +
+
+
+
加载中...
+
+
+
+
+
+ + +
+
+
+

选择浏览类型

+ +
+
+
+ + + +
+ + +
+ +
+
+ +
+
+ + +
+
+
×
+
+ + + +
+
拖动图片 | 滚轮缩放 | 双击重置
+ 预览 +
+
+
+ + + + diff --git a/templates/login.html b/templates/login.html new file mode 100644 index 0000000..c3c12a3 --- /dev/null +++ b/templates/login.html @@ -0,0 +1,480 @@ + + + + + + 用户登录 - 知识管理平台 + + + + + + + + + + + diff --git a/templates/register.html b/templates/register.html new file mode 100644 index 0000000..0dad955 --- /dev/null +++ b/templates/register.html @@ -0,0 +1,258 @@ + + + + + + 用户注册 - 知识管理平台 + + + +
+
+

用户注册

+
+ +
+
+ +
+
+ + + 至少3个字符 +
+ +
+ + + 至少6个字符 +
+ +
+ + +
+ +
+ + + 选填,用于接收审核通知 +
+
+ +
+ + ---- + +
+
+ + +
+ + +
+ + + +