feat(notifications): add redis-backed sse notification streaming

This commit is contained in:
2026-04-25 11:27:46 +03:30
parent e7de587f59
commit 0ca3255270
14 changed files with 1146 additions and 1 deletions

View File

View File

View File

@@ -0,0 +1,7 @@
from django.apps import AppConfig
class NotificationsConfig(AppConfig):
default_auto_field = "django.db.models.BigAutoField"
name = "apps.notifications"
verbose_name = "09-notifications"

View File

View File

@@ -0,0 +1,61 @@
from rest_framework import serializers
class NotificationSerializer(serializers.Serializer):
id = serializers.CharField()
type = serializers.CharField()
title = serializers.CharField(allow_blank=True)
message = serializers.CharField(allow_blank=True)
level = serializers.ChoiceField(
choices=("info", "success", "warning", "error")
)
created_at = serializers.CharField()
is_seen = serializers.BooleanField()
delete_on_seen = serializers.BooleanField()
action_url = serializers.CharField(
required=False, allow_blank=True, allow_null=True
)
entity_type = serializers.CharField(
required=False, allow_blank=True, allow_null=True
)
entity_id = serializers.CharField(
required=False, allow_blank=True, allow_null=True
)
meta = serializers.JSONField(required=False)
class NotificationListResponseSerializer(serializers.Serializer):
count = serializers.IntegerField()
unread_count = serializers.IntegerField()
notifications = NotificationSerializer(many=True)
class NotificationSeenRequestSerializer(serializers.Serializer):
id = serializers.CharField()
class NotificationDeleteResponseSerializer(serializers.Serializer):
deleted = serializers.BooleanField()
notification_id = serializers.CharField(required=False)
unread_count = serializers.IntegerField(required=False)
class NotificationSeenResponseSerializer(serializers.Serializer):
marked_read = serializers.BooleanField()
notification_id = serializers.CharField(required=False)
deleted = serializers.BooleanField(required=False)
unread_count = serializers.IntegerField(required=False)
notification = NotificationSerializer(required=False, allow_null=True)
class NotificationMarkAllReadResponseSerializer(serializers.Serializer):
marked_read = serializers.IntegerField()
class NotificationTypeFilterSerializer(serializers.Serializer):
type = serializers.CharField(required=False, allow_blank=True, allow_null=True)
class NotificationStreamTokenResponseSerializer(serializers.Serializer):
token = serializers.CharField()
expires_in = serializers.IntegerField()

View File

