Fix API compatibility and add user/role/permission and asset import/export

This commit is contained in:
2026-01-25 23:36:23 +08:00
commit 501d11e14e
371 changed files with 68853 additions and 0 deletions

View File

@@ -0,0 +1,6 @@
"""
中间件模块
"""
from app.middleware.operation_log import OperationLogMiddleware
__all__ = ["OperationLogMiddleware"]

View File

@@ -0,0 +1,146 @@
"""
API request/response transformation middleware.
"""
import json
from typing import Tuple
from urllib.parse import urlencode
from fastapi import Request
from fastapi.responses import JSONResponse
from app.core.config import settings
from app.core.response import success_response
from app.utils.case import convert_keys_to_snake, add_camelcase_aliases, to_snake
def _needs_api_transform(path: str) -> bool:
return path.startswith(settings.API_V1_PREFIX)
def _convert_query_params(request: Request) -> Tuple[Request, dict]:
if not request.query_params:
return request, {}
items = []
param_map = {}
for key, value in request.query_params.multi_items():
new_key = to_snake(key)
items.append((new_key, value))
if new_key not in param_map:
param_map[new_key] = value
# If page/page_size provided, add skip/limit for legacy endpoints
if "page" in param_map and "page_size" in param_map:
try:
page = int(param_map["page"])
page_size = int(param_map["page_size"])
if page > 0 and page_size > 0:
if "skip" not in param_map:
items.append(("skip", str((page - 1) * page_size)))
if "limit" not in param_map:
items.append(("limit", str(page_size)))
except ValueError:
pass
# Transfers/Recoveries: map status -> approval_status
path = request.url.path
if path.endswith("/transfers") or path.endswith("/recoveries"):
if "status" in param_map and "approval_status" not in param_map:
items.append(("approval_status", param_map["status"]))
scope = dict(request.scope)
scope["query_string"] = urlencode(items, doseq=True).encode()
return Request(scope, request.receive), param_map
async def _convert_json_body(request: Request) -> Request:
content_type = request.headers.get("content-type", "")
if "application/json" not in content_type.lower():
return request
body = await request.body()
if not body:
return request
try:
data = json.loads(body)
except json.JSONDecodeError:
return request
converted = convert_keys_to_snake(data)
# Path-specific payload compatibility
path = request.url.path
if request.method.upper() == "POST":
if path.endswith("/transfers") and isinstance(converted, dict):
if "reason" in converted and "title" not in converted:
converted["title"] = converted.get("reason")
converted.setdefault("transfer_type", "internal")
if path.endswith("/recoveries") and isinstance(converted, dict):
if "reason" in converted and "title" not in converted:
converted["title"] = converted.get("reason")
converted.setdefault("recovery_type", "org")
new_body = json.dumps(converted).encode()
async def receive():
return {"type": "http.request", "body": new_body, "more_body": False}
return Request(request.scope, receive)
async def _wrap_response(request: Request, response):
# Skip non-API paths
if not _needs_api_transform(request.url.path):
return response
# Do not wrap errors; they are already handled by exception handlers
if response.status_code >= 400:
return response
# Normalize empty 204 responses
if response.status_code == 204:
wrapped = success_response(data=None)
headers = dict(response.headers)
headers.pop("content-length", None)
return JSONResponse(status_code=200, content=wrapped, headers=headers)
content_type = response.headers.get("content-type", "")
if "application/json" not in content_type.lower():
return response
# Handle empty body (e.g., 204)
body = getattr(response, "body", None)
if not body:
wrapped = success_response(data=None)
headers = dict(response.headers)
headers.pop("content-length", None)
return JSONResponse(status_code=200, content=wrapped, headers=headers)
try:
payload = json.loads(body)
except json.JSONDecodeError:
return response
if isinstance(payload, dict) and "code" in payload and "message" in payload:
if "data" in payload:
payload["data"] = add_camelcase_aliases(payload["data"])
status_code = 200 if response.status_code == 204 else response.status_code
headers = dict(response.headers)
headers.pop("content-length", None)
return JSONResponse(status_code=status_code, content=payload, headers=headers)
wrapped = success_response(data=add_camelcase_aliases(payload))
status_code = 200 if response.status_code == 204 else response.status_code
headers = dict(response.headers)
headers.pop("content-length", None)
return JSONResponse(status_code=status_code, content=wrapped, headers=headers)
async def api_transform_middleware(request: Request, call_next):
if _needs_api_transform(request.url.path):
request, _ = _convert_query_params(request)
request = await _convert_json_body(request)
response = await call_next(request)
return await _wrap_response(request, response)

View File

