feat: 知识管理平台精简版 - PyQt6桌面应用
主要功能: - 账号管理:添加/编辑/删除账号,测试登录 - 浏览任务:批量浏览应读/选读内容并标记已读 - 截图管理:wkhtmltoimage截图,查看历史 - 金山文档:扫码登录/微信快捷登录,自动上传截图 技术栈: - PyQt6 GUI框架 - Playwright 浏览器自动化 - SQLite 本地数据存储 - wkhtmltoimage 网页截图 Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
This commit is contained in:
13
utils/__init__.py
Normal file
13
utils/__init__.py
Normal file
@@ -0,0 +1,13 @@
|
||||
#!/usr/bin/env python3
|
||||
# -*- coding: utf-8 -*-
|
||||
"""工具模块"""
|
||||
|
||||
from .storage import load_config, save_config
|
||||
from .crypto import encrypt_password, decrypt_password, is_encrypted
|
||||
from .worker import Worker, WorkerSignals
|
||||
|
||||
__all__ = [
|
||||
'load_config', 'save_config',
|
||||
'encrypt_password', 'decrypt_password', 'is_encrypted',
|
||||
'Worker', 'WorkerSignals'
|
||||
]
|
||||
156
utils/crypto.py
Normal file
156
utils/crypto.py
Normal file
@@ -0,0 +1,156 @@
|
||||
#!/usr/bin/env python3
|
||||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
加密工具模块 - 精简版
|
||||
用于加密存储敏感信息(如第三方账号密码)
|
||||
使用Fernet对称加密
|
||||
"""
|
||||
|
||||
import os
|
||||
import base64
|
||||
from pathlib import Path
|
||||
from cryptography.fernet import Fernet
|
||||
from cryptography.hazmat.primitives import hashes
|
||||
from cryptography.hazmat.primitives.kdf.pbkdf2 import PBKDF2HMAC
|
||||
|
||||
|
||||
def _get_key_paths():
|
||||
"""获取密钥文件路径"""
|
||||
from config import ENCRYPTION_KEY_FILE, ENCRYPTION_SALT_FILE
|
||||
return ENCRYPTION_KEY_FILE, ENCRYPTION_SALT_FILE
|
||||
|
||||
|
||||
def _get_or_create_salt() -> bytes:
|
||||
"""获取或创建盐值"""
|
||||
_, salt_path = _get_key_paths()
|
||||
|
||||
if salt_path.exists():
|
||||
with open(salt_path, 'rb') as f:
|
||||
return f.read()
|
||||
|
||||
# 生成新的盐值
|
||||
salt = os.urandom(16)
|
||||
salt_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
with open(salt_path, 'wb') as f:
|
||||
f.write(salt)
|
||||
return salt
|
||||
|
||||
|
||||
def _derive_key(password: bytes, salt: bytes) -> bytes:
|
||||
"""从密码派生加密密钥"""
|
||||
kdf = PBKDF2HMAC(
|
||||
algorithm=hashes.SHA256(),
|
||||
length=32,
|
||||
salt=salt,
|
||||
iterations=480000, # OWASP推荐的迭代次数
|
||||
)
|
||||
return base64.urlsafe_b64encode(kdf.derive(password))
|
||||
|
||||
|
||||
def get_encryption_key() -> bytes:
|
||||
"""获取加密密钥(优先环境变量,否则从文件读取或生成)"""
|
||||
key_path, _ = _get_key_paths()
|
||||
|
||||
# 优先从环境变量读取
|
||||
env_key = os.environ.get('ENCRYPTION_KEY')
|
||||
if env_key:
|
||||
salt = _get_or_create_salt()
|
||||
return _derive_key(env_key.encode(), salt)
|
||||
|
||||
# 从文件读取
|
||||
if key_path.exists():
|
||||
with open(key_path, 'rb') as f:
|
||||
return f.read()
|
||||
|
||||
# 生成新的密钥
|
||||
key = Fernet.generate_key()
|
||||
key_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
with open(key_path, 'wb') as f:
|
||||
f.write(key)
|
||||
print(f"[OK] 已生成新的加密密钥")
|
||||
return key
|
||||
|
||||
|
||||
# 全局Fernet实例
|
||||
_fernet = None
|
||||
|
||||
|
||||
def _get_fernet() -> Fernet:
|
||||
"""获取Fernet加密器(懒加载)"""
|
||||
global _fernet
|
||||
if _fernet is None:
|
||||
key = get_encryption_key()
|
||||
_fernet = Fernet(key)
|
||||
return _fernet
|
||||
|
||||
|
||||
def encrypt_password(plain_password: str) -> str:
|
||||
"""
|
||||
加密密码
|
||||
|
||||
Args:
|
||||
plain_password: 明文密码
|
||||
|
||||
Returns:
|
||||
str: 加密后的密码(base64编码)
|
||||
"""
|
||||
if not plain_password:
|
||||
return ''
|
||||
|
||||
fernet = _get_fernet()
|
||||
encrypted = fernet.encrypt(plain_password.encode('utf-8'))
|
||||
return encrypted.decode('utf-8')
|
||||
|
||||
|
||||
def decrypt_password(encrypted_password: str) -> str:
|
||||
"""
|
||||
解密密码
|
||||
|
||||
Args:
|
||||
encrypted_password: 加密的密码
|
||||
|
||||
Returns:
|
||||
str: 明文密码
|
||||
"""
|
||||
if not encrypted_password:
|
||||
return ''
|
||||
|
||||
try:
|
||||
fernet = _get_fernet()
|
||||
decrypted = fernet.decrypt(encrypted_password.encode('utf-8'))
|
||||
return decrypted.decode('utf-8')
|
||||
except Exception as e:
|
||||
# 解密失败,可能是旧的明文密码
|
||||
print(f"[Warning] 密码解密失败,可能是未加密的旧数据: {e}")
|
||||
return encrypted_password
|
||||
|
||||
|
||||
def is_encrypted(password: str) -> bool:
|
||||
"""
|
||||
检查密码是否已加密
|
||||
Fernet加密的数据以'gAAAAA'开头
|
||||
|
||||
Args:
|
||||
password: 要检查的密码
|
||||
|
||||
Returns:
|
||||
bool: 是否已加密
|
||||
"""
|
||||
if not password:
|
||||
return False
|
||||
return password.startswith('gAAAAA')
|
||||
|
||||
|
||||
def migrate_password(password: str) -> str:
|
||||
"""
|
||||
迁移密码:如果是明文则加密,如果已加密则保持不变
|
||||
|
||||
Args:
|
||||
password: 密码(可能是明文或已加密)
|
||||
|
||||
Returns:
|
||||
str: 加密后的密码
|
||||
"""
|
||||
if is_encrypted(password):
|
||||
return password
|
||||
return encrypt_password(password)
|
||||
398
utils/storage.py
Normal file
398
utils/storage.py
Normal file
@@ -0,0 +1,398 @@
|
||||
#!/usr/bin/env python3
|
||||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
SQLite storage module - local database for config and accounts
|
||||
"""
|
||||
|
||||
import sqlite3
|
||||
import json
|
||||
import os
|
||||
from pathlib import Path
|
||||
from typing import TYPE_CHECKING, List, Optional
|
||||
from contextlib import contextmanager
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from config import AppConfig, AccountConfig
|
||||
|
||||
|
||||
def _get_db_path() -> Path:
|
||||
"""Get database file path"""
|
||||
from config import DATA_DIR
|
||||
return DATA_DIR / "zsglpt.db"
|
||||
|
||||
|
||||
@contextmanager
|
||||
def get_connection():
|
||||
"""Get database connection with context manager"""
|
||||
db_path = _get_db_path()
|
||||
db_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
conn = sqlite3.connect(str(db_path))
|
||||
conn.row_factory = sqlite3.Row
|
||||
try:
|
||||
yield conn
|
||||
conn.commit()
|
||||
except Exception:
|
||||
conn.rollback()
|
||||
raise
|
||||
finally:
|
||||
conn.close()
|
||||
|
||||
|
||||
def init_database():
|
||||
"""Initialize database tables"""
|
||||
with get_connection() as conn:
|
||||
cursor = conn.cursor()
|
||||
|
||||
# Accounts table
|
||||
cursor.execute('''
|
||||
CREATE TABLE IF NOT EXISTS accounts (
|
||||
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
||||
username TEXT NOT NULL UNIQUE,
|
||||
password TEXT NOT NULL,
|
||||
remark TEXT DEFAULT '',
|
||||
enabled INTEGER DEFAULT 1,
|
||||
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
|
||||
updated_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP
|
||||
)
|
||||
''')
|
||||
|
||||
# Settings table (key-value store)
|
||||
cursor.execute('''
|
||||
CREATE TABLE IF NOT EXISTS settings (
|
||||
key TEXT PRIMARY KEY,
|
||||
value TEXT NOT NULL,
|
||||
updated_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP
|
||||
)
|
||||
''')
|
||||
|
||||
# Create index
|
||||
cursor.execute('''
|
||||
CREATE INDEX IF NOT EXISTS idx_accounts_username ON accounts(username)
|
||||
''')
|
||||
|
||||
|
||||
def _ensure_db():
|
||||
"""Ensure database is initialized"""
|
||||
db_path = _get_db_path()
|
||||
if not db_path.exists():
|
||||
init_database()
|
||||
|
||||
|
||||
# ==================== Account Operations ====================
|
||||
|
||||
def get_all_accounts() -> List[dict]:
|
||||
"""Get all accounts from database"""
|
||||
_ensure_db()
|
||||
with get_connection() as conn:
|
||||
cursor = conn.cursor()
|
||||
cursor.execute('SELECT * FROM accounts ORDER BY id')
|
||||
rows = cursor.fetchall()
|
||||
return [dict(row) for row in rows]
|
||||
|
||||
|
||||
def get_account_by_username(username: str) -> Optional[dict]:
|
||||
"""Get account by username"""
|
||||
_ensure_db()
|
||||
with get_connection() as conn:
|
||||
cursor = conn.cursor()
|
||||
cursor.execute('SELECT * FROM accounts WHERE username = ?', (username,))
|
||||
row = cursor.fetchone()
|
||||
return dict(row) if row else None
|
||||
|
||||
|
||||
def add_account(username: str, password: str, remark: str = '', enabled: bool = True) -> int:
|
||||
"""Add new account, returns account id"""
|
||||
_ensure_db()
|
||||
with get_connection() as conn:
|
||||
cursor = conn.cursor()
|
||||
cursor.execute('''
|
||||
INSERT INTO accounts (username, password, remark, enabled)
|
||||
VALUES (?, ?, ?, ?)
|
||||
''', (username, password, remark, 1 if enabled else 0))
|
||||
return cursor.lastrowid
|
||||
|
||||
|
||||
def update_account(account_id: int, username: str, password: str, remark: str, enabled: bool):
|
||||
"""Update existing account"""
|
||||
_ensure_db()
|
||||
with get_connection() as conn:
|
||||
cursor = conn.cursor()
|
||||
cursor.execute('''
|
||||
UPDATE accounts
|
||||
SET username = ?, password = ?, remark = ?, enabled = ?, updated_at = CURRENT_TIMESTAMP
|
||||
WHERE id = ?
|
||||
''', (username, password, remark, 1 if enabled else 0, account_id))
|
||||
|
||||
|
||||
def delete_account(account_id: int):
|
||||
"""Delete account by id"""
|
||||
_ensure_db()
|
||||
with get_connection() as conn:
|
||||
cursor = conn.cursor()
|
||||
cursor.execute('DELETE FROM accounts WHERE id = ?', (account_id,))
|
||||
|
||||
|
||||
def delete_account_by_username(username: str):
|
||||
"""Delete account by username"""
|
||||
_ensure_db()
|
||||
with get_connection() as conn:
|
||||
cursor = conn.cursor()
|
||||
cursor.execute('DELETE FROM accounts WHERE username = ?', (username,))
|
||||
|
||||
|
||||
# ==================== Settings Operations ====================
|
||||
|
||||
def get_setting(key: str, default: str = '') -> str:
|
||||
"""Get setting value by key"""
|
||||
_ensure_db()
|
||||
with get_connection() as conn:
|
||||
cursor = conn.cursor()
|
||||
cursor.execute('SELECT value FROM settings WHERE key = ?', (key,))
|
||||
row = cursor.fetchone()
|
||||
return row['value'] if row else default
|
||||
|
||||
|
||||
def set_setting(key: str, value: str):
|
||||
"""Set setting value"""
|
||||
_ensure_db()
|
||||
with get_connection() as conn:
|
||||
cursor = conn.cursor()
|
||||
cursor.execute('''
|
||||
INSERT OR REPLACE INTO settings (key, value, updated_at)
|
||||
VALUES (?, ?, CURRENT_TIMESTAMP)
|
||||
''', (key, value))
|
||||
|
||||
|
||||
def get_all_settings() -> dict:
|
||||
"""Get all settings as dictionary"""
|
||||
_ensure_db()
|
||||
with get_connection() as conn:
|
||||
cursor = conn.cursor()
|
||||
cursor.execute('SELECT key, value FROM settings')
|
||||
rows = cursor.fetchall()
|
||||
return {row['key']: row['value'] for row in rows}
|
||||
|
||||
|
||||
# ==================== Config Bridge (compatibility with existing code) ====================
|
||||
|
||||
def load_config() -> "AppConfig":
|
||||
"""
|
||||
Load config from SQLite database
|
||||
Returns AppConfig object for compatibility
|
||||
"""
|
||||
from config import AppConfig, AccountConfig, KDocsConfig, ScreenshotConfig, ProxyConfig, ZSGLConfig, SCREENSHOTS_DIR
|
||||
|
||||
_ensure_db()
|
||||
|
||||
config = AppConfig()
|
||||
|
||||
# Load accounts
|
||||
accounts = get_all_accounts()
|
||||
config.accounts = [
|
||||
AccountConfig(
|
||||
username=a['username'],
|
||||
password=a['password'],
|
||||
remark=a['remark'] or '',
|
||||
enabled=bool(a['enabled'])
|
||||
) for a in accounts
|
||||
]
|
||||
|
||||
# Load settings
|
||||
settings = get_all_settings()
|
||||
|
||||
# KDocs config - 默认文档链接
|
||||
DEFAULT_KDOCS_URL = 'https://kdocs.cn/l/cpwEOo5ynKX4'
|
||||
config.kdocs = KDocsConfig(
|
||||
enabled=settings.get('kdocs_enabled', 'false').lower() == 'true',
|
||||
doc_url=settings.get('kdocs_doc_url', '') or DEFAULT_KDOCS_URL, # 空字符串也用默认值
|
||||
sheet_name=settings.get('kdocs_sheet_name', 'Sheet1'),
|
||||
sheet_index=int(settings.get('kdocs_sheet_index', '0')),
|
||||
unit_column=settings.get('kdocs_unit_column', 'A'),
|
||||
image_column=settings.get('kdocs_image_column', 'D'),
|
||||
unit=settings.get('kdocs_unit', ''),
|
||||
name_column=settings.get('kdocs_name_column', 'C'),
|
||||
row_start=int(settings.get('kdocs_row_start', '0')),
|
||||
row_end=int(settings.get('kdocs_row_end', '0')),
|
||||
)
|
||||
|
||||
# Screenshot config
|
||||
config.screenshot = ScreenshotConfig(
|
||||
dir=settings.get('screenshot_dir', str(SCREENSHOTS_DIR)),
|
||||
quality=int(settings.get('screenshot_quality', '95')),
|
||||
width=int(settings.get('screenshot_width', '1920')),
|
||||
height=int(settings.get('screenshot_height', '1080')),
|
||||
js_delay_ms=int(settings.get('screenshot_js_delay_ms', '3000')),
|
||||
timeout_seconds=int(settings.get('screenshot_timeout_seconds', '60')),
|
||||
wkhtmltoimage_path=settings.get('screenshot_wkhtmltoimage_path', ''),
|
||||
)
|
||||
|
||||
# Proxy config
|
||||
config.proxy = ProxyConfig(
|
||||
enabled=settings.get('proxy_enabled', 'false').lower() == 'true',
|
||||
server=settings.get('proxy_server', ''),
|
||||
)
|
||||
|
||||
# ZSGL config
|
||||
config.zsgl = ZSGLConfig(
|
||||
base_url=settings.get('zsgl_base_url', 'https://postoa.aidunsoft.com'),
|
||||
login_url=settings.get('zsgl_login_url', 'https://postoa.aidunsoft.com/admin/login.aspx'),
|
||||
index_url_pattern=settings.get('zsgl_index_url_pattern', 'index.aspx'),
|
||||
)
|
||||
|
||||
# Theme
|
||||
config.theme = settings.get('theme', 'light')
|
||||
|
||||
return config
|
||||
|
||||
|
||||
def save_config(config: "AppConfig") -> bool:
|
||||
"""
|
||||
Save config to SQLite database
|
||||
"""
|
||||
_ensure_db()
|
||||
|
||||
try:
|
||||
with get_connection() as conn:
|
||||
cursor = conn.cursor()
|
||||
|
||||
# Save accounts - first get existing accounts
|
||||
existing_usernames = set()
|
||||
cursor.execute('SELECT username FROM accounts')
|
||||
for row in cursor.fetchall():
|
||||
existing_usernames.add(row['username'])
|
||||
|
||||
# Update or insert accounts
|
||||
config_usernames = set()
|
||||
for account in config.accounts:
|
||||
config_usernames.add(account.username)
|
||||
|
||||
if account.username in existing_usernames:
|
||||
# Update existing
|
||||
cursor.execute('''
|
||||
UPDATE accounts
|
||||
SET password = ?, remark = ?, enabled = ?, updated_at = CURRENT_TIMESTAMP
|
||||
WHERE username = ?
|
||||
''', (account.password, account.remark, 1 if account.enabled else 0, account.username))
|
||||
else:
|
||||
# Insert new
|
||||
cursor.execute('''
|
||||
INSERT INTO accounts (username, password, remark, enabled)
|
||||
VALUES (?, ?, ?, ?)
|
||||
''', (account.username, account.password, account.remark, 1 if account.enabled else 0))
|
||||
|
||||
# Delete removed accounts
|
||||
removed = existing_usernames - config_usernames
|
||||
for username in removed:
|
||||
cursor.execute('DELETE FROM accounts WHERE username = ?', (username,))
|
||||
|
||||
# Save settings
|
||||
settings_to_save = {
|
||||
# KDocs
|
||||
'kdocs_enabled': str(config.kdocs.enabled).lower(),
|
||||
'kdocs_doc_url': config.kdocs.doc_url,
|
||||
'kdocs_sheet_name': config.kdocs.sheet_name,
|
||||
'kdocs_sheet_index': str(config.kdocs.sheet_index),
|
||||
'kdocs_unit_column': config.kdocs.unit_column,
|
||||
'kdocs_image_column': config.kdocs.image_column,
|
||||
'kdocs_unit': config.kdocs.unit,
|
||||
'kdocs_name_column': config.kdocs.name_column,
|
||||
'kdocs_row_start': str(config.kdocs.row_start),
|
||||
'kdocs_row_end': str(config.kdocs.row_end),
|
||||
|
||||
# Screenshot
|
||||
'screenshot_dir': config.screenshot.dir,
|
||||
'screenshot_quality': str(config.screenshot.quality),
|
||||
'screenshot_width': str(config.screenshot.width),
|
||||
'screenshot_height': str(config.screenshot.height),
|
||||
'screenshot_js_delay_ms': str(config.screenshot.js_delay_ms),
|
||||
'screenshot_timeout_seconds': str(config.screenshot.timeout_seconds),
|
||||
'screenshot_wkhtmltoimage_path': config.screenshot.wkhtmltoimage_path,
|
||||
|
||||
# Proxy
|
||||
'proxy_enabled': str(config.proxy.enabled).lower(),
|
||||
'proxy_server': config.proxy.server,
|
||||
|
||||
# ZSGL
|
||||
'zsgl_base_url': config.zsgl.base_url,
|
||||
'zsgl_login_url': config.zsgl.login_url,
|
||||
'zsgl_index_url_pattern': config.zsgl.index_url_pattern,
|
||||
|
||||
# Theme
|
||||
'theme': config.theme,
|
||||
}
|
||||
|
||||
for key, value in settings_to_save.items():
|
||||
cursor.execute('''
|
||||
INSERT OR REPLACE INTO settings (key, value, updated_at)
|
||||
VALUES (?, ?, CURRENT_TIMESTAMP)
|
||||
''', (key, value))
|
||||
|
||||
return True
|
||||
except Exception as e:
|
||||
print(f"[Error] Save config failed: {e}")
|
||||
return False
|
||||
|
||||
|
||||
def backup_config() -> bool:
|
||||
"""Backup database file"""
|
||||
db_path = _get_db_path()
|
||||
|
||||
if not db_path.exists():
|
||||
return False
|
||||
|
||||
backup_path = db_path.with_suffix('.db.bak')
|
||||
|
||||
try:
|
||||
import shutil
|
||||
shutil.copy2(db_path, backup_path)
|
||||
return True
|
||||
except IOError as e:
|
||||
print(f"[Error] Backup failed: {e}")
|
||||
return False
|
||||
|
||||
|
||||
def restore_config() -> bool:
|
||||
"""Restore database from backup"""
|
||||
db_path = _get_db_path()
|
||||
backup_path = db_path.with_suffix('.db.bak')
|
||||
|
||||
if not backup_path.exists():
|
||||
return False
|
||||
|
||||
try:
|
||||
import shutil
|
||||
shutil.copy2(backup_path, db_path)
|
||||
return True
|
||||
except IOError as e:
|
||||
print(f"[Error] Restore failed: {e}")
|
||||
return False
|
||||
|
||||
|
||||
def migrate_from_json():
|
||||
"""Migrate data from old JSON config to SQLite"""
|
||||
from config import CONFIG_FILE
|
||||
|
||||
if not CONFIG_FILE.exists():
|
||||
return False
|
||||
|
||||
try:
|
||||
with open(CONFIG_FILE, 'r', encoding='utf-8') as f:
|
||||
data = json.load(f)
|
||||
|
||||
# Load using old format
|
||||
from config import AppConfig
|
||||
old_config = AppConfig.from_dict(data)
|
||||
|
||||
# Save to SQLite
|
||||
save_config(old_config)
|
||||
|
||||
# Rename old file
|
||||
backup = CONFIG_FILE.with_suffix('.json.migrated')
|
||||
CONFIG_FILE.rename(backup)
|
||||
|
||||
print(f"[Info] Migrated from JSON to SQLite, old file renamed to {backup}")
|
||||
return True
|
||||
except Exception as e:
|
||||
print(f"[Error] Migration failed: {e}")
|
||||
return False
|
||||
193
utils/worker.py
Normal file
193
utils/worker.py
Normal file
@@ -0,0 +1,193 @@
|
||||
#!/usr/bin/env python3
|
||||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
后台线程管理模块
|
||||
使用QThread实现非阻塞的后台任务
|
||||
"""
|
||||
|
||||
from typing import Callable, Any, Optional
|
||||
from PyQt6.QtCore import QObject, QThread, pyqtSignal
|
||||
|
||||
|
||||
class WorkerSignals(QObject):
|
||||
"""工作线程信号类"""
|
||||
# 进度信号:(百分比, 消息)
|
||||
progress = pyqtSignal(int, str)
|
||||
# 日志信号:日志消息
|
||||
log = pyqtSignal(str)
|
||||
# 完成信号:(成功/失败, 结果/错误消息)
|
||||
finished = pyqtSignal(bool, str)
|
||||
# 截图完成信号:截图文件路径
|
||||
screenshot_ready = pyqtSignal(str)
|
||||
# 通用结果信号:任意结果对象
|
||||
result = pyqtSignal(object)
|
||||
# 错误信号
|
||||
error = pyqtSignal(str)
|
||||
|
||||
|
||||
class Worker(QThread):
|
||||
"""
|
||||
通用后台工作线程
|
||||
|
||||
用法:
|
||||
worker = Worker(some_function, arg1, arg2, kwarg1=value1)
|
||||
worker.signals.finished.connect(on_finished)
|
||||
worker.signals.progress.connect(on_progress)
|
||||
worker.start()
|
||||
"""
|
||||
|
||||
def __init__(self, fn: Callable, *args, **kwargs):
|
||||
super().__init__()
|
||||
self.fn = fn
|
||||
self.args = args
|
||||
self.kwargs = kwargs
|
||||
self.signals = WorkerSignals()
|
||||
self._is_stopped = False
|
||||
|
||||
def run(self):
|
||||
"""执行后台任务"""
|
||||
try:
|
||||
# 把信号传递给任务函数,方便回调
|
||||
self.kwargs['_signals'] = self.signals
|
||||
self.kwargs['_should_stop'] = self.should_stop
|
||||
|
||||
result = self.fn(*self.args, **self.kwargs)
|
||||
|
||||
if not self._is_stopped:
|
||||
self.signals.result.emit(result)
|
||||
self.signals.finished.emit(True, "完成")
|
||||
except Exception as e:
|
||||
if not self._is_stopped:
|
||||
error_msg = str(e)
|
||||
self.signals.error.emit(error_msg)
|
||||
self.signals.finished.emit(False, error_msg)
|
||||
|
||||
def stop(self):
|
||||
"""停止线程"""
|
||||
self._is_stopped = True
|
||||
|
||||
def should_stop(self) -> bool:
|
||||
"""检查是否应该停止"""
|
||||
return self._is_stopped
|
||||
|
||||
|
||||
class TaskRunner:
|
||||
"""
|
||||
任务运行器
|
||||
管理多个后台任务,确保不会同时运行太多任务
|
||||
"""
|
||||
|
||||
def __init__(self, max_workers: int = 1):
|
||||
self.max_workers = max_workers
|
||||
self._workers: list[Worker] = []
|
||||
self._queue: list[tuple] = [] # (fn, args, kwargs, callbacks)
|
||||
|
||||
def submit(self, fn: Callable, *args,
|
||||
on_progress: Optional[Callable] = None,
|
||||
on_log: Optional[Callable] = None,
|
||||
on_finished: Optional[Callable] = None,
|
||||
on_result: Optional[Callable] = None,
|
||||
on_error: Optional[Callable] = None,
|
||||
**kwargs) -> Optional[Worker]:
|
||||
"""
|
||||
提交任务
|
||||
|
||||
Args:
|
||||
fn: 要执行的函数
|
||||
*args: 位置参数
|
||||
on_progress: 进度回调 (percent, message)
|
||||
on_log: 日志回调 (message)
|
||||
on_finished: 完成回调 (success, message)
|
||||
on_result: 结果回调 (result)
|
||||
on_error: 错误回调 (error_message)
|
||||
**kwargs: 关键字参数
|
||||
|
||||
Returns:
|
||||
Worker对象,如果队列满了返回None
|
||||
"""
|
||||
callbacks = {
|
||||
'on_progress': on_progress,
|
||||
'on_log': on_log,
|
||||
'on_finished': on_finished,
|
||||
'on_result': on_result,
|
||||
'on_error': on_error,
|
||||
}
|
||||
|
||||
# 清理已完成的worker
|
||||
self._workers = [w for w in self._workers if w.isRunning()]
|
||||
|
||||
if len(self._workers) >= self.max_workers:
|
||||
# 加入队列等待
|
||||
self._queue.append((fn, args, kwargs, callbacks))
|
||||
return None
|
||||
|
||||
return self._start_worker(fn, args, kwargs, callbacks)
|
||||
|
||||
def _start_worker(self, fn: Callable, args: tuple, kwargs: dict,
|
||||
callbacks: dict) -> Worker:
|
||||
"""启动一个worker"""
|
||||
worker = Worker(fn, *args, **kwargs)
|
||||
|
||||
# 连接信号
|
||||
if callbacks.get('on_progress'):
|
||||
worker.signals.progress.connect(callbacks['on_progress'])
|
||||
if callbacks.get('on_log'):
|
||||
worker.signals.log.connect(callbacks['on_log'])
|
||||
if callbacks.get('on_result'):
|
||||
worker.signals.result.connect(callbacks['on_result'])
|
||||
if callbacks.get('on_error'):
|
||||
worker.signals.error.connect(callbacks['on_error'])
|
||||
|
||||
# 完成时的处理
|
||||
def on_worker_finished(success, message):
|
||||
if callbacks.get('on_finished'):
|
||||
callbacks['on_finished'](success, message)
|
||||
# 处理队列中的下一个任务
|
||||
self._process_queue()
|
||||
|
||||
worker.signals.finished.connect(on_worker_finished)
|
||||
|
||||
self._workers.append(worker)
|
||||
worker.start()
|
||||
return worker
|
||||
|
||||
def _process_queue(self):
|
||||
"""处理队列中的任务"""
|
||||
# 清理已完成的worker
|
||||
self._workers = [w for w in self._workers if w.isRunning()]
|
||||
|
||||
if self._queue and len(self._workers) < self.max_workers:
|
||||
fn, args, kwargs, callbacks = self._queue.pop(0)
|
||||
self._start_worker(fn, args, kwargs, callbacks)
|
||||
|
||||
def stop_all(self):
|
||||
"""停止所有任务"""
|
||||
for worker in self._workers:
|
||||
worker.stop()
|
||||
self._queue.clear()
|
||||
|
||||
def is_running(self) -> bool:
|
||||
"""检查是否有任务在运行"""
|
||||
return any(w.isRunning() for w in self._workers)
|
||||
|
||||
@property
|
||||
def running_count(self) -> int:
|
||||
"""运行中的任务数"""
|
||||
return sum(1 for w in self._workers if w.isRunning())
|
||||
|
||||
@property
|
||||
def queue_size(self) -> int:
|
||||
"""队列中等待的任务数"""
|
||||
return len(self._queue)
|
||||
|
||||
|
||||
# 全局任务运行器
|
||||
_task_runner: Optional[TaskRunner] = None
|
||||
|
||||
|
||||
def get_task_runner() -> TaskRunner:
|
||||
"""获取全局任务运行器"""
|
||||
global _task_runner
|
||||
if _task_runner is None:
|
||||
_task_runner = TaskRunner(max_workers=1)
|
||||
return _task_runner
|
||||
Reference in New Issue
Block a user