@@ -0,0 +1,355 @@
import json
import uuid
from datetime import timedelta
import redis
from django.conf import settings
from django.utils import timezone
from django.utils.dateparse import parse_datetime
redis_client = redis.StrictRedis.from_url(settings.REDIS_URL, decode_responses=True)
def _isoformat_datetime(value) -> str:
if not value:
return timezone.now().isoformat()
if isinstance(value, str):
parsed = parse_datetime(value)
if parsed is not None:
value = parsed
else:
return value
if timezone.is_naive(value):
value = timezone.make_aware(value, timezone.get_current_timezone())
return timezone.localtime(value).isoformat()
class RedisNotificationStore:
USERS_KEY = "notif:users"
@classmethod
def _ids_key(cls, user_id: str) -> str:
return f"notif:{user_id}:ids"
@classmethod
def _data_key(cls, user_id: str) -> str:
return f"notif:{user_id}:data"
@classmethod
def _channel_key(cls, user_id: str) -> str:
prefix = settings.NOTIFICATION_REDIS_CHANNEL_PREFIX.rstrip(":")
return f"{prefix}:{user_id}"
@staticmethod
def _normalize_notification(data: dict | None) -> dict:
payload = dict(data or {})
return {
"id": str(payload.get("id") or uuid.uuid4()),
"type": payload.get("type") or "notification",
"title": payload.get("title") or "",
"message": payload.get("message") or "",
"level": payload.get("level") or "info",
"created_at": _isoformat_datetime(payload.get("created_at")),
"is_seen": bool(payload.get("is_seen", False)),
"delete_on_seen": bool(payload.get("delete_on_seen", False)),
"action_url": payload.get("action_url"),
"entity_type": payload.get("entity_type"),
"entity_id": payload.get("entity_id"),
"meta": payload.get("meta") or {},
}
@classmethod
def _publish_event(cls, user_id: str, event: str, data: dict) -> None:
if not settings.NOTIFICATIONS_ENABLED:
return
payload = {
"event": event,
"data": data,
}
redis_client.publish(
cls._channel_key(user_id),
json.dumps(payload, ensure_ascii=False, default=str),
)
@classmethod
def unread_count(cls, user_id: str, *, type_filter: str | None = None) -> int:
notifications, _ = cls.list(
user_id,
limit=settings.NOTIFICATION_MAX_PAGE_SIZE,
offset=0,
type_filter=type_filter,
paginate=False,
)
return sum(1 for notification in notifications if not notification.get("is_seen"))
@classmethod
def add(cls, user_id: str, payload: dict) -> dict:
data = cls._normalize_notification(payload)
created_at = parse_datetime(data["created_at"]) or timezone.now()
created_at_ts = created_at.timestamp()
json_str = json.dumps(data, ensure_ascii=False, default=str)
ids_key = cls._ids_key(user_id)
data_key = cls._data_key(user_id)
pipe = redis_client.pipeline()
pipe.zadd(ids_key, {data["id"]: created_at_ts})
pipe.hset(data_key, data["id"], json_str)
pipe.sadd(cls.USERS_KEY, user_id)
pipe.execute()
unread_count = cls.unread_count(user_id)
cls._publish_event(
user_id,
"notification",
{
"notification": data,
"unread_count": unread_count,
},
)
cls._publish_event(
user_id,
"unread_count",
{
"unread_count": unread_count,
},
)
return data
@classmethod
def list(
cls,
user_id: str,
*,
limit: int | None = None,
offset: int = 0,
type_filter: str | None = None,
paginate: bool = True,
) -> tuple[list[dict], int]:
ids_key = cls._ids_key(user_id)
data_key = cls._data_key(user_id)
ids = redis_client.zrevrange(ids_key, 0, -1)
if not ids:
return [], 0
pipe = redis_client.pipeline()
for notif_id in ids:
pipe.hget(data_key, notif_id)
raw_items = pipe.execute()
items: list[dict] = []
cleanup_ids: list[str] = []
for notif_id, raw in zip(ids, raw_items, strict=False):
if not raw:
cleanup_ids.append(notif_id)
continue
try:
data = cls._normalize_notification(json.loads(raw))
except json.JSONDecodeError:
cleanup_ids.append(notif_id)
continue
if type_filter and data.get("type") != type_filter:
continue
items.append(data)
if cleanup_ids:
redis_client.zrem(ids_key, *cleanup_ids)
redis_client.hdel(data_key, *cleanup_ids)
total_count = len(items)
if not paginate:
return items, total_count
safe_offset = max(offset, 0)
safe_limit = max(limit or settings.NOTIFICATION_DEFAULT_PAGE_SIZE, 1)
return items[safe_offset : safe_offset + safe_limit], total_count
@classmethod
def get(cls, user_id: str, notif_id: str) -> dict | None:
data_key = cls._data_key(user_id)
raw = redis_client.hget(data_key, notif_id)
if not raw:
return None
try:
return cls._normalize_notification(json.loads(raw))
except json.JSONDecodeError:
return None
@classmethod
def delete(cls, user_id: str, notif_id: str) -> bool:
ids_key = cls._ids_key(user_id)
data_key = cls._data_key(user_id)
pipe = redis_client.pipeline()
pipe.zrem(ids_key, notif_id)
pipe.hdel(data_key, notif_id)
result = pipe.execute()
if any(result):
unread_count = cls.unread_count(user_id)
cls._publish_event(
user_id,
"notification_seen",
{
"notification_id": notif_id,
"deleted": True,
"unread_count": unread_count,
},
)
cls._publish_event(
user_id,
"unread_count",
{
"unread_count": unread_count,
},
)
return True
return False
@classmethod
def mark_seen(cls, user_id: str, notif_id: str) -> dict | None:
data = cls.get(user_id, notif_id)
if not data:
return None
if data.get("delete_on_seen"):
deleted = cls.delete(user_id, notif_id)
if deleted:
return {
"notification_id": notif_id,
"deleted": True,
"notification": None,
"unread_count": cls.unread_count(user_id),
}
return None
if not data.get("is_seen"):
data["is_seen"] = True
data_key = cls._data_key(user_id)
redis_client.hset(
data_key, notif_id, json.dumps(data, ensure_ascii=False, default=str)
)
unread_count = cls.unread_count(user_id)
payload = {
"notification_id": notif_id,
"deleted": False,
"notification": data,
"unread_count": unread_count,
}
cls._publish_event(user_id, "notification_seen", payload)
cls._publish_event(
user_id,
"unread_count",
{
"unread_count": unread_count,
},
)
return payload
@classmethod
def mark_all_seen(
cls,
user_id: str,
*,
delete_on_seen_only: bool = False,
type_filter: str | None = None,
) -> int:
ids_key = cls._ids_key(user_id)
data_key = cls._data_key(user_id)
ids = redis_client.zrevrange(ids_key, 0, -1)
if not ids:
return 0
updated = 0
pipe = redis_client.pipeline()
for notif_id in ids:
raw = redis_client.hget(data_key, notif_id)
if not raw:
continue
try:
data = cls._normalize_notification(json.loads(raw))
except json.JSONDecodeError:
continue
if type_filter and data.get("type") != type_filter:
continue
if delete_on_seen_only and not data.get("delete_on_seen"):
continue
if data.get("delete_on_seen"):
pipe.zrem(ids_key, notif_id)
pipe.hdel(data_key, notif_id)
else:
data["is_seen"] = True
pipe.hset(
data_key,
notif_id,
json.dumps(data, ensure_ascii=False, default=str),
)
updated += 1
if updated:
pipe.execute()
unread_count = cls.unread_count(user_id, type_filter=type_filter)
cls._publish_event(
user_id,
"notification_mark_all_read",
{
"type": type_filter,
"unread_count": unread_count,
},
)
cls._publish_event(
user_id,
"unread_count",
{
"unread_count": cls.unread_count(user_id),
},
)
return updated
@classmethod
def get_pubsub(cls):
return redis_client.pubsub(ignore_subscribe_messages=True)
@classmethod
def cleanup_expired(cls, retention_days: int = 30) -> int:
cutoff_ts = (timezone.now() - timedelta(days=retention_days)).timestamp()
removed = 0
user_ids = redis_client.smembers(cls.USERS_KEY)
for user_id in user_ids:
ids_key = cls._ids_key(user_id)
data_key = cls._data_key(user_id)
old_ids = redis_client.zrangebyscore(ids_key, "-inf", cutoff_ts)
if not old_ids:
continue
pipe = redis_client.pipeline()
user_removed = 0
for notif_id in old_ids:
raw = redis_client.hget(data_key, notif_id)
if not raw:
pipe.zrem(ids_key, notif_id)
pipe.hdel(data_key, notif_id)
removed += 1
user_removed += 1
continue
try:
data = cls._normalize_notification(json.loads(raw))
except json.JSONDecodeError:
pipe.zrem(ids_key, notif_id)
pipe.hdel(data_key, notif_id)
removed += 1
user_removed += 1
continue
if data.get("delete_on_seen"):
continue
pipe.zrem(ids_key, notif_id)
pipe.hdel(data_key, notif_id)
removed += 1
user_removed += 1
if user_removed:
pipe.execute()
if redis_client.zcard(ids_key) == 0:
redis_client.srem(cls.USERS_KEY, user_id)
return removed

View File

@@ -0,0 +1,11 @@
from celery import shared_task
from django.conf import settings
from apps.notifications.services import RedisNotificationStore
@shared_task(name="notifications.cleanup_redis_notifications")
def cleanup_redis_notifications():
return RedisNotificationStore.cleanup_expired(
retention_days=settings.NOTIFICATION_RETENTION_DAYS
)

View File

@@ -0,0 +1,200 @@
import json
from collections import defaultdict
import pytest
from apps.notifications import services
from apps.notifications.services import RedisNotificationStore
class FakePipeline:
def __init__(self, client):
self.client = client
self.operations = []
def __getattr__(self, name):
def wrapper(*args, **kwargs):
self.operations.append((name, args, kwargs))
return self
return wrapper
def execute(self):
results = []
for name, args, kwargs in self.operations:
results.append(getattr(self.client, name)(*args, **kwargs))
self.operations.clear()
return results
class FakePubSub:
def __init__(self):
self.channels = []
self.messages = []
self.closed = False
def subscribe(self, channel):
self.channels.append(channel)
def unsubscribe(self, channel):
if channel in self.channels:
self.channels.remove(channel)
def get_message(self, timeout=1.0):
if self.messages:
return self.messages.pop(0)
return None
def close(self):
self.closed = True
class FakeRedis:
def __init__(self):
self.sorted_sets = defaultdict(dict)
self.hashes = defaultdict(dict)
self.sets = defaultdict(set)
self.published = []
self.pubsub_instance = FakePubSub()
def pipeline(self):
return FakePipeline(self)
def zadd(self, key, mapping):
self.sorted_sets[key].update(mapping)
return len(mapping)
def hset(self, key, field, value):
self.hashes[key][field] = value
return 1
def sadd(self, key, *members):
before = len(self.sets[key])
self.sets[key].update(members)
return len(self.sets[key]) - before
def zrevrange(self, key, start, stop):
items = sorted(
self.sorted_sets[key].items(),
key=lambda item: (item[1], item[0]),
reverse=True,
)
if stop == -1:
return [member for member, _ in items[start:]]
return [member for member, _ in items[start : stop + 1]]
def hget(self, key, field):
return self.hashes[key].get(field)
def zrem(self, key, *members):
removed = 0
for member in members:
if member in self.sorted_sets[key]:
del self.sorted_sets[key][member]
removed += 1
return removed
def hdel(self, key, *fields):
removed = 0
for field in fields:
if field in self.hashes[key]:
del self.hashes[key][field]
removed += 1
return removed
def smembers(self, key):
return set(self.sets[key])
def srem(self, key, member):
if member in self.sets[key]:
self.sets[key].remove(member)
return 1
return 0
def zrangebyscore(self, key, min_score, max_score):
lower = float("-inf") if min_score == "-inf" else float(min_score)
upper = float(max_score)
return [
member
for member, score in self.sorted_sets[key].items()
if lower <= score <= upper
]
def zcard(self, key):
return len(self.sorted_sets[key])
def publish(self, channel, message):
self.published.append((channel, json.loads(message)))
return 1
def pubsub(self, ignore_subscribe_messages=True):
return self.pubsub_instance
@pytest.fixture()
def fake_redis(monkeypatch):
redis = FakeRedis()
monkeypatch.setattr(services, "redis_client", redis)
return redis
def test_add_publishes_notification_and_unread_count(fake_redis, settings):
settings.NOTIFICATIONS_ENABLED = True
notification = RedisNotificationStore.add(
"user-1",
{
"title": "Build finished",
"message": "Your deploy completed.",
"level": "success",
},
)
assert notification["title"] == "Build finished"
assert notification["message"] == "Your deploy completed."
assert notification["level"] == "success"
assert len(fake_redis.published) == 2
channel, payload = fake_redis.published[0]
assert channel == f"{settings.NOTIFICATION_REDIS_CHANNEL_PREFIX}:user-1"
assert payload["event"] == "notification"
assert payload["data"]["notification"]["id"] == notification["id"]
assert payload["data"]["unread_count"] == 1
def test_mark_seen_and_mark_all_seen_publish_sync_events(fake_redis, settings):
settings.NOTIFICATIONS_ENABLED = True
first = RedisNotificationStore.add("user-2", {"title": "First"})
second = RedisNotificationStore.add("user-2", {"title": "Second"})
fake_redis.published.clear()
payload = RedisNotificationStore.mark_seen("user-2", first["id"])
assert payload["notification_id"] == first["id"]
assert payload["deleted"] is False
assert payload["notification"]["is_seen"] is True
assert fake_redis.published[0][1]["event"] == "notification_seen"
fake_redis.published.clear()
updated = RedisNotificationStore.mark_all_seen("user-2")
assert updated == 2
assert fake_redis.published[0][1]["event"] == "notification_mark_all_read"
assert fake_redis.published[1][1]["event"] == "unread_count"
assert fake_redis.published[1][1]["data"]["unread_count"] == 0
def test_list_returns_total_count_and_filtered_notifications(fake_redis):
RedisNotificationStore.add("user-3", {"title": "General", "type": "general"})
RedisNotificationStore.add("user-3", {"title": "Billing", "type": "billing"})
RedisNotificationStore.add("user-3", {"title": "General 2", "type": "general"})
notifications, total_count = RedisNotificationStore.list(
"user-3",
limit=1,
offset=0,
type_filter="general",
)
assert total_count == 2
assert len(notifications) == 1
assert notifications[0]["type"] == "general"

