feat(throttling): add auth throttling and structured cooldown errors

This commit is contained in:
2026-04-30 15:25:45 +03:30
parent 3152284cf3
commit 08e1793765
5 changed files with 338 additions and 1 deletions

View File

@@ -0,0 +1,99 @@
from __future__ import annotations
import re
from django.core.exceptions import ImproperlyConfigured
from rest_framework.settings import api_settings
from rest_framework.throttling import SimpleRateThrottle
class ScopedMobileThrottle(SimpleRateThrottle):
scope = ""
def get_rate(self):
if not self.scope:
raise ImproperlyConfigured(
f"{self.__class__.__name__} must define a scope or rate."
)
try:
return api_settings.DEFAULT_THROTTLE_RATES[self.scope]
except KeyError as exc:
raise ImproperlyConfigured(
f'No default throttle rate set for scope "{self.scope}".'
) from exc
def parse_rate(self, rate):
if rate is None:
return (None, None)
num_requests, period = rate.split("/")
match = re.fullmatch(r"(?:(\d+)\s*)?([smhd]|sec|second|min|minute|hour|day)s?", period.strip(), re.IGNORECASE)
if not match:
return super().parse_rate(rate)
multiplier = int(match.group(1) or "1")
unit = match.group(2).lower()
unit_seconds = {
"s": 1,
"sec": 1,
"second": 1,
"m": 60,
"min": 60,
"minute": 60,
"h": 3600,
"hour": 3600,
"d": 86400,
"day": 86400,
}[unit]
return int(num_requests), multiplier * unit_seconds
def get_mobile_identifier(self, request) -> str | None:
mobile = None
try:
mobile = request.data.get("mobile")
except Exception:
mobile = None
if not isinstance(mobile, str):
return None
normalized = "".join(ch for ch in mobile if ch.isdigit())
return normalized or None
def get_cache_key(self, request, view):
ident = self.get_ident(request)
mobile = self.get_mobile_identifier(request)
if mobile:
return self.cache_format % {
"scope": self.scope,
"ident": f"{ident}:{mobile}",
}
return self.cache_format % {
"scope": self.scope,
"ident": ident,
}
def allow_request(self, request, view):
allowed = super().allow_request(request, view)
if not allowed:
request._throttle_scope = self.scope
request._retry_after_seconds = self.wait()
return allowed
class OTPSendBurstThrottle(ScopedMobileThrottle):
scope = "otp_send_burst"
class OTPSendSustainedThrottle(ScopedMobileThrottle):
scope = "otp_send_sustained"
class PasswordLoginThrottle(ScopedMobileThrottle):
scope = "login_password"
class OTPLoginThrottle(ScopedMobileThrottle):
scope = "login_otp"

View File

@@ -30,6 +30,12 @@ from apps.users.api.serializers import (
UserProfileSerializer, UserProfileSerializer,
UserSearchSerializer, UserSearchSerializer,
) )
from apps.users.api.throttles import (
OTPLoginThrottle,
OTPSendBurstThrottle,
OTPSendSustainedThrottle,
PasswordLoginThrottle,
)
from apps.users.services.auth import ( from apps.users.services.auth import (
register_user_with_password, register_user_with_password,
register_user_with_otp, register_user_with_otp,
@@ -91,6 +97,7 @@ class SendOTPView(APIView):
+ password reset + password reset
""" """
permission_classes = (AllowAny,) permission_classes = (AllowAny,)
throttle_classes = [OTPSendBurstThrottle, OTPSendSustainedThrottle]
@extend_schema(request=SendOTPSerializer, responses=None) @extend_schema(request=SendOTPSerializer, responses=None)
def post(self, request): def post(self, request):
@@ -107,6 +114,7 @@ class SendOTPView(APIView):
class LoginView(APIView): class LoginView(APIView):
permission_classes = (AllowAny,) permission_classes = (AllowAny,)
throttle_classes = [PasswordLoginThrottle]
@extend_schema(request=LoginSerializer, responses=TokenPairSerializer) @extend_schema(request=LoginSerializer, responses=TokenPairSerializer)
def post(self, request): def post(self, request):
@@ -123,6 +131,7 @@ class LoginView(APIView):
class LoginOTPView(APIView): class LoginOTPView(APIView):
permission_classes = (AllowAny,) permission_classes = (AllowAny,)
throttle_classes = [OTPLoginThrottle]
@extend_schema(request=LoginOtpSerializer, responses=TokenPairSerializer) @extend_schema(request=LoginOtpSerializer, responses=TokenPairSerializer)
def post(self, request): def post(self, request):

View File

@@ -1,5 +1,8 @@
from unittest.mock import patch from unittest.mock import patch
from django.conf import settings
from django.core.cache import cache
from django.test import override_settings
from rest_framework.test import APIRequestFactory from rest_framework.test import APIRequestFactory
from rest_framework import status from rest_framework import status
from rest_framework.test import APITestCase from rest_framework.test import APITestCase
@@ -197,3 +200,189 @@ class UserApiViewTests(APITestCase):
success = self.client.get(f"/api/users/search/?mobile={self.other_user.mobile}") success = self.client.get(f"/api/users/search/?mobile={self.other_user.mobile}")
self.assertEqual(success.status_code, status.HTTP_200_OK) self.assertEqual(success.status_code, status.HTTP_200_OK)
self.assertEqual(success.data["mobile"], self.other_user.mobile) self.assertEqual(success.data["mobile"], self.other_user.mobile)
class UserThrottleTests(APITestCase):
@classmethod
def setUpTestData(cls):
cls.user = User.objects.create_user(
mobile="09124440001",
password="secret123",
)
def setUp(self):
cache.clear()
def tearDown(self):
cache.clear()
@override_settings(
REST_FRAMEWORK={
**settings.REST_FRAMEWORK,
"DEFAULT_THROTTLE_RATES": {
**settings.REST_FRAMEWORK["DEFAULT_THROTTLE_RATES"],
"login_password": "2/min",
},
}
)
@patch("apps.users.api.views.login_with_password")
def test_password_login_returns_structured_429_with_retry_after(self, login_with_password):
login_with_password.return_value = {"access": "a", "refresh": "r"}
first = self.client.post(
"/api/users/login/",
{"mobile": "09124440001", "password": "secret123"},
format="json",
REMOTE_ADDR="10.0.0.1",
)
second = self.client.post(
"/api/users/login/",
{"mobile": "09124440001", "password": "secret123"},
format="json",
REMOTE_ADDR="10.0.0.1",
)
throttled = self.client.post(
"/api/users/login/",
{"mobile": "09124440001", "password": "secret123"},
format="json",
REMOTE_ADDR="10.0.0.1",
)
self.assertEqual(first.status_code, 200)
self.assertEqual(second.status_code, 200)
self.assertEqual(throttled.status_code, 429)
self.assertEqual(throttled.data["code"], "throttled")
self.assertEqual(throttled.data["scope"], "login_password")
self.assertIsInstance(throttled.data["retry_after_seconds"], int)
self.assertTrue(throttled.data["throttled_until"])
self.assertIn("Retry-After", throttled.headers)
@override_settings(
REST_FRAMEWORK={
**settings.REST_FRAMEWORK,
"DEFAULT_THROTTLE_RATES": {
**settings.REST_FRAMEWORK["DEFAULT_THROTTLE_RATES"],
"otp_send_burst": "1/min",
"otp_send_sustained": "10/day",
},
}
)
@patch("apps.users.api.views.generate_and_send_otp")
def test_otp_send_throttle_is_keyed_by_mobile_and_ip(self, generate_and_send_otp):
first_mobile_first = self.client.post(
"/api/users/otp/send/",
{"mobile": "09124440011", "mode": "login"},
format="json",
REMOTE_ADDR="10.0.0.2",
)
second_mobile_first = self.client.post(
"/api/users/otp/send/",
{"mobile": "09124440012", "mode": "login"},
format="json",
REMOTE_ADDR="10.0.0.2",
)
first_mobile_second = self.client.post(
"/api/users/otp/send/",
{"mobile": "09124440011", "mode": "login"},
format="json",
REMOTE_ADDR="10.0.0.2",
)
self.assertEqual(first_mobile_first.status_code, 200)
self.assertEqual(second_mobile_first.status_code, 200)
self.assertEqual(first_mobile_second.status_code, 429)
self.assertEqual(first_mobile_second.data["scope"], "otp_send_burst")
@override_settings(
REST_FRAMEWORK={
**settings.REST_FRAMEWORK,
"DEFAULT_THROTTLE_RATES": {
**settings.REST_FRAMEWORK["DEFAULT_THROTTLE_RATES"],
"login_otp": "1/min",
},
}
)
@patch("apps.users.api.views.login_with_otp")
def test_otp_login_throttle_blocks_after_limit(self, login_with_otp):
login_with_otp.return_value = {"access": "a", "refresh": "r"}
allowed = self.client.post(
"/api/users/otp/login/",
{"mobile": "09124440021", "code": "123456"},
format="json",
REMOTE_ADDR="10.0.0.3",
)
throttled = self.client.post(
"/api/users/otp/login/",
{"mobile": "09124440021", "code": "123456"},
format="json",
REMOTE_ADDR="10.0.0.3",
)
self.assertEqual(allowed.status_code, 200)
self.assertEqual(throttled.status_code, 429)
self.assertEqual(throttled.data["scope"], "login_otp")
@patch.dict("rest_framework.throttling.AnonRateThrottle.THROTTLE_RATES", {"anon": "1/min"}, clear=False)
@patch("apps.users.api.views.register_user_with_otp")
def test_global_anon_throttle_applies_to_unrestricted_anonymous_endpoint(
self,
register_user_with_otp,
):
register_user_with_otp.return_value = {"access": "a", "refresh": "r"}
first = self.client.post(
"/api/users/register/",
{
"mobile": "09124440031",
"code": "12345",
"password": "secret123",
"re_password": "secret123",
},
format="json",
REMOTE_ADDR="10.0.0.4",
)
throttled = self.client.post(
"/api/users/register/",
{
"mobile": "09124440032",
"code": "12345",
"password": "secret123",
"re_password": "secret123",
},
format="json",
REMOTE_ADDR="10.0.0.4",
)
self.assertEqual(first.status_code, 201)
self.assertEqual(throttled.status_code, 429)
self.assertEqual(throttled.data["code"], "throttled")
@override_settings(
REST_FRAMEWORK={
**settings.REST_FRAMEWORK,
"DEFAULT_THROTTLE_RATES": {
**settings.REST_FRAMEWORK["DEFAULT_THROTTLE_RATES"],
"login_password": "1/min",
},
}
)
@patch("apps.users.api.views.login_with_password")
def test_throttle_falls_back_to_ip_when_mobile_is_missing(self, login_with_password):
login_with_password.return_value = {"access": "a", "refresh": "r"}
first = self.client.post(
"/api/users/login/",
{"password": "secret123"},
format="json",
REMOTE_ADDR="10.0.0.5",
)
second = self.client.post(
"/api/users/login/",
{"password": "secret123"},
format="json",
REMOTE_ADDR="10.0.0.5",
)
self.assertEqual(first.status_code, 400)
self.assertEqual(second.status_code, 429)

View File

@@ -131,6 +131,14 @@ REST_FRAMEWORK = {
"rest_framework.throttling.AnonRateThrottle", "rest_framework.throttling.AnonRateThrottle",
"rest_framework.throttling.UserRateThrottle", "rest_framework.throttling.UserRateThrottle",
], ],
"DEFAULT_THROTTLE_RATES": {
"anon": "60/min",
"user": "300/min",
"otp_send_burst": "3/10m",
"otp_send_sustained": "10/day",
"login_password": "5/10m",
"login_otp": "5/10m",
},
"EXCEPTION_HANDLER": "core.exceptions.handlers.exception_handler", "EXCEPTION_HANDLER": "core.exceptions.handlers.exception_handler",
} }

