#!/usr/bin/env python3 # -*- coding: utf-8 -*- """ 应用状态管理模块 提供线程安全的全局状态管理 """ import threading from typing import Tuple from typing import Dict, Any, Optional from datetime import datetime, timedelta from app_logger import get_logger logger = get_logger('app_state') class ThreadSafeDict: """线程安全的字典包装类""" def __init__(self): self._dict = {} self._lock = threading.RLock() def get(self, key, default=None): """获取值""" with self._lock: return self._dict.get(key, default) def set(self, key, value): """设置值""" with self._lock: self._dict[key] = value def delete(self, key): """删除键""" with self._lock: if key in self._dict: del self._dict[key] def pop(self, key, default=None): """弹出键值""" with self._lock: return self._dict.pop(key, default) def keys(self): """获取所有键(返回副本)""" with self._lock: return list(self._dict.keys()) def items(self): """获取所有键值对(返回副本)""" with self._lock: return list(self._dict.items()) def __contains__(self, key): """检查键是否存在""" with self._lock: return key in self._dict def clear(self): """清空字典""" with self._lock: self._dict.clear() def __len__(self): """获取长度""" with self._lock: return len(self._dict) class LogCacheManager: """日志缓存管理器(线程安全)""" def __init__(self, max_logs_per_user=100, max_total_logs=1000): self._cache = {} # {user_id: [logs]} self._total_count = 0 self._lock = threading.RLock() self._max_logs_per_user = max_logs_per_user self._max_total_logs = max_total_logs def add_log(self, user_id: int, log_entry: Dict[str, Any]) -> bool: """添加日志到缓存""" with self._lock: # 检查总数限制 if self._total_count >= self._max_total_logs: logger.warning(f"日志缓存已满 ({self._max_total_logs}),拒绝添加") return False # 初始化用户日志列表 if user_id not in self._cache: self._cache[user_id] = [] user_logs = self._cache[user_id] # 检查用户日志数限制 if len(user_logs) >= self._max_logs_per_user: # 移除最旧的日志 user_logs.pop(0) self._total_count -= 1 # 添加新日志 user_logs.append(log_entry) self._total_count += 1 return True def get_logs(self, user_id: int) -> list: """获取用户的所有日志(返回副本)""" with self._lock: return list(self._cache.get(user_id, [])) def clear_user_logs(self, user_id: int): """清空用户的日志""" with self._lock: if user_id in self._cache: count = len(self._cache[user_id]) del self._cache[user_id] self._total_count -= count logger.info(f"清空用户 {user_id} 的 {count} 条日志") def get_total_count(self) -> int: """获取总日志数""" with self._lock: return self._total_count def get_stats(self) -> Dict[str, int]: """获取统计信息""" with self._lock: return { 'total_count': self._total_count, 'user_count': len(self._cache), 'max_per_user': self._max_logs_per_user, 'max_total': self._max_total_logs } class CaptchaManager: """验证码管理器(线程安全)""" def __init__(self, expire_seconds=300): self._storage = {} # {identifier: {'code': str, 'expire': datetime}} self._lock = threading.RLock() self._expire_seconds = expire_seconds def create(self, identifier: str, code: str) -> None: """创建验证码""" with self._lock: self._storage[identifier] = { 'code': code, 'expire': datetime.now() + timedelta(seconds=self._expire_seconds) } def verify(self, identifier: str, code: str) -> Tuple[bool, str]: """验证验证码""" with self._lock: if identifier not in self._storage: return False, "验证码不存在或已过期" captcha_data = self._storage[identifier] # 检查是否过期 if datetime.now() > captcha_data['expire']: del self._storage[identifier] return False, "验证码已过期,请重新获取" # 验证码码值 if captcha_data['code'] != code: return False, "验证码错误" # 验证成功,删除验证码 del self._storage[identifier] return True, "验证成功" def cleanup_expired(self) -> int: """清理过期的验证码""" with self._lock: now = datetime.now() expired_keys = [ key for key, data in self._storage.items() if now > data['expire'] ] for key in expired_keys: del self._storage[key] if expired_keys: logger.info(f"清理了 {len(expired_keys)} 个过期验证码") return len(expired_keys) def get_count(self) -> int: """获取当前验证码数量""" with self._lock: return len(self._storage) class ApplicationState: """应用全局状态管理器(单例模式)""" _instance = None _lock = threading.Lock() def __new__(cls): if cls._instance is None: with cls._lock: if cls._instance is None: cls._instance = super().__new__(cls) cls._instance._initialized = False return cls._instance def __init__(self): if self._initialized: return # 浏览器管理器 self.browser_manager = None self._browser_lock = threading.Lock() # 用户账号管理 {user_id: {account_id: Account对象}} self.user_accounts = ThreadSafeDict() # 活动任务管理 {account_id: Thread对象} self.active_tasks = ThreadSafeDict() # 日志缓存管理 self.log_cache = LogCacheManager() # 验证码管理 self.captcha = CaptchaManager() # 用户信号量管理 {account_id: Semaphore} self.user_semaphores = ThreadSafeDict() # 全局信号量 self.global_semaphore = None self.screenshot_semaphore = threading.Semaphore(1) self._initialized = True logger.info("应用状态管理器初始化完成") def set_browser_manager(self, manager): """设置浏览器管理器""" with self._browser_lock: self.browser_manager = manager def get_browser_manager(self): """获取浏览器管理器""" with self._browser_lock: return self.browser_manager def get_user_semaphore(self, account_id: int, max_concurrent: int = 1): """获取或创建用户信号量""" if account_id not in self.user_semaphores: self.user_semaphores.set(account_id, threading.Semaphore(max_concurrent)) return self.user_semaphores.get(account_id) def set_global_semaphore(self, max_concurrent: int): """设置全局信号量""" self.global_semaphore = threading.Semaphore(max_concurrent) def get_stats(self) -> Dict[str, Any]: """获取状态统计信息""" return { 'user_accounts_count': len(self.user_accounts), 'active_tasks_count': len(self.active_tasks), 'log_cache_stats': self.log_cache.get_stats(), 'captcha_count': self.captcha.get_count(), 'user_semaphores_count': len(self.user_semaphores), 'browser_manager': 'initialized' if self.browser_manager else 'not_initialized' } # 全局单例实例 app_state = ApplicationState() # 向后兼容的辅助函数 def verify_captcha(identifier: str, code: str) -> Tuple[bool, str]: """验证验证码(向后兼容接口)""" return app_state.captcha.verify(identifier, code) def create_captcha(identifier: str, code: str) -> None: """创建验证码(向后兼容接口)""" app_state.captcha.create(identifier, code) def cleanup_expired_captchas() -> int: """清理过期验证码(向后兼容接口)""" return app_state.captcha.cleanup_expired() if __name__ == '__main__': # 测试代码 print("测试线程安全状态管理器...") print("=" * 60) # 测试 ThreadSafeDict print("\n1. 测试 ThreadSafeDict:") td = ThreadSafeDict() td.set('key1', 'value1') print(f" 设置 key1 = {td.get('key1')}") print(f" 长度: {len(td)}") # 测试 LogCacheManager print("\n2. 测试 LogCacheManager:") lcm = LogCacheManager(max_logs_per_user=3, max_total_logs=10) for i in range(5): lcm.add_log(1, {'message': f'log {i}'}) print(f" 用户1日志数: {len(lcm.get_logs(1))}") print(f" 总日志数: {lcm.get_total_count()}") print(f" 统计: {lcm.get_stats()}") # 测试 CaptchaManager print("\n3. 测试 CaptchaManager:") cm = CaptchaManager(expire_seconds=2) cm.create('test@example.com', '1234') success, msg = cm.verify('test@example.com', '1234') print(f" 验证结果: {success}, {msg}") # 测试 ApplicationState print("\n4. 测试 ApplicationState (单例):") state1 = ApplicationState() state2 = ApplicationState() print(f" 单例验证: {state1 is state2}") print(f" 状态统计: {state1.get_stats()}") print("\n" + "=" * 60) print("✓ 所有测试通过!")