View File

@@ -0,0 +1,165 @@
import json
import time
from datetime import timedelta
import pytest
from django.utils import timezone
from rest_framework.test import APIClient
from apps.notifications import services, views
from apps.notifications.services import RedisNotificationStore
from apps.notifications.tests.test_services import FakePubSub, FakeRedis
from apps.users.models import User
@pytest.fixture()
def fake_redis(monkeypatch):
redis = FakeRedis()
monkeypatch.setattr(services, "redis_client", redis)
return redis
@pytest.fixture()
def user(db):
return User.objects.create_user(mobile="09121111111", password="secret123")
@pytest.fixture()
def second_user(db):
return User.objects.create_user(mobile="09122222222", password="secret123")
def _read_sse_chunks(response, count):
iterator = iter(response.streaming_content)
chunks = []
for _ in range(count):
chunk = next(iterator)
if isinstance(chunk, bytes):
chunk = chunk.decode("utf-8")
chunks.append(chunk)
response.close()
return chunks
def _parse_sse_data(chunk: str) -> dict:
for line in chunk.splitlines():
if line.startswith("data: "):
return json.loads(line.removeprefix("data: "))
raise AssertionError("SSE payload did not include data")
def test_stream_token_endpoint_returns_short_lived_token(user):
client = APIClient()
client.force_authenticate(user=user)
response = client.post("/api/notifications/stream-token/")
assert response.status_code == 200
assert response.data["token"]
assert response.data["expires_in"] > 0
def test_stream_endpoint_rejects_missing_and_expired_token(user, settings):
client = APIClient()
missing = client.get("/api/notifications/stream/")
assert missing.status_code == 401
settings.NOTIFICATION_STREAM_TOKEN_LIFETIME_SECONDS = 1
token = views._issue_stream_token_for_user(str(user.id))
time.sleep(1.1)
expired = client.get(f"/api/notifications/stream/?token={token}")
assert expired.status_code == 401
def test_stream_endpoint_sends_only_current_users_notifications(
fake_redis, user, second_user, monkeypatch
):
RedisNotificationStore.add(str(user.id), {"title": "For current user"})
RedisNotificationStore.add(str(second_user.id), {"title": "For another user"})
pubsub = FakePubSub()
monkeypatch.setattr(RedisNotificationStore, "get_pubsub", classmethod(lambda cls: pubsub))
token = views._issue_stream_token_for_user(str(user.id))
client = APIClient()
response = client.get(
f"/api/notifications/stream/?token={token}",
HTTP_ACCEPT="text/event-stream",
)
retry_line, connected_chunk = _read_sse_chunks(response, 2)
assert response.status_code == 200
assert retry_line.startswith("retry:")
connected = _parse_sse_data(connected_chunk)
assert connected["unread_count"] == 1
assert [item["title"] for item in connected["notifications"]] == ["For current user"]
def test_stream_endpoint_emits_heartbeat(fake_redis, user, settings, monkeypatch):
pubsub = FakePubSub()
monkeypatch.setattr(RedisNotificationStore, "get_pubsub", classmethod(lambda cls: pubsub))
settings.NOTIFICATION_SSE_HEARTBEAT_SECONDS = 1
first_now = timezone.now()
tick_values = iter(
[
first_now,
first_now,
first_now + timedelta(seconds=2),
first_now + timedelta(seconds=2),
first_now + timedelta(seconds=2),
first_now + timedelta(seconds=2),
]
)
last_tick = first_now + timedelta(seconds=2)
def fake_now():
return next(tick_values, last_tick)
monkeypatch.setattr(views.timezone, "now", fake_now)
view = views.NotificationStreamView()
stream = view._build_stream(str(user.id))
chunks = [next(stream) for _ in range(4)]
stream.close()
assert "event: ping" in chunks[3]
def test_notification_list_and_seen_endpoints_work(fake_redis, user):
notification = RedisNotificationStore.add(
str(user.id),
{"title": "Deploy succeeded", "type": "deploy"},
)
client = APIClient()
client.force_authenticate(user=user)
list_response = client.get("/api/notifications/list/?type=deploy")
assert list_response.status_code == 200
assert list_response.data["count"] == 1
assert list_response.data["unread_count"] == 1
assert list_response.data["notifications"][0]["title"] == "Deploy succeeded"
seen_response = client.post("/api/notifications/seen/", {"id": notification["id"]}, format="json")
assert seen_response.status_code == 200
assert seen_response.data["marked_read"] is True
assert seen_response.data["notification"]["is_seen"] is True
def test_notification_delete_endpoint_removes_notification(fake_redis, user):
notification = RedisNotificationStore.add(
str(user.id),
{"title": "Delete me", "type": "deploy"},
)
client = APIClient()
client.force_authenticate(user=user)
response = client.delete(f"/api/notifications/{notification['id']}/")
assert response.status_code == 200
assert response.data["deleted"] is True
assert response.data["notification_id"] == notification["id"]
assert RedisNotificationStore.get(str(user.id), notification["id"]) is None

