diff --git a/apps/users/api/throttles.py b/apps/users/api/throttles.py new file mode 100644 index 0000000..9d49697 --- /dev/null +++ b/apps/users/api/throttles.py @@ -0,0 +1,99 @@ +from __future__ import annotations + +import re + +from django.core.exceptions import ImproperlyConfigured +from rest_framework.settings import api_settings +from rest_framework.throttling import SimpleRateThrottle + + +class ScopedMobileThrottle(SimpleRateThrottle): + scope = "" + + def get_rate(self): + if not self.scope: + raise ImproperlyConfigured( + f"{self.__class__.__name__} must define a scope or rate." + ) + + try: + return api_settings.DEFAULT_THROTTLE_RATES[self.scope] + except KeyError as exc: + raise ImproperlyConfigured( + f'No default throttle rate set for scope "{self.scope}".' + ) from exc + + def parse_rate(self, rate): + if rate is None: + return (None, None) + + num_requests, period = rate.split("/") + match = re.fullmatch(r"(?:(\d+)\s*)?([smhd]|sec|second|min|minute|hour|day)s?", period.strip(), re.IGNORECASE) + if not match: + return super().parse_rate(rate) + + multiplier = int(match.group(1) or "1") + unit = match.group(2).lower() + unit_seconds = { + "s": 1, + "sec": 1, + "second": 1, + "m": 60, + "min": 60, + "minute": 60, + "h": 3600, + "hour": 3600, + "d": 86400, + "day": 86400, + }[unit] + + return int(num_requests), multiplier * unit_seconds + + def get_mobile_identifier(self, request) -> str | None: + mobile = None + try: + mobile = request.data.get("mobile") + except Exception: + mobile = None + + if not isinstance(mobile, str): + return None + + normalized = "".join(ch for ch in mobile if ch.isdigit()) + return normalized or None + + def get_cache_key(self, request, view): + ident = self.get_ident(request) + mobile = self.get_mobile_identifier(request) + if mobile: + return self.cache_format % { + "scope": self.scope, + "ident": f"{ident}:{mobile}", + } + return self.cache_format % { + "scope": self.scope, + "ident": ident, + } + + def allow_request(self, request, view): + allowed = super().allow_request(request, view) + if not allowed: + request._throttle_scope = self.scope + request._retry_after_seconds = self.wait() + return allowed + + +class OTPSendBurstThrottle(ScopedMobileThrottle): + scope = "otp_send_burst" + + +class OTPSendSustainedThrottle(ScopedMobileThrottle): + scope = "otp_send_sustained" + + +class PasswordLoginThrottle(ScopedMobileThrottle): + scope = "login_password" + + +class OTPLoginThrottle(ScopedMobileThrottle): + scope = "login_otp" diff --git a/apps/users/api/views.py b/apps/users/api/views.py index dffdb36..1d85d1b 100644 --- a/apps/users/api/views.py +++ b/apps/users/api/views.py @@ -30,6 +30,12 @@ from apps.users.api.serializers import ( UserProfileSerializer, UserSearchSerializer, ) +from apps.users.api.throttles import ( + OTPLoginThrottle, + OTPSendBurstThrottle, + OTPSendSustainedThrottle, + PasswordLoginThrottle, +) from apps.users.services.auth import ( register_user_with_password, register_user_with_otp, @@ -91,6 +97,7 @@ class SendOTPView(APIView): + password reset """ permission_classes = (AllowAny,) + throttle_classes = [OTPSendBurstThrottle, OTPSendSustainedThrottle] @extend_schema(request=SendOTPSerializer, responses=None) def post(self, request): @@ -107,6 +114,7 @@ class SendOTPView(APIView): class LoginView(APIView): permission_classes = (AllowAny,) + throttle_classes = [PasswordLoginThrottle] @extend_schema(request=LoginSerializer, responses=TokenPairSerializer) def post(self, request): @@ -123,6 +131,7 @@ class LoginView(APIView): class LoginOTPView(APIView): permission_classes = (AllowAny,) + throttle_classes = [OTPLoginThrottle] @extend_schema(request=LoginOtpSerializer, responses=TokenPairSerializer) def post(self, request): diff --git a/apps/users/tests/test_api_views.py b/apps/users/tests/test_api_views.py index 47ad56a..1cefe95 100644 --- a/apps/users/tests/test_api_views.py +++ b/apps/users/tests/test_api_views.py @@ -1,5 +1,8 @@ from unittest.mock import patch +from django.conf import settings +from django.core.cache import cache +from django.test import override_settings from rest_framework.test import APIRequestFactory from rest_framework import status from rest_framework.test import APITestCase @@ -197,3 +200,189 @@ class UserApiViewTests(APITestCase): success = self.client.get(f"/api/users/search/?mobile={self.other_user.mobile}") self.assertEqual(success.status_code, status.HTTP_200_OK) self.assertEqual(success.data["mobile"], self.other_user.mobile) + + +class UserThrottleTests(APITestCase): + @classmethod + def setUpTestData(cls): + cls.user = User.objects.create_user( + mobile="09124440001", + password="secret123", + ) + + def setUp(self): + cache.clear() + + def tearDown(self): + cache.clear() + + @override_settings( + REST_FRAMEWORK={ + **settings.REST_FRAMEWORK, + "DEFAULT_THROTTLE_RATES": { + **settings.REST_FRAMEWORK["DEFAULT_THROTTLE_RATES"], + "login_password": "2/min", + }, + } + ) + @patch("apps.users.api.views.login_with_password") + def test_password_login_returns_structured_429_with_retry_after(self, login_with_password): + login_with_password.return_value = {"access": "a", "refresh": "r"} + + first = self.client.post( + "/api/users/login/", + {"mobile": "09124440001", "password": "secret123"}, + format="json", + REMOTE_ADDR="10.0.0.1", + ) + second = self.client.post( + "/api/users/login/", + {"mobile": "09124440001", "password": "secret123"}, + format="json", + REMOTE_ADDR="10.0.0.1", + ) + throttled = self.client.post( + "/api/users/login/", + {"mobile": "09124440001", "password": "secret123"}, + format="json", + REMOTE_ADDR="10.0.0.1", + ) + + self.assertEqual(first.status_code, 200) + self.assertEqual(second.status_code, 200) + self.assertEqual(throttled.status_code, 429) + self.assertEqual(throttled.data["code"], "throttled") + self.assertEqual(throttled.data["scope"], "login_password") + self.assertIsInstance(throttled.data["retry_after_seconds"], int) + self.assertTrue(throttled.data["throttled_until"]) + self.assertIn("Retry-After", throttled.headers) + + @override_settings( + REST_FRAMEWORK={ + **settings.REST_FRAMEWORK, + "DEFAULT_THROTTLE_RATES": { + **settings.REST_FRAMEWORK["DEFAULT_THROTTLE_RATES"], + "otp_send_burst": "1/min", + "otp_send_sustained": "10/day", + }, + } + ) + @patch("apps.users.api.views.generate_and_send_otp") + def test_otp_send_throttle_is_keyed_by_mobile_and_ip(self, generate_and_send_otp): + first_mobile_first = self.client.post( + "/api/users/otp/send/", + {"mobile": "09124440011", "mode": "login"}, + format="json", + REMOTE_ADDR="10.0.0.2", + ) + second_mobile_first = self.client.post( + "/api/users/otp/send/", + {"mobile": "09124440012", "mode": "login"}, + format="json", + REMOTE_ADDR="10.0.0.2", + ) + first_mobile_second = self.client.post( + "/api/users/otp/send/", + {"mobile": "09124440011", "mode": "login"}, + format="json", + REMOTE_ADDR="10.0.0.2", + ) + + self.assertEqual(first_mobile_first.status_code, 200) + self.assertEqual(second_mobile_first.status_code, 200) + self.assertEqual(first_mobile_second.status_code, 429) + self.assertEqual(first_mobile_second.data["scope"], "otp_send_burst") + + @override_settings( + REST_FRAMEWORK={ + **settings.REST_FRAMEWORK, + "DEFAULT_THROTTLE_RATES": { + **settings.REST_FRAMEWORK["DEFAULT_THROTTLE_RATES"], + "login_otp": "1/min", + }, + } + ) + @patch("apps.users.api.views.login_with_otp") + def test_otp_login_throttle_blocks_after_limit(self, login_with_otp): + login_with_otp.return_value = {"access": "a", "refresh": "r"} + + allowed = self.client.post( + "/api/users/otp/login/", + {"mobile": "09124440021", "code": "123456"}, + format="json", + REMOTE_ADDR="10.0.0.3", + ) + throttled = self.client.post( + "/api/users/otp/login/", + {"mobile": "09124440021", "code": "123456"}, + format="json", + REMOTE_ADDR="10.0.0.3", + ) + + self.assertEqual(allowed.status_code, 200) + self.assertEqual(throttled.status_code, 429) + self.assertEqual(throttled.data["scope"], "login_otp") + + @patch.dict("rest_framework.throttling.AnonRateThrottle.THROTTLE_RATES", {"anon": "1/min"}, clear=False) + @patch("apps.users.api.views.register_user_with_otp") + def test_global_anon_throttle_applies_to_unrestricted_anonymous_endpoint( + self, + register_user_with_otp, + ): + register_user_with_otp.return_value = {"access": "a", "refresh": "r"} + + first = self.client.post( + "/api/users/register/", + { + "mobile": "09124440031", + "code": "12345", + "password": "secret123", + "re_password": "secret123", + }, + format="json", + REMOTE_ADDR="10.0.0.4", + ) + throttled = self.client.post( + "/api/users/register/", + { + "mobile": "09124440032", + "code": "12345", + "password": "secret123", + "re_password": "secret123", + }, + format="json", + REMOTE_ADDR="10.0.0.4", + ) + + self.assertEqual(first.status_code, 201) + self.assertEqual(throttled.status_code, 429) + self.assertEqual(throttled.data["code"], "throttled") + + @override_settings( + REST_FRAMEWORK={ + **settings.REST_FRAMEWORK, + "DEFAULT_THROTTLE_RATES": { + **settings.REST_FRAMEWORK["DEFAULT_THROTTLE_RATES"], + "login_password": "1/min", + }, + } + ) + @patch("apps.users.api.views.login_with_password") + def test_throttle_falls_back_to_ip_when_mobile_is_missing(self, login_with_password): + login_with_password.return_value = {"access": "a", "refresh": "r"} + + first = self.client.post( + "/api/users/login/", + {"password": "secret123"}, + format="json", + REMOTE_ADDR="10.0.0.5", + ) + second = self.client.post( + "/api/users/login/", + {"password": "secret123"}, + format="json", + REMOTE_ADDR="10.0.0.5", + ) + + self.assertEqual(first.status_code, 400) + self.assertEqual(second.status_code, 429) diff --git a/config/settings/base.py b/config/settings/base.py index 9999cbd..d4c6336 100644 --- a/config/settings/base.py +++ b/config/settings/base.py @@ -131,6 +131,14 @@ REST_FRAMEWORK = { "rest_framework.throttling.AnonRateThrottle", "rest_framework.throttling.UserRateThrottle", ], + "DEFAULT_THROTTLE_RATES": { + "anon": "60/min", + "user": "300/min", + "otp_send_burst": "3/10m", + "otp_send_sustained": "10/day", + "login_password": "5/10m", + "login_otp": "5/10m", + }, "EXCEPTION_HANDLER": "core.exceptions.handlers.exception_handler", } diff --git a/core/exceptions/handlers.py b/core/exceptions/handlers.py index 976266a..9524ad6 100644 --- a/core/exceptions/handlers.py +++ b/core/exceptions/handlers.py @@ -1,11 +1,14 @@ import logging import traceback from collections.abc import Iterable +from datetime import timedelta from typing import Any from django.conf import settings +from django.utils import timezone from rest_framework import status as http_status from rest_framework.exceptions import ErrorDetail +from rest_framework.exceptions import Throttled from rest_framework.response import Response from rest_framework.views import exception_handler as drf_exception_handler @@ -83,7 +86,36 @@ def exception_handler(exc, context) -> Response: if status_code < 500: messages = _to_str_list(detail) payload = _format_payload(messages, status_code) - return Response(payload, status=status_code) + if isinstance(exc, Throttled): + request = context.get("request") + wait = getattr(exc, "wait", None) + retry_after_seconds = None + if wait is not None: + retry_after_seconds = max(int(wait), 0) + elif request is not None: + retry_after_seconds = getattr(request, "_retry_after_seconds", None) + throttle_scope = getattr(request, "_throttle_scope", None) if request else None + payload.update( + { + "code": "throttled", + "scope": throttle_scope, + "retry_after_seconds": retry_after_seconds, + "throttled_until": ( + timezone.now() + timedelta(seconds=retry_after_seconds) + ).isoformat() + if retry_after_seconds is not None + else None, + } + ) + formatted_response = Response(payload, status=status_code) + for header, value in response.headers.items(): + formatted_response[header] = value + if isinstance(exc, Throttled) and "Retry-After" not in formatted_response: + request = context.get("request") + retry_after_seconds = getattr(request, "_retry_after_seconds", None) if request else None + if retry_after_seconds is not None: + formatted_response["Retry-After"] = str(max(int(retry_after_seconds), 0)) + return formatted_response traceback_text = traceback.format_exc() payload = _format_payload(