128 lines
4.8 KiB
Python
128 lines
4.8 KiB
Python
import logging
|
|
import traceback
|
|
from collections.abc import Iterable
|
|
from datetime import timedelta
|
|
from typing import Any
|
|
|
|
from django.conf import settings
|
|
from django.utils import timezone
|
|
from rest_framework import status as http_status
|
|
from rest_framework.exceptions import ErrorDetail
|
|
from rest_framework.exceptions import Throttled
|
|
from rest_framework.response import Response
|
|
from rest_framework.views import exception_handler as drf_exception_handler
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
def _flatten_messages(values: Iterable) -> list[str]:
|
|
items: list[str] = []
|
|
for value in values:
|
|
items.extend(_to_str_list(value))
|
|
return items
|
|
|
|
|
|
def _to_str_list(value: str | ErrorDetail | list | tuple | dict) -> list[str]:
|
|
if isinstance(value, str | ErrorDetail):
|
|
return [str(value)]
|
|
if isinstance(value, list | tuple):
|
|
return _flatten_messages(value)
|
|
if isinstance(value, dict):
|
|
items: list[str] = []
|
|
for field, v in value.items():
|
|
msgs = _to_str_list(v)
|
|
for msg in msgs:
|
|
if field in ("non_field_errors", "__all__"):
|
|
items.append(str(msg))
|
|
else:
|
|
items.append(f"{msg}")
|
|
return items
|
|
return [str(value)]
|
|
|
|
|
|
def _format_payload(messages: list[str], status_code: int) -> dict[str, Any]:
|
|
clean_messages: list[str] = []
|
|
for msg in messages:
|
|
msg = msg.replace("error:", "").strip()
|
|
if ":" in msg:
|
|
_, only_msg = msg.split(":", 1)
|
|
clean_messages.append(only_msg.strip())
|
|
else:
|
|
clean_messages.append(msg)
|
|
|
|
error_message = messages[0] if messages else http_status.HTTP_STATUS_CODES.get(status_code, "Error")
|
|
|
|
return {
|
|
"error": error_message,
|
|
"status_code": status_code,
|
|
"messages": [{"message": msg} for msg in clean_messages],
|
|
}
|
|
|
|
|
|
def _request_extra(context: dict[str, Any]) -> dict[str, Any]:
|
|
request = context.get("request")
|
|
meta = getattr(request, "META", {})
|
|
return {
|
|
"request_method": getattr(request, "method", None),
|
|
"request_url": getattr(request, "get_full_path", lambda: None)(),
|
|
"remote_addr": meta.get("REMOTE_ADDR") if meta else None,
|
|
"user_agent": meta.get("HTTP_USER_AGENT", "") if meta else "",
|
|
}
|
|
|
|
|
|
def exception_handler(exc, context) -> Response:
|
|
response = drf_exception_handler(exc, context)
|
|
is_server_error = response is None or getattr(response, "status_code", 500) >= 500
|
|
if is_server_error:
|
|
logger.exception("DRF exception", extra=_request_extra(context))
|
|
if settings.DEBUG:
|
|
is_unhandled = response is None
|
|
if is_unhandled or is_server_error:
|
|
raise
|
|
|
|
if response is not None:
|
|
status_code = response.status_code
|
|
detail = response.data
|
|
if status_code < 500:
|
|
messages = _to_str_list(detail)
|
|
payload = _format_payload(messages, status_code)
|
|
if isinstance(exc, Throttled):
|
|
request = context.get("request")
|
|
wait = getattr(exc, "wait", None)
|
|
retry_after_seconds = None
|
|
if wait is not None:
|
|
retry_after_seconds = max(int(wait), 0)
|
|
elif request is not None:
|
|
retry_after_seconds = getattr(request, "_retry_after_seconds", None)
|
|
throttle_scope = getattr(request, "_throttle_scope", None) if request else None
|
|
payload.update(
|
|
{
|
|
"code": "throttled",
|
|
"scope": throttle_scope,
|
|
"retry_after_seconds": retry_after_seconds,
|
|
"throttled_until": (
|
|
timezone.now() + timedelta(seconds=retry_after_seconds)
|
|
).isoformat()
|
|
if retry_after_seconds is not None
|
|
else None,
|
|
}
|
|
)
|
|
formatted_response = Response(payload, status=status_code)
|
|
for header, value in response.headers.items():
|
|
formatted_response[header] = value
|
|
if isinstance(exc, Throttled) and "Retry-After" not in formatted_response:
|
|
request = context.get("request")
|
|
retry_after_seconds = getattr(request, "_retry_after_seconds", None) if request else None
|
|
if retry_after_seconds is not None:
|
|
formatted_response["Retry-After"] = str(max(int(retry_after_seconds), 0))
|
|
return formatted_response
|
|
|
|
traceback_text = traceback.format_exc()
|
|
payload = _format_payload(
|
|
["Internal server error."],
|
|
http_status.HTTP_500_INTERNAL_SERVER_ERROR,
|
|
)
|
|
payload["exception"] = str(exc)
|
|
payload["traceback"] = traceback_text
|
|
return Response(payload, status=http_status.HTTP_500_INTERNAL_SERVER_ERROR)
|