View File

@@ -1,11 +1,14 @@
import logging import logging
import traceback import traceback
from collections.abc import Iterable from collections.abc import Iterable
from datetime import timedelta
from typing import Any from typing import Any
from django.conf import settings from django.conf import settings
from django.utils import timezone
from rest_framework import status as http_status from rest_framework import status as http_status
from rest_framework.exceptions import ErrorDetail from rest_framework.exceptions import ErrorDetail
from rest_framework.exceptions import Throttled
from rest_framework.response import Response from rest_framework.response import Response
from rest_framework.views import exception_handler as drf_exception_handler from rest_framework.views import exception_handler as drf_exception_handler
@@ -83,7 +86,36 @@ def exception_handler(exc, context) -> Response:
if status_code < 500: if status_code < 500:
messages = _to_str_list(detail) messages = _to_str_list(detail)
payload = _format_payload(messages, status_code) payload = _format_payload(messages, status_code)
return Response(payload, status=status_code) if isinstance(exc, Throttled):
request = context.get("request")
wait = getattr(exc, "wait", None)
retry_after_seconds = None
if wait is not None:
retry_after_seconds = max(int(wait), 0)
elif request is not None:
retry_after_seconds = getattr(request, "_retry_after_seconds", None)
throttle_scope = getattr(request, "_throttle_scope", None) if request else None
payload.update(
{
"code": "throttled",
"scope": throttle_scope,
"retry_after_seconds": retry_after_seconds,
"throttled_until": (
timezone.now() + timedelta(seconds=retry_after_seconds)
).isoformat()
if retry_after_seconds is not None
else None,
}
)
formatted_response = Response(payload, status=status_code)
for header, value in response.headers.items():
formatted_response[header] = value
if isinstance(exc, Throttled) and "Retry-After" not in formatted_response:
request = context.get("request")
retry_after_seconds = getattr(request, "_retry_after_seconds", None) if request else None
if retry_after_seconds is not None:
formatted_response["Retry-After"] = str(max(int(retry_after_seconds), 0))
return formatted_response
traceback_text = traceback.format_exc() traceback_text = traceback.format_exc()
payload = _format_payload( payload = _format_payload(