diff --git a/apps/users/api/views.py b/apps/users/api/views.py index 6840309..509f715 100644 --- a/apps/users/api/views.py +++ b/apps/users/api/views.py @@ -1,61 +1,59 @@ -from django.http import HttpResponseRedirect from django.contrib.auth import get_user_model +from django.http import HttpResponseRedirect from django_filters.rest_framework import DjangoFilterBackend from drf_spectacular.utils import extend_schema, inline_serializer from rest_framework import serializers, status from rest_framework.filters import OrderingFilter, SearchFilter from rest_framework.generics import ListAPIView, UpdateAPIView +from rest_framework.mixins import DestroyModelMixin, RetrieveModelMixin, UpdateModelMixin from rest_framework.parsers import FormParser, MultiPartParser from rest_framework.permissions import AllowAny, IsAuthenticated from rest_framework.response import Response from rest_framework.views import APIView -from rest_framework_simplejwt.authentication import JWTAuthentication -from rest_framework.mixins import UpdateModelMixin, RetrieveModelMixin, DestroyModelMixin from rest_framework.viewsets import GenericViewSet - - -from core.paginations.limit_offset import CustomLimitOffsetPagination +from rest_framework_simplejwt.authentication import JWTAuthentication from apps.users.api.serializers import ( ChangePasswordSerializer, - LoginOtpSerializer, - LoginSerializer, GoogleOAuthClaimVerifySerializer, GoogleOAuthCompleteSerializer, GoogleOAuthFlowSerializer, + LoginOtpSerializer, + LoginSerializer, + LogoutSerializer, RegisterSerializer, + RegisterWithPasswordSerializer, ResetPasswordSerializer, SendOTPSerializer, + TokenPairSerializer, UserListSerializer, UserProfilePictureSerializer, - LogoutSerializer, - TokenPairSerializer, - RegisterWithPasswordSerializer, UserProfileSerializer, UserSearchSerializer, ) from apps.users.api.throttles import ( + GoogleClaimSendBurstThrottle, + GoogleClaimSendSustainedThrottle, + GoogleClaimVerifyThrottle, OTPLoginThrottle, OTPSendBurstThrottle, OTPSendSustainedThrottle, PasswordLoginThrottle, - GoogleClaimSendBurstThrottle, - GoogleClaimSendSustainedThrottle, - GoogleClaimVerifyThrottle, ) from apps.users.services.auth import ( - register_user_with_password, - register_user_with_otp, - generate_and_send_otp, - login_with_password, - login_with_otp, - reset_password_with_otp, change_password, - logout_user + generate_and_send_otp, + login_with_otp, + login_with_password, + logout_user, + register_user_with_otp, + register_user_with_password, + reset_password_with_otp, ) from apps.users.services.google_oauth import ( build_authenticated_flow_payload, build_google_authorization_url, + build_google_callback_error_redirect_url, build_google_callback_redirect_url, build_pending_google_flow_payload, complete_google_signup, @@ -68,6 +66,7 @@ from apps.users.services.google_oauth import ( sync_user_from_google_profile, verify_google_claim, ) +from core.paginations.limit_offset import CustomLimitOffsetPagination User = get_user_model() @@ -89,7 +88,7 @@ class RegisterWithPasswordView(APIView): status=status.HTTP_400_BAD_REQUEST, ) - tokens = register_user_with_password(mobile, password) + tokens = register_user_with_password(mobile, password) return Response(tokens, status=status.HTTP_201_CREATED) @@ -125,7 +124,7 @@ class SendOTPView(APIView): def post(self, request): serializer = SendOTPSerializer(data=request.data) serializer.is_valid(raise_exception=True) - + payload = generate_and_send_otp( mobile=serializer.validated_data["mobile"], mode=serializer.validated_data["mode"] @@ -142,7 +141,7 @@ class LoginView(APIView): def post(self, request): serializer = LoginSerializer(data=request.data) serializer.is_valid(raise_exception=True) - + tokens = login_with_password( mobile=serializer.validated_data["mobile"], password=serializer.validated_data["password"], @@ -159,7 +158,7 @@ class LoginOTPView(APIView): def post(self, request): serializer = LoginOtpSerializer(data=request.data) serializer.is_valid(raise_exception=True) - + tokens = login_with_otp( mobile=serializer.validated_data["mobile"], code=serializer.validated_data["code"], @@ -182,22 +181,42 @@ class GoogleOAuthCallbackView(APIView): @extend_schema(responses=None) def get(self, request): if request.query_params.get("error"): - raise serializers.ValidationError( - {"detail": request.query_params.get("error_description") or "Google sign-in was cancelled."} + return HttpResponseRedirect( + build_google_callback_error_redirect_url( + code=request.query_params.get("error") or "google_sign_in_cancelled", + detail=( + request.query_params.get("error_description") + or "Google sign-in was cancelled." + ), + ) ) + try: + consume_google_state(request.query_params.get("state")) + profile = exchange_code_for_google_profile(request.query_params.get("code")) + social_account = find_social_account_for_profile(profile) - consume_google_state(request.query_params.get("state")) - profile = exchange_code_for_google_profile(request.query_params.get("code")) - social_account = find_social_account_for_profile(profile) + if social_account: + sync_user_from_google_profile(social_account.user, profile) + flow_payload = build_authenticated_flow_payload(social_account.user) + else: + flow_payload = build_pending_google_flow_payload(profile) - if social_account: - sync_user_from_google_profile(social_account.user, profile) - flow_payload = build_authenticated_flow_payload(social_account.user) - else: - flow_payload = build_pending_google_flow_payload(profile) - - flow = create_google_flow(flow_payload) - return HttpResponseRedirect(build_google_callback_redirect_url(flow)) + flow = create_google_flow(flow_payload) + return HttpResponseRedirect(build_google_callback_redirect_url(flow)) + except serializers.ValidationError as exc: + detail = exc.detail + if isinstance(detail, dict): + message = detail.get("detail", "Google sign-in could not be completed.") + else: + message = detail + if isinstance(message, list): + message = message[0] if message else "Google sign-in could not be completed." + return HttpResponseRedirect( + build_google_callback_error_redirect_url( + code="google_callback_failed", + detail=str(message), + ) + ) class GoogleOAuthFlowView(APIView): @@ -254,12 +273,12 @@ class GoogleOAuthClaimVerifyView(APIView): class ResetPasswordView(APIView): permission_classes = (AllowAny,) serializer_class = ResetPasswordSerializer - + @extend_schema(request=ResetPasswordSerializer) def post(self, request): serializer = ResetPasswordSerializer(data=request.data) serializer.is_valid(raise_exception=True) - + reset_password_with_otp( mobile=serializer.validated_data["mobile"], code=serializer.validated_data["code"], @@ -276,7 +295,7 @@ class ChangePasswordView(APIView): def patch(self, request, *args, **kwargs): serializer = ChangePasswordSerializer(data=request.data, context={"request": request}) serializer.is_valid(raise_exception=True) - + change_password( user=request.user, old_password=serializer.validated_data["old_password"], @@ -385,14 +404,14 @@ class UserSearchAPIView(APIView): mobile = request.query_params.get('mobile') if not mobile: return Response( - {"detail": "Mobile parameter is required."}, + {"detail": "Mobile parameter is required."}, status=status.HTTP_400_BAD_REQUEST ) - + user = User.objects.filter(mobile=mobile).first() if not user: return Response( - {"detail": "User not found."}, + {"detail": "User not found."}, status=status.HTTP_404_NOT_FOUND ) diff --git a/apps/users/services/google_oauth.py b/apps/users/services/google_oauth.py index ed51faa..8265e87 100644 --- a/apps/users/services/google_oauth.py +++ b/apps/users/services/google_oauth.py @@ -367,6 +367,16 @@ def build_google_callback_redirect_url(flow: str) -> str: return f"{get_frontend_google_callback_url()}?flow={flow}" +def build_google_callback_error_redirect_url(*, code: str, detail: str) -> str: + params = urlencode( + { + "error": code, + "error_description": detail, + } + ) + return f"{get_frontend_google_callback_url()}?{params}" + + def find_social_account_for_profile(profile: GoogleProfile) -> UserSocialAccount | None: return ( UserSocialAccount.objects.select_related("user") diff --git a/apps/users/tests/test_api_views.py b/apps/users/tests/test_api_views.py index 0717f5f..0ced766 100644 --- a/apps/users/tests/test_api_views.py +++ b/apps/users/tests/test_api_views.py @@ -7,7 +7,7 @@ from django.core.cache import cache from django.core.management import call_command from django.db import IntegrityError from django.test import override_settings -from rest_framework import status +from rest_framework import serializers, status from rest_framework.test import APIRequestFactory, APITestCase from apps.users.api.views import RegisterWithPasswordView @@ -673,6 +673,41 @@ class GoogleOAuthApiTests(APITestCase): self.assertEqual(flow_response.data["resolution"], "new_account") self.assertIsNone(flow_response.data["mobile_hint"]) + def test_google_callback_redirects_cancellation_back_to_frontend(self): + response = self.client.get( + "/api/users/oauth/google/callback/?error=access_denied&error_description=User%20cancelled", + ) + + self.assertEqual(response.status_code, 302) + self.assertIn("/auth/google/callback?error=access_denied", response["Location"]) + self.assertIn("error_description=User+cancelled", response["Location"]) + + @patch("apps.users.api.views.exchange_code_for_google_profile") + def test_google_callback_redirects_backend_errors_back_to_frontend( + self, + exchange_code_for_google_profile, + ): + exchange_code_for_google_profile.side_effect = serializers.ValidationError( + {"detail": "Google token exchange failed."} + ) + + start_response = self.client.get("/api/users/oauth/google/start/") + state = start_response["Location"].split("state=", 1)[1].split("&", 1)[0] + + response = self.client.get( + f"/api/users/oauth/google/callback/?state={state}&code=google-code", + ) + + self.assertEqual(response.status_code, 302) + self.assertIn( + "/auth/google/callback?error=google_callback_failed", + response["Location"], + ) + self.assertIn( + "error_description=Google+token+exchange+failed.", + response["Location"], + ) + @patch("apps.users.api.views.exchange_code_for_google_profile") def test_google_callback_redirects_with_email_claim_flow_for_matching_email( self,