Files

285 lines
9.5 KiB
Python

import json
import logging
from typing import Iterator
from django.conf import settings
from django.core import signing
from django.http import JsonResponse, StreamingHttpResponse
from django.utils import timezone
from django.views import View
from drf_spectacular.utils import extend_schema
from rest_framework import status, viewsets
from rest_framework.permissions import IsAuthenticated
from rest_framework.response import Response
from rest_framework.views import APIView
from apps.notifications.api.serializers import (
NotificationDeleteResponseSerializer,
NotificationListResponseSerializer,
NotificationMarkAllReadResponseSerializer,
NotificationSeenRequestSerializer,
NotificationSeenResponseSerializer,
NotificationStreamTokenResponseSerializer,
NotificationTypeFilterSerializer,
)
from apps.notifications.services import RedisNotificationStore
logger = logging.getLogger(__name__)
STREAM_TOKEN_SALT = "notifications.stream"
def _safe_int(value, default: int) -> int:
try:
return max(int(value), 0)
except (TypeError, ValueError):
return default
def _format_sse_event(event: str, data: dict) -> str:
payload = json.dumps(data, ensure_ascii=False, default=str)
return f"event: {event}\ndata: {payload}\n\n"
def _issue_stream_token_for_user(user_id: str) -> str:
return signing.dumps(
{
"user_id": str(user_id),
"type": "notification_stream",
},
salt=STREAM_TOKEN_SALT,
)
def _validate_stream_token(token: str | None) -> str:
if not token:
raise signing.BadSignature("Missing stream token")
payload = signing.loads(
token,
salt=STREAM_TOKEN_SALT,
max_age=settings.NOTIFICATION_STREAM_TOKEN_LIFETIME_SECONDS,
)
if payload.get("type") != "notification_stream":
raise signing.BadSignature("Invalid stream token type")
return str(payload["user_id"])
class NotificationListViewSet(viewsets.ViewSet):
permission_classes = [IsAuthenticated]
serializer_class = NotificationListResponseSerializer
def list(self, request):
user_id = str(request.user.id)
limit = min(
_safe_int(
request.query_params.get("limit"),
settings.NOTIFICATION_DEFAULT_PAGE_SIZE,
),
settings.NOTIFICATION_MAX_PAGE_SIZE,
)
offset = _safe_int(request.query_params.get("offset"), 0)
type_filter = request.query_params.get("type")
notifications, total_count = RedisNotificationStore.list(
user_id,
limit=limit,
offset=offset,
type_filter=type_filter,
)
return Response(
{
"count": total_count,
"unread_count": RedisNotificationStore.unread_count(user_id),
"notifications": notifications,
}
)
class NotificationListView(APIView):
permission_classes = [IsAuthenticated]
serializer_class = NotificationListResponseSerializer
@extend_schema(responses=NotificationListResponseSerializer)
def get(self, request):
user_id = str(request.user.id)
limit = min(
_safe_int(
request.query_params.get("limit"),
settings.NOTIFICATION_DEFAULT_PAGE_SIZE,
),
settings.NOTIFICATION_MAX_PAGE_SIZE,
)
offset = _safe_int(request.query_params.get("offset"), 0)
type_filter = request.query_params.get("type")
notifications, total_count = RedisNotificationStore.list(
user_id,
limit=limit,
offset=offset,
type_filter=type_filter,
)
return Response(
{
"count": total_count,
"unread_count": RedisNotificationStore.unread_count(user_id),
"notifications": notifications,
}
)
class NotificationSeenView(APIView):
permission_classes = [IsAuthenticated]
serializer_class = NotificationSeenRequestSerializer
@extend_schema(
request=NotificationSeenRequestSerializer,
responses=NotificationSeenResponseSerializer,
)
def post(self, request):
serializer = self.serializer_class(data=request.data)
serializer.is_valid(raise_exception=True)
notif_id = serializer.validated_data["id"]
payload = RedisNotificationStore.mark_seen(str(request.user.id), notif_id)
if payload is None:
return Response({"marked_read": False}, status=status.HTTP_404_NOT_FOUND)
return Response({"marked_read": True, **payload})
class NotificationDeleteView(APIView):
permission_classes = [IsAuthenticated]
@extend_schema(responses=NotificationDeleteResponseSerializer)
def delete(self, request, notif_id: str):
deleted = RedisNotificationStore.delete(str(request.user.id), notif_id)
if not deleted:
return Response({"deleted": False}, status=status.HTTP_404_NOT_FOUND)
return Response(
{
"deleted": True,
"notification_id": notif_id,
"unread_count": RedisNotificationStore.unread_count(
str(request.user.id)
),
}
)
class NotificationMarkAllReadView(APIView):
permission_classes = [IsAuthenticated]
serializer_class = NotificationTypeFilterSerializer
@extend_schema(
request=NotificationTypeFilterSerializer,
responses=NotificationMarkAllReadResponseSerializer,
)
def post(self, request):
type_filter = request.data.get("type") or request.query_params.get("type")
updated = RedisNotificationStore.mark_all_seen(
str(request.user.id),
delete_on_seen_only=False,
type_filter=type_filter,
)
return Response({"marked_read": updated})
class NotificationStreamTokenView(APIView):
permission_classes = [IsAuthenticated]
@extend_schema(responses=NotificationStreamTokenResponseSerializer)
def post(self, request):
if not settings.NOTIFICATIONS_ENABLED:
return Response(
{"detail": "Notifications are disabled."},
status=status.HTTP_503_SERVICE_UNAVAILABLE,
)
return Response(
{
"token": _issue_stream_token_for_user(str(request.user.id)),
"expires_in": settings.NOTIFICATION_STREAM_TOKEN_LIFETIME_SECONDS,
}
)
class NotificationStreamView(View):
def _build_stream(self, user_id: str) -> Iterator[str]:
pubsub = RedisNotificationStore.get_pubsub()
channel = RedisNotificationStore._channel_key(user_id)
heartbeat_seconds = max(settings.NOTIFICATION_SSE_HEARTBEAT_SECONDS, 1)
initial_notifications, _ = RedisNotificationStore.list(
user_id,
limit=settings.NOTIFICATION_DEFAULT_PAGE_SIZE,
offset=0,
)
unread_count = RedisNotificationStore.unread_count(user_id)
yield f"retry: {settings.NOTIFICATION_SSE_RETRY_MS}\n\n"
yield _format_sse_event(
"connected",
{
"notifications": initial_notifications,
"unread_count": unread_count,
},
)
yield _format_sse_event(
"unread_count",
{
"unread_count": unread_count,
},
)
pubsub.subscribe(channel)
last_ping_at = timezone.now()
try:
while True:
message = pubsub.get_message(timeout=1.0)
if message and message.get("type") == "message":
try:
payload = json.loads(message["data"])
except json.JSONDecodeError:
logger.warning("Invalid notification stream payload for user %s", user_id)
else:
event = payload.get("event") or "notification"
data = payload.get("data") or {}
yield _format_sse_event(event, data)
if (timezone.now() - last_ping_at).total_seconds() >= heartbeat_seconds:
last_ping_at = timezone.now()
yield _format_sse_event(
"ping",
{"timestamp": last_ping_at.isoformat()},
)
except GeneratorExit:
logger.debug("Notification stream closed for user %s", user_id)
finally:
try:
pubsub.unsubscribe(channel)
finally:
pubsub.close()
def get(self, request, *args, **kwargs):
if not settings.NOTIFICATIONS_ENABLED:
return JsonResponse(
{"detail": "Notifications are disabled."},
status=status.HTTP_503_SERVICE_UNAVAILABLE,
)
try:
user_id = _validate_stream_token(request.GET.get("token"))
except signing.SignatureExpired:
return JsonResponse(
{"detail": "Stream token expired."},
status=status.HTTP_401_UNAUTHORIZED,
)
except signing.BadSignature:
return JsonResponse(
{"detail": "Invalid stream token."},
status=status.HTTP_401_UNAUTHORIZED,
)
response = StreamingHttpResponse(
self._build_stream(user_id),
content_type="text/event-stream",
)
response["Cache-Control"] = "no-cache"
response["X-Accel-Buffering"] = "no"
return response