Files
zsglpt/app_state.py

333 lines
10 KiB
Python
Executable File
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
应用状态管理模块
提供线程安全的全局状态管理
说明P0P3 优化后):
- 该模块为历史遗留实现,保留用于兼容与参考
- 当前实际生效的全局状态入口为 `services/state.py`safe_* API
"""
import threading
from typing import Tuple
from typing import Dict, Any, Optional
from datetime import datetime, timedelta
from app_logger import get_logger
logger = get_logger('app_state')
class ThreadSafeDict:
"""线程安全的字典包装类"""
def __init__(self):
self._dict = {}
self._lock = threading.RLock()
def get(self, key, default=None):
"""获取值"""
with self._lock:
return self._dict.get(key, default)
def set(self, key, value):
"""设置值"""
with self._lock:
self._dict[key] = value
def delete(self, key):
"""删除键"""
with self._lock:
if key in self._dict:
del self._dict[key]
def pop(self, key, default=None):
"""弹出键值"""
with self._lock:
return self._dict.pop(key, default)
def keys(self):
"""获取所有键(返回副本)"""
with self._lock:
return list(self._dict.keys())
def items(self):
"""获取所有键值对(返回副本)"""
with self._lock:
return list(self._dict.items())
def __contains__(self, key):
"""检查键是否存在"""
with self._lock:
return key in self._dict
def clear(self):
"""清空字典"""
with self._lock:
self._dict.clear()
def __len__(self):
"""获取长度"""
with self._lock:
return len(self._dict)
class LogCacheManager:
"""日志缓存管理器(线程安全)"""
def __init__(self, max_logs_per_user=100, max_total_logs=1000):
self._cache = {} # {user_id: [logs]}
self._total_count = 0
self._lock = threading.RLock()
self._max_logs_per_user = max_logs_per_user
self._max_total_logs = max_total_logs
def add_log(self, user_id: int, log_entry: Dict[str, Any]) -> bool:
"""添加日志到缓存"""
with self._lock:
# 检查总数限制
if self._total_count >= self._max_total_logs:
logger.warning(f"日志缓存已满 ({self._max_total_logs}),拒绝添加")
return False
# 初始化用户日志列表
if user_id not in self._cache:
self._cache[user_id] = []
user_logs = self._cache[user_id]
# 检查用户日志数限制
if len(user_logs) >= self._max_logs_per_user:
# 移除最旧的日志
user_logs.pop(0)
self._total_count -= 1
# 添加新日志
user_logs.append(log_entry)
self._total_count += 1
return True
def get_logs(self, user_id: int) -> list:
"""获取用户的所有日志(返回副本)"""
with self._lock:
return list(self._cache.get(user_id, []))
def clear_user_logs(self, user_id: int):
"""清空用户的日志"""
with self._lock:
if user_id in self._cache:
count = len(self._cache[user_id])
del self._cache[user_id]
self._total_count -= count
logger.info(f"清空用户 {user_id}{count} 条日志")
def get_total_count(self) -> int:
"""获取总日志数"""
with self._lock:
return self._total_count
def get_stats(self) -> Dict[str, int]:
"""获取统计信息"""
with self._lock:
return {
'total_count': self._total_count,
'user_count': len(self._cache),
'max_per_user': self._max_logs_per_user,
'max_total': self._max_total_logs
}
class CaptchaManager:
"""验证码管理器(线程安全)"""
def __init__(self, expire_seconds=300):
self._storage = {} # {identifier: {'code': str, 'expire': datetime}}
self._lock = threading.RLock()
self._expire_seconds = expire_seconds
def create(self, identifier: str, code: str) -> None:
"""创建验证码"""
with self._lock:
self._storage[identifier] = {
'code': code,
'expire': datetime.now() + timedelta(seconds=self._expire_seconds)
}
def verify(self, identifier: str, code: str) -> Tuple[bool, str]:
"""验证验证码"""
with self._lock:
if identifier not in self._storage:
return False, "验证码不存在或已过期"
captcha_data = self._storage[identifier]
# 检查是否过期
if datetime.now() > captcha_data['expire']:
del self._storage[identifier]
return False, "验证码已过期,请重新获取"
# 验证码码值
if captcha_data['code'] != code:
return False, "验证码错误"
# 验证成功,删除验证码
del self._storage[identifier]
return True, "验证成功"
def cleanup_expired(self) -> int:
"""清理过期的验证码"""
with self._lock:
now = datetime.now()
expired_keys = [
key for key, data in self._storage.items()
if now > data['expire']
]
for key in expired_keys:
del self._storage[key]
if expired_keys:
logger.info(f"清理了 {len(expired_keys)} 个过期验证码")
return len(expired_keys)
def get_count(self) -> int:
"""获取当前验证码数量"""
with self._lock:
return len(self._storage)
class ApplicationState:
"""应用全局状态管理器(单例模式)"""
_instance = None
_lock = threading.Lock()
def __new__(cls):
if cls._instance is None:
with cls._lock:
if cls._instance is None:
cls._instance = super().__new__(cls)
cls._instance._initialized = False
return cls._instance
def __init__(self):
if self._initialized:
return
# 浏览器管理器
self.browser_manager = None
self._browser_lock = threading.Lock()
# 用户账号管理 {user_id: {account_id: Account对象}}
self.user_accounts = ThreadSafeDict()
# 活动任务管理 {account_id: Thread对象}
self.active_tasks = ThreadSafeDict()
# 日志缓存管理
self.log_cache = LogCacheManager()
# 验证码管理
self.captcha = CaptchaManager()
# 用户信号量管理 {account_id: Semaphore}
self.user_semaphores = ThreadSafeDict()
# 全局信号量
self.global_semaphore = None
self.screenshot_semaphore = threading.Semaphore(1)
self._initialized = True
logger.info("应用状态管理器初始化完成")
def set_browser_manager(self, manager):
"""设置浏览器管理器"""
with self._browser_lock:
self.browser_manager = manager
def get_browser_manager(self):
"""获取浏览器管理器"""
with self._browser_lock:
return self.browser_manager
def get_user_semaphore(self, account_id: int, max_concurrent: int = 1):
"""获取或创建用户信号量"""
if account_id not in self.user_semaphores:
self.user_semaphores.set(account_id, threading.Semaphore(max_concurrent))
return self.user_semaphores.get(account_id)
def set_global_semaphore(self, max_concurrent: int):
"""设置全局信号量"""
self.global_semaphore = threading.Semaphore(max_concurrent)
def get_stats(self) -> Dict[str, Any]:
"""获取状态统计信息"""
return {
'user_accounts_count': len(self.user_accounts),
'active_tasks_count': len(self.active_tasks),
'log_cache_stats': self.log_cache.get_stats(),
'captcha_count': self.captcha.get_count(),
'user_semaphores_count': len(self.user_semaphores),
'browser_manager': 'initialized' if self.browser_manager else 'not_initialized'
}
# 全局单例实例
app_state = ApplicationState()
# 向后兼容的辅助函数
def verify_captcha(identifier: str, code: str) -> Tuple[bool, str]:
"""验证验证码(向后兼容接口)"""
return app_state.captcha.verify(identifier, code)
def create_captcha(identifier: str, code: str) -> None:
"""创建验证码(向后兼容接口)"""
app_state.captcha.create(identifier, code)
def cleanup_expired_captchas() -> int:
"""清理过期验证码(向后兼容接口)"""
return app_state.captcha.cleanup_expired()
if __name__ == '__main__':
# 测试代码
print("测试线程安全状态管理器...")
print("=" * 60)
# 测试 ThreadSafeDict
print("\n1. 测试 ThreadSafeDict:")
td = ThreadSafeDict()
td.set('key1', 'value1')
print(f" 设置 key1 = {td.get('key1')}")
print(f" 长度: {len(td)}")
# 测试 LogCacheManager
print("\n2. 测试 LogCacheManager:")
lcm = LogCacheManager(max_logs_per_user=3, max_total_logs=10)
for i in range(5):
lcm.add_log(1, {'message': f'log {i}'})
print(f" 用户1日志数: {len(lcm.get_logs(1))}")
print(f" 总日志数: {lcm.get_total_count()}")
print(f" 统计: {lcm.get_stats()}")
# 测试 CaptchaManager
print("\n3. 测试 CaptchaManager:")
cm = CaptchaManager(expire_seconds=2)
cm.create('test@example.com', '1234')
success, msg = cm.verify('test@example.com', '1234')
print(f" 验证结果: {success}, {msg}")
# 测试 ApplicationState
print("\n4. 测试 ApplicationState (单例):")
state1 = ApplicationState()
state2 = ApplicationState()
print(f" 单例验证: {state1 is state2}")
print(f" 状态统计: {state1.get_stats()}")
print("\n" + "=" * 60)
print("✓ 所有测试通过!")