feat(throttling): add auth throttling and structured cooldown errors

This commit is contained in:
2026-04-30 15:25:45 +03:30
parent 3152284cf3
commit 08e1793765
5 changed files with 338 additions and 1 deletions

View File

@@ -0,0 +1,99 @@
from __future__ import annotations
import re
from django.core.exceptions import ImproperlyConfigured
from rest_framework.settings import api_settings
from rest_framework.throttling import SimpleRateThrottle
class ScopedMobileThrottle(SimpleRateThrottle):
scope = ""
def get_rate(self):
if not self.scope:
raise ImproperlyConfigured(
f"{self.__class__.__name__} must define a scope or rate."
)
try:
return api_settings.DEFAULT_THROTTLE_RATES[self.scope]
except KeyError as exc:
raise ImproperlyConfigured(
f'No default throttle rate set for scope "{self.scope}".'
) from exc
def parse_rate(self, rate):
if rate is None:
return (None, None)
num_requests, period = rate.split("/")
match = re.fullmatch(r"(?:(\d+)\s*)?([smhd]|sec|second|min|minute|hour|day)s?", period.strip(), re.IGNORECASE)
if not match:
return super().parse_rate(rate)
multiplier = int(match.group(1) or "1")
unit = match.group(2).lower()
unit_seconds = {
"s": 1,
"sec": 1,
"second": 1,
"m": 60,
"min": 60,
"minute": 60,
"h": 3600,
"hour": 3600,
"d": 86400,
"day": 86400,
}[unit]
return int(num_requests), multiplier * unit_seconds
def get_mobile_identifier(self, request) -> str | None:
mobile = None
try:
mobile = request.data.get("mobile")
except Exception:
mobile = None
if not isinstance(mobile, str):
return None
normalized = "".join(ch for ch in mobile if ch.isdigit())
return normalized or None
def get_cache_key(self, request, view):
ident = self.get_ident(request)
mobile = self.get_mobile_identifier(request)
if mobile:
return self.cache_format % {
"scope": self.scope,
"ident": f"{ident}:{mobile}",
}
return self.cache_format % {
"scope": self.scope,
"ident": ident,
}
def allow_request(self, request, view):
allowed = super().allow_request(request, view)
if not allowed:
request._throttle_scope = self.scope
request._retry_after_seconds = self.wait()
return allowed
class OTPSendBurstThrottle(ScopedMobileThrottle):
scope = "otp_send_burst"
class OTPSendSustainedThrottle(ScopedMobileThrottle):
scope = "otp_send_sustained"
class PasswordLoginThrottle(ScopedMobileThrottle):
scope = "login_password"
class OTPLoginThrottle(ScopedMobileThrottle):
scope = "login_otp"

View File

