feat(throttling): add auth throttling and structured cooldown errors
This commit is contained in:
99
apps/users/api/throttles.py
Normal file
99
apps/users/api/throttles.py
Normal file
@@ -0,0 +1,99 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import re
|
||||
|
||||
from django.core.exceptions import ImproperlyConfigured
|
||||
from rest_framework.settings import api_settings
|
||||
from rest_framework.throttling import SimpleRateThrottle
|
||||
|
||||
|
||||
class ScopedMobileThrottle(SimpleRateThrottle):
|
||||
scope = ""
|
||||
|
||||
def get_rate(self):
|
||||
if not self.scope:
|
||||
raise ImproperlyConfigured(
|
||||
f"{self.__class__.__name__} must define a scope or rate."
|
||||
)
|
||||
|
||||
try:
|
||||
return api_settings.DEFAULT_THROTTLE_RATES[self.scope]
|
||||
except KeyError as exc:
|
||||
raise ImproperlyConfigured(
|
||||
f'No default throttle rate set for scope "{self.scope}".'
|
||||
) from exc
|
||||
|
||||
def parse_rate(self, rate):
|
||||
if rate is None:
|
||||
return (None, None)
|
||||
|
||||
num_requests, period = rate.split("/")
|
||||
match = re.fullmatch(r"(?:(\d+)\s*)?([smhd]|sec|second|min|minute|hour|day)s?", period.strip(), re.IGNORECASE)
|
||||
if not match:
|
||||
return super().parse_rate(rate)
|
||||
|
||||
multiplier = int(match.group(1) or "1")
|
||||
unit = match.group(2).lower()
|
||||
unit_seconds = {
|
||||
"s": 1,
|
||||
"sec": 1,
|
||||
"second": 1,
|
||||
"m": 60,
|
||||
"min": 60,
|
||||
"minute": 60,
|
||||
"h": 3600,
|
||||
"hour": 3600,
|
||||
"d": 86400,
|
||||
"day": 86400,
|
||||
}[unit]
|
||||
|
||||
return int(num_requests), multiplier * unit_seconds
|
||||
|
||||
def get_mobile_identifier(self, request) -> str | None:
|
||||
mobile = None
|
||||
try:
|
||||
mobile = request.data.get("mobile")
|
||||
except Exception:
|
||||
mobile = None
|
||||
|
||||
if not isinstance(mobile, str):
|
||||
return None
|
||||
|
||||
normalized = "".join(ch for ch in mobile if ch.isdigit())
|
||||
return normalized or None
|
||||
|
||||
def get_cache_key(self, request, view):
|
||||
ident = self.get_ident(request)
|
||||
mobile = self.get_mobile_identifier(request)
|
||||
if mobile:
|
||||
return self.cache_format % {
|
||||
"scope": self.scope,
|
||||
"ident": f"{ident}:{mobile}",
|
||||
}
|
||||
return self.cache_format % {
|
||||
"scope": self.scope,
|
||||
"ident": ident,
|
||||
}
|
||||
|
||||
def allow_request(self, request, view):
|
||||
allowed = super().allow_request(request, view)
|
||||
if not allowed:
|
||||
request._throttle_scope = self.scope
|
||||
request._retry_after_seconds = self.wait()
|
||||
return allowed
|
||||
|
||||
|
||||
class OTPSendBurstThrottle(ScopedMobileThrottle):
|
||||
scope = "otp_send_burst"
|
||||
|
||||
|
||||
class OTPSendSustainedThrottle(ScopedMobileThrottle):
|
||||
scope = "otp_send_sustained"
|
||||
|
||||
|
||||
class PasswordLoginThrottle(ScopedMobileThrottle):
|
||||
scope = "login_password"
|
||||
|
||||
|
||||
class OTPLoginThrottle(ScopedMobileThrottle):
|
||||
scope = "login_otp"
|
||||
Reference in New Issue
Block a user