147 lines
5.1 KiB
Python
147 lines
5.1 KiB
Python
"""
|
|
API request/response transformation middleware.
|
|
"""
|
|
import json
|
|
from typing import Tuple
|
|
from urllib.parse import urlencode
|
|
|
|
from fastapi import Request
|
|
from fastapi.responses import JSONResponse
|
|
|
|
from app.core.config import settings
|
|
from app.core.response import success_response
|
|
from app.utils.case import convert_keys_to_snake, add_camelcase_aliases, to_snake
|
|
|
|
|
|
def _needs_api_transform(path: str) -> bool:
|
|
return path.startswith(settings.API_V1_PREFIX)
|
|
|
|
|
|
def _convert_query_params(request: Request) -> Tuple[Request, dict]:
|
|
if not request.query_params:
|
|
return request, {}
|
|
|
|
items = []
|
|
param_map = {}
|
|
for key, value in request.query_params.multi_items():
|
|
new_key = to_snake(key)
|
|
items.append((new_key, value))
|
|
if new_key not in param_map:
|
|
param_map[new_key] = value
|
|
|
|
# If page/page_size provided, add skip/limit for legacy endpoints
|
|
if "page" in param_map and "page_size" in param_map:
|
|
try:
|
|
page = int(param_map["page"])
|
|
page_size = int(param_map["page_size"])
|
|
if page > 0 and page_size > 0:
|
|
if "skip" not in param_map:
|
|
items.append(("skip", str((page - 1) * page_size)))
|
|
if "limit" not in param_map:
|
|
items.append(("limit", str(page_size)))
|
|
except ValueError:
|
|
pass
|
|
|
|
# Transfers/Recoveries: map status -> approval_status
|
|
path = request.url.path
|
|
if path.endswith("/transfers") or path.endswith("/recoveries"):
|
|
if "status" in param_map and "approval_status" not in param_map:
|
|
items.append(("approval_status", param_map["status"]))
|
|
|
|
scope = dict(request.scope)
|
|
scope["query_string"] = urlencode(items, doseq=True).encode()
|
|
return Request(scope, request.receive), param_map
|
|
|
|
|
|
async def _convert_json_body(request: Request) -> Request:
|
|
content_type = request.headers.get("content-type", "")
|
|
if "application/json" not in content_type.lower():
|
|
return request
|
|
|
|
body = await request.body()
|
|
if not body:
|
|
return request
|
|
|
|
try:
|
|
data = json.loads(body)
|
|
except json.JSONDecodeError:
|
|
return request
|
|
|
|
converted = convert_keys_to_snake(data)
|
|
|
|
# Path-specific payload compatibility
|
|
path = request.url.path
|
|
if request.method.upper() == "POST":
|
|
if path.endswith("/transfers") and isinstance(converted, dict):
|
|
if "reason" in converted and "title" not in converted:
|
|
converted["title"] = converted.get("reason")
|
|
converted.setdefault("transfer_type", "internal")
|
|
if path.endswith("/recoveries") and isinstance(converted, dict):
|
|
if "reason" in converted and "title" not in converted:
|
|
converted["title"] = converted.get("reason")
|
|
converted.setdefault("recovery_type", "org")
|
|
|
|
new_body = json.dumps(converted).encode()
|
|
|
|
async def receive():
|
|
return {"type": "http.request", "body": new_body, "more_body": False}
|
|
|
|
return Request(request.scope, receive)
|
|
|
|
|
|
async def _wrap_response(request: Request, response):
|
|
# Skip non-API paths
|
|
if not _needs_api_transform(request.url.path):
|
|
return response
|
|
|
|
# Do not wrap errors; they are already handled by exception handlers
|
|
if response.status_code >= 400:
|
|
return response
|
|
|
|
# Normalize empty 204 responses
|
|
if response.status_code == 204:
|
|
wrapped = success_response(data=None)
|
|
headers = dict(response.headers)
|
|
headers.pop("content-length", None)
|
|
return JSONResponse(status_code=200, content=wrapped, headers=headers)
|
|
|
|
content_type = response.headers.get("content-type", "")
|
|
if "application/json" not in content_type.lower():
|
|
return response
|
|
|
|
# Handle empty body (e.g., 204)
|
|
body = getattr(response, "body", None)
|
|
if not body:
|
|
wrapped = success_response(data=None)
|
|
headers = dict(response.headers)
|
|
headers.pop("content-length", None)
|
|
return JSONResponse(status_code=200, content=wrapped, headers=headers)
|
|
|
|
try:
|
|
payload = json.loads(body)
|
|
except json.JSONDecodeError:
|
|
return response
|
|
|
|
if isinstance(payload, dict) and "code" in payload and "message" in payload:
|
|
if "data" in payload:
|
|
payload["data"] = add_camelcase_aliases(payload["data"])
|
|
status_code = 200 if response.status_code == 204 else response.status_code
|
|
headers = dict(response.headers)
|
|
headers.pop("content-length", None)
|
|
return JSONResponse(status_code=status_code, content=payload, headers=headers)
|
|
|
|
wrapped = success_response(data=add_camelcase_aliases(payload))
|
|
status_code = 200 if response.status_code == 204 else response.status_code
|
|
headers = dict(response.headers)
|
|
headers.pop("content-length", None)
|
|
return JSONResponse(status_code=status_code, content=wrapped, headers=headers)
|
|
|
|
|
|
async def api_transform_middleware(request: Request, call_next):
|
|
if _needs_api_transform(request.url.path):
|
|
request, _ = _convert_query_params(request)
|
|
request = await _convert_json_body(request)
|
|
|
|
response = await call_next(request)
|
|
return await _wrap_response(request, response)
|