Fix API compatibility and add user/role/permission and asset import/export
This commit is contained in:
6
backend/app/middleware/__init__.py
Normal file
6
backend/app/middleware/__init__.py
Normal file
@@ -0,0 +1,6 @@
|
||||
"""
|
||||
中间件模块
|
||||
"""
|
||||
from app.middleware.operation_log import OperationLogMiddleware
|
||||
|
||||
__all__ = ["OperationLogMiddleware"]
|
||||
146
backend/app/middleware/api_transform.py
Normal file
146
backend/app/middleware/api_transform.py
Normal file
@@ -0,0 +1,146 @@
|
||||
"""
|
||||
API request/response transformation middleware.
|
||||
"""
|
||||
import json
|
||||
from typing import Tuple
|
||||
from urllib.parse import urlencode
|
||||
|
||||
from fastapi import Request
|
||||
from fastapi.responses import JSONResponse
|
||||
|
||||
from app.core.config import settings
|
||||
from app.core.response import success_response
|
||||
from app.utils.case import convert_keys_to_snake, add_camelcase_aliases, to_snake
|
||||
|
||||
|
||||
def _needs_api_transform(path: str) -> bool:
|
||||
return path.startswith(settings.API_V1_PREFIX)
|
||||
|
||||
|
||||
def _convert_query_params(request: Request) -> Tuple[Request, dict]:
|
||||
if not request.query_params:
|
||||
return request, {}
|
||||
|
||||
items = []
|
||||
param_map = {}
|
||||
for key, value in request.query_params.multi_items():
|
||||
new_key = to_snake(key)
|
||||
items.append((new_key, value))
|
||||
if new_key not in param_map:
|
||||
param_map[new_key] = value
|
||||
|
||||
# If page/page_size provided, add skip/limit for legacy endpoints
|
||||
if "page" in param_map and "page_size" in param_map:
|
||||
try:
|
||||
page = int(param_map["page"])
|
||||
page_size = int(param_map["page_size"])
|
||||
if page > 0 and page_size > 0:
|
||||
if "skip" not in param_map:
|
||||
items.append(("skip", str((page - 1) * page_size)))
|
||||
if "limit" not in param_map:
|
||||
items.append(("limit", str(page_size)))
|
||||
except ValueError:
|
||||
pass
|
||||
|
||||
# Transfers/Recoveries: map status -> approval_status
|
||||
path = request.url.path
|
||||
if path.endswith("/transfers") or path.endswith("/recoveries"):
|
||||
if "status" in param_map and "approval_status" not in param_map:
|
||||
items.append(("approval_status", param_map["status"]))
|
||||
|
||||
scope = dict(request.scope)
|
||||
scope["query_string"] = urlencode(items, doseq=True).encode()
|
||||
return Request(scope, request.receive), param_map
|
||||
|
||||
|
||||
async def _convert_json_body(request: Request) -> Request:
|
||||
content_type = request.headers.get("content-type", "")
|
||||
if "application/json" not in content_type.lower():
|
||||
return request
|
||||
|
||||
body = await request.body()
|
||||
if not body:
|
||||
return request
|
||||
|
||||
try:
|
||||
data = json.loads(body)
|
||||
except json.JSONDecodeError:
|
||||
return request
|
||||
|
||||
converted = convert_keys_to_snake(data)
|
||||
|
||||
# Path-specific payload compatibility
|
||||
path = request.url.path
|
||||
if request.method.upper() == "POST":
|
||||
if path.endswith("/transfers") and isinstance(converted, dict):
|
||||
if "reason" in converted and "title" not in converted:
|
||||
converted["title"] = converted.get("reason")
|
||||
converted.setdefault("transfer_type", "internal")
|
||||
if path.endswith("/recoveries") and isinstance(converted, dict):
|
||||
if "reason" in converted and "title" not in converted:
|
||||
converted["title"] = converted.get("reason")
|
||||
converted.setdefault("recovery_type", "org")
|
||||
|
||||
new_body = json.dumps(converted).encode()
|
||||
|
||||
async def receive():
|
||||
return {"type": "http.request", "body": new_body, "more_body": False}
|
||||
|
||||
return Request(request.scope, receive)
|
||||
|
||||
|
||||
async def _wrap_response(request: Request, response):
|
||||
# Skip non-API paths
|
||||
if not _needs_api_transform(request.url.path):
|
||||
return response
|
||||
|
||||
# Do not wrap errors; they are already handled by exception handlers
|
||||
if response.status_code >= 400:
|
||||
return response
|
||||
|
||||
# Normalize empty 204 responses
|
||||
if response.status_code == 204:
|
||||
wrapped = success_response(data=None)
|
||||
headers = dict(response.headers)
|
||||
headers.pop("content-length", None)
|
||||
return JSONResponse(status_code=200, content=wrapped, headers=headers)
|
||||
|
||||
content_type = response.headers.get("content-type", "")
|
||||
if "application/json" not in content_type.lower():
|
||||
return response
|
||||
|
||||
# Handle empty body (e.g., 204)
|
||||
body = getattr(response, "body", None)
|
||||
if not body:
|
||||
wrapped = success_response(data=None)
|
||||
headers = dict(response.headers)
|
||||
headers.pop("content-length", None)
|
||||
return JSONResponse(status_code=200, content=wrapped, headers=headers)
|
||||
|
||||
try:
|
||||
payload = json.loads(body)
|
||||
except json.JSONDecodeError:
|
||||
return response
|
||||
|
||||
if isinstance(payload, dict) and "code" in payload and "message" in payload:
|
||||
if "data" in payload:
|
||||
payload["data"] = add_camelcase_aliases(payload["data"])
|
||||
status_code = 200 if response.status_code == 204 else response.status_code
|
||||
headers = dict(response.headers)
|
||||
headers.pop("content-length", None)
|
||||
return JSONResponse(status_code=status_code, content=payload, headers=headers)
|
||||
|
||||
wrapped = success_response(data=add_camelcase_aliases(payload))
|
||||
status_code = 200 if response.status_code == 204 else response.status_code
|
||||
headers = dict(response.headers)
|
||||
headers.pop("content-length", None)
|
||||
return JSONResponse(status_code=status_code, content=wrapped, headers=headers)
|
||||
|
||||
|
||||
async def api_transform_middleware(request: Request, call_next):
|
||||
if _needs_api_transform(request.url.path):
|
||||
request, _ = _convert_query_params(request)
|
||||
request = await _convert_json_body(request)
|
||||
|
||||
response = await call_next(request)
|
||||
return await _wrap_response(request, response)
|
||||
194
backend/app/middleware/operation_log.py
Normal file
194
backend/app/middleware/operation_log.py
Normal file
@@ -0,0 +1,194 @@
|
||||
"""
|
||||
操作日志中间件
|
||||
"""
|
||||
import time
|
||||
import json
|
||||
from typing import Callable
|
||||
from fastapi import Request, Response
|
||||
from starlette.middleware.base import BaseHTTPMiddleware
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from app.db.session import async_session_maker
|
||||
from app.schemas.operation_log import OperationLogCreate, OperationModuleEnum, OperationTypeEnum, OperationResultEnum
|
||||
from app.services.operation_log_service import operation_log_service
|
||||
|
||||
|
||||
class OperationLogMiddleware(BaseHTTPMiddleware):
|
||||
"""操作日志中间件"""
|
||||
|
||||
# 不需要记录的路径
|
||||
EXCLUDE_PATHS = [
|
||||
"/health",
|
||||
"/docs",
|
||||
"/openapi.json",
|
||||
"/api/v1/auth/login",
|
||||
"/api/v1/auth/captcha",
|
||||
]
|
||||
|
||||
# 路径到模块的映射
|
||||
PATH_MODULE_MAP = {
|
||||
"/auth": OperationModuleEnum.AUTH,
|
||||
"/device-types": OperationModuleEnum.DEVICE_TYPE,
|
||||
"/organizations": OperationModuleEnum.ORGANIZATION,
|
||||
"/assets": OperationModuleEnum.ASSET,
|
||||
"/brands": OperationModuleEnum.BRAND_SUPPLIER,
|
||||
"/suppliers": OperationModuleEnum.BRAND_SUPPLIER,
|
||||
"/allocation-orders": OperationModuleEnum.ALLOCATION,
|
||||
"/maintenance-records": OperationModuleEnum.MAINTENANCE,
|
||||
"/system-config": OperationModuleEnum.SYSTEM_CONFIG,
|
||||
"/users": OperationModuleEnum.USER,
|
||||
"/statistics": OperationModuleEnum.STATISTICS,
|
||||
"/operation-logs": OperationModuleEnum.SYSTEM_CONFIG,
|
||||
"/notifications": OperationModuleEnum.SYSTEM_CONFIG,
|
||||
}
|
||||
|
||||
# 方法到操作类型的映射
|
||||
METHOD_OPERATION_MAP = {
|
||||
"GET": OperationTypeEnum.QUERY,
|
||||
"POST": OperationTypeEnum.CREATE,
|
||||
"PUT": OperationTypeEnum.UPDATE,
|
||||
"PATCH": OperationTypeEnum.UPDATE,
|
||||
"DELETE": OperationTypeEnum.DELETE,
|
||||
}
|
||||
|
||||
async def dispatch(self, request: Request, call_next: Callable) -> Response:
|
||||
"""处理请求"""
|
||||
# 检查是否需要记录
|
||||
if self._should_log(request):
|
||||
# 记录开始时间
|
||||
start_time = time.time()
|
||||
|
||||
# 获取用户信息
|
||||
user = getattr(request.state, "user", None)
|
||||
|
||||
# 处理请求
|
||||
response = await call_next(request)
|
||||
|
||||
# 计算执行时长
|
||||
duration = int((time.time() - start_time) * 1000)
|
||||
|
||||
# 异步记录日志
|
||||
if user:
|
||||
await self._log_operation(request, response, user, duration)
|
||||
|
||||
return response
|
||||
|
||||
return await call_next(request)
|
||||
|
||||
def _should_log(self, request: Request) -> bool:
|
||||
"""判断是否需要记录日志"""
|
||||
path = request.url.path
|
||||
|
||||
# 检查排除路径
|
||||
for exclude_path in self.EXCLUDE_PATHS:
|
||||
if path.startswith(exclude_path):
|
||||
return False
|
||||
|
||||
# 只记录API请求
|
||||
return path.startswith("/api/")
|
||||
|
||||
async def _log_operation(
|
||||
self,
|
||||
request: Request,
|
||||
response: Response,
|
||||
user,
|
||||
duration: int
|
||||
):
|
||||
"""记录操作日志"""
|
||||
try:
|
||||
# 获取模块
|
||||
module = self._get_module(request.url.path)
|
||||
|
||||
# 获取操作类型
|
||||
operation_type = self.METHOD_OPERATION_MAP.get(request.method, OperationTypeEnum.QUERY)
|
||||
|
||||
# 特殊处理:如果是登录/登出
|
||||
if "/auth/login" in request.url.path:
|
||||
operation_type = OperationTypeEnum.LOGIN
|
||||
elif "/auth/logout" in request.url.path:
|
||||
operation_type = OperationTypeEnum.LOGOUT
|
||||
|
||||
# 获取请求参数
|
||||
params = await self._get_request_params(request)
|
||||
|
||||
# 构建日志数据
|
||||
log_data = OperationLogCreate(
|
||||
operator_id=user.id,
|
||||
operator_name=user.real_name or user.username,
|
||||
operator_ip=request.client.host if request.client else None,
|
||||
module=module,
|
||||
operation_type=operation_type,
|
||||
method=request.method,
|
||||
url=request.url.path,
|
||||
params=params,
|
||||
result=OperationResultEnum.SUCCESS if response.status_code < 400 else OperationResultEnum.FAILED,
|
||||
error_msg=None if response.status_code < 400 else f"HTTP {response.status_code}",
|
||||
duration=duration,
|
||||
user_agent=request.headers.get("user-agent"),
|
||||
)
|
||||
|
||||
# 异步保存日志
|
||||
async with async_session_maker() as db:
|
||||
await operation_log_service.create_log(db, log_data)
|
||||
|
||||
except Exception as e:
|
||||
# 记录日志失败不应影响业务
|
||||
print(f"Failed to log operation: {e}")
|
||||
|
||||
def _get_module(self, path: str) -> OperationModuleEnum:
|
||||
"""根据路径获取模块"""
|
||||
for path_prefix, module in self.PATH_MODULE_MAP.items():
|
||||
if path_prefix in path:
|
||||
return module
|
||||
return OperationModuleEnum.SYSTEM_CONFIG
|
||||
|
||||
async def _get_request_params(self, request: Request) -> str:
|
||||
"""获取请求参数"""
|
||||
try:
|
||||
# GET请求
|
||||
if request.method == "GET":
|
||||
params = dict(request.query_params)
|
||||
return json.dumps(params, ensure_ascii=False)
|
||||
|
||||
# POST/PUT/DELETE请求
|
||||
if request.method in ["POST", "PUT", "DELETE", "PATCH"]:
|
||||
try:
|
||||
body = await request.body()
|
||||
if body:
|
||||
# 尝试解析JSON
|
||||
try:
|
||||
body_json = json.loads(body.decode())
|
||||
# 过滤敏感字段
|
||||
filtered_body = self._filter_sensitive_data(body_json)
|
||||
return json.dumps(filtered_body, ensure_ascii=False)
|
||||
except json.JSONDecodeError:
|
||||
# 不是JSON,返回原始数据
|
||||
return body.decode()[:500] # 限制长度
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
return ""
|
||||
except Exception:
|
||||
return ""
|
||||
|
||||
def _filter_sensitive_data(self, data: dict) -> dict:
|
||||
"""过滤敏感数据"""
|
||||
sensitive_fields = ["password", "old_password", "new_password", "token", "secret"]
|
||||
|
||||
if not isinstance(data, dict):
|
||||
return data
|
||||
|
||||
filtered = {}
|
||||
for key, value in data.items():
|
||||
if key in sensitive_fields:
|
||||
filtered[key] = "******"
|
||||
elif isinstance(value, dict):
|
||||
filtered[key] = self._filter_sensitive_data(value)
|
||||
elif isinstance(value, list):
|
||||
filtered[key] = [
|
||||
self._filter_sensitive_data(item) if isinstance(item, dict) else item
|
||||
for item in value
|
||||
]
|
||||
else:
|
||||
filtered[key] = value
|
||||
|
||||
return filtered
|
||||
Reference in New Issue
Block a user