195 lines
6.8 KiB
Python
195 lines
6.8 KiB
Python
"""
|
||
操作日志中间件
|
||
"""
|
||
import time
|
||
import json
|
||
from typing import Callable
|
||
from fastapi import Request, Response
|
||
from starlette.middleware.base import BaseHTTPMiddleware
|
||
from sqlalchemy.ext.asyncio import AsyncSession
|
||
from app.db.session import async_session_maker
|
||
from app.schemas.operation_log import OperationLogCreate, OperationModuleEnum, OperationTypeEnum, OperationResultEnum
|
||
from app.services.operation_log_service import operation_log_service
|
||
|
||
|
||
class OperationLogMiddleware(BaseHTTPMiddleware):
|
||
"""操作日志中间件"""
|
||
|
||
# 不需要记录的路径
|
||
EXCLUDE_PATHS = [
|
||
"/health",
|
||
"/docs",
|
||
"/openapi.json",
|
||
"/api/v1/auth/login",
|
||
"/api/v1/auth/captcha",
|
||
]
|
||
|
||
# 路径到模块的映射
|
||
PATH_MODULE_MAP = {
|
||
"/auth": OperationModuleEnum.AUTH,
|
||
"/device-types": OperationModuleEnum.DEVICE_TYPE,
|
||
"/organizations": OperationModuleEnum.ORGANIZATION,
|
||
"/assets": OperationModuleEnum.ASSET,
|
||
"/brands": OperationModuleEnum.BRAND_SUPPLIER,
|
||
"/suppliers": OperationModuleEnum.BRAND_SUPPLIER,
|
||
"/allocation-orders": OperationModuleEnum.ALLOCATION,
|
||
"/maintenance-records": OperationModuleEnum.MAINTENANCE,
|
||
"/system-config": OperationModuleEnum.SYSTEM_CONFIG,
|
||
"/users": OperationModuleEnum.USER,
|
||
"/statistics": OperationModuleEnum.STATISTICS,
|
||
"/operation-logs": OperationModuleEnum.SYSTEM_CONFIG,
|
||
"/notifications": OperationModuleEnum.SYSTEM_CONFIG,
|
||
}
|
||
|
||
# 方法到操作类型的映射
|
||
METHOD_OPERATION_MAP = {
|
||
"GET": OperationTypeEnum.QUERY,
|
||
"POST": OperationTypeEnum.CREATE,
|
||
"PUT": OperationTypeEnum.UPDATE,
|
||
"PATCH": OperationTypeEnum.UPDATE,
|
||
"DELETE": OperationTypeEnum.DELETE,
|
||
}
|
||
|
||
async def dispatch(self, request: Request, call_next: Callable) -> Response:
|
||
"""处理请求"""
|
||
# 检查是否需要记录
|
||
if self._should_log(request):
|
||
# 记录开始时间
|
||
start_time = time.time()
|
||
|
||
# 获取用户信息
|
||
user = getattr(request.state, "user", None)
|
||
|
||
# 处理请求
|
||
response = await call_next(request)
|
||
|
||
# 计算执行时长
|
||
duration = int((time.time() - start_time) * 1000)
|
||
|
||
# 异步记录日志
|
||
if user:
|
||
await self._log_operation(request, response, user, duration)
|
||
|
||
return response
|
||
|
||
return await call_next(request)
|
||
|
||
def _should_log(self, request: Request) -> bool:
|
||
"""判断是否需要记录日志"""
|
||
path = request.url.path
|
||
|
||
# 检查排除路径
|
||
for exclude_path in self.EXCLUDE_PATHS:
|
||
if path.startswith(exclude_path):
|
||
return False
|
||
|
||
# 只记录API请求
|
||
return path.startswith("/api/")
|
||
|
||
async def _log_operation(
|
||
self,
|
||
request: Request,
|
||
response: Response,
|
||
user,
|
||
duration: int
|
||
):
|
||
"""记录操作日志"""
|
||
try:
|
||
# 获取模块
|
||
module = self._get_module(request.url.path)
|
||
|
||
# 获取操作类型
|
||
operation_type = self.METHOD_OPERATION_MAP.get(request.method, OperationTypeEnum.QUERY)
|
||
|
||
# 特殊处理:如果是登录/登出
|
||
if "/auth/login" in request.url.path:
|
||
operation_type = OperationTypeEnum.LOGIN
|
||
elif "/auth/logout" in request.url.path:
|
||
operation_type = OperationTypeEnum.LOGOUT
|
||
|
||
# 获取请求参数
|
||
params = await self._get_request_params(request)
|
||
|
||
# 构建日志数据
|
||
log_data = OperationLogCreate(
|
||
operator_id=user.id,
|
||
operator_name=user.real_name or user.username,
|
||
operator_ip=request.client.host if request.client else None,
|
||
module=module,
|
||
operation_type=operation_type,
|
||
method=request.method,
|
||
url=request.url.path,
|
||
params=params,
|
||
result=OperationResultEnum.SUCCESS if response.status_code < 400 else OperationResultEnum.FAILED,
|
||
error_msg=None if response.status_code < 400 else f"HTTP {response.status_code}",
|
||
duration=duration,
|
||
user_agent=request.headers.get("user-agent"),
|
||
)
|
||
|
||
# 异步保存日志
|
||
async with async_session_maker() as db:
|
||
await operation_log_service.create_log(db, log_data)
|
||
|
||
except Exception as e:
|
||
# 记录日志失败不应影响业务
|
||
print(f"Failed to log operation: {e}")
|
||
|
||
def _get_module(self, path: str) -> OperationModuleEnum:
|
||
"""根据路径获取模块"""
|
||
for path_prefix, module in self.PATH_MODULE_MAP.items():
|
||
if path_prefix in path:
|
||
return module
|
||
return OperationModuleEnum.SYSTEM_CONFIG
|
||
|
||
async def _get_request_params(self, request: Request) -> str:
|
||
"""获取请求参数"""
|
||
try:
|
||
# GET请求
|
||
if request.method == "GET":
|
||
params = dict(request.query_params)
|
||
return json.dumps(params, ensure_ascii=False)
|
||
|
||
# POST/PUT/DELETE请求
|
||
if request.method in ["POST", "PUT", "DELETE", "PATCH"]:
|
||
try:
|
||
body = await request.body()
|
||
if body:
|
||
# 尝试解析JSON
|
||
try:
|
||
body_json = json.loads(body.decode())
|
||
# 过滤敏感字段
|
||
filtered_body = self._filter_sensitive_data(body_json)
|
||
return json.dumps(filtered_body, ensure_ascii=False)
|
||
except json.JSONDecodeError:
|
||
# 不是JSON,返回原始数据
|
||
return body.decode()[:500] # 限制长度
|
||
except Exception:
|
||
pass
|
||
|
||
return ""
|
||
except Exception:
|
||
return ""
|
||
|
||
def _filter_sensitive_data(self, data: dict) -> dict:
|
||
"""过滤敏感数据"""
|
||
sensitive_fields = ["password", "old_password", "new_password", "token", "secret"]
|
||
|
||
if not isinstance(data, dict):
|
||
return data
|
||
|
||
filtered = {}
|
||
for key, value in data.items():
|
||
if key in sensitive_fields:
|
||
filtered[key] = "******"
|
||
elif isinstance(value, dict):
|
||
filtered[key] = self._filter_sensitive_data(value)
|
||
elif isinstance(value, list):
|
||
filtered[key] = [
|
||
self._filter_sensitive_data(item) if isinstance(item, dict) else item
|
||
for item in value
|
||
]
|
||
else:
|
||
filtered[key] = value
|
||
|
||
return filtered
|