From 054bb5a5827cd125584f78c24e8d009bc02e6b6e Mon Sep 17 00:00:00 2001 From: Amirhossein Khalili Date: Thu, 30 Apr 2026 16:13:12 +0330 Subject: [PATCH] feat(cache): add targeted server-side response caching --- apps/reports/api/views.py | 40 +++++- apps/reports/tests/test_views.py | 28 ++++ apps/workspaces/api/views.py | 44 ++++++- apps/workspaces/signals.py | 96 ++++++++++++-- apps/workspaces/tests/test_api_permissions.py | 45 +++++++ apps/workspaces/tests/test_rates.py | 70 ++++++++++ core/services/__init__.py | 1 + core/services/cache.py | 124 ++++++++++++++++++ 8 files changed, 432 insertions(+), 16 deletions(-) create mode 100644 core/services/__init__.py create mode 100644 core/services/cache.py diff --git a/apps/reports/api/views.py b/apps/reports/api/views.py index 37cc501..40b4824 100644 --- a/apps/reports/api/views.py +++ b/apps/reports/api/views.py @@ -23,6 +23,10 @@ from apps.reports.services import ( load_report_filters, ) from apps.reports.tasks import generate_report_export_task +from core.services.cache import CACHE_NAMESPACE_REPORTS, get_or_set_cache_payload + + +REPORT_CACHE_TTL_SECONDS = 90 class ReportChartView(APIView): @@ -30,7 +34,17 @@ class ReportChartView(APIView): @extend_schema(responses=dict) def get(self, request): - return Response(build_chart_report(request.user, request.query_params)) + workspace_id = request.query_params.get("workspace") + payload = get_or_set_cache_payload( + CACHE_NAMESPACE_REPORTS, + ttl_seconds=REPORT_CACHE_TTL_SECONDS, + builder=lambda: build_chart_report(request.user, request.query_params), + resource="chart", + user_id=request.user.id, + workspace_id=workspace_id, + params=request.query_params, + ) + return Response(payload) class ReportTableView(APIView): @@ -38,7 +52,17 @@ class ReportTableView(APIView): @extend_schema(responses=dict) def get(self, request): - return Response(build_table_report(request.user, request.query_params)) + workspace_id = request.query_params.get("workspace") + payload = get_or_set_cache_payload( + CACHE_NAMESPACE_REPORTS, + ttl_seconds=REPORT_CACHE_TTL_SECONDS, + builder=lambda: build_table_report(request.user, request.query_params), + resource="table", + user_id=request.user.id, + workspace_id=workspace_id, + params=request.query_params, + ) + return Response(payload) class ReportDayDetailsView(APIView): @@ -46,7 +70,17 @@ class ReportDayDetailsView(APIView): @extend_schema(responses=dict) def get(self, request): - return Response(build_day_details_report(request.user, request.query_params)) + workspace_id = request.query_params.get("workspace") + payload = get_or_set_cache_payload( + CACHE_NAMESPACE_REPORTS, + ttl_seconds=REPORT_CACHE_TTL_SECONDS, + builder=lambda: build_day_details_report(request.user, request.query_params), + resource="day-details", + user_id=request.user.id, + workspace_id=workspace_id, + params=request.query_params, + ) + return Response(payload) class ReportExportJobViewSet( diff --git a/apps/reports/tests/test_views.py b/apps/reports/tests/test_views.py index e9f7a05..acfac81 100644 --- a/apps/reports/tests/test_views.py +++ b/apps/reports/tests/test_views.py @@ -2,6 +2,7 @@ from datetime import date, timedelta from decimal import Decimal from unittest.mock import patch +from django.core.cache import cache from rest_framework.test import APITestCase from apps.clients.models import Client @@ -82,6 +83,9 @@ class ReportViewTests(APITestCase): ) entry_member.tags.add(cls.tag) + def setUp(self): + cache.clear() + def test_member_only_sees_own_chart_report(self): self.client.force_authenticate(user=self.member) @@ -208,3 +212,27 @@ class ReportViewTests(APITestCase): self.assertEqual(response.status_code, 200) self.assertEqual(response.data["summary"]["total_duration"], "02:00:00") self.assertEqual(response.data["scope"]["from_date"], "2026-04-21") + + def test_table_report_cache_stays_until_time_entry_invalidation(self): + self.client.force_authenticate(user=self.owner) + url = "/api/reports/table/" + params = {"workspace": str(self.workspace.id), "period": "this_month"} + + first_response = self.client.get(url, params) + self.assertEqual(first_response.status_code, 200) + self.assertEqual(first_response.data["summary"]["total_duration"], "03:00:00") + + member_entry = TimeEntry.objects.get(description="Member work") + TimeEntry.objects.filter(id=member_entry.id).update(duration=timedelta(hours=5)) + + cached_response = self.client.get(url, params) + self.assertEqual(cached_response.status_code, 200) + self.assertEqual(cached_response.data["summary"]["total_duration"], "03:00:00") + + member_entry.refresh_from_db() + member_entry.description = "Member work updated" + member_entry.save(update_fields=["description"]) + + fresh_response = self.client.get(url, params) + self.assertEqual(fresh_response.status_code, 200) + self.assertEqual(fresh_response.data["summary"]["total_duration"], "07:00:00") diff --git a/apps/workspaces/api/views.py b/apps/workspaces/api/views.py index bad2194..ab055c0 100644 --- a/apps/workspaces/api/views.py +++ b/apps/workspaces/api/views.py @@ -42,6 +42,17 @@ from apps.workspaces.services import ( update_workspace_user_rate, ) from core.paginations.limit_offset import CustomLimitOffsetPagination +from core.services.cache import ( + CACHE_NAMESPACE_PRICE_UNITS, + CACHE_NAMESPACE_WORKSPACE_MEMBERSHIPS, + CACHE_NAMESPACE_WORKSPACE_RATES, + get_namespace_version, + get_or_set_cache_payload, +) + + +REFERENCE_CACHE_TTL_SECONDS = 60 * 5 +PRICE_UNITS_CACHE_TTL_SECONDS = 60 * 60 class WorkspaceViewSet(ModelViewSet): @@ -129,7 +140,15 @@ class WorkspaceMembershipViewSet(ModelViewSet): status=status.HTTP_403_FORBIDDEN, ) - return super().list(request, *args, **kwargs) + payload = get_or_set_cache_payload( + CACHE_NAMESPACE_WORKSPACE_MEMBERSHIPS, + ttl_seconds=REFERENCE_CACHE_TTL_SECONDS, + builder=lambda: super(WorkspaceMembershipViewSet, self).list(request, *args, **kwargs).data, + user_id=request.user.id, + workspace_id=workspace_id, + params=request.query_params, + ) + return Response(payload) def create(self, request, *args, **kwargs): """ @@ -271,6 +290,16 @@ class PriceUnitViewSet(ModelViewSet): def get_queryset(self): return PriceUnit.objects.filter(is_deleted=False) + def list(self, request, *args, **kwargs): + payload = get_or_set_cache_payload( + CACHE_NAMESPACE_PRICE_UNITS, + ttl_seconds=PRICE_UNITS_CACHE_TTL_SECONDS, + builder=lambda: super(PriceUnitViewSet, self).list(request, *args, **kwargs).data, + user_id=request.user.id, + params=request.query_params, + ) + return Response(payload) + class WorkspaceUserRateViewSet(ModelViewSet): serializer_class = WorkspaceUserRateSerializer @@ -310,7 +339,18 @@ class WorkspaceUserRateViewSet(ModelViewSet): ) workspace = get_object_or_404(Workspace, id=workspace_id, is_deleted=False) self._ensure_manage_access(request.user, workspace) - return super().list(request, *args, **kwargs) + payload = get_or_set_cache_payload( + CACHE_NAMESPACE_WORKSPACE_RATES, + ttl_seconds=REFERENCE_CACHE_TTL_SECONDS, + builder=lambda: super(WorkspaceUserRateViewSet, self).list(request, *args, **kwargs).data, + user_id=request.user.id, + workspace_id=workspace_id, + params=request.query_params, + extra_versions={ + CACHE_NAMESPACE_PRICE_UNITS: get_namespace_version(CACHE_NAMESPACE_PRICE_UNITS), + }, + ) + return Response(payload) def create(self, request, *args, **kwargs): serializer = self.get_serializer(data=request.data) diff --git a/apps/workspaces/signals.py b/apps/workspaces/signals.py index 207c085..17a904b 100644 --- a/apps/workspaces/signals.py +++ b/apps/workspaces/signals.py @@ -1,14 +1,88 @@ -from django.db.models.signals import post_save -from django.dispatch import receiver - -from apps.workspaces.models import Workspace, WorkspaceMembership +from django.db.models.signals import m2m_changed, post_delete, post_save +from django.dispatch import receiver + +from apps.clients.models import Client +from apps.projects.models import Project, ProjectRate, ProjectUserRate +from apps.tags.models import Tag +from apps.time_entries.models import TimeEntry +from apps.workspaces.models import Workspace, WorkspaceMembership +from apps.workspaces.models import PriceUnit, WorkspaceUserRate +from core.services.cache import ( + CACHE_NAMESPACE_PRICE_UNITS, + CACHE_NAMESPACE_REPORTS, + CACHE_NAMESPACE_WORKSPACE_MEMBERSHIPS, + CACHE_NAMESPACE_WORKSPACE_RATES, + bump_namespace_version, +) @receiver(post_save, sender=Workspace) -def create_owner_membership(sender, instance, created, **kwargs): - if created: - WorkspaceMembership.objects.create( - workspace=instance, - user=instance.owner, - role=WorkspaceMembership.Role.OWNER, - ) +def create_owner_membership(sender, instance, created, **kwargs): + if created: + WorkspaceMembership.objects.create( + workspace=instance, + user=instance.owner, + role=WorkspaceMembership.Role.OWNER, + ) + + +def _bump_workspace_reports(instance): + workspace_id = getattr(instance, "workspace_id", None) + if not workspace_id and hasattr(instance, "project"): + workspace_id = getattr(instance.project, "workspace_id", None) + if workspace_id: + bump_namespace_version(CACHE_NAMESPACE_REPORTS, str(workspace_id)) + + +def _bump_workspace_memberships(instance): + workspace_id = getattr(instance, "workspace_id", None) + if workspace_id: + bump_namespace_version(CACHE_NAMESPACE_WORKSPACE_MEMBERSHIPS, str(workspace_id)) + + +def _bump_workspace_rates(instance): + workspace_id = getattr(instance, "workspace_id", None) + if workspace_id: + bump_namespace_version(CACHE_NAMESPACE_WORKSPACE_RATES, str(workspace_id)) + + +@receiver(post_save, sender=TimeEntry) +@receiver(post_delete, sender=TimeEntry) +@receiver(post_save, sender=Project) +@receiver(post_delete, sender=Project) +@receiver(post_save, sender=Client) +@receiver(post_delete, sender=Client) +@receiver(post_save, sender=Tag) +@receiver(post_delete, sender=Tag) +@receiver(post_save, sender=ProjectRate) +@receiver(post_delete, sender=ProjectRate) +@receiver(post_save, sender=ProjectUserRate) +@receiver(post_delete, sender=ProjectUserRate) +def invalidate_workspace_report_cache(sender, instance, **kwargs): + _bump_workspace_reports(instance) + + +@receiver(m2m_changed, sender=TimeEntry.tags.through) +def invalidate_workspace_report_cache_for_tags(sender, instance, action, **kwargs): + if action in {"post_add", "post_remove", "post_clear"}: + _bump_workspace_reports(instance) + + +@receiver(post_save, sender=WorkspaceMembership) +@receiver(post_delete, sender=WorkspaceMembership) +def invalidate_workspace_membership_caches(sender, instance, **kwargs): + _bump_workspace_memberships(instance) + _bump_workspace_reports(instance) + + +@receiver(post_save, sender=WorkspaceUserRate) +@receiver(post_delete, sender=WorkspaceUserRate) +def invalidate_workspace_rate_caches(sender, instance, **kwargs): + _bump_workspace_rates(instance) + _bump_workspace_reports(instance) + + +@receiver(post_save, sender=PriceUnit) +@receiver(post_delete, sender=PriceUnit) +def invalidate_price_unit_cache(sender, instance, **kwargs): + bump_namespace_version(CACHE_NAMESPACE_PRICE_UNITS) diff --git a/apps/workspaces/tests/test_api_permissions.py b/apps/workspaces/tests/test_api_permissions.py index a3ed6ec..74367b9 100644 --- a/apps/workspaces/tests/test_api_permissions.py +++ b/apps/workspaces/tests/test_api_permissions.py @@ -1,6 +1,8 @@ from types import SimpleNamespace +from django.core.cache import cache from django.test import TestCase +from rest_framework.test import APITestCase from apps.users.models import User from apps.workspaces.api.permissions import ( @@ -144,3 +146,46 @@ class WorkspacePermissionTests(TestCase): object(), ) ) + + +class WorkspaceMembershipCacheTests(APITestCase): + @classmethod + def setUpTestData(cls): + cls.owner = User.objects.create_user(mobile="09127770031", password="secret123") + cls.member = User.objects.create_user(mobile="09127770032", password="secret123") + cls.workspace = Workspace.objects.create(name="Membership Cache", owner=cls.owner) + cls.membership = WorkspaceMembership.objects.create( + workspace=cls.workspace, + user=cls.member, + role=WorkspaceMembership.Role.MEMBER, + is_active=True, + ) + + def setUp(self): + cache.clear() + self.client.force_authenticate(user=self.owner) + + def test_membership_list_cache_invalidates_after_membership_save(self): + params = {"workspace": str(self.workspace.id)} + + first_response = self.client.get("/api/workspace-memberships/", params) + self.assertEqual(first_response.status_code, 200) + target = next(item for item in first_response.data["items"] if item["id"] == str(self.membership.id)) + self.assertEqual(target["role"], WorkspaceMembership.Role.MEMBER) + + WorkspaceMembership.objects.filter(id=self.membership.id).update(role=WorkspaceMembership.Role.GUEST) + + cached_response = self.client.get("/api/workspace-memberships/", params) + self.assertEqual(cached_response.status_code, 200) + target = next(item for item in cached_response.data["items"] if item["id"] == str(self.membership.id)) + self.assertEqual(target["role"], WorkspaceMembership.Role.MEMBER) + + self.membership.refresh_from_db() + self.membership.is_active = False + self.membership.save(update_fields=["is_active"]) + + fresh_response = self.client.get("/api/workspace-memberships/", params) + self.assertEqual(fresh_response.status_code, 200) + target = next(item for item in fresh_response.data["items"] if item["id"] == str(self.membership.id)) + self.assertEqual(target["role"], WorkspaceMembership.Role.GUEST) + self.assertFalse(target["is_active"]) diff --git a/apps/workspaces/tests/test_rates.py b/apps/workspaces/tests/test_rates.py index d17c033..38f7021 100644 --- a/apps/workspaces/tests/test_rates.py +++ b/apps/workspaces/tests/test_rates.py @@ -1,5 +1,6 @@ from decimal import Decimal +from django.core.cache import cache from django.test import TestCase from rest_framework.test import APITestCase @@ -53,6 +54,9 @@ class WorkspaceRateTests(APITestCase): symbol="EUR", ) + def setUp(self): + cache.clear() + def test_resolve_rate_uses_workspace_user_rate(self): WorkspaceUserRate.objects.create( workspace=self.workspace, @@ -122,6 +126,72 @@ class WorkspaceRateTests(APITestCase): self.assertEqual(response.status_code, 403) + def test_workspace_user_rates_cache_invalidates_after_rate_save(self): + rate = WorkspaceUserRate.objects.create( + workspace=self.workspace, + user=self.member, + hourly_rate=Decimal("30.00"), + currency="USD", + effective_from=self.workspace.created_at, + is_active=True, + ) + self.client.force_authenticate(user=self.admin) + + first_response = self.client.get( + "/api/workspace-user-rates/", + {"workspace": str(self.workspace.id)}, + ) + self.assertEqual(first_response.status_code, 200) + self.assertEqual(first_response.data["items"][0]["hourly_rate"], "30.00") + + WorkspaceUserRate.objects.filter(id=rate.id).update(hourly_rate=Decimal("45.00")) + + cached_response = self.client.get( + "/api/workspace-user-rates/", + {"workspace": str(self.workspace.id)}, + ) + self.assertEqual(cached_response.status_code, 200) + self.assertEqual(cached_response.data["items"][0]["hourly_rate"], "30.00") + + rate.refresh_from_db() + rate.currency = "EUR" + rate.save(update_fields=["currency"]) + + fresh_response = self.client.get( + "/api/workspace-user-rates/", + {"workspace": str(self.workspace.id)}, + ) + self.assertEqual(fresh_response.status_code, 200) + self.assertEqual(fresh_response.data["items"][0]["hourly_rate"], "45.00") + self.assertEqual(fresh_response.data["items"][0]["currency"], "EUR") + + def test_price_unit_cache_invalidates_after_price_unit_create(self): + self.client.force_authenticate(user=self.owner) + + first_response = self.client.get("/api/price-units/") + self.assertEqual(first_response.status_code, 200) + self.assertEqual(first_response.data[0]["name"], "Euro") + self.assertEqual(len(first_response.data), 2) + + PriceUnit.objects.filter(code="EUR").update(name="Updated Euro") + + cached_response = self.client.get("/api/price-units/") + self.assertEqual(cached_response.status_code, 200) + self.assertEqual(cached_response.data[0]["name"], "Euro") + + PriceUnit.objects.create( + code="GBP", + name="British Pound", + local_name="Pound", + symbol="£", + ) + + fresh_response = self.client.get("/api/price-units/") + self.assertEqual(fresh_response.status_code, 200) + self.assertEqual(len(fresh_response.data), 3) + euro_row = next(item for item in fresh_response.data if item["code"] == "EUR") + self.assertEqual(euro_row["name"], "Updated Euro") + class WorkspaceRateServiceTests(TestCase): @classmethod diff --git a/core/services/__init__.py b/core/services/__init__.py new file mode 100644 index 0000000..8b13789 --- /dev/null +++ b/core/services/__init__.py @@ -0,0 +1 @@ + diff --git a/core/services/cache.py b/core/services/cache.py new file mode 100644 index 0000000..bbc09c2 --- /dev/null +++ b/core/services/cache.py @@ -0,0 +1,124 @@ +import hashlib +import json +from collections.abc import Callable, Mapping +from typing import Any + +from django.core.cache import cache + + +CACHE_NAMESPACE_REPORTS = "reports" +CACHE_NAMESPACE_WORKSPACE_MEMBERSHIPS = "workspace-memberships" +CACHE_NAMESPACE_WORKSPACE_RATES = "workspace-rates" +CACHE_NAMESPACE_PRICE_UNITS = "price-units" + +_CACHE_VERSION_TTL_SECONDS = 60 * 60 * 24 * 30 + + +def _stringify_value(value: Any) -> str: + if value is None: + return "" + if isinstance(value, bool): + return "true" if value else "false" + return str(value) + + +def normalize_query_params(params: Any) -> dict[str, list[str]]: + if hasattr(params, "lists"): + raw_items = params.lists() + elif isinstance(params, Mapping): + raw_items = params.items() + else: + raw_items = [] + + normalized: dict[str, list[str]] = {} + for key, value in raw_items: + if isinstance(value, (list, tuple)): + values = [_stringify_value(item) for item in value if item is not None] + else: + values = [_stringify_value(value)] + normalized[str(key)] = sorted(values) + + return dict(sorted(normalized.items())) + + +def get_namespace_version(namespace: str, workspace_id: str | None = None) -> int: + scope = workspace_id or "global" + cache_key = f"cache-version:{namespace}:{scope}" + version = cache.get(cache_key) + if version is None: + cache.set(cache_key, 1, timeout=_CACHE_VERSION_TTL_SECONDS) + return 1 + return int(version) + + +def bump_namespace_version(namespace: str, workspace_id: str | None = None) -> int: + scope = workspace_id or "global" + cache_key = f"cache-version:{namespace}:{scope}" + version = cache.get(cache_key) + if version is None: + cache.set(cache_key, 2, timeout=_CACHE_VERSION_TTL_SECONDS) + return 2 + try: + return int(cache.incr(cache_key)) + except ValueError: + next_version = int(version) + 1 + cache.set(cache_key, next_version, timeout=_CACHE_VERSION_TTL_SECONDS) + return next_version + + +def build_cache_key( + namespace: str, + *, + resource: str | None = None, + user_id: Any = None, + workspace_id: Any = None, + params: Any = None, + extra_versions: Mapping[str, int] | None = None, +) -> str: + normalized_params = normalize_query_params(params or {}) + params_json = json.dumps(normalized_params, sort_keys=True, separators=(",", ":")) + params_hash = hashlib.md5(params_json.encode("utf-8")).hexdigest() + namespace_version = get_namespace_version(namespace, str(workspace_id) if workspace_id else None) + + segments = [ + namespace, + f"resource:{resource or 'default'}", + f"v{namespace_version}", + f"user:{user_id or 'anon'}", + f"workspace:{workspace_id or 'global'}", + ] + + if extra_versions: + for key, value in sorted(extra_versions.items()): + segments.append(f"{key}:v{value}") + + segments.append(params_hash) + return ":".join(segments) + + +def get_or_set_cache_payload( + namespace: str, + *, + ttl_seconds: int, + builder: Callable[[], Any], + resource: str | None = None, + user_id: Any = None, + workspace_id: Any = None, + params: Any = None, + extra_versions: Mapping[str, int] | None = None, +) -> Any: + cache_key = build_cache_key( + namespace, + resource=resource, + user_id=user_id, + workspace_id=workspace_id, + params=params, + extra_versions=extra_versions, + ) + payload = cache.get(cache_key) + if payload is not None: + return payload + + payload = builder() + cache.set(cache_key, payload, timeout=ttl_seconds) + return payload