diff --git a/apps/notifications/__init__.py b/apps/notifications/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/apps/notifications/admin.py b/apps/notifications/admin.py new file mode 100644 index 0000000..e69de29 diff --git a/apps/notifications/apps.py b/apps/notifications/apps.py new file mode 100644 index 0000000..4f5397c --- /dev/null +++ b/apps/notifications/apps.py @@ -0,0 +1,7 @@ +from django.apps import AppConfig + + +class NotificationsConfig(AppConfig): + default_auto_field = "django.db.models.BigAutoField" + name = "apps.notifications" + verbose_name = "09-notifications" diff --git a/apps/notifications/migrations/__init__.py b/apps/notifications/migrations/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/apps/notifications/models.py b/apps/notifications/models.py new file mode 100644 index 0000000..e69de29 diff --git a/apps/notifications/serializers.py b/apps/notifications/serializers.py new file mode 100644 index 0000000..d98ca14 --- /dev/null +++ b/apps/notifications/serializers.py @@ -0,0 +1,61 @@ +from rest_framework import serializers + + +class NotificationSerializer(serializers.Serializer): + id = serializers.CharField() + type = serializers.CharField() + title = serializers.CharField(allow_blank=True) + message = serializers.CharField(allow_blank=True) + level = serializers.ChoiceField( + choices=("info", "success", "warning", "error") + ) + created_at = serializers.CharField() + is_seen = serializers.BooleanField() + delete_on_seen = serializers.BooleanField() + action_url = serializers.CharField( + required=False, allow_blank=True, allow_null=True + ) + entity_type = serializers.CharField( + required=False, allow_blank=True, allow_null=True + ) + entity_id = serializers.CharField( + required=False, allow_blank=True, allow_null=True + ) + meta = serializers.JSONField(required=False) + + +class NotificationListResponseSerializer(serializers.Serializer): + count = serializers.IntegerField() + unread_count = serializers.IntegerField() + notifications = NotificationSerializer(many=True) + + +class NotificationSeenRequestSerializer(serializers.Serializer): + id = serializers.CharField() + + +class NotificationDeleteResponseSerializer(serializers.Serializer): + deleted = serializers.BooleanField() + notification_id = serializers.CharField(required=False) + unread_count = serializers.IntegerField(required=False) + + +class NotificationSeenResponseSerializer(serializers.Serializer): + marked_read = serializers.BooleanField() + notification_id = serializers.CharField(required=False) + deleted = serializers.BooleanField(required=False) + unread_count = serializers.IntegerField(required=False) + notification = NotificationSerializer(required=False, allow_null=True) + + +class NotificationMarkAllReadResponseSerializer(serializers.Serializer): + marked_read = serializers.IntegerField() + + +class NotificationTypeFilterSerializer(serializers.Serializer): + type = serializers.CharField(required=False, allow_blank=True, allow_null=True) + + +class NotificationStreamTokenResponseSerializer(serializers.Serializer): + token = serializers.CharField() + expires_in = serializers.IntegerField() diff --git a/apps/notifications/services.py b/apps/notifications/services.py new file mode 100644 index 0000000..8b975bb --- /dev/null +++ b/apps/notifications/services.py @@ -0,0 +1,355 @@ +import json +import uuid +from datetime import timedelta + +import redis +from django.conf import settings +from django.utils import timezone +from django.utils.dateparse import parse_datetime + +redis_client = redis.StrictRedis.from_url(settings.REDIS_URL, decode_responses=True) + + +def _isoformat_datetime(value) -> str: + if not value: + return timezone.now().isoformat() + if isinstance(value, str): + parsed = parse_datetime(value) + if parsed is not None: + value = parsed + else: + return value + if timezone.is_naive(value): + value = timezone.make_aware(value, timezone.get_current_timezone()) + return timezone.localtime(value).isoformat() + + +class RedisNotificationStore: + USERS_KEY = "notif:users" + + @classmethod + def _ids_key(cls, user_id: str) -> str: + return f"notif:{user_id}:ids" + + @classmethod + def _data_key(cls, user_id: str) -> str: + return f"notif:{user_id}:data" + + @classmethod + def _channel_key(cls, user_id: str) -> str: + prefix = settings.NOTIFICATION_REDIS_CHANNEL_PREFIX.rstrip(":") + return f"{prefix}:{user_id}" + + @staticmethod + def _normalize_notification(data: dict | None) -> dict: + payload = dict(data or {}) + return { + "id": str(payload.get("id") or uuid.uuid4()), + "type": payload.get("type") or "notification", + "title": payload.get("title") or "", + "message": payload.get("message") or "", + "level": payload.get("level") or "info", + "created_at": _isoformat_datetime(payload.get("created_at")), + "is_seen": bool(payload.get("is_seen", False)), + "delete_on_seen": bool(payload.get("delete_on_seen", False)), + "action_url": payload.get("action_url"), + "entity_type": payload.get("entity_type"), + "entity_id": payload.get("entity_id"), + "meta": payload.get("meta") or {}, + } + + @classmethod + def _publish_event(cls, user_id: str, event: str, data: dict) -> None: + if not settings.NOTIFICATIONS_ENABLED: + return + payload = { + "event": event, + "data": data, + } + redis_client.publish( + cls._channel_key(user_id), + json.dumps(payload, ensure_ascii=False, default=str), + ) + + @classmethod + def unread_count(cls, user_id: str, *, type_filter: str | None = None) -> int: + notifications, _ = cls.list( + user_id, + limit=settings.NOTIFICATION_MAX_PAGE_SIZE, + offset=0, + type_filter=type_filter, + paginate=False, + ) + return sum(1 for notification in notifications if not notification.get("is_seen")) + + @classmethod + def add(cls, user_id: str, payload: dict) -> dict: + data = cls._normalize_notification(payload) + created_at = parse_datetime(data["created_at"]) or timezone.now() + created_at_ts = created_at.timestamp() + json_str = json.dumps(data, ensure_ascii=False, default=str) + + ids_key = cls._ids_key(user_id) + data_key = cls._data_key(user_id) + + pipe = redis_client.pipeline() + pipe.zadd(ids_key, {data["id"]: created_at_ts}) + pipe.hset(data_key, data["id"], json_str) + pipe.sadd(cls.USERS_KEY, user_id) + pipe.execute() + + unread_count = cls.unread_count(user_id) + cls._publish_event( + user_id, + "notification", + { + "notification": data, + "unread_count": unread_count, + }, + ) + cls._publish_event( + user_id, + "unread_count", + { + "unread_count": unread_count, + }, + ) + return data + + @classmethod + def list( + cls, + user_id: str, + *, + limit: int | None = None, + offset: int = 0, + type_filter: str | None = None, + paginate: bool = True, + ) -> tuple[list[dict], int]: + ids_key = cls._ids_key(user_id) + data_key = cls._data_key(user_id) + + ids = redis_client.zrevrange(ids_key, 0, -1) + if not ids: + return [], 0 + + pipe = redis_client.pipeline() + for notif_id in ids: + pipe.hget(data_key, notif_id) + raw_items = pipe.execute() + + items: list[dict] = [] + cleanup_ids: list[str] = [] + for notif_id, raw in zip(ids, raw_items, strict=False): + if not raw: + cleanup_ids.append(notif_id) + continue + try: + data = cls._normalize_notification(json.loads(raw)) + except json.JSONDecodeError: + cleanup_ids.append(notif_id) + continue + if type_filter and data.get("type") != type_filter: + continue + items.append(data) + + if cleanup_ids: + redis_client.zrem(ids_key, *cleanup_ids) + redis_client.hdel(data_key, *cleanup_ids) + + total_count = len(items) + if not paginate: + return items, total_count + + safe_offset = max(offset, 0) + safe_limit = max(limit or settings.NOTIFICATION_DEFAULT_PAGE_SIZE, 1) + return items[safe_offset : safe_offset + safe_limit], total_count + + @classmethod + def get(cls, user_id: str, notif_id: str) -> dict | None: + data_key = cls._data_key(user_id) + raw = redis_client.hget(data_key, notif_id) + if not raw: + return None + try: + return cls._normalize_notification(json.loads(raw)) + except json.JSONDecodeError: + return None + + @classmethod + def delete(cls, user_id: str, notif_id: str) -> bool: + ids_key = cls._ids_key(user_id) + data_key = cls._data_key(user_id) + pipe = redis_client.pipeline() + pipe.zrem(ids_key, notif_id) + pipe.hdel(data_key, notif_id) + result = pipe.execute() + if any(result): + unread_count = cls.unread_count(user_id) + cls._publish_event( + user_id, + "notification_seen", + { + "notification_id": notif_id, + "deleted": True, + "unread_count": unread_count, + }, + ) + cls._publish_event( + user_id, + "unread_count", + { + "unread_count": unread_count, + }, + ) + return True + return False + + @classmethod + def mark_seen(cls, user_id: str, notif_id: str) -> dict | None: + data = cls.get(user_id, notif_id) + if not data: + return None + + if data.get("delete_on_seen"): + deleted = cls.delete(user_id, notif_id) + if deleted: + return { + "notification_id": notif_id, + "deleted": True, + "notification": None, + "unread_count": cls.unread_count(user_id), + } + return None + + if not data.get("is_seen"): + data["is_seen"] = True + data_key = cls._data_key(user_id) + redis_client.hset( + data_key, notif_id, json.dumps(data, ensure_ascii=False, default=str) + ) + + unread_count = cls.unread_count(user_id) + payload = { + "notification_id": notif_id, + "deleted": False, + "notification": data, + "unread_count": unread_count, + } + cls._publish_event(user_id, "notification_seen", payload) + cls._publish_event( + user_id, + "unread_count", + { + "unread_count": unread_count, + }, + ) + return payload + + @classmethod + def mark_all_seen( + cls, + user_id: str, + *, + delete_on_seen_only: bool = False, + type_filter: str | None = None, + ) -> int: + ids_key = cls._ids_key(user_id) + data_key = cls._data_key(user_id) + ids = redis_client.zrevrange(ids_key, 0, -1) + if not ids: + return 0 + + updated = 0 + pipe = redis_client.pipeline() + for notif_id in ids: + raw = redis_client.hget(data_key, notif_id) + if not raw: + continue + try: + data = cls._normalize_notification(json.loads(raw)) + except json.JSONDecodeError: + continue + if type_filter and data.get("type") != type_filter: + continue + if delete_on_seen_only and not data.get("delete_on_seen"): + continue + + if data.get("delete_on_seen"): + pipe.zrem(ids_key, notif_id) + pipe.hdel(data_key, notif_id) + else: + data["is_seen"] = True + pipe.hset( + data_key, + notif_id, + json.dumps(data, ensure_ascii=False, default=str), + ) + updated += 1 + + if updated: + pipe.execute() + unread_count = cls.unread_count(user_id, type_filter=type_filter) + cls._publish_event( + user_id, + "notification_mark_all_read", + { + "type": type_filter, + "unread_count": unread_count, + }, + ) + cls._publish_event( + user_id, + "unread_count", + { + "unread_count": cls.unread_count(user_id), + }, + ) + return updated + + @classmethod + def get_pubsub(cls): + return redis_client.pubsub(ignore_subscribe_messages=True) + + @classmethod + def cleanup_expired(cls, retention_days: int = 30) -> int: + cutoff_ts = (timezone.now() - timedelta(days=retention_days)).timestamp() + removed = 0 + user_ids = redis_client.smembers(cls.USERS_KEY) + for user_id in user_ids: + ids_key = cls._ids_key(user_id) + data_key = cls._data_key(user_id) + old_ids = redis_client.zrangebyscore(ids_key, "-inf", cutoff_ts) + if not old_ids: + continue + + pipe = redis_client.pipeline() + user_removed = 0 + for notif_id in old_ids: + raw = redis_client.hget(data_key, notif_id) + if not raw: + pipe.zrem(ids_key, notif_id) + pipe.hdel(data_key, notif_id) + removed += 1 + user_removed += 1 + continue + try: + data = cls._normalize_notification(json.loads(raw)) + except json.JSONDecodeError: + pipe.zrem(ids_key, notif_id) + pipe.hdel(data_key, notif_id) + removed += 1 + user_removed += 1 + continue + if data.get("delete_on_seen"): + continue + pipe.zrem(ids_key, notif_id) + pipe.hdel(data_key, notif_id) + removed += 1 + user_removed += 1 + if user_removed: + pipe.execute() + + if redis_client.zcard(ids_key) == 0: + redis_client.srem(cls.USERS_KEY, user_id) + return removed diff --git a/apps/notifications/tasks.py b/apps/notifications/tasks.py new file mode 100644 index 0000000..eeb5d85 --- /dev/null +++ b/apps/notifications/tasks.py @@ -0,0 +1,11 @@ +from celery import shared_task +from django.conf import settings + +from apps.notifications.services import RedisNotificationStore + + +@shared_task(name="notifications.cleanup_redis_notifications") +def cleanup_redis_notifications(): + return RedisNotificationStore.cleanup_expired( + retention_days=settings.NOTIFICATION_RETENTION_DAYS + ) diff --git a/apps/notifications/tests/test_services.py b/apps/notifications/tests/test_services.py new file mode 100644 index 0000000..deae25a --- /dev/null +++ b/apps/notifications/tests/test_services.py @@ -0,0 +1,200 @@ +import json +from collections import defaultdict + +import pytest + +from apps.notifications import services +from apps.notifications.services import RedisNotificationStore + + +class FakePipeline: + def __init__(self, client): + self.client = client + self.operations = [] + + def __getattr__(self, name): + def wrapper(*args, **kwargs): + self.operations.append((name, args, kwargs)) + return self + + return wrapper + + def execute(self): + results = [] + for name, args, kwargs in self.operations: + results.append(getattr(self.client, name)(*args, **kwargs)) + self.operations.clear() + return results + + +class FakePubSub: + def __init__(self): + self.channels = [] + self.messages = [] + self.closed = False + + def subscribe(self, channel): + self.channels.append(channel) + + def unsubscribe(self, channel): + if channel in self.channels: + self.channels.remove(channel) + + def get_message(self, timeout=1.0): + if self.messages: + return self.messages.pop(0) + return None + + def close(self): + self.closed = True + + +class FakeRedis: + def __init__(self): + self.sorted_sets = defaultdict(dict) + self.hashes = defaultdict(dict) + self.sets = defaultdict(set) + self.published = [] + self.pubsub_instance = FakePubSub() + + def pipeline(self): + return FakePipeline(self) + + def zadd(self, key, mapping): + self.sorted_sets[key].update(mapping) + return len(mapping) + + def hset(self, key, field, value): + self.hashes[key][field] = value + return 1 + + def sadd(self, key, *members): + before = len(self.sets[key]) + self.sets[key].update(members) + return len(self.sets[key]) - before + + def zrevrange(self, key, start, stop): + items = sorted( + self.sorted_sets[key].items(), + key=lambda item: (item[1], item[0]), + reverse=True, + ) + if stop == -1: + return [member for member, _ in items[start:]] + return [member for member, _ in items[start : stop + 1]] + + def hget(self, key, field): + return self.hashes[key].get(field) + + def zrem(self, key, *members): + removed = 0 + for member in members: + if member in self.sorted_sets[key]: + del self.sorted_sets[key][member] + removed += 1 + return removed + + def hdel(self, key, *fields): + removed = 0 + for field in fields: + if field in self.hashes[key]: + del self.hashes[key][field] + removed += 1 + return removed + + def smembers(self, key): + return set(self.sets[key]) + + def srem(self, key, member): + if member in self.sets[key]: + self.sets[key].remove(member) + return 1 + return 0 + + def zrangebyscore(self, key, min_score, max_score): + lower = float("-inf") if min_score == "-inf" else float(min_score) + upper = float(max_score) + return [ + member + for member, score in self.sorted_sets[key].items() + if lower <= score <= upper + ] + + def zcard(self, key): + return len(self.sorted_sets[key]) + + def publish(self, channel, message): + self.published.append((channel, json.loads(message))) + return 1 + + def pubsub(self, ignore_subscribe_messages=True): + return self.pubsub_instance + + +@pytest.fixture() +def fake_redis(monkeypatch): + redis = FakeRedis() + monkeypatch.setattr(services, "redis_client", redis) + return redis + + +def test_add_publishes_notification_and_unread_count(fake_redis, settings): + settings.NOTIFICATIONS_ENABLED = True + + notification = RedisNotificationStore.add( + "user-1", + { + "title": "Build finished", + "message": "Your deploy completed.", + "level": "success", + }, + ) + + assert notification["title"] == "Build finished" + assert notification["message"] == "Your deploy completed." + assert notification["level"] == "success" + assert len(fake_redis.published) == 2 + channel, payload = fake_redis.published[0] + assert channel == f"{settings.NOTIFICATION_REDIS_CHANNEL_PREFIX}:user-1" + assert payload["event"] == "notification" + assert payload["data"]["notification"]["id"] == notification["id"] + assert payload["data"]["unread_count"] == 1 + + +def test_mark_seen_and_mark_all_seen_publish_sync_events(fake_redis, settings): + settings.NOTIFICATIONS_ENABLED = True + first = RedisNotificationStore.add("user-2", {"title": "First"}) + second = RedisNotificationStore.add("user-2", {"title": "Second"}) + fake_redis.published.clear() + + payload = RedisNotificationStore.mark_seen("user-2", first["id"]) + + assert payload["notification_id"] == first["id"] + assert payload["deleted"] is False + assert payload["notification"]["is_seen"] is True + assert fake_redis.published[0][1]["event"] == "notification_seen" + + fake_redis.published.clear() + updated = RedisNotificationStore.mark_all_seen("user-2") + + assert updated == 2 + assert fake_redis.published[0][1]["event"] == "notification_mark_all_read" + assert fake_redis.published[1][1]["event"] == "unread_count" + assert fake_redis.published[1][1]["data"]["unread_count"] == 0 + + +def test_list_returns_total_count_and_filtered_notifications(fake_redis): + RedisNotificationStore.add("user-3", {"title": "General", "type": "general"}) + RedisNotificationStore.add("user-3", {"title": "Billing", "type": "billing"}) + RedisNotificationStore.add("user-3", {"title": "General 2", "type": "general"}) + + notifications, total_count = RedisNotificationStore.list( + "user-3", + limit=1, + offset=0, + type_filter="general", + ) + + assert total_count == 2 + assert len(notifications) == 1 + assert notifications[0]["type"] == "general" diff --git a/apps/notifications/tests/test_views.py b/apps/notifications/tests/test_views.py new file mode 100644 index 0000000..8826c3b --- /dev/null +++ b/apps/notifications/tests/test_views.py @@ -0,0 +1,165 @@ +import json +import time +from datetime import timedelta + +import pytest +from django.utils import timezone +from rest_framework.test import APIClient + +from apps.notifications import services, views +from apps.notifications.services import RedisNotificationStore +from apps.notifications.tests.test_services import FakePubSub, FakeRedis +from apps.users.models import User + + +@pytest.fixture() +def fake_redis(monkeypatch): + redis = FakeRedis() + monkeypatch.setattr(services, "redis_client", redis) + return redis + + +@pytest.fixture() +def user(db): + return User.objects.create_user(mobile="09121111111", password="secret123") + + +@pytest.fixture() +def second_user(db): + return User.objects.create_user(mobile="09122222222", password="secret123") + + +def _read_sse_chunks(response, count): + iterator = iter(response.streaming_content) + chunks = [] + for _ in range(count): + chunk = next(iterator) + if isinstance(chunk, bytes): + chunk = chunk.decode("utf-8") + chunks.append(chunk) + response.close() + return chunks + + +def _parse_sse_data(chunk: str) -> dict: + for line in chunk.splitlines(): + if line.startswith("data: "): + return json.loads(line.removeprefix("data: ")) + raise AssertionError("SSE payload did not include data") + + +def test_stream_token_endpoint_returns_short_lived_token(user): + client = APIClient() + client.force_authenticate(user=user) + + response = client.post("/api/notifications/stream-token/") + + assert response.status_code == 200 + assert response.data["token"] + assert response.data["expires_in"] > 0 + + +def test_stream_endpoint_rejects_missing_and_expired_token(user, settings): + client = APIClient() + + missing = client.get("/api/notifications/stream/") + assert missing.status_code == 401 + + settings.NOTIFICATION_STREAM_TOKEN_LIFETIME_SECONDS = 1 + token = views._issue_stream_token_for_user(str(user.id)) + time.sleep(1.1) + + expired = client.get(f"/api/notifications/stream/?token={token}") + assert expired.status_code == 401 + + +def test_stream_endpoint_sends_only_current_users_notifications( + fake_redis, user, second_user, monkeypatch +): + RedisNotificationStore.add(str(user.id), {"title": "For current user"}) + RedisNotificationStore.add(str(second_user.id), {"title": "For another user"}) + pubsub = FakePubSub() + monkeypatch.setattr(RedisNotificationStore, "get_pubsub", classmethod(lambda cls: pubsub)) + token = views._issue_stream_token_for_user(str(user.id)) + + client = APIClient() + response = client.get( + f"/api/notifications/stream/?token={token}", + HTTP_ACCEPT="text/event-stream", + ) + retry_line, connected_chunk = _read_sse_chunks(response, 2) + + assert response.status_code == 200 + assert retry_line.startswith("retry:") + connected = _parse_sse_data(connected_chunk) + assert connected["unread_count"] == 1 + assert [item["title"] for item in connected["notifications"]] == ["For current user"] + + +def test_stream_endpoint_emits_heartbeat(fake_redis, user, settings, monkeypatch): + pubsub = FakePubSub() + monkeypatch.setattr(RedisNotificationStore, "get_pubsub", classmethod(lambda cls: pubsub)) + settings.NOTIFICATION_SSE_HEARTBEAT_SECONDS = 1 + + first_now = timezone.now() + tick_values = iter( + [ + first_now, + first_now, + first_now + timedelta(seconds=2), + first_now + timedelta(seconds=2), + first_now + timedelta(seconds=2), + first_now + timedelta(seconds=2), + ] + ) + last_tick = first_now + timedelta(seconds=2) + + def fake_now(): + return next(tick_values, last_tick) + + monkeypatch.setattr(views.timezone, "now", fake_now) + view = views.NotificationStreamView() + stream = view._build_stream(str(user.id)) + + chunks = [next(stream) for _ in range(4)] + stream.close() + + assert "event: ping" in chunks[3] + + +def test_notification_list_and_seen_endpoints_work(fake_redis, user): + notification = RedisNotificationStore.add( + str(user.id), + {"title": "Deploy succeeded", "type": "deploy"}, + ) + + client = APIClient() + client.force_authenticate(user=user) + + list_response = client.get("/api/notifications/list/?type=deploy") + assert list_response.status_code == 200 + assert list_response.data["count"] == 1 + assert list_response.data["unread_count"] == 1 + assert list_response.data["notifications"][0]["title"] == "Deploy succeeded" + + seen_response = client.post("/api/notifications/seen/", {"id": notification["id"]}, format="json") + assert seen_response.status_code == 200 + assert seen_response.data["marked_read"] is True + assert seen_response.data["notification"]["is_seen"] is True + + +def test_notification_delete_endpoint_removes_notification(fake_redis, user): + notification = RedisNotificationStore.add( + str(user.id), + {"title": "Delete me", "type": "deploy"}, + ) + + client = APIClient() + client.force_authenticate(user=user) + + response = client.delete(f"/api/notifications/{notification['id']}/") + + assert response.status_code == 200 + assert response.data["deleted"] is True + assert response.data["notification_id"] == notification["id"] + assert RedisNotificationStore.get(str(user.id), notification["id"]) is None diff --git a/apps/notifications/urls.py b/apps/notifications/urls.py new file mode 100644 index 0000000..83d84fc --- /dev/null +++ b/apps/notifications/urls.py @@ -0,0 +1,35 @@ +from django.urls import include, path +from rest_framework.routers import DefaultRouter + +from apps.notifications import views + +router = DefaultRouter() +router.register("box", views.NotificationListViewSet, basename="box") + +app_name = "notification" + +urlpatterns = [ + path("", include(router.urls)), + path("list/", views.NotificationListView.as_view(), name="notifications"), + path("seen/", views.NotificationSeenView.as_view(), name="notifications-seen"), + path( + "stream-token/", + views.NotificationStreamTokenView.as_view(), + name="notifications-stream-token", + ), + path( + "stream/", + views.NotificationStreamView.as_view(), + name="notifications-stream", + ), + path( + "seen/all/", + views.NotificationMarkAllReadView.as_view(), + name="notifications-mark-read", + ), + path( + "/", + views.NotificationDeleteView.as_view(), + name="notifications-delete", + ), +] diff --git a/apps/notifications/views.py b/apps/notifications/views.py new file mode 100644 index 0000000..e7b80d2 --- /dev/null +++ b/apps/notifications/views.py @@ -0,0 +1,284 @@ +import json +import logging +from typing import Iterator + +from django.conf import settings +from django.core import signing +from django.http import JsonResponse, StreamingHttpResponse +from django.utils import timezone +from django.views import View +from drf_spectacular.utils import extend_schema +from rest_framework import status, viewsets +from rest_framework.permissions import IsAuthenticated +from rest_framework.response import Response +from rest_framework.views import APIView + +from apps.notifications.serializers import ( + NotificationDeleteResponseSerializer, + NotificationListResponseSerializer, + NotificationMarkAllReadResponseSerializer, + NotificationSeenRequestSerializer, + NotificationSeenResponseSerializer, + NotificationStreamTokenResponseSerializer, + NotificationTypeFilterSerializer, +) +from apps.notifications.services import RedisNotificationStore + +logger = logging.getLogger(__name__) + +STREAM_TOKEN_SALT = "notifications.stream" + + +def _safe_int(value, default: int) -> int: + try: + return max(int(value), 0) + except (TypeError, ValueError): + return default + + +def _format_sse_event(event: str, data: dict) -> str: + payload = json.dumps(data, ensure_ascii=False, default=str) + return f"event: {event}\ndata: {payload}\n\n" + + +def _issue_stream_token_for_user(user_id: str) -> str: + return signing.dumps( + { + "user_id": str(user_id), + "type": "notification_stream", + }, + salt=STREAM_TOKEN_SALT, + ) + + +def _validate_stream_token(token: str | None) -> str: + if not token: + raise signing.BadSignature("Missing stream token") + payload = signing.loads( + token, + salt=STREAM_TOKEN_SALT, + max_age=settings.NOTIFICATION_STREAM_TOKEN_LIFETIME_SECONDS, + ) + if payload.get("type") != "notification_stream": + raise signing.BadSignature("Invalid stream token type") + return str(payload["user_id"]) + + +class NotificationListViewSet(viewsets.ViewSet): + permission_classes = [IsAuthenticated] + serializer_class = NotificationListResponseSerializer + + def list(self, request): + user_id = str(request.user.id) + limit = min( + _safe_int( + request.query_params.get("limit"), + settings.NOTIFICATION_DEFAULT_PAGE_SIZE, + ), + settings.NOTIFICATION_MAX_PAGE_SIZE, + ) + offset = _safe_int(request.query_params.get("offset"), 0) + type_filter = request.query_params.get("type") + notifications, total_count = RedisNotificationStore.list( + user_id, + limit=limit, + offset=offset, + type_filter=type_filter, + ) + return Response( + { + "count": total_count, + "unread_count": RedisNotificationStore.unread_count(user_id), + "notifications": notifications, + } + ) + + +class NotificationListView(APIView): + permission_classes = [IsAuthenticated] + serializer_class = NotificationListResponseSerializer + + @extend_schema(responses=NotificationListResponseSerializer) + def get(self, request): + user_id = str(request.user.id) + limit = min( + _safe_int( + request.query_params.get("limit"), + settings.NOTIFICATION_DEFAULT_PAGE_SIZE, + ), + settings.NOTIFICATION_MAX_PAGE_SIZE, + ) + offset = _safe_int(request.query_params.get("offset"), 0) + type_filter = request.query_params.get("type") + notifications, total_count = RedisNotificationStore.list( + user_id, + limit=limit, + offset=offset, + type_filter=type_filter, + ) + return Response( + { + "count": total_count, + "unread_count": RedisNotificationStore.unread_count(user_id), + "notifications": notifications, + } + ) + + +class NotificationSeenView(APIView): + permission_classes = [IsAuthenticated] + serializer_class = NotificationSeenRequestSerializer + + @extend_schema( + request=NotificationSeenRequestSerializer, + responses=NotificationSeenResponseSerializer, + ) + def post(self, request): + serializer = self.serializer_class(data=request.data) + serializer.is_valid(raise_exception=True) + notif_id = serializer.validated_data["id"] + payload = RedisNotificationStore.mark_seen(str(request.user.id), notif_id) + if payload is None: + return Response({"marked_read": False}, status=status.HTTP_404_NOT_FOUND) + return Response({"marked_read": True, **payload}) + + +class NotificationDeleteView(APIView): + permission_classes = [IsAuthenticated] + + @extend_schema(responses=NotificationDeleteResponseSerializer) + def delete(self, request, notif_id: str): + deleted = RedisNotificationStore.delete(str(request.user.id), notif_id) + if not deleted: + return Response({"deleted": False}, status=status.HTTP_404_NOT_FOUND) + return Response( + { + "deleted": True, + "notification_id": notif_id, + "unread_count": RedisNotificationStore.unread_count( + str(request.user.id) + ), + } + ) + + +class NotificationMarkAllReadView(APIView): + permission_classes = [IsAuthenticated] + serializer_class = NotificationTypeFilterSerializer + + @extend_schema( + request=NotificationTypeFilterSerializer, + responses=NotificationMarkAllReadResponseSerializer, + ) + def post(self, request): + type_filter = request.data.get("type") or request.query_params.get("type") + updated = RedisNotificationStore.mark_all_seen( + str(request.user.id), + delete_on_seen_only=False, + type_filter=type_filter, + ) + return Response({"marked_read": updated}) + + +class NotificationStreamTokenView(APIView): + permission_classes = [IsAuthenticated] + + @extend_schema(responses=NotificationStreamTokenResponseSerializer) + def post(self, request): + if not settings.NOTIFICATIONS_ENABLED: + return Response( + {"detail": "Notifications are disabled."}, + status=status.HTTP_503_SERVICE_UNAVAILABLE, + ) + + return Response( + { + "token": _issue_stream_token_for_user(str(request.user.id)), + "expires_in": settings.NOTIFICATION_STREAM_TOKEN_LIFETIME_SECONDS, + } + ) + + +class NotificationStreamView(View): + def _build_stream(self, user_id: str) -> Iterator[str]: + pubsub = RedisNotificationStore.get_pubsub() + channel = RedisNotificationStore._channel_key(user_id) + heartbeat_seconds = max(settings.NOTIFICATION_SSE_HEARTBEAT_SECONDS, 1) + initial_notifications, _ = RedisNotificationStore.list( + user_id, + limit=settings.NOTIFICATION_DEFAULT_PAGE_SIZE, + offset=0, + ) + unread_count = RedisNotificationStore.unread_count(user_id) + + yield f"retry: {settings.NOTIFICATION_SSE_RETRY_MS}\n\n" + yield _format_sse_event( + "connected", + { + "notifications": initial_notifications, + "unread_count": unread_count, + }, + ) + yield _format_sse_event( + "unread_count", + { + "unread_count": unread_count, + }, + ) + + pubsub.subscribe(channel) + last_ping_at = timezone.now() + try: + while True: + message = pubsub.get_message(timeout=1.0) + if message and message.get("type") == "message": + try: + payload = json.loads(message["data"]) + except json.JSONDecodeError: + logger.warning("Invalid notification stream payload for user %s", user_id) + else: + event = payload.get("event") or "notification" + data = payload.get("data") or {} + yield _format_sse_event(event, data) + + if (timezone.now() - last_ping_at).total_seconds() >= heartbeat_seconds: + last_ping_at = timezone.now() + yield _format_sse_event( + "ping", + {"timestamp": last_ping_at.isoformat()}, + ) + except GeneratorExit: + logger.debug("Notification stream closed for user %s", user_id) + finally: + try: + pubsub.unsubscribe(channel) + finally: + pubsub.close() + + def get(self, request, *args, **kwargs): + if not settings.NOTIFICATIONS_ENABLED: + return JsonResponse( + {"detail": "Notifications are disabled."}, + status=status.HTTP_503_SERVICE_UNAVAILABLE, + ) + + try: + user_id = _validate_stream_token(request.GET.get("token")) + except signing.SignatureExpired: + return JsonResponse( + {"detail": "Stream token expired."}, + status=status.HTTP_401_UNAUTHORIZED, + ) + except signing.BadSignature: + return JsonResponse( + {"detail": "Invalid stream token."}, + status=status.HTTP_401_UNAUTHORIZED, + ) + + response = StreamingHttpResponse( + self._build_stream(user_id), + content_type="text/event-stream", + ) + response["Cache-Control"] = "no-cache" + response["X-Accel-Buffering"] = "no" + return response diff --git a/config/settings/base.py b/config/settings/base.py index 0738ae6..3b5296a 100644 --- a/config/settings/base.py +++ b/config/settings/base.py @@ -45,6 +45,7 @@ LOCAL_APPS = [ "apps.projects", "apps.tags", "apps.time_entries", + "apps.notifications", ] INSTALLED_APPS = DJANGO_APPS + THIRD_PARTY_APPS + LOCAL_APPS @@ -202,10 +203,35 @@ CELERY_ACCEPT_CONTENT = ["json"] CELERY_TASK_SERIALIZER = "json" CELERY_RESULT_SERIALIZER = "json" CELERY_TASK_ALWAYS_EAGER = False -CELERY_IMPORTS = ("apps.users.tasks",) CELERY_TIMEZONE = os.getenv("TIME_ZONE") CELERY_TASK_TRACK_STARTED = True +NOTIFICATIONS_ENABLED = os.getenv("NOTIFICATIONS_ENABLED", "True") == "True" +NOTIFICATION_STREAM_TOKEN_LIFETIME_SECONDS = int( + os.getenv("NOTIFICATION_STREAM_TOKEN_LIFETIME_SECONDS", "90") +) +NOTIFICATION_SSE_HEARTBEAT_SECONDS = int( + os.getenv("NOTIFICATION_SSE_HEARTBEAT_SECONDS", "20") +) +NOTIFICATION_SSE_RETRY_MS = int(os.getenv("NOTIFICATION_SSE_RETRY_MS", "5000")) +NOTIFICATION_DEFAULT_PAGE_SIZE = int( + os.getenv("NOTIFICATION_DEFAULT_PAGE_SIZE", "20") +) +NOTIFICATION_MAX_PAGE_SIZE = int(os.getenv("NOTIFICATION_MAX_PAGE_SIZE", "50")) +NOTIFICATION_REDIS_CHANNEL_PREFIX = os.getenv( + "NOTIFICATION_REDIS_CHANNEL_PREFIX", "notif:user" +) +NOTIFICATION_RETENTION_DAYS = int(os.getenv("NOTIFICATION_RETENTION_DAYS", "30")) +NOTIFICATION_TOAST_LEVELS = tuple( + level.strip() + for level in os.getenv( + "NOTIFICATION_TOAST_LEVELS", "info,success,warning,error" + ).split(",") + if level.strip() +) + +CELERY_IMPORTS = ("apps.users.tasks", "apps.notifications.tasks") + STORAGES = { "default": {"BACKEND": "django.core.files.storage.FileSystemStorage"}, diff --git a/config/urls.py b/config/urls.py index bb94a9e..bd0fbd8 100644 --- a/config/urls.py +++ b/config/urls.py @@ -21,6 +21,7 @@ urlpatterns = [ path('api/', include('apps.projects.api.urls'), name="projects"), path('api/', include('apps.tags.api.urls'), name="tags"), path('api/', include('apps.time_entries.api.urls'), name="time_entries"), + path("api/notifications/", include("apps.notifications.urls"), name="notifications"), ] if settings.DEBUG: