""" 操作日志中间件 """ import time import json from typing import Callable from fastapi import Request, Response from starlette.middleware.base import BaseHTTPMiddleware from sqlalchemy.ext.asyncio import AsyncSession from app.db.session import async_session_maker from app.schemas.operation_log import OperationLogCreate, OperationModuleEnum, OperationTypeEnum, OperationResultEnum from app.services.operation_log_service import operation_log_service class OperationLogMiddleware(BaseHTTPMiddleware): """操作日志中间件""" # 不需要记录的路径 EXCLUDE_PATHS = [ "/health", "/docs", "/openapi.json", "/api/v1/auth/login", "/api/v1/auth/captcha", ] # 路径到模块的映射 PATH_MODULE_MAP = { "/auth": OperationModuleEnum.AUTH, "/device-types": OperationModuleEnum.DEVICE_TYPE, "/organizations": OperationModuleEnum.ORGANIZATION, "/assets": OperationModuleEnum.ASSET, "/brands": OperationModuleEnum.BRAND_SUPPLIER, "/suppliers": OperationModuleEnum.BRAND_SUPPLIER, "/allocation-orders": OperationModuleEnum.ALLOCATION, "/maintenance-records": OperationModuleEnum.MAINTENANCE, "/system-config": OperationModuleEnum.SYSTEM_CONFIG, "/users": OperationModuleEnum.USER, "/statistics": OperationModuleEnum.STATISTICS, "/operation-logs": OperationModuleEnum.SYSTEM_CONFIG, "/notifications": OperationModuleEnum.SYSTEM_CONFIG, } # 方法到操作类型的映射 METHOD_OPERATION_MAP = { "GET": OperationTypeEnum.QUERY, "POST": OperationTypeEnum.CREATE, "PUT": OperationTypeEnum.UPDATE, "PATCH": OperationTypeEnum.UPDATE, "DELETE": OperationTypeEnum.DELETE, } async def dispatch(self, request: Request, call_next: Callable) -> Response: """处理请求""" # 检查是否需要记录 if self._should_log(request): # 记录开始时间 start_time = time.time() # 获取用户信息 user = getattr(request.state, "user", None) # 处理请求 response = await call_next(request) # 计算执行时长 duration = int((time.time() - start_time) * 1000) # 异步记录日志 if user: await self._log_operation(request, response, user, duration) return response return await call_next(request) def _should_log(self, request: Request) -> bool: """判断是否需要记录日志""" path = request.url.path # 检查排除路径 for exclude_path in self.EXCLUDE_PATHS: if path.startswith(exclude_path): return False # 只记录API请求 return path.startswith("/api/") async def _log_operation( self, request: Request, response: Response, user, duration: int ): """记录操作日志""" try: # 获取模块 module = self._get_module(request.url.path) # 获取操作类型 operation_type = self.METHOD_OPERATION_MAP.get(request.method, OperationTypeEnum.QUERY) # 特殊处理:如果是登录/登出 if "/auth/login" in request.url.path: operation_type = OperationTypeEnum.LOGIN elif "/auth/logout" in request.url.path: operation_type = OperationTypeEnum.LOGOUT # 获取请求参数 params = await self._get_request_params(request) # 构建日志数据 log_data = OperationLogCreate( operator_id=user.id, operator_name=user.real_name or user.username, operator_ip=request.client.host if request.client else None, module=module, operation_type=operation_type, method=request.method, url=request.url.path, params=params, result=OperationResultEnum.SUCCESS if response.status_code < 400 else OperationResultEnum.FAILED, error_msg=None if response.status_code < 400 else f"HTTP {response.status_code}", duration=duration, user_agent=request.headers.get("user-agent"), ) # 异步保存日志 async with async_session_maker() as db: await operation_log_service.create_log(db, log_data) except Exception as e: # 记录日志失败不应影响业务 print(f"Failed to log operation: {e}") def _get_module(self, path: str) -> OperationModuleEnum: """根据路径获取模块""" for path_prefix, module in self.PATH_MODULE_MAP.items(): if path_prefix in path: return module return OperationModuleEnum.SYSTEM_CONFIG async def _get_request_params(self, request: Request) -> str: """获取请求参数""" try: # GET请求 if request.method == "GET": params = dict(request.query_params) return json.dumps(params, ensure_ascii=False) # POST/PUT/DELETE请求 if request.method in ["POST", "PUT", "DELETE", "PATCH"]: try: body = await request.body() if body: # 尝试解析JSON try: body_json = json.loads(body.decode()) # 过滤敏感字段 filtered_body = self._filter_sensitive_data(body_json) return json.dumps(filtered_body, ensure_ascii=False) except json.JSONDecodeError: # 不是JSON,返回原始数据 return body.decode()[:500] # 限制长度 except Exception: pass return "" except Exception: return "" def _filter_sensitive_data(self, data: dict) -> dict: """过滤敏感数据""" sensitive_fields = ["password", "old_password", "new_password", "token", "secret"] if not isinstance(data, dict): return data filtered = {} for key, value in data.items(): if key in sensitive_fields: filtered[key] = "******" elif isinstance(value, dict): filtered[key] = self._filter_sensitive_data(value) elif isinstance(value, list): filtered[key] = [ self._filter_sensitive_data(item) if isinstance(item, dict) else item for item in value ] else: filtered[key] = value return filtered