View File

@@ -0,0 +1,35 @@
from django.urls import include, path
from rest_framework.routers import DefaultRouter
from apps.notifications import views
router = DefaultRouter()
router.register("box", views.NotificationListViewSet, basename="box")
app_name = "notification"
urlpatterns = [
path("", include(router.urls)),
path("list/", views.NotificationListView.as_view(), name="notifications"),
path("seen/", views.NotificationSeenView.as_view(), name="notifications-seen"),
path(
"stream-token/",
views.NotificationStreamTokenView.as_view(),
name="notifications-stream-token",
),
path(
"stream/",
views.NotificationStreamView.as_view(),
name="notifications-stream",
),
path(
"seen/all/",
views.NotificationMarkAllReadView.as_view(),
name="notifications-mark-read",
),
path(
"<str:notif_id>/",
views.NotificationDeleteView.as_view(),
name="notifications-delete",
),
]

284
apps/notifications/views.py Normal file
View File

@@ -0,0 +1,284 @@
import json
import logging
from typing import Iterator
from django.conf import settings
from django.core import signing
from django.http import JsonResponse, StreamingHttpResponse
from django.utils import timezone
from django.views import View
from drf_spectacular.utils import extend_schema
from rest_framework import status, viewsets
from rest_framework.permissions import IsAuthenticated
from rest_framework.response import Response
from rest_framework.views import APIView
from apps.notifications.serializers import (
NotificationDeleteResponseSerializer,
NotificationListResponseSerializer,
NotificationMarkAllReadResponseSerializer,
NotificationSeenRequestSerializer,
NotificationSeenResponseSerializer,
NotificationStreamTokenResponseSerializer,
NotificationTypeFilterSerializer,
)
from apps.notifications.services import RedisNotificationStore
logger = logging.getLogger(__name__)
STREAM_TOKEN_SALT = "notifications.stream"
def _safe_int(value, default: int) -> int:
try:
return max(int(value), 0)
except (TypeError, ValueError):
return default
def _format_sse_event(event: str, data: dict) -> str:
payload = json.dumps(data, ensure_ascii=False, default=str)
return f"event: {event}\ndata: {payload}\n\n"
def _issue_stream_token_for_user(user_id: str) -> str:
return signing.dumps(
{
"user_id": str(user_id),
"type": "notification_stream",
},
salt=STREAM_TOKEN_SALT,
)
def _validate_stream_token(token: str | None) -> str:
if not token:
raise signing.BadSignature("Missing stream token")
payload = signing.loads(
token,
salt=STREAM_TOKEN_SALT,
max_age=settings.NOTIFICATION_STREAM_TOKEN_LIFETIME_SECONDS,
)
if payload.get("type") != "notification_stream":
raise signing.BadSignature("Invalid stream token type")
return str(payload["user_id"])
class NotificationListViewSet(viewsets.ViewSet):
permission_classes = [IsAuthenticated]
serializer_class = NotificationListResponseSerializer
def list(self, request):
user_id = str(request.user.id)
limit = min(
_safe_int(
request.query_params.get("limit"),
settings.NOTIFICATION_DEFAULT_PAGE_SIZE,
),
settings.NOTIFICATION_MAX_PAGE_SIZE,
)
offset = _safe_int(request.query_params.get("offset"), 0)
type_filter = request.query_params.get("type")
notifications, total_count = RedisNotificationStore.list(
user_id,
limit=limit,
offset=offset,
type_filter=type_filter,
)
return Response(
{
"count": total_count,
"unread_count": RedisNotificationStore.unread_count(user_id),
"notifications": notifications,
}
)
class NotificationListView(APIView):
permission_classes = [IsAuthenticated]
serializer_class = NotificationListResponseSerializer
@extend_schema(responses=NotificationListResponseSerializer)
def get(self, request):
user_id = str(request.user.id)
limit = min(
_safe_int(
request.query_params.get("limit"),
settings.NOTIFICATION_DEFAULT_PAGE_SIZE,
),
settings.NOTIFICATION_MAX_PAGE_SIZE,
)
offset = _safe_int(request.query_params.get("offset"), 0)
type_filter = request.query_params.get("type")
notifications, total_count = RedisNotificationStore.list(
user_id,
limit=limit,
offset=offset,
type_filter=type_filter,
)
return Response(
{
"count": total_count,
"unread_count": RedisNotificationStore.unread_count(user_id),
"notifications": notifications,
}
)
class NotificationSeenView(APIView):
permission_classes = [IsAuthenticated]
serializer_class = NotificationSeenRequestSerializer
@extend_schema(
request=NotificationSeenRequestSerializer,
responses=NotificationSeenResponseSerializer,
)
def post(self, request):
serializer = self.serializer_class(data=request.data)
serializer.is_valid(raise_exception=True)
notif_id = serializer.validated_data["id"]
payload = RedisNotificationStore.mark_seen(str(request.user.id), notif_id)
if payload is None:
return Response({"marked_read": False}, status=status.HTTP_404_NOT_FOUND)
return Response({"marked_read": True, **payload})
class NotificationDeleteView(APIView):
permission_classes = [IsAuthenticated]
@extend_schema(responses=NotificationDeleteResponseSerializer)
def delete(self, request, notif_id: str):
deleted = RedisNotificationStore.delete(str(request.user.id), notif_id)
if not deleted:
return Response({"deleted": False}, status=status.HTTP_404_NOT_FOUND)
return Response(
{
"deleted": True,
"notification_id": notif_id,
"unread_count": RedisNotificationStore.unread_count(
str(request.user.id)
),
}
)
class NotificationMarkAllReadView(APIView):
permission_classes = [IsAuthenticated]
serializer_class = NotificationTypeFilterSerializer
@extend_schema(
request=NotificationTypeFilterSerializer,
responses=NotificationMarkAllReadResponseSerializer,
)
def post(self, request):
type_filter = request.data.get("type") or request.query_params.get("type")
updated = RedisNotificationStore.mark_all_seen(
str(request.user.id),
delete_on_seen_only=False,
type_filter=type_filter,
)
return Response({"marked_read": updated})
class NotificationStreamTokenView(APIView):
permission_classes = [IsAuthenticated]
@extend_schema(responses=NotificationStreamTokenResponseSerializer)
def post(self, request):
if not settings.NOTIFICATIONS_ENABLED:
return Response(
{"detail": "Notifications are disabled."},
status=status.HTTP_503_SERVICE_UNAVAILABLE,
)
return Response(
{
"token": _issue_stream_token_for_user(str(request.user.id)),
"expires_in": settings.NOTIFICATION_STREAM_TOKEN_LIFETIME_SECONDS,
}
)
class NotificationStreamView(View):
def _build_stream(self, user_id: str) -> Iterator[str]:
pubsub = RedisNotificationStore.get_pubsub()
channel = RedisNotificationStore._channel_key(user_id)
heartbeat_seconds = max(settings.NOTIFICATION_SSE_HEARTBEAT_SECONDS, 1)
initial_notifications, _ = RedisNotificationStore.list(
user_id,
limit=settings.NOTIFICATION_DEFAULT_PAGE_SIZE,
offset=0,
)
unread_count = RedisNotificationStore.unread_count(user_id)
yield f"retry: {settings.NOTIFICATION_SSE_RETRY_MS}\n\n"
yield _format_sse_event(
"connected",
{
"notifications": initial_notifications,
"unread_count": unread_count,
},
)
yield _format_sse_event(
"unread_count",
{
"unread_count": unread_count,
},
)
pubsub.subscribe(channel)
last_ping_at = timezone.now()
try:
while True:
message = pubsub.get_message(timeout=1.0)
if message and message.get("type") == "message":
try:
payload = json.loads(message["data"])
except json.JSONDecodeError:
logger.warning("Invalid notification stream payload for user %s", user_id)
else:
event = payload.get("event") or "notification"
data = payload.get("data") or {}
yield _format_sse_event(event, data)
if (timezone.now() - last_ping_at).total_seconds() >= heartbeat_seconds:
last_ping_at = timezone.now()
yield _format_sse_event(
"ping",
{"timestamp": last_ping_at.isoformat()},
)
except GeneratorExit:
logger.debug("Notification stream closed for user %s", user_id)
finally:
try:
pubsub.unsubscribe(channel)
finally:
pubsub.close()
def get(self, request, *args, **kwargs):
if not settings.NOTIFICATIONS_ENABLED:
return JsonResponse(
{"detail": "Notifications are disabled."},
status=status.HTTP_503_SERVICE_UNAVAILABLE,
)
try:
user_id = _validate_stream_token(request.GET.get("token"))
except signing.SignatureExpired:
return JsonResponse(
{"detail": "Stream token expired."},
status=status.HTTP_401_UNAUTHORIZED,
)
except signing.BadSignature:
return JsonResponse(
{"detail": "Invalid stream token."},
status=status.HTTP_401_UNAUTHORIZED,
)
response = StreamingHttpResponse(
self._build_stream(user_id),
content_type="text/event-stream",
)
response["Cache-Control"] = "no-cache"
response["X-Accel-Buffering"] = "no"
return response

View File

@@ -45,6 +45,7 @@ LOCAL_APPS = [
"apps.projects",
"apps.tags",
"apps.time_entries",
"apps.notifications",
]
INSTALLED_APPS = DJANGO_APPS + THIRD_PARTY_APPS + LOCAL_APPS
@@ -202,10 +203,35 @@ CELERY_ACCEPT_CONTENT = ["json"]
CELERY_TASK_SERIALIZER = "json"
CELERY_RESULT_SERIALIZER = "json"
CELERY_TASK_ALWAYS_EAGER = False
CELERY_IMPORTS = ("apps.users.tasks",)
CELERY_TIMEZONE = os.getenv("TIME_ZONE")
CELERY_TASK_TRACK_STARTED = True
NOTIFICATIONS_ENABLED = os.getenv("NOTIFICATIONS_ENABLED", "True") == "True"
NOTIFICATION_STREAM_TOKEN_LIFETIME_SECONDS = int(
os.getenv("NOTIFICATION_STREAM_TOKEN_LIFETIME_SECONDS", "90")
)
NOTIFICATION_SSE_HEARTBEAT_SECONDS = int(
os.getenv("NOTIFICATION_SSE_HEARTBEAT_SECONDS", "20")
)
NOTIFICATION_SSE_RETRY_MS = int(os.getenv("NOTIFICATION_SSE_RETRY_MS", "5000"))
NOTIFICATION_DEFAULT_PAGE_SIZE = int(
os.getenv("NOTIFICATION_DEFAULT_PAGE_SIZE", "20")
)
NOTIFICATION_MAX_PAGE_SIZE = int(os.getenv("NOTIFICATION_MAX_PAGE_SIZE", "50"))
NOTIFICATION_REDIS_CHANNEL_PREFIX = os.getenv(
"NOTIFICATION_REDIS_CHANNEL_PREFIX", "notif:user"
)
NOTIFICATION_RETENTION_DAYS = int(os.getenv("NOTIFICATION_RETENTION_DAYS", "30"))
NOTIFICATION_TOAST_LEVELS = tuple(
level.strip()
for level in os.getenv(
"NOTIFICATION_TOAST_LEVELS", "info,success,warning,error"
).split(",")
if level.strip()
)
CELERY_IMPORTS = ("apps.users.tasks", "apps.notifications.tasks")
STORAGES = {
"default": {"BACKEND": "django.core.files.storage.FileSystemStorage"},

View File

@@ -21,6 +21,7 @@ urlpatterns = [
path('api/', include('apps.projects.api.urls'), name="projects"),
path('api/', include('apps.tags.api.urls'), name="tags"),
path('api/', include('apps.time_entries.api.urls'), name="time_entries"),
path("api/notifications/", include("apps.notifications.urls"), name="notifications"),
]
if settings.DEBUG: