""" Redis客户端工具类 """ import json import asyncio import hashlib from functools import wraps from typing import Optional, Any, List, Callable from redis.asyncio import Redis, ConnectionPool from app.core.config import settings class RedisClient: """Redis客户端""" def __init__(self): """初始化Redis客户端""" self.pool: Optional[ConnectionPool] = None self.redis: Optional[Redis] = None async def connect(self): """连接Redis""" if not self.pool: self.pool = ConnectionPool.from_url( settings.REDIS_URL, max_connections=settings.REDIS_MAX_CONNECTIONS, decode_responses=True ) self.redis = Redis(connection_pool=self.pool) async def close(self): """关闭连接""" if self.redis: await self.redis.close() if self.pool: await self.pool.disconnect() async def get(self, key: str) -> Optional[str]: """获取缓存""" if not self.redis: await self.connect() return await self.redis.get(key) async def set( self, key: str, value: str, expire: Optional[int] = None ) -> bool: """设置缓存""" if not self.redis: await self.connect() return await self.redis.set(key, value, ex=expire) async def delete(self, key: str) -> int: """删除缓存""" if not self.redis: await self.connect() return await self.redis.delete(key) async def exists(self, key: str) -> bool: """检查键是否存在""" if not self.redis: await self.connect() return await self.redis.exists(key) > 0 async def expire(self, key: str, seconds: int) -> bool: """设置过期时间""" if not self.redis: await self.connect() return await self.redis.expire(key, seconds) async def keys(self, pattern: str) -> List[str]: """获取匹配的键""" if not self.redis: await self.connect() return await self.redis.keys(pattern) async def delete_pattern(self, pattern: str) -> int: """删除匹配的键""" keys = await self.keys(pattern) if keys: return await self.redis.delete(*keys) return 0 async def setex(self, key: str, time: int, value: str) -> bool: """设置缓存并指定过期时间(秒)""" if not self.redis: await self.connect() return await self.redis.setex(key, time, value) # JSON操作辅助方法 async def get_json(self, key: str) -> Optional[Any]: """获取JSON数据""" value = await self.get(key) if value: try: return json.loads(value) except json.JSONDecodeError: return value return None async def set_json( self, key: str, value: Any, expire: Optional[int] = None ) -> bool: """设置JSON数据""" json_str = json.dumps(value, ensure_ascii=False) return await self.set(key, json_str, expire) # 缓存装饰器 def cache(self, key_prefix: str, expire: int = 300): """ Redis缓存装饰器(改进版) Args: key_prefix: 缓存键前缀 expire: 过期时间(秒),默认300秒(5分钟) Example: @redis_client.cache("device_types", expire=1800) async def get_device_types(...): pass """ def decorator(func): @wraps(func) async def wrapper(*args, **kwargs): # 使用MD5生成更稳定的缓存键 key_data = f"{key_prefix}:{str(args)}:{str(kwargs)}" cache_key = f"cache:{hashlib.md5(key_data.encode()).hexdigest()}" # 尝试从缓存获取 cached = await self.get_json(cache_key) if cached is not None: return cached # 执行函数 result = await func(*args, **kwargs) # 存入缓存 await self.set_json(cache_key, result, expire) return result return wrapper return decorator # 统计缓存辅助方法 async def cache_statistics( self, key: str, data: Any, expire: int = 600 ): """缓存统计数据""" return await self.set_json(key, data, expire) async def get_cached_statistics(self, key: str) -> Optional[Any]: """获取缓存的统计数据""" return await self.get_json(key) async def invalidate_statistics_cache(self, pattern: str = "statistics:*"): """清除统计数据缓存""" return await self.delete_pattern(pattern) # 同步函数的异步缓存包装器 def cached_async(self, key_prefix: str, expire: int = 300): """ 为同步函数提供异步缓存包装的装饰器 Args: key_prefix: 缓存键前缀 expire: 过期时间(秒),默认300秒(5分钟) Example: @redis_client.cached_async("device_types", expire=1800) async def cached_get_device_types(db, skip, limit, ...): return device_type_service.get_device_types(...) """ def decorator(func): @wraps(func) async def wrapper(*args, **kwargs): # 使用MD5生成更稳定的缓存键 key_data = f"{key_prefix}:{str(args)}:{str(kwargs)}" cache_key = f"cache:{hashlib.md5(key_data.encode()).hexdigest()}" # 尝试从缓存获取 cached = await self.get_json(cache_key) if cached is not None: return cached # 执行函数 result = await func(*args, **kwargs) # 存入缓存 await self.set_json(cache_key, result, expire) return result return wrapper return decorator # 创建全局实例 redis_client = RedisClient() async def init_redis(): """初始化Redis连接""" await redis_client.connect() async def close_redis(): """关闭Redis连接""" await redis_client.close()