@@ -30,6 +30,12 @@ from apps.users.api.serializers import (
UserProfileSerializer,
UserSearchSerializer,
)
from apps.users.api.throttles import (
OTPLoginThrottle,
OTPSendBurstThrottle,
OTPSendSustainedThrottle,
PasswordLoginThrottle,
)
from apps.users.services.auth import (
register_user_with_password,
register_user_with_otp,
@@ -91,6 +97,7 @@ class SendOTPView(APIView):
+ password reset
"""
permission_classes = (AllowAny,)
throttle_classes = [OTPSendBurstThrottle, OTPSendSustainedThrottle]
@extend_schema(request=SendOTPSerializer, responses=None)
def post(self, request):
@@ -107,6 +114,7 @@ class SendOTPView(APIView):
class LoginView(APIView):
permission_classes = (AllowAny,)
throttle_classes = [PasswordLoginThrottle]
@extend_schema(request=LoginSerializer, responses=TokenPairSerializer)
def post(self, request):
@@ -123,6 +131,7 @@ class LoginView(APIView):
class LoginOTPView(APIView):
permission_classes = (AllowAny,)
throttle_classes = [OTPLoginThrottle]
@extend_schema(request=LoginOtpSerializer, responses=TokenPairSerializer)
def post(self, request):

View File

@@ -1,5 +1,8 @@
from unittest.mock import patch
from django.conf import settings
from django.core.cache import cache
from django.test import override_settings
from rest_framework.test import APIRequestFactory
from rest_framework import status
from rest_framework.test import APITestCase
@@ -197,3 +200,189 @@ class UserApiViewTests(APITestCase):
success = self.client.get(f"/api/users/search/?mobile={self.other_user.mobile}")
self.assertEqual(success.status_code, status.HTTP_200_OK)
self.assertEqual(success.data["mobile"], self.other_user.mobile)
class UserThrottleTests(APITestCase):
@classmethod
def setUpTestData(cls):
cls.user = User.objects.create_user(
mobile="09124440001",
password="secret123",
)
def setUp(self):
cache.clear()
def tearDown(self):
cache.clear()
@override_settings(
REST_FRAMEWORK={
**settings.REST_FRAMEWORK,
"DEFAULT_THROTTLE_RATES": {
**settings.REST_FRAMEWORK["DEFAULT_THROTTLE_RATES"],
"login_password": "2/min",
},
}
)
@patch("apps.users.api.views.login_with_password")
def test_password_login_returns_structured_429_with_retry_after(self, login_with_password):
login_with_password.return_value = {"access": "a", "refresh": "r"}
first = self.client.post(
"/api/users/login/",
{"mobile": "09124440001", "password": "secret123"},
format="json",
REMOTE_ADDR="10.0.0.1",
)
second = self.client.post(
"/api/users/login/",
{"mobile": "09124440001", "password": "secret123"},
format="json",
REMOTE_ADDR="10.0.0.1",
)
throttled = self.client.post(
"/api/users/login/",
{"mobile": "09124440001", "password": "secret123"},
format="json",
REMOTE_ADDR="10.0.0.1",
)
self.assertEqual(first.status_code, 200)
self.assertEqual(second.status_code, 200)
self.assertEqual(throttled.status_code, 429)
self.assertEqual(throttled.data["code"], "throttled")
self.assertEqual(throttled.data["scope"], "login_password")
self.assertIsInstance(throttled.data["retry_after_seconds"], int)
self.assertTrue(throttled.data["throttled_until"])
self.assertIn("Retry-After", throttled.headers)
@override_settings(
REST_FRAMEWORK={
**settings.REST_FRAMEWORK,
"DEFAULT_THROTTLE_RATES": {
**settings.REST_FRAMEWORK["DEFAULT_THROTTLE_RATES"],
"otp_send_burst": "1/min",
"otp_send_sustained": "10/day",
},
}
)
@patch("apps.users.api.views.generate_and_send_otp")
def test_otp_send_throttle_is_keyed_by_mobile_and_ip(self, generate_and_send_otp):
first_mobile_first = self.client.post(
"/api/users/otp/send/",
{"mobile": "09124440011", "mode": "login"},
format="json",
REMOTE_ADDR="10.0.0.2",
)
second_mobile_first = self.client.post(
"/api/users/otp/send/",
{"mobile": "09124440012", "mode": "login"},
format="json",
REMOTE_ADDR="10.0.0.2",
)
first_mobile_second = self.client.post(
"/api/users/otp/send/",
{"mobile": "09124440011", "mode": "login"},
format="json",
REMOTE_ADDR="10.0.0.2",
)
self.assertEqual(first_mobile_first.status_code, 200)
self.assertEqual(second_mobile_first.status_code, 200)
self.assertEqual(first_mobile_second.status_code, 429)
self.assertEqual(first_mobile_second.data["scope"], "otp_send_burst")
@override_settings(
REST_FRAMEWORK={
**settings.REST_FRAMEWORK,
"DEFAULT_THROTTLE_RATES": {
**settings.REST_FRAMEWORK["DEFAULT_THROTTLE_RATES"],
"login_otp": "1/min",
},
}
)
@patch("apps.users.api.views.login_with_otp")
def test_otp_login_throttle_blocks_after_limit(self, login_with_otp):
login_with_otp.return_value = {"access": "a", "refresh": "r"}
allowed = self.client.post(
"/api/users/otp/login/",
{"mobile": "09124440021", "code": "123456"},
format="json",
REMOTE_ADDR="10.0.0.3",
)
throttled = self.client.post(
"/api/users/otp/login/",
{"mobile": "09124440021", "code": "123456"},
format="json",
REMOTE_ADDR="10.0.0.3",
)
self.assertEqual(allowed.status_code, 200)
self.assertEqual(throttled.status_code, 429)
self.assertEqual(throttled.data["scope"], "login_otp")
@patch.dict("rest_framework.throttling.AnonRateThrottle.THROTTLE_RATES", {"anon": "1/min"}, clear=False)
@patch("apps.users.api.views.register_user_with_otp")
def test_global_anon_throttle_applies_to_unrestricted_anonymous_endpoint(
self,
register_user_with_otp,
):
register_user_with_otp.return_value = {"access": "a", "refresh": "r"}
first = self.client.post(
"/api/users/register/",
{
"mobile": "09124440031",
"code": "12345",
"password": "secret123",
"re_password": "secret123",
},
format="json",
REMOTE_ADDR="10.0.0.4",
)
throttled = self.client.post(
"/api/users/register/",
{
"mobile": "09124440032",
"code": "12345",
"password": "secret123",
"re_password": "secret123",
},
format="json",
REMOTE_ADDR="10.0.0.4",
)
self.assertEqual(first.status_code, 201)
self.assertEqual(throttled.status_code, 429)
self.assertEqual(throttled.data["code"], "throttled")
@override_settings(
REST_FRAMEWORK={
**settings.REST_FRAMEWORK,
"DEFAULT_THROTTLE_RATES": {
**settings.REST_FRAMEWORK["DEFAULT_THROTTLE_RATES"],
"login_password": "1/min",
},
}
)
@patch("apps.users.api.views.login_with_password")
def test_throttle_falls_back_to_ip_when_mobile_is_missing(self, login_with_password):
login_with_password.return_value = {"access": "a", "refresh": "r"}
first = self.client.post(
"/api/users/login/",
{"password": "secret123"},
format="json",
REMOTE_ADDR="10.0.0.5",
)
second = self.client.post(
"/api/users/login/",
{"password": "secret123"},
format="json",
REMOTE_ADDR="10.0.0.5",
)
self.assertEqual(first.status_code, 400)
self.assertEqual(second.status_code, 429)