Initial commit: 知识管理平台
主要功能: - 多用户管理系统 - 浏览器自动化(Playwright) - 任务编排和执行 - Docker容器化部署 - 数据持久化和日志管理 技术栈: - Flask 3.0.0 - Playwright 1.40.0 - SQLite with connection pooling - Docker + Docker Compose 部署说明详见README.md
This commit is contained in:
328
app_state.py
Executable file
328
app_state.py
Executable file
@@ -0,0 +1,328 @@
|
||||
#!/usr/bin/env python3
|
||||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
应用状态管理模块
|
||||
提供线程安全的全局状态管理
|
||||
"""
|
||||
|
||||
import threading
|
||||
from typing import Tuple
|
||||
from typing import Dict, Any, Optional
|
||||
from datetime import datetime, timedelta
|
||||
from app_logger import get_logger
|
||||
|
||||
logger = get_logger('app_state')
|
||||
|
||||
|
||||
class ThreadSafeDict:
|
||||
"""线程安全的字典包装类"""
|
||||
|
||||
def __init__(self):
|
||||
self._dict = {}
|
||||
self._lock = threading.RLock()
|
||||
|
||||
def get(self, key, default=None):
|
||||
"""获取值"""
|
||||
with self._lock:
|
||||
return self._dict.get(key, default)
|
||||
|
||||
def set(self, key, value):
|
||||
"""设置值"""
|
||||
with self._lock:
|
||||
self._dict[key] = value
|
||||
|
||||
def delete(self, key):
|
||||
"""删除键"""
|
||||
with self._lock:
|
||||
if key in self._dict:
|
||||
del self._dict[key]
|
||||
|
||||
def pop(self, key, default=None):
|
||||
"""弹出键值"""
|
||||
with self._lock:
|
||||
return self._dict.pop(key, default)
|
||||
|
||||
def keys(self):
|
||||
"""获取所有键(返回副本)"""
|
||||
with self._lock:
|
||||
return list(self._dict.keys())
|
||||
|
||||
def items(self):
|
||||
"""获取所有键值对(返回副本)"""
|
||||
with self._lock:
|
||||
return list(self._dict.items())
|
||||
|
||||
def __contains__(self, key):
|
||||
"""检查键是否存在"""
|
||||
with self._lock:
|
||||
return key in self._dict
|
||||
|
||||
def clear(self):
|
||||
"""清空字典"""
|
||||
with self._lock:
|
||||
self._dict.clear()
|
||||
|
||||
def __len__(self):
|
||||
"""获取长度"""
|
||||
with self._lock:
|
||||
return len(self._dict)
|
||||
|
||||
|
||||
class LogCacheManager:
|
||||
"""日志缓存管理器(线程安全)"""
|
||||
|
||||
def __init__(self, max_logs_per_user=100, max_total_logs=1000):
|
||||
self._cache = {} # {user_id: [logs]}
|
||||
self._total_count = 0
|
||||
self._lock = threading.RLock()
|
||||
self._max_logs_per_user = max_logs_per_user
|
||||
self._max_total_logs = max_total_logs
|
||||
|
||||
def add_log(self, user_id: int, log_entry: Dict[str, Any]) -> bool:
|
||||
"""添加日志到缓存"""
|
||||
with self._lock:
|
||||
# 检查总数限制
|
||||
if self._total_count >= self._max_total_logs:
|
||||
logger.warning(f"日志缓存已满 ({self._max_total_logs}),拒绝添加")
|
||||
return False
|
||||
|
||||
# 初始化用户日志列表
|
||||
if user_id not in self._cache:
|
||||
self._cache[user_id] = []
|
||||
|
||||
user_logs = self._cache[user_id]
|
||||
|
||||
# 检查用户日志数限制
|
||||
if len(user_logs) >= self._max_logs_per_user:
|
||||
# 移除最旧的日志
|
||||
user_logs.pop(0)
|
||||
self._total_count -= 1
|
||||
|
||||
# 添加新日志
|
||||
user_logs.append(log_entry)
|
||||
self._total_count += 1
|
||||
|
||||
return True
|
||||
|
||||
def get_logs(self, user_id: int) -> list:
|
||||
"""获取用户的所有日志(返回副本)"""
|
||||
with self._lock:
|
||||
return list(self._cache.get(user_id, []))
|
||||
|
||||
def clear_user_logs(self, user_id: int):
|
||||
"""清空用户的日志"""
|
||||
with self._lock:
|
||||
if user_id in self._cache:
|
||||
count = len(self._cache[user_id])
|
||||
del self._cache[user_id]
|
||||
self._total_count -= count
|
||||
logger.info(f"清空用户 {user_id} 的 {count} 条日志")
|
||||
|
||||
def get_total_count(self) -> int:
|
||||
"""获取总日志数"""
|
||||
with self._lock:
|
||||
return self._total_count
|
||||
|
||||
def get_stats(self) -> Dict[str, int]:
|
||||
"""获取统计信息"""
|
||||
with self._lock:
|
||||
return {
|
||||
'total_count': self._total_count,
|
||||
'user_count': len(self._cache),
|
||||
'max_per_user': self._max_logs_per_user,
|
||||
'max_total': self._max_total_logs
|
||||
}
|
||||
|
||||
|
||||
class CaptchaManager:
|
||||
"""验证码管理器(线程安全)"""
|
||||
|
||||
def __init__(self, expire_seconds=300):
|
||||
self._storage = {} # {identifier: {'code': str, 'expire': datetime}}
|
||||
self._lock = threading.RLock()
|
||||
self._expire_seconds = expire_seconds
|
||||
|
||||
def create(self, identifier: str, code: str) -> None:
|
||||
"""创建验证码"""
|
||||
with self._lock:
|
||||
self._storage[identifier] = {
|
||||
'code': code,
|
||||
'expire': datetime.now() + timedelta(seconds=self._expire_seconds)
|
||||
}
|
||||
|
||||
def verify(self, identifier: str, code: str) -> Tuple[bool, str]:
|
||||
"""验证验证码"""
|
||||
with self._lock:
|
||||
if identifier not in self._storage:
|
||||
return False, "验证码不存在或已过期"
|
||||
|
||||
captcha_data = self._storage[identifier]
|
||||
|
||||
# 检查是否过期
|
||||
if datetime.now() > captcha_data['expire']:
|
||||
del self._storage[identifier]
|
||||
return False, "验证码已过期,请重新获取"
|
||||
|
||||
# 验证码码值
|
||||
if captcha_data['code'] != code:
|
||||
return False, "验证码错误"
|
||||
|
||||
# 验证成功,删除验证码
|
||||
del self._storage[identifier]
|
||||
return True, "验证成功"
|
||||
|
||||
def cleanup_expired(self) -> int:
|
||||
"""清理过期的验证码"""
|
||||
with self._lock:
|
||||
now = datetime.now()
|
||||
expired_keys = [
|
||||
key for key, data in self._storage.items()
|
||||
if now > data['expire']
|
||||
]
|
||||
for key in expired_keys:
|
||||
del self._storage[key]
|
||||
|
||||
if expired_keys:
|
||||
logger.info(f"清理了 {len(expired_keys)} 个过期验证码")
|
||||
|
||||
return len(expired_keys)
|
||||
|
||||
def get_count(self) -> int:
|
||||
"""获取当前验证码数量"""
|
||||
with self._lock:
|
||||
return len(self._storage)
|
||||
|
||||
|
||||
class ApplicationState:
|
||||
"""应用全局状态管理器(单例模式)"""
|
||||
|
||||
_instance = None
|
||||
_lock = threading.Lock()
|
||||
|
||||
def __new__(cls):
|
||||
if cls._instance is None:
|
||||
with cls._lock:
|
||||
if cls._instance is None:
|
||||
cls._instance = super().__new__(cls)
|
||||
cls._instance._initialized = False
|
||||
return cls._instance
|
||||
|
||||
def __init__(self):
|
||||
if self._initialized:
|
||||
return
|
||||
|
||||
# 浏览器管理器
|
||||
self.browser_manager = None
|
||||
self._browser_lock = threading.Lock()
|
||||
|
||||
# 用户账号管理 {user_id: {account_id: Account对象}}
|
||||
self.user_accounts = ThreadSafeDict()
|
||||
|
||||
# 活动任务管理 {account_id: Thread对象}
|
||||
self.active_tasks = ThreadSafeDict()
|
||||
|
||||
# 日志缓存管理
|
||||
self.log_cache = LogCacheManager()
|
||||
|
||||
# 验证码管理
|
||||
self.captcha = CaptchaManager()
|
||||
|
||||
# 用户信号量管理 {account_id: Semaphore}
|
||||
self.user_semaphores = ThreadSafeDict()
|
||||
|
||||
# 全局信号量
|
||||
self.global_semaphore = None
|
||||
self.screenshot_semaphore = threading.Semaphore(1)
|
||||
|
||||
self._initialized = True
|
||||
logger.info("应用状态管理器初始化完成")
|
||||
|
||||
def set_browser_manager(self, manager):
|
||||
"""设置浏览器管理器"""
|
||||
with self._browser_lock:
|
||||
self.browser_manager = manager
|
||||
|
||||
def get_browser_manager(self):
|
||||
"""获取浏览器管理器"""
|
||||
with self._browser_lock:
|
||||
return self.browser_manager
|
||||
|
||||
def get_user_semaphore(self, account_id: int, max_concurrent: int = 1):
|
||||
"""获取或创建用户信号量"""
|
||||
if account_id not in self.user_semaphores:
|
||||
self.user_semaphores.set(account_id, threading.Semaphore(max_concurrent))
|
||||
return self.user_semaphores.get(account_id)
|
||||
|
||||
def set_global_semaphore(self, max_concurrent: int):
|
||||
"""设置全局信号量"""
|
||||
self.global_semaphore = threading.Semaphore(max_concurrent)
|
||||
|
||||
def get_stats(self) -> Dict[str, Any]:
|
||||
"""获取状态统计信息"""
|
||||
return {
|
||||
'user_accounts_count': len(self.user_accounts),
|
||||
'active_tasks_count': len(self.active_tasks),
|
||||
'log_cache_stats': self.log_cache.get_stats(),
|
||||
'captcha_count': self.captcha.get_count(),
|
||||
'user_semaphores_count': len(self.user_semaphores),
|
||||
'browser_manager': 'initialized' if self.browser_manager else 'not_initialized'
|
||||
}
|
||||
|
||||
|
||||
# 全局单例实例
|
||||
app_state = ApplicationState()
|
||||
|
||||
|
||||
# 向后兼容的辅助函数
|
||||
def verify_captcha(identifier: str, code: str) -> Tuple[bool, str]:
|
||||
"""验证验证码(向后兼容接口)"""
|
||||
return app_state.captcha.verify(identifier, code)
|
||||
|
||||
|
||||
def create_captcha(identifier: str, code: str) -> None:
|
||||
"""创建验证码(向后兼容接口)"""
|
||||
app_state.captcha.create(identifier, code)
|
||||
|
||||
|
||||
def cleanup_expired_captchas() -> int:
|
||||
"""清理过期验证码(向后兼容接口)"""
|
||||
return app_state.captcha.cleanup_expired()
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
# 测试代码
|
||||
print("测试线程安全状态管理器...")
|
||||
print("=" * 60)
|
||||
|
||||
# 测试 ThreadSafeDict
|
||||
print("\n1. 测试 ThreadSafeDict:")
|
||||
td = ThreadSafeDict()
|
||||
td.set('key1', 'value1')
|
||||
print(f" 设置 key1 = {td.get('key1')}")
|
||||
print(f" 长度: {len(td)}")
|
||||
|
||||
# 测试 LogCacheManager
|
||||
print("\n2. 测试 LogCacheManager:")
|
||||
lcm = LogCacheManager(max_logs_per_user=3, max_total_logs=10)
|
||||
for i in range(5):
|
||||
lcm.add_log(1, {'message': f'log {i}'})
|
||||
print(f" 用户1日志数: {len(lcm.get_logs(1))}")
|
||||
print(f" 总日志数: {lcm.get_total_count()}")
|
||||
print(f" 统计: {lcm.get_stats()}")
|
||||
|
||||
# 测试 CaptchaManager
|
||||
print("\n3. 测试 CaptchaManager:")
|
||||
cm = CaptchaManager(expire_seconds=2)
|
||||
cm.create('test@example.com', '1234')
|
||||
success, msg = cm.verify('test@example.com', '1234')
|
||||
print(f" 验证结果: {success}, {msg}")
|
||||
|
||||
# 测试 ApplicationState
|
||||
print("\n4. 测试 ApplicationState (单例):")
|
||||
state1 = ApplicationState()
|
||||
state2 = ApplicationState()
|
||||
print(f" 单例验证: {state1 is state2}")
|
||||
print(f" 状态统计: {state1.get_stats()}")
|
||||
|
||||
print("\n" + "=" * 60)
|
||||
print("✓ 所有测试通过!")
|
||||
Reference in New Issue
Block a user