@@ -0,0 +1,194 @@
"""
操作日志中间件
"""
import time
import json
from typing import Callable
from fastapi import Request, Response
from starlette.middleware.base import BaseHTTPMiddleware
from sqlalchemy.ext.asyncio import AsyncSession
from app.db.session import async_session_maker
from app.schemas.operation_log import OperationLogCreate, OperationModuleEnum, OperationTypeEnum, OperationResultEnum
from app.services.operation_log_service import operation_log_service
class OperationLogMiddleware(BaseHTTPMiddleware):
"""操作日志中间件"""
# 不需要记录的路径
EXCLUDE_PATHS = [
"/health",
"/docs",
"/openapi.json",
"/api/v1/auth/login",
"/api/v1/auth/captcha",
]
# 路径到模块的映射
PATH_MODULE_MAP = {
"/auth": OperationModuleEnum.AUTH,
"/device-types": OperationModuleEnum.DEVICE_TYPE,
"/organizations": OperationModuleEnum.ORGANIZATION,
"/assets": OperationModuleEnum.ASSET,
"/brands": OperationModuleEnum.BRAND_SUPPLIER,
"/suppliers": OperationModuleEnum.BRAND_SUPPLIER,
"/allocation-orders": OperationModuleEnum.ALLOCATION,
"/maintenance-records": OperationModuleEnum.MAINTENANCE,
"/system-config": OperationModuleEnum.SYSTEM_CONFIG,
"/users": OperationModuleEnum.USER,
"/statistics": OperationModuleEnum.STATISTICS,
"/operation-logs": OperationModuleEnum.SYSTEM_CONFIG,
"/notifications": OperationModuleEnum.SYSTEM_CONFIG,
}
# 方法到操作类型的映射
METHOD_OPERATION_MAP = {
"GET": OperationTypeEnum.QUERY,
"POST": OperationTypeEnum.CREATE,
"PUT": OperationTypeEnum.UPDATE,
"PATCH": OperationTypeEnum.UPDATE,
"DELETE": OperationTypeEnum.DELETE,
}
async def dispatch(self, request: Request, call_next: Callable) -> Response:
"""处理请求"""
# 检查是否需要记录
if self._should_log(request):
# 记录开始时间
start_time = time.time()
# 获取用户信息
user = getattr(request.state, "user", None)
# 处理请求
response = await call_next(request)
# 计算执行时长
duration = int((time.time() - start_time) * 1000)
# 异步记录日志
if user:
await self._log_operation(request, response, user, duration)
return response
return await call_next(request)
def _should_log(self, request: Request) -> bool:
"""判断是否需要记录日志"""
path = request.url.path
# 检查排除路径
for exclude_path in self.EXCLUDE_PATHS:
if path.startswith(exclude_path):
return False
# 只记录API请求
return path.startswith("/api/")
async def _log_operation(
self,
request: Request,
response: Response,
user,
duration: int
):
"""记录操作日志"""
try:
# 获取模块
module = self._get_module(request.url.path)
# 获取操作类型
operation_type = self.METHOD_OPERATION_MAP.get(request.method, OperationTypeEnum.QUERY)
# 特殊处理:如果是登录/登出
if "/auth/login" in request.url.path:
operation_type = OperationTypeEnum.LOGIN
elif "/auth/logout" in request.url.path:
operation_type = OperationTypeEnum.LOGOUT
# 获取请求参数
params = await self._get_request_params(request)
# 构建日志数据
log_data = OperationLogCreate(
operator_id=user.id,
operator_name=user.real_name or user.username,
operator_ip=request.client.host if request.client else None,
module=module,
operation_type=operation_type,
method=request.method,
url=request.url.path,
params=params,
result=OperationResultEnum.SUCCESS if response.status_code < 400 else OperationResultEnum.FAILED,
error_msg=None if response.status_code < 400 else f"HTTP {response.status_code}",
duration=duration,
user_agent=request.headers.get("user-agent"),
)
# 异步保存日志
async with async_session_maker() as db:
await operation_log_service.create_log(db, log_data)
except Exception as e:
# 记录日志失败不应影响业务
print(f"Failed to log operation: {e}")
def _get_module(self, path: str) -> OperationModuleEnum:
"""根据路径获取模块"""
for path_prefix, module in self.PATH_MODULE_MAP.items():
if path_prefix in path:
return module
return OperationModuleEnum.SYSTEM_CONFIG
async def _get_request_params(self, request: Request) -> str:
"""获取请求参数"""
try:
# GET请求
if request.method == "GET":
params = dict(request.query_params)
return json.dumps(params, ensure_ascii=False)
# POST/PUT/DELETE请求
if request.method in ["POST", "PUT", "DELETE", "PATCH"]:
try:
body = await request.body()
if body:
# 尝试解析JSON
try:
body_json = json.loads(body.decode())
# 过滤敏感字段
filtered_body = self._filter_sensitive_data(body_json)
return json.dumps(filtered_body, ensure_ascii=False)
except json.JSONDecodeError:
# 不是JSON返回原始数据
return body.decode()[:500] # 限制长度
except Exception:
pass
return ""
except Exception:
return ""
def _filter_sensitive_data(self, data: dict) -> dict:
"""过滤敏感数据"""
sensitive_fields = ["password", "old_password", "new_password", "token", "secret"]
if not isinstance(data, dict):
return data
filtered = {}
for key, value in data.items():
if key in sensitive_fields:
filtered[key] = "******"
elif isinstance(value, dict):
filtered[key] = self._filter_sensitive_data(value)
elif isinstance(value, list):
filtered[key] = [
self._filter_sensitive_data(item) if isinstance(item, dict) else item
for item in value
]
else:
filtered[key] = value
return filtered