333 lines
10 KiB
Python
Executable File
333 lines
10 KiB
Python
Executable File
#!/usr/bin/env python3
|
||
# -*- coding: utf-8 -*-
|
||
"""
|
||
应用状态管理模块
|
||
提供线程安全的全局状态管理
|
||
|
||
说明(P0–P3 优化后):
|
||
- 该模块为历史遗留实现,保留用于兼容与参考
|
||
- 当前实际生效的全局状态入口为 `services/state.py`(safe_* API)
|
||
"""
|
||
|
||
import threading
|
||
from typing import Tuple
|
||
from typing import Dict, Any, Optional
|
||
from datetime import datetime, timedelta
|
||
from app_logger import get_logger
|
||
|
||
logger = get_logger('app_state')
|
||
|
||
|
||
class ThreadSafeDict:
|
||
"""线程安全的字典包装类"""
|
||
|
||
def __init__(self):
|
||
self._dict = {}
|
||
self._lock = threading.RLock()
|
||
|
||
def get(self, key, default=None):
|
||
"""获取值"""
|
||
with self._lock:
|
||
return self._dict.get(key, default)
|
||
|
||
def set(self, key, value):
|
||
"""设置值"""
|
||
with self._lock:
|
||
self._dict[key] = value
|
||
|
||
def delete(self, key):
|
||
"""删除键"""
|
||
with self._lock:
|
||
if key in self._dict:
|
||
del self._dict[key]
|
||
|
||
def pop(self, key, default=None):
|
||
"""弹出键值"""
|
||
with self._lock:
|
||
return self._dict.pop(key, default)
|
||
|
||
def keys(self):
|
||
"""获取所有键(返回副本)"""
|
||
with self._lock:
|
||
return list(self._dict.keys())
|
||
|
||
def items(self):
|
||
"""获取所有键值对(返回副本)"""
|
||
with self._lock:
|
||
return list(self._dict.items())
|
||
|
||
def __contains__(self, key):
|
||
"""检查键是否存在"""
|
||
with self._lock:
|
||
return key in self._dict
|
||
|
||
def clear(self):
|
||
"""清空字典"""
|
||
with self._lock:
|
||
self._dict.clear()
|
||
|
||
def __len__(self):
|
||
"""获取长度"""
|
||
with self._lock:
|
||
return len(self._dict)
|
||
|
||
|
||
class LogCacheManager:
|
||
"""日志缓存管理器(线程安全)"""
|
||
|
||
def __init__(self, max_logs_per_user=100, max_total_logs=1000):
|
||
self._cache = {} # {user_id: [logs]}
|
||
self._total_count = 0
|
||
self._lock = threading.RLock()
|
||
self._max_logs_per_user = max_logs_per_user
|
||
self._max_total_logs = max_total_logs
|
||
|
||
def add_log(self, user_id: int, log_entry: Dict[str, Any]) -> bool:
|
||
"""添加日志到缓存"""
|
||
with self._lock:
|
||
# 检查总数限制
|
||
if self._total_count >= self._max_total_logs:
|
||
logger.warning(f"日志缓存已满 ({self._max_total_logs}),拒绝添加")
|
||
return False
|
||
|
||
# 初始化用户日志列表
|
||
if user_id not in self._cache:
|
||
self._cache[user_id] = []
|
||
|
||
user_logs = self._cache[user_id]
|
||
|
||
# 检查用户日志数限制
|
||
if len(user_logs) >= self._max_logs_per_user:
|
||
# 移除最旧的日志
|
||
user_logs.pop(0)
|
||
self._total_count -= 1
|
||
|
||
# 添加新日志
|
||
user_logs.append(log_entry)
|
||
self._total_count += 1
|
||
|
||
return True
|
||
|
||
def get_logs(self, user_id: int) -> list:
|
||
"""获取用户的所有日志(返回副本)"""
|
||
with self._lock:
|
||
return list(self._cache.get(user_id, []))
|
||
|
||
def clear_user_logs(self, user_id: int):
|
||
"""清空用户的日志"""
|
||
with self._lock:
|
||
if user_id in self._cache:
|
||
count = len(self._cache[user_id])
|
||
del self._cache[user_id]
|
||
self._total_count -= count
|
||
logger.info(f"清空用户 {user_id} 的 {count} 条日志")
|
||
|
||
def get_total_count(self) -> int:
|
||
"""获取总日志数"""
|
||
with self._lock:
|
||
return self._total_count
|
||
|
||
def get_stats(self) -> Dict[str, int]:
|
||
"""获取统计信息"""
|
||
with self._lock:
|
||
return {
|
||
'total_count': self._total_count,
|
||
'user_count': len(self._cache),
|
||
'max_per_user': self._max_logs_per_user,
|
||
'max_total': self._max_total_logs
|
||
}
|
||
|
||
|
||
class CaptchaManager:
|
||
"""验证码管理器(线程安全)"""
|
||
|
||
def __init__(self, expire_seconds=300):
|
||
self._storage = {} # {identifier: {'code': str, 'expire': datetime}}
|
||
self._lock = threading.RLock()
|
||
self._expire_seconds = expire_seconds
|
||
|
||
def create(self, identifier: str, code: str) -> None:
|
||
"""创建验证码"""
|
||
with self._lock:
|
||
self._storage[identifier] = {
|
||
'code': code,
|
||
'expire': datetime.now() + timedelta(seconds=self._expire_seconds)
|
||
}
|
||
|
||
def verify(self, identifier: str, code: str) -> Tuple[bool, str]:
|
||
"""验证验证码"""
|
||
with self._lock:
|
||
if identifier not in self._storage:
|
||
return False, "验证码不存在或已过期"
|
||
|
||
captcha_data = self._storage[identifier]
|
||
|
||
# 检查是否过期
|
||
if datetime.now() > captcha_data['expire']:
|
||
del self._storage[identifier]
|
||
return False, "验证码已过期,请重新获取"
|
||
|
||
# 验证码码值
|
||
if captcha_data['code'] != code:
|
||
return False, "验证码错误"
|
||
|
||
# 验证成功,删除验证码
|
||
del self._storage[identifier]
|
||
return True, "验证成功"
|
||
|
||
def cleanup_expired(self) -> int:
|
||
"""清理过期的验证码"""
|
||
with self._lock:
|
||
now = datetime.now()
|
||
expired_keys = [
|
||
key for key, data in self._storage.items()
|
||
if now > data['expire']
|
||
]
|
||
for key in expired_keys:
|
||
del self._storage[key]
|
||
|
||
if expired_keys:
|
||
logger.info(f"清理了 {len(expired_keys)} 个过期验证码")
|
||
|
||
return len(expired_keys)
|
||
|
||
def get_count(self) -> int:
|
||
"""获取当前验证码数量"""
|
||
with self._lock:
|
||
return len(self._storage)
|
||
|
||
|
||
class ApplicationState:
|
||
"""应用全局状态管理器(单例模式)"""
|
||
|
||
_instance = None
|
||
_lock = threading.Lock()
|
||
|
||
def __new__(cls):
|
||
if cls._instance is None:
|
||
with cls._lock:
|
||
if cls._instance is None:
|
||
cls._instance = super().__new__(cls)
|
||
cls._instance._initialized = False
|
||
return cls._instance
|
||
|
||
def __init__(self):
|
||
if self._initialized:
|
||
return
|
||
|
||
# 浏览器管理器
|
||
self.browser_manager = None
|
||
self._browser_lock = threading.Lock()
|
||
|
||
# 用户账号管理 {user_id: {account_id: Account对象}}
|
||
self.user_accounts = ThreadSafeDict()
|
||
|
||
# 活动任务管理 {account_id: Thread对象}
|
||
self.active_tasks = ThreadSafeDict()
|
||
|
||
# 日志缓存管理
|
||
self.log_cache = LogCacheManager()
|
||
|
||
# 验证码管理
|
||
self.captcha = CaptchaManager()
|
||
|
||
# 用户信号量管理 {account_id: Semaphore}
|
||
self.user_semaphores = ThreadSafeDict()
|
||
|
||
# 全局信号量
|
||
self.global_semaphore = None
|
||
self.screenshot_semaphore = threading.Semaphore(1)
|
||
|
||
self._initialized = True
|
||
logger.info("应用状态管理器初始化完成")
|
||
|
||
def set_browser_manager(self, manager):
|
||
"""设置浏览器管理器"""
|
||
with self._browser_lock:
|
||
self.browser_manager = manager
|
||
|
||
def get_browser_manager(self):
|
||
"""获取浏览器管理器"""
|
||
with self._browser_lock:
|
||
return self.browser_manager
|
||
|
||
def get_user_semaphore(self, account_id: int, max_concurrent: int = 1):
|
||
"""获取或创建用户信号量"""
|
||
if account_id not in self.user_semaphores:
|
||
self.user_semaphores.set(account_id, threading.Semaphore(max_concurrent))
|
||
return self.user_semaphores.get(account_id)
|
||
|
||
def set_global_semaphore(self, max_concurrent: int):
|
||
"""设置全局信号量"""
|
||
self.global_semaphore = threading.Semaphore(max_concurrent)
|
||
|
||
def get_stats(self) -> Dict[str, Any]:
|
||
"""获取状态统计信息"""
|
||
return {
|
||
'user_accounts_count': len(self.user_accounts),
|
||
'active_tasks_count': len(self.active_tasks),
|
||
'log_cache_stats': self.log_cache.get_stats(),
|
||
'captcha_count': self.captcha.get_count(),
|
||
'user_semaphores_count': len(self.user_semaphores),
|
||
'browser_manager': 'initialized' if self.browser_manager else 'not_initialized'
|
||
}
|
||
|
||
|
||
# 全局单例实例
|
||
app_state = ApplicationState()
|
||
|
||
|
||
# 向后兼容的辅助函数
|
||
def verify_captcha(identifier: str, code: str) -> Tuple[bool, str]:
|
||
"""验证验证码(向后兼容接口)"""
|
||
return app_state.captcha.verify(identifier, code)
|
||
|
||
|
||
def create_captcha(identifier: str, code: str) -> None:
|
||
"""创建验证码(向后兼容接口)"""
|
||
app_state.captcha.create(identifier, code)
|
||
|
||
|
||
def cleanup_expired_captchas() -> int:
|
||
"""清理过期验证码(向后兼容接口)"""
|
||
return app_state.captcha.cleanup_expired()
|
||
|
||
|
||
if __name__ == '__main__':
|
||
# 测试代码
|
||
print("测试线程安全状态管理器...")
|
||
print("=" * 60)
|
||
|
||
# 测试 ThreadSafeDict
|
||
print("\n1. 测试 ThreadSafeDict:")
|
||
td = ThreadSafeDict()
|
||
td.set('key1', 'value1')
|
||
print(f" 设置 key1 = {td.get('key1')}")
|
||
print(f" 长度: {len(td)}")
|
||
|
||
# 测试 LogCacheManager
|
||
print("\n2. 测试 LogCacheManager:")
|
||
lcm = LogCacheManager(max_logs_per_user=3, max_total_logs=10)
|
||
for i in range(5):
|
||
lcm.add_log(1, {'message': f'log {i}'})
|
||
print(f" 用户1日志数: {len(lcm.get_logs(1))}")
|
||
print(f" 总日志数: {lcm.get_total_count()}")
|
||
print(f" 统计: {lcm.get_stats()}")
|
||
|
||
# 测试 CaptchaManager
|
||
print("\n3. 测试 CaptchaManager:")
|
||
cm = CaptchaManager(expire_seconds=2)
|
||
cm.create('test@example.com', '1234')
|
||
success, msg = cm.verify('test@example.com', '1234')
|
||
print(f" 验证结果: {success}, {msg}")
|
||
|
||
# 测试 ApplicationState
|
||
print("\n4. 测试 ApplicationState (单例):")
|
||
state1 = ApplicationState()
|
||
state2 = ApplicationState()
|
||
print(f" 单例验证: {state1 is state2}")
|
||
print(f" 状态统计: {state1.get_stats()}")
|
||
|
||
print("\n" + "=" * 60)
|
||
print("✓ 所有测试通过!")
|