220 lines
6.3 KiB
Python
220 lines
6.3 KiB
Python
"""
|
||
Redis客户端工具类
|
||
"""
|
||
import json
|
||
import asyncio
|
||
import hashlib
|
||
from functools import wraps
|
||
from typing import Optional, Any, List, Callable
|
||
from redis.asyncio import Redis, ConnectionPool
|
||
from app.core.config import settings
|
||
|
||
|
||
class RedisClient:
|
||
"""Redis客户端"""
|
||
|
||
def __init__(self):
|
||
"""初始化Redis客户端"""
|
||
self.pool: Optional[ConnectionPool] = None
|
||
self.redis: Optional[Redis] = None
|
||
|
||
async def connect(self):
|
||
"""连接Redis"""
|
||
if not self.pool:
|
||
self.pool = ConnectionPool.from_url(
|
||
settings.REDIS_URL,
|
||
max_connections=settings.REDIS_MAX_CONNECTIONS,
|
||
decode_responses=True
|
||
)
|
||
self.redis = Redis(connection_pool=self.pool)
|
||
|
||
async def close(self):
|
||
"""关闭连接"""
|
||
if self.redis:
|
||
await self.redis.close()
|
||
if self.pool:
|
||
await self.pool.disconnect()
|
||
|
||
async def get(self, key: str) -> Optional[str]:
|
||
"""获取缓存"""
|
||
if not self.redis:
|
||
await self.connect()
|
||
return await self.redis.get(key)
|
||
|
||
async def set(
|
||
self,
|
||
key: str,
|
||
value: str,
|
||
expire: Optional[int] = None
|
||
) -> bool:
|
||
"""设置缓存"""
|
||
if not self.redis:
|
||
await self.connect()
|
||
return await self.redis.set(key, value, ex=expire)
|
||
|
||
async def delete(self, key: str) -> int:
|
||
"""删除缓存"""
|
||
if not self.redis:
|
||
await self.connect()
|
||
return await self.redis.delete(key)
|
||
|
||
async def exists(self, key: str) -> bool:
|
||
"""检查键是否存在"""
|
||
if not self.redis:
|
||
await self.connect()
|
||
return await self.redis.exists(key) > 0
|
||
|
||
async def expire(self, key: str, seconds: int) -> bool:
|
||
"""设置过期时间"""
|
||
if not self.redis:
|
||
await self.connect()
|
||
return await self.redis.expire(key, seconds)
|
||
|
||
async def keys(self, pattern: str) -> List[str]:
|
||
"""获取匹配的键"""
|
||
if not self.redis:
|
||
await self.connect()
|
||
return await self.redis.keys(pattern)
|
||
|
||
async def delete_pattern(self, pattern: str) -> int:
|
||
"""删除匹配的键"""
|
||
keys = await self.keys(pattern)
|
||
if keys:
|
||
return await self.redis.delete(*keys)
|
||
return 0
|
||
|
||
async def setex(self, key: str, time: int, value: str) -> bool:
|
||
"""设置缓存并指定过期时间(秒)"""
|
||
if not self.redis:
|
||
await self.connect()
|
||
return await self.redis.setex(key, time, value)
|
||
|
||
# JSON操作辅助方法
|
||
|
||
async def get_json(self, key: str) -> Optional[Any]:
|
||
"""获取JSON数据"""
|
||
value = await self.get(key)
|
||
if value:
|
||
try:
|
||
return json.loads(value)
|
||
except json.JSONDecodeError:
|
||
return value
|
||
return None
|
||
|
||
async def set_json(
|
||
self,
|
||
key: str,
|
||
value: Any,
|
||
expire: Optional[int] = None
|
||
) -> bool:
|
||
"""设置JSON数据"""
|
||
json_str = json.dumps(value, ensure_ascii=False)
|
||
return await self.set(key, json_str, expire)
|
||
|
||
# 缓存装饰器
|
||
|
||
def cache(self, key_prefix: str, expire: int = 300):
|
||
"""
|
||
Redis缓存装饰器(改进版)
|
||
|
||
Args:
|
||
key_prefix: 缓存键前缀
|
||
expire: 过期时间(秒),默认300秒(5分钟)
|
||
|
||
Example:
|
||
@redis_client.cache("device_types", expire=1800)
|
||
async def get_device_types(...):
|
||
pass
|
||
"""
|
||
def decorator(func):
|
||
@wraps(func)
|
||
async def wrapper(*args, **kwargs):
|
||
# 使用MD5生成更稳定的缓存键
|
||
key_data = f"{key_prefix}:{str(args)}:{str(kwargs)}"
|
||
cache_key = f"cache:{hashlib.md5(key_data.encode()).hexdigest()}"
|
||
|
||
# 尝试从缓存获取
|
||
cached = await self.get_json(cache_key)
|
||
if cached is not None:
|
||
return cached
|
||
|
||
# 执行函数
|
||
result = await func(*args, **kwargs)
|
||
|
||
# 存入缓存
|
||
await self.set_json(cache_key, result, expire)
|
||
|
||
return result
|
||
return wrapper
|
||
return decorator
|
||
|
||
# 统计缓存辅助方法
|
||
|
||
async def cache_statistics(
|
||
self,
|
||
key: str,
|
||
data: Any,
|
||
expire: int = 600
|
||
):
|
||
"""缓存统计数据"""
|
||
return await self.set_json(key, data, expire)
|
||
|
||
async def get_cached_statistics(self, key: str) -> Optional[Any]:
|
||
"""获取缓存的统计数据"""
|
||
return await self.get_json(key)
|
||
|
||
async def invalidate_statistics_cache(self, pattern: str = "statistics:*"):
|
||
"""清除统计数据缓存"""
|
||
return await self.delete_pattern(pattern)
|
||
|
||
# 同步函数的异步缓存包装器
|
||
|
||
def cached_async(self, key_prefix: str, expire: int = 300):
|
||
"""
|
||
为同步函数提供异步缓存包装的装饰器
|
||
|
||
Args:
|
||
key_prefix: 缓存键前缀
|
||
expire: 过期时间(秒),默认300秒(5分钟)
|
||
|
||
Example:
|
||
@redis_client.cached_async("device_types", expire=1800)
|
||
async def cached_get_device_types(db, skip, limit, ...):
|
||
return device_type_service.get_device_types(...)
|
||
"""
|
||
def decorator(func):
|
||
@wraps(func)
|
||
async def wrapper(*args, **kwargs):
|
||
# 使用MD5生成更稳定的缓存键
|
||
key_data = f"{key_prefix}:{str(args)}:{str(kwargs)}"
|
||
cache_key = f"cache:{hashlib.md5(key_data.encode()).hexdigest()}"
|
||
|
||
# 尝试从缓存获取
|
||
cached = await self.get_json(cache_key)
|
||
if cached is not None:
|
||
return cached
|
||
|
||
# 执行函数
|
||
result = await func(*args, **kwargs)
|
||
|
||
# 存入缓存
|
||
await self.set_json(cache_key, result, expire)
|
||
|
||
return result
|
||
return wrapper
|
||
return decorator
|
||
|
||
|
||
# 创建全局实例
|
||
redis_client = RedisClient()
|
||
|
||
|
||
async def init_redis():
|
||
"""初始化Redis连接"""
|
||
await redis_client.connect()
|
||
|
||
|
||
async def close_redis():
|
||
"""关闭Redis连接"""
|
||
await redis_client.close()
|