feat(throttling): add auth throttling and structured cooldown errors
This commit is contained in:
99
apps/users/api/throttles.py
Normal file
99
apps/users/api/throttles.py
Normal file
@@ -0,0 +1,99 @@
|
|||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import re
|
||||||
|
|
||||||
|
from django.core.exceptions import ImproperlyConfigured
|
||||||
|
from rest_framework.settings import api_settings
|
||||||
|
from rest_framework.throttling import SimpleRateThrottle
|
||||||
|
|
||||||
|
|
||||||
|
class ScopedMobileThrottle(SimpleRateThrottle):
|
||||||
|
scope = ""
|
||||||
|
|
||||||
|
def get_rate(self):
|
||||||
|
if not self.scope:
|
||||||
|
raise ImproperlyConfigured(
|
||||||
|
f"{self.__class__.__name__} must define a scope or rate."
|
||||||
|
)
|
||||||
|
|
||||||
|
try:
|
||||||
|
return api_settings.DEFAULT_THROTTLE_RATES[self.scope]
|
||||||
|
except KeyError as exc:
|
||||||
|
raise ImproperlyConfigured(
|
||||||
|
f'No default throttle rate set for scope "{self.scope}".'
|
||||||
|
) from exc
|
||||||
|
|
||||||
|
def parse_rate(self, rate):
|
||||||
|
if rate is None:
|
||||||
|
return (None, None)
|
||||||
|
|
||||||
|
num_requests, period = rate.split("/")
|
||||||
|
match = re.fullmatch(r"(?:(\d+)\s*)?([smhd]|sec|second|min|minute|hour|day)s?", period.strip(), re.IGNORECASE)
|
||||||
|
if not match:
|
||||||
|
return super().parse_rate(rate)
|
||||||
|
|
||||||
|
multiplier = int(match.group(1) or "1")
|
||||||
|
unit = match.group(2).lower()
|
||||||
|
unit_seconds = {
|
||||||
|
"s": 1,
|
||||||
|
"sec": 1,
|
||||||
|
"second": 1,
|
||||||
|
"m": 60,
|
||||||
|
"min": 60,
|
||||||
|
"minute": 60,
|
||||||
|
"h": 3600,
|
||||||
|
"hour": 3600,
|
||||||
|
"d": 86400,
|
||||||
|
"day": 86400,
|
||||||
|
}[unit]
|
||||||
|
|
||||||
|
return int(num_requests), multiplier * unit_seconds
|
||||||
|
|
||||||
|
def get_mobile_identifier(self, request) -> str | None:
|
||||||
|
mobile = None
|
||||||
|
try:
|
||||||
|
mobile = request.data.get("mobile")
|
||||||
|
except Exception:
|
||||||
|
mobile = None
|
||||||
|
|
||||||
|
if not isinstance(mobile, str):
|
||||||
|
return None
|
||||||
|
|
||||||
|
normalized = "".join(ch for ch in mobile if ch.isdigit())
|
||||||
|
return normalized or None
|
||||||
|
|
||||||
|
def get_cache_key(self, request, view):
|
||||||
|
ident = self.get_ident(request)
|
||||||
|
mobile = self.get_mobile_identifier(request)
|
||||||
|
if mobile:
|
||||||
|
return self.cache_format % {
|
||||||
|
"scope": self.scope,
|
||||||
|
"ident": f"{ident}:{mobile}",
|
||||||
|
}
|
||||||
|
return self.cache_format % {
|
||||||
|
"scope": self.scope,
|
||||||
|
"ident": ident,
|
||||||
|
}
|
||||||
|
|
||||||
|
def allow_request(self, request, view):
|
||||||
|
allowed = super().allow_request(request, view)
|
||||||
|
if not allowed:
|
||||||
|
request._throttle_scope = self.scope
|
||||||
|
request._retry_after_seconds = self.wait()
|
||||||
|
return allowed
|
||||||
|
|
||||||
|
|
||||||
|
class OTPSendBurstThrottle(ScopedMobileThrottle):
|
||||||
|
scope = "otp_send_burst"
|
||||||
|
|
||||||
|
|
||||||
|
class OTPSendSustainedThrottle(ScopedMobileThrottle):
|
||||||
|
scope = "otp_send_sustained"
|
||||||
|
|
||||||
|
|
||||||
|
class PasswordLoginThrottle(ScopedMobileThrottle):
|
||||||
|
scope = "login_password"
|
||||||
|
|
||||||
|
|
||||||
|
class OTPLoginThrottle(ScopedMobileThrottle):
|
||||||
|
scope = "login_otp"
|
||||||
@@ -30,6 +30,12 @@ from apps.users.api.serializers import (
|
|||||||
UserProfileSerializer,
|
UserProfileSerializer,
|
||||||
UserSearchSerializer,
|
UserSearchSerializer,
|
||||||
)
|
)
|
||||||
|
from apps.users.api.throttles import (
|
||||||
|
OTPLoginThrottle,
|
||||||
|
OTPSendBurstThrottle,
|
||||||
|
OTPSendSustainedThrottle,
|
||||||
|
PasswordLoginThrottle,
|
||||||
|
)
|
||||||
from apps.users.services.auth import (
|
from apps.users.services.auth import (
|
||||||
register_user_with_password,
|
register_user_with_password,
|
||||||
register_user_with_otp,
|
register_user_with_otp,
|
||||||
@@ -91,6 +97,7 @@ class SendOTPView(APIView):
|
|||||||
+ password reset
|
+ password reset
|
||||||
"""
|
"""
|
||||||
permission_classes = (AllowAny,)
|
permission_classes = (AllowAny,)
|
||||||
|
throttle_classes = [OTPSendBurstThrottle, OTPSendSustainedThrottle]
|
||||||
|
|
||||||
@extend_schema(request=SendOTPSerializer, responses=None)
|
@extend_schema(request=SendOTPSerializer, responses=None)
|
||||||
def post(self, request):
|
def post(self, request):
|
||||||
@@ -107,6 +114,7 @@ class SendOTPView(APIView):
|
|||||||
|
|
||||||
class LoginView(APIView):
|
class LoginView(APIView):
|
||||||
permission_classes = (AllowAny,)
|
permission_classes = (AllowAny,)
|
||||||
|
throttle_classes = [PasswordLoginThrottle]
|
||||||
|
|
||||||
@extend_schema(request=LoginSerializer, responses=TokenPairSerializer)
|
@extend_schema(request=LoginSerializer, responses=TokenPairSerializer)
|
||||||
def post(self, request):
|
def post(self, request):
|
||||||
@@ -123,6 +131,7 @@ class LoginView(APIView):
|
|||||||
|
|
||||||
class LoginOTPView(APIView):
|
class LoginOTPView(APIView):
|
||||||
permission_classes = (AllowAny,)
|
permission_classes = (AllowAny,)
|
||||||
|
throttle_classes = [OTPLoginThrottle]
|
||||||
|
|
||||||
@extend_schema(request=LoginOtpSerializer, responses=TokenPairSerializer)
|
@extend_schema(request=LoginOtpSerializer, responses=TokenPairSerializer)
|
||||||
def post(self, request):
|
def post(self, request):
|
||||||
|
|||||||
@@ -1,5 +1,8 @@
|
|||||||
from unittest.mock import patch
|
from unittest.mock import patch
|
||||||
|
|
||||||
|
from django.conf import settings
|
||||||
|
from django.core.cache import cache
|
||||||
|
from django.test import override_settings
|
||||||
from rest_framework.test import APIRequestFactory
|
from rest_framework.test import APIRequestFactory
|
||||||
from rest_framework import status
|
from rest_framework import status
|
||||||
from rest_framework.test import APITestCase
|
from rest_framework.test import APITestCase
|
||||||
@@ -197,3 +200,189 @@ class UserApiViewTests(APITestCase):
|
|||||||
success = self.client.get(f"/api/users/search/?mobile={self.other_user.mobile}")
|
success = self.client.get(f"/api/users/search/?mobile={self.other_user.mobile}")
|
||||||
self.assertEqual(success.status_code, status.HTTP_200_OK)
|
self.assertEqual(success.status_code, status.HTTP_200_OK)
|
||||||
self.assertEqual(success.data["mobile"], self.other_user.mobile)
|
self.assertEqual(success.data["mobile"], self.other_user.mobile)
|
||||||
|
|
||||||
|
|
||||||
|
class UserThrottleTests(APITestCase):
|
||||||
|
@classmethod
|
||||||
|
def setUpTestData(cls):
|
||||||
|
cls.user = User.objects.create_user(
|
||||||
|
mobile="09124440001",
|
||||||
|
password="secret123",
|
||||||
|
)
|
||||||
|
|
||||||
|
def setUp(self):
|
||||||
|
cache.clear()
|
||||||
|
|
||||||
|
def tearDown(self):
|
||||||
|
cache.clear()
|
||||||
|
|
||||||
|
@override_settings(
|
||||||
|
REST_FRAMEWORK={
|
||||||
|
**settings.REST_FRAMEWORK,
|
||||||
|
"DEFAULT_THROTTLE_RATES": {
|
||||||
|
**settings.REST_FRAMEWORK["DEFAULT_THROTTLE_RATES"],
|
||||||
|
"login_password": "2/min",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
)
|
||||||
|
@patch("apps.users.api.views.login_with_password")
|
||||||
|
def test_password_login_returns_structured_429_with_retry_after(self, login_with_password):
|
||||||
|
login_with_password.return_value = {"access": "a", "refresh": "r"}
|
||||||
|
|
||||||
|
first = self.client.post(
|
||||||
|
"/api/users/login/",
|
||||||
|
{"mobile": "09124440001", "password": "secret123"},
|
||||||
|
format="json",
|
||||||
|
REMOTE_ADDR="10.0.0.1",
|
||||||
|
)
|
||||||
|
second = self.client.post(
|
||||||
|
"/api/users/login/",
|
||||||
|
{"mobile": "09124440001", "password": "secret123"},
|
||||||
|
format="json",
|
||||||
|
REMOTE_ADDR="10.0.0.1",
|
||||||
|
)
|
||||||
|
throttled = self.client.post(
|
||||||
|
"/api/users/login/",
|
||||||
|
{"mobile": "09124440001", "password": "secret123"},
|
||||||
|
format="json",
|
||||||
|
REMOTE_ADDR="10.0.0.1",
|
||||||
|
)
|
||||||
|
|
||||||
|
self.assertEqual(first.status_code, 200)
|
||||||
|
self.assertEqual(second.status_code, 200)
|
||||||
|
self.assertEqual(throttled.status_code, 429)
|
||||||
|
self.assertEqual(throttled.data["code"], "throttled")
|
||||||
|
self.assertEqual(throttled.data["scope"], "login_password")
|
||||||
|
self.assertIsInstance(throttled.data["retry_after_seconds"], int)
|
||||||
|
self.assertTrue(throttled.data["throttled_until"])
|
||||||
|
self.assertIn("Retry-After", throttled.headers)
|
||||||
|
|
||||||
|
@override_settings(
|
||||||
|
REST_FRAMEWORK={
|
||||||
|
**settings.REST_FRAMEWORK,
|
||||||
|
"DEFAULT_THROTTLE_RATES": {
|
||||||
|
**settings.REST_FRAMEWORK["DEFAULT_THROTTLE_RATES"],
|
||||||
|
"otp_send_burst": "1/min",
|
||||||
|
"otp_send_sustained": "10/day",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
)
|
||||||
|
@patch("apps.users.api.views.generate_and_send_otp")
|
||||||
|
def test_otp_send_throttle_is_keyed_by_mobile_and_ip(self, generate_and_send_otp):
|
||||||
|
first_mobile_first = self.client.post(
|
||||||
|
"/api/users/otp/send/",
|
||||||
|
{"mobile": "09124440011", "mode": "login"},
|
||||||
|
format="json",
|
||||||
|
REMOTE_ADDR="10.0.0.2",
|
||||||
|
)
|
||||||
|
second_mobile_first = self.client.post(
|
||||||
|
"/api/users/otp/send/",
|
||||||
|
{"mobile": "09124440012", "mode": "login"},
|
||||||
|
format="json",
|
||||||
|
REMOTE_ADDR="10.0.0.2",
|
||||||
|
)
|
||||||
|
first_mobile_second = self.client.post(
|
||||||
|
"/api/users/otp/send/",
|
||||||
|
{"mobile": "09124440011", "mode": "login"},
|
||||||
|
format="json",
|
||||||
|
REMOTE_ADDR="10.0.0.2",
|
||||||
|
)
|
||||||
|
|
||||||
|
self.assertEqual(first_mobile_first.status_code, 200)
|
||||||
|
self.assertEqual(second_mobile_first.status_code, 200)
|
||||||
|
self.assertEqual(first_mobile_second.status_code, 429)
|
||||||
|
self.assertEqual(first_mobile_second.data["scope"], "otp_send_burst")
|
||||||
|
|
||||||
|
@override_settings(
|
||||||
|
REST_FRAMEWORK={
|
||||||
|
**settings.REST_FRAMEWORK,
|
||||||
|
"DEFAULT_THROTTLE_RATES": {
|
||||||
|
**settings.REST_FRAMEWORK["DEFAULT_THROTTLE_RATES"],
|
||||||
|
"login_otp": "1/min",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
)
|
||||||
|
@patch("apps.users.api.views.login_with_otp")
|
||||||
|
def test_otp_login_throttle_blocks_after_limit(self, login_with_otp):
|
||||||
|
login_with_otp.return_value = {"access": "a", "refresh": "r"}
|
||||||
|
|
||||||
|
allowed = self.client.post(
|
||||||
|
"/api/users/otp/login/",
|
||||||
|
{"mobile": "09124440021", "code": "123456"},
|
||||||
|
format="json",
|
||||||
|
REMOTE_ADDR="10.0.0.3",
|
||||||
|
)
|
||||||
|
throttled = self.client.post(
|
||||||
|
"/api/users/otp/login/",
|
||||||
|
{"mobile": "09124440021", "code": "123456"},
|
||||||
|
format="json",
|
||||||
|
REMOTE_ADDR="10.0.0.3",
|
||||||
|
)
|
||||||
|
|
||||||
|
self.assertEqual(allowed.status_code, 200)
|
||||||
|
self.assertEqual(throttled.status_code, 429)
|
||||||
|
self.assertEqual(throttled.data["scope"], "login_otp")
|
||||||
|
|
||||||
|
@patch.dict("rest_framework.throttling.AnonRateThrottle.THROTTLE_RATES", {"anon": "1/min"}, clear=False)
|
||||||
|
@patch("apps.users.api.views.register_user_with_otp")
|
||||||
|
def test_global_anon_throttle_applies_to_unrestricted_anonymous_endpoint(
|
||||||
|
self,
|
||||||
|
register_user_with_otp,
|
||||||
|
):
|
||||||
|
register_user_with_otp.return_value = {"access": "a", "refresh": "r"}
|
||||||
|
|
||||||
|
first = self.client.post(
|
||||||
|
"/api/users/register/",
|
||||||
|
{
|
||||||
|
"mobile": "09124440031",
|
||||||
|
"code": "12345",
|
||||||
|
"password": "secret123",
|
||||||
|
"re_password": "secret123",
|
||||||
|
},
|
||||||
|
format="json",
|
||||||
|
REMOTE_ADDR="10.0.0.4",
|
||||||
|
)
|
||||||
|
throttled = self.client.post(
|
||||||
|
"/api/users/register/",
|
||||||
|
{
|
||||||
|
"mobile": "09124440032",
|
||||||
|
"code": "12345",
|
||||||
|
"password": "secret123",
|
||||||
|
"re_password": "secret123",
|
||||||
|
},
|
||||||
|
format="json",
|
||||||
|
REMOTE_ADDR="10.0.0.4",
|
||||||
|
)
|
||||||
|
|
||||||
|
self.assertEqual(first.status_code, 201)
|
||||||
|
self.assertEqual(throttled.status_code, 429)
|
||||||
|
self.assertEqual(throttled.data["code"], "throttled")
|
||||||
|
|
||||||
|
@override_settings(
|
||||||
|
REST_FRAMEWORK={
|
||||||
|
**settings.REST_FRAMEWORK,
|
||||||
|
"DEFAULT_THROTTLE_RATES": {
|
||||||
|
**settings.REST_FRAMEWORK["DEFAULT_THROTTLE_RATES"],
|
||||||
|
"login_password": "1/min",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
)
|
||||||
|
@patch("apps.users.api.views.login_with_password")
|
||||||
|
def test_throttle_falls_back_to_ip_when_mobile_is_missing(self, login_with_password):
|
||||||
|
login_with_password.return_value = {"access": "a", "refresh": "r"}
|
||||||
|
|
||||||
|
first = self.client.post(
|
||||||
|
"/api/users/login/",
|
||||||
|
{"password": "secret123"},
|
||||||
|
format="json",
|
||||||
|
REMOTE_ADDR="10.0.0.5",
|
||||||
|
)
|
||||||
|
second = self.client.post(
|
||||||
|
"/api/users/login/",
|
||||||
|
{"password": "secret123"},
|
||||||
|
format="json",
|
||||||
|
REMOTE_ADDR="10.0.0.5",
|
||||||
|
)
|
||||||
|
|
||||||
|
self.assertEqual(first.status_code, 400)
|
||||||
|
self.assertEqual(second.status_code, 429)
|
||||||
|
|||||||
@@ -131,6 +131,14 @@ REST_FRAMEWORK = {
|
|||||||
"rest_framework.throttling.AnonRateThrottle",
|
"rest_framework.throttling.AnonRateThrottle",
|
||||||
"rest_framework.throttling.UserRateThrottle",
|
"rest_framework.throttling.UserRateThrottle",
|
||||||
],
|
],
|
||||||
|
"DEFAULT_THROTTLE_RATES": {
|
||||||
|
"anon": "60/min",
|
||||||
|
"user": "300/min",
|
||||||
|
"otp_send_burst": "3/10m",
|
||||||
|
"otp_send_sustained": "10/day",
|
||||||
|
"login_password": "5/10m",
|
||||||
|
"login_otp": "5/10m",
|
||||||
|
},
|
||||||
"EXCEPTION_HANDLER": "core.exceptions.handlers.exception_handler",
|
"EXCEPTION_HANDLER": "core.exceptions.handlers.exception_handler",
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -1,11 +1,14 @@
|
|||||||
import logging
|
import logging
|
||||||
import traceback
|
import traceback
|
||||||
from collections.abc import Iterable
|
from collections.abc import Iterable
|
||||||
|
from datetime import timedelta
|
||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
from django.conf import settings
|
from django.conf import settings
|
||||||
|
from django.utils import timezone
|
||||||
from rest_framework import status as http_status
|
from rest_framework import status as http_status
|
||||||
from rest_framework.exceptions import ErrorDetail
|
from rest_framework.exceptions import ErrorDetail
|
||||||
|
from rest_framework.exceptions import Throttled
|
||||||
from rest_framework.response import Response
|
from rest_framework.response import Response
|
||||||
from rest_framework.views import exception_handler as drf_exception_handler
|
from rest_framework.views import exception_handler as drf_exception_handler
|
||||||
|
|
||||||
@@ -83,7 +86,36 @@ def exception_handler(exc, context) -> Response:
|
|||||||
if status_code < 500:
|
if status_code < 500:
|
||||||
messages = _to_str_list(detail)
|
messages = _to_str_list(detail)
|
||||||
payload = _format_payload(messages, status_code)
|
payload = _format_payload(messages, status_code)
|
||||||
return Response(payload, status=status_code)
|
if isinstance(exc, Throttled):
|
||||||
|
request = context.get("request")
|
||||||
|
wait = getattr(exc, "wait", None)
|
||||||
|
retry_after_seconds = None
|
||||||
|
if wait is not None:
|
||||||
|
retry_after_seconds = max(int(wait), 0)
|
||||||
|
elif request is not None:
|
||||||
|
retry_after_seconds = getattr(request, "_retry_after_seconds", None)
|
||||||
|
throttle_scope = getattr(request, "_throttle_scope", None) if request else None
|
||||||
|
payload.update(
|
||||||
|
{
|
||||||
|
"code": "throttled",
|
||||||
|
"scope": throttle_scope,
|
||||||
|
"retry_after_seconds": retry_after_seconds,
|
||||||
|
"throttled_until": (
|
||||||
|
timezone.now() + timedelta(seconds=retry_after_seconds)
|
||||||
|
).isoformat()
|
||||||
|
if retry_after_seconds is not None
|
||||||
|
else None,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
formatted_response = Response(payload, status=status_code)
|
||||||
|
for header, value in response.headers.items():
|
||||||
|
formatted_response[header] = value
|
||||||
|
if isinstance(exc, Throttled) and "Retry-After" not in formatted_response:
|
||||||
|
request = context.get("request")
|
||||||
|
retry_after_seconds = getattr(request, "_retry_after_seconds", None) if request else None
|
||||||
|
if retry_after_seconds is not None:
|
||||||
|
formatted_response["Retry-After"] = str(max(int(retry_after_seconds), 0))
|
||||||
|
return formatted_response
|
||||||
|
|
||||||
traceback_text = traceback.format_exc()
|
traceback_text = traceback.format_exc()
|
||||||
payload = _format_payload(
|
payload = _format_payload(
|
||||||
|
|||||||
Reference in New Issue
Block a user