169 lines
6.3 KiB
Python
169 lines
6.3 KiB
Python
import json
|
|
import time
|
|
from datetime import timedelta
|
|
from unittest.mock import patch
|
|
|
|
from django.test import override_settings
|
|
from django.utils import timezone
|
|
from rest_framework.test import APITestCase
|
|
|
|
from apps.notifications.api import views
|
|
from apps.notifications.services import store as services
|
|
from apps.notifications.services import RedisNotificationStore
|
|
from apps.notifications.tests.fakes import FakePubSub, FakeRedis
|
|
from apps.users.models import User
|
|
|
|
|
|
class NotificationViewTests(APITestCase):
|
|
@classmethod
|
|
def setUpTestData(cls):
|
|
cls.user = User.objects.create_user(mobile="09121111111", password="secret123")
|
|
cls.second_user = User.objects.create_user(
|
|
mobile="09122222222",
|
|
password="secret123",
|
|
)
|
|
|
|
def setUp(self):
|
|
self.fake_redis = FakeRedis()
|
|
self.original_redis_client = services.redis_client
|
|
services.redis_client = self.fake_redis
|
|
|
|
def tearDown(self):
|
|
services.redis_client = self.original_redis_client
|
|
|
|
@staticmethod
|
|
def _read_sse_chunks(response, count):
|
|
iterator = iter(response.streaming_content)
|
|
chunks = []
|
|
for _ in range(count):
|
|
chunk = next(iterator)
|
|
if isinstance(chunk, bytes):
|
|
chunk = chunk.decode("utf-8")
|
|
chunks.append(chunk)
|
|
response.close()
|
|
return chunks
|
|
|
|
@staticmethod
|
|
def _parse_sse_data(chunk):
|
|
for line in chunk.splitlines():
|
|
if line.startswith("data: "):
|
|
return json.loads(line.removeprefix("data: "))
|
|
raise AssertionError("SSE payload did not include data")
|
|
|
|
def test_stream_token_endpoint_returns_short_lived_token(self):
|
|
self.client.force_authenticate(user=self.user)
|
|
|
|
response = self.client.post("/api/notifications/stream-token/")
|
|
|
|
self.assertEqual(response.status_code, 200)
|
|
self.assertTrue(response.data["token"])
|
|
self.assertGreater(response.data["expires_in"], 0)
|
|
|
|
def test_stream_endpoint_rejects_missing_and_expired_token(self):
|
|
missing = self.client.get("/api/notifications/stream/")
|
|
self.assertEqual(missing.status_code, 401)
|
|
|
|
with override_settings(NOTIFICATION_STREAM_TOKEN_LIFETIME_SECONDS=1):
|
|
token = views._issue_stream_token_for_user(str(self.user.id))
|
|
time.sleep(1.1)
|
|
expired = self.client.get(f"/api/notifications/stream/?token={token}")
|
|
|
|
self.assertEqual(expired.status_code, 401)
|
|
|
|
def test_stream_endpoint_sends_only_current_users_notifications(self):
|
|
RedisNotificationStore.add(str(self.user.id), {"title": "For current user"})
|
|
RedisNotificationStore.add(str(self.second_user.id), {"title": "For another user"})
|
|
pubsub = FakePubSub()
|
|
|
|
with patch.object(
|
|
RedisNotificationStore,
|
|
"get_pubsub",
|
|
classmethod(lambda cls: pubsub),
|
|
):
|
|
token = views._issue_stream_token_for_user(str(self.user.id))
|
|
response = self.client.get(
|
|
f"/api/notifications/stream/?token={token}",
|
|
HTTP_ACCEPT="text/event-stream",
|
|
)
|
|
retry_line, connected_chunk = self._read_sse_chunks(response, 2)
|
|
|
|
self.assertEqual(response.status_code, 200)
|
|
self.assertTrue(retry_line.startswith("retry:"))
|
|
connected = self._parse_sse_data(connected_chunk)
|
|
self.assertEqual(connected["unread_count"], 1)
|
|
self.assertEqual(
|
|
[item["title"] for item in connected["notifications"]],
|
|
["For current user"],
|
|
)
|
|
|
|
def test_stream_endpoint_emits_heartbeat(self):
|
|
pubsub = FakePubSub()
|
|
first_now = timezone.now()
|
|
tick_values = iter(
|
|
[
|
|
first_now,
|
|
first_now,
|
|
first_now + timedelta(seconds=2),
|
|
first_now + timedelta(seconds=2),
|
|
first_now + timedelta(seconds=2),
|
|
first_now + timedelta(seconds=2),
|
|
]
|
|
)
|
|
last_tick = first_now + timedelta(seconds=2)
|
|
|
|
def fake_now():
|
|
return next(tick_values, last_tick)
|
|
|
|
with override_settings(NOTIFICATION_SSE_HEARTBEAT_SECONDS=1):
|
|
with patch.object(
|
|
RedisNotificationStore,
|
|
"get_pubsub",
|
|
classmethod(lambda cls: pubsub),
|
|
):
|
|
with patch.object(views.timezone, "now", side_effect=fake_now):
|
|
view = views.NotificationStreamView()
|
|
stream = view._build_stream(str(self.user.id))
|
|
chunks = [next(stream) for _ in range(4)]
|
|
stream.close()
|
|
|
|
self.assertIn("event: ping", chunks[3])
|
|
|
|
def test_notification_list_and_seen_endpoints_work(self):
|
|
notification = RedisNotificationStore.add(
|
|
str(self.user.id),
|
|
{"title": "Deploy succeeded", "type": "deploy"},
|
|
)
|
|
self.client.force_authenticate(user=self.user)
|
|
|
|
list_response = self.client.get("/api/notifications/list/?type=deploy")
|
|
self.assertEqual(list_response.status_code, 200)
|
|
self.assertEqual(list_response.data["count"], 1)
|
|
self.assertEqual(list_response.data["unread_count"], 1)
|
|
self.assertEqual(
|
|
list_response.data["notifications"][0]["title"],
|
|
"Deploy succeeded",
|
|
)
|
|
|
|
seen_response = self.client.post(
|
|
"/api/notifications/seen/",
|
|
{"id": notification["id"]},
|
|
format="json",
|
|
)
|
|
self.assertEqual(seen_response.status_code, 200)
|
|
self.assertTrue(seen_response.data["marked_read"])
|
|
self.assertTrue(seen_response.data["notification"]["is_seen"])
|
|
|
|
def test_notification_delete_endpoint_removes_notification(self):
|
|
notification = RedisNotificationStore.add(
|
|
str(self.user.id),
|
|
{"title": "Delete me", "type": "deploy"},
|
|
)
|
|
self.client.force_authenticate(user=self.user)
|
|
|
|
response = self.client.delete(f"/api/notifications/{notification['id']}/")
|
|
|
|
self.assertEqual(response.status_code, 200)
|
|
self.assertTrue(response.data["deleted"])
|
|
self.assertEqual(response.data["notification_id"], notification["id"])
|
|
self.assertIsNone(RedisNotificationStore.get(str(self.user.id), notification["id"]))
|