""" API request/response transformation middleware. """ import json from typing import Tuple from urllib.parse import urlencode from fastapi import Request from fastapi.responses import JSONResponse from app.core.config import settings from app.core.response import success_response from app.utils.case import convert_keys_to_snake, add_camelcase_aliases, to_snake def _needs_api_transform(path: str) -> bool: return path.startswith(settings.API_V1_PREFIX) def _convert_query_params(request: Request) -> Tuple[Request, dict]: if not request.query_params: return request, {} items = [] param_map = {} for key, value in request.query_params.multi_items(): new_key = to_snake(key) items.append((new_key, value)) if new_key not in param_map: param_map[new_key] = value # If page/page_size provided, add skip/limit for legacy endpoints if "page" in param_map and "page_size" in param_map: try: page = int(param_map["page"]) page_size = int(param_map["page_size"]) if page > 0 and page_size > 0: if "skip" not in param_map: items.append(("skip", str((page - 1) * page_size))) if "limit" not in param_map: items.append(("limit", str(page_size))) except ValueError: pass # Transfers/Recoveries: map status -> approval_status path = request.url.path if path.endswith("/transfers") or path.endswith("/recoveries"): if "status" in param_map and "approval_status" not in param_map: items.append(("approval_status", param_map["status"])) scope = dict(request.scope) scope["query_string"] = urlencode(items, doseq=True).encode() return Request(scope, request.receive), param_map async def _convert_json_body(request: Request) -> Request: content_type = request.headers.get("content-type", "") if "application/json" not in content_type.lower(): return request body = await request.body() if not body: return request try: data = json.loads(body) except json.JSONDecodeError: return request converted = convert_keys_to_snake(data) # Path-specific payload compatibility path = request.url.path if request.method.upper() == "POST": if path.endswith("/transfers") and isinstance(converted, dict): if "reason" in converted and "title" not in converted: converted["title"] = converted.get("reason") converted.setdefault("transfer_type", "internal") if path.endswith("/recoveries") and isinstance(converted, dict): if "reason" in converted and "title" not in converted: converted["title"] = converted.get("reason") converted.setdefault("recovery_type", "org") new_body = json.dumps(converted).encode() async def receive(): return {"type": "http.request", "body": new_body, "more_body": False} return Request(request.scope, receive) async def _wrap_response(request: Request, response): # Skip non-API paths if not _needs_api_transform(request.url.path): return response # Do not wrap errors; they are already handled by exception handlers if response.status_code >= 400: return response # Normalize empty 204 responses if response.status_code == 204: wrapped = success_response(data=None) headers = dict(response.headers) headers.pop("content-length", None) return JSONResponse(status_code=200, content=wrapped, headers=headers) content_type = response.headers.get("content-type", "") if "application/json" not in content_type.lower(): return response # Handle empty body (e.g., 204) body = getattr(response, "body", None) if not body: wrapped = success_response(data=None) headers = dict(response.headers) headers.pop("content-length", None) return JSONResponse(status_code=200, content=wrapped, headers=headers) try: payload = json.loads(body) except json.JSONDecodeError: return response if isinstance(payload, dict) and "code" in payload and "message" in payload: if "data" in payload: payload["data"] = add_camelcase_aliases(payload["data"]) status_code = 200 if response.status_code == 204 else response.status_code headers = dict(response.headers) headers.pop("content-length", None) return JSONResponse(status_code=status_code, content=payload, headers=headers) wrapped = success_response(data=add_camelcase_aliases(payload)) status_code = 200 if response.status_code == 204 else response.status_code headers = dict(response.headers) headers.pop("content-length", None) return JSONResponse(status_code=status_code, content=wrapped, headers=headers) async def api_transform_middleware(request: Request, call_next): if _needs_api_transform(request.url.path): request, _ = _convert_query_params(request) request = await _convert_json_body(request) response = await call_next(request) return await _wrap_response(request, response)