Skip to content

feat: DIA-2201: Add endpoint to rotate personal access token #7435

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 12 commits into from
Apr 29, 2025
Merged
6 changes: 6 additions & 0 deletions label_studio/core/all_urls.json
Original file line number Diff line number Diff line change
Expand Up @@ -1936,5 +1936,11 @@
"module": "jwt_auth.views.LSTokenBlacklistView",
"name": "jwt_auth:token_blacklist",
"decorators": ""
},
{
"url": "/api/token/rotate/",
"module": "jwt_auth.views.LSAPITokenRotateView",
"name": "jwt_auth:token_rotate",
"decorators": ""
}
]
18 changes: 18 additions & 0 deletions label_studio/jwt_auth/serializers.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from jwt_auth.models import JWTSettings, LSAPIToken, TruncatedLSAPIToken
from rest_framework import serializers
from rest_framework_simplejwt.serializers import TokenBlacklistSerializer
from rest_framework_simplejwt.tokens import RefreshToken


# Recommended implementation from JWT to support drf-yasg:
Expand Down Expand Up @@ -34,3 +35,20 @@ def get_token(self, obj):

class LSAPITokenBlacklistSerializer(TokenBlacklistSerializer):
token_class = TruncatedLSAPIToken


class LSAPITokenRotateSerializer(serializers.Serializer):
refresh = serializers.CharField()

def validate(self, data):
refresh = data.get('refresh')
try:
token = RefreshToken(refresh)
except Exception:
raise serializers.ValidationError('Invalid refresh token')
data['refresh'] = token
return data


class TokenRotateResponseSerializer(serializers.Serializer):
refresh = serializers.CharField()
1 change: 1 addition & 0 deletions label_studio/jwt_auth/urls.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,4 +9,5 @@
path('api/token/', views.LSAPITokenView.as_view(), name='token_manage'),
path('api/token/refresh/', views.DecoratedTokenRefreshView.as_view(), name='token_refresh'),
path('api/token/blacklist/', views.LSTokenBlacklistView.as_view(), name='token_blacklist'),
path('api/token/rotate/', views.LSAPITokenRotateView.as_view(), name='token_rotate'),
]
41 changes: 41 additions & 0 deletions label_studio/jwt_auth/views.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,18 +4,22 @@
from core.permissions import all_permissions
from django.utils.decorators import method_decorator
from drf_yasg.utils import swagger_auto_schema
from jwt_auth.auth import TokenAuthenticationPhaseout
from jwt_auth.models import JWTSettings, LSAPIToken, TruncatedLSAPIToken
from jwt_auth.serializers import (
JWTSettingsSerializer,
LSAPITokenCreateSerializer,
LSAPITokenListSerializer,
TokenRefreshResponseSerializer,
TokenRotateResponseSerializer,
)
from rest_framework import generics, status
from rest_framework.authentication import SessionAuthentication
from rest_framework.exceptions import APIException
from rest_framework.generics import CreateAPIView
from rest_framework.permissions import IsAuthenticated
from rest_framework.response import Response
from rest_framework_simplejwt.authentication import JWTAuthentication
from rest_framework_simplejwt.exceptions import TokenBackendError, TokenError
from rest_framework_simplejwt.token_blacklist.models import BlacklistedToken, OutstandingToken
from rest_framework_simplejwt.views import TokenRefreshView, TokenViewBase
Expand Down Expand Up @@ -180,3 +184,40 @@
return Response({'detail': 'Token is invalid or already blacklisted.'}, status=status.HTTP_404_NOT_FOUND)

return Response(status=status.HTTP_204_NO_CONTENT)


class LSAPITokenRotateView(TokenViewBase):
# Have to explicitly set authentication_classes here, due to how auth works in our middleware, request.user is not set
# properly before executing the view.
authentication_classes = [JWTAuthentication, TokenAuthenticationPhaseout, SessionAuthentication]
permission_classes = [IsAuthenticated]
_serializer_class = 'jwt_auth.serializers.LSAPITokenRotateSerializer'

@swagger_auto_schema(
tags=['JWT'],
operation_summary='Rotate JWT refresh token',
operation_description='Creates a new JWT refresh token and blacklists the current one.',
responses={
status.HTTP_200_OK: TokenRotateResponseSerializer,
status.HTTP_400_BAD_REQUEST: 'Invalid token or token already blacklisted',
},
)
def post(self, request, *args, **kwargs):
serializer = self.get_serializer(data=request.data)
serializer.is_valid(raise_exception=True)

# Ensure the user is authenticated
if not request.user or not request.user.is_authenticated:
return Response({'detail': 'Authentication credentials were not provided or are invalid.'}, status=401)

Check warning on line 211 in label_studio/jwt_auth/views.py

View check run for this annotation

Codecov / codecov/patch

label_studio/jwt_auth/views.py#L211

Added line #L211 was not covered by tests

current_token = serializer.validated_data['refresh']

# Blacklist the current token
try:
current_token.blacklist()
except TokenError:
return Response({'detail': 'Token is invalid or already blacklisted.'}, status=status.HTTP_400_BAD_REQUEST)

Check warning on line 219 in label_studio/jwt_auth/views.py

View check run for this annotation

Codecov / codecov/patch

label_studio/jwt_auth/views.py#L218-L219

Added lines #L218 - L219 were not covered by tests

# Create a new token for the user
new_token = LSAPIToken.for_user(request.user)
return Response({'refresh': new_token.get_full_jwt()}, status=status.HTTP_200_OK)
44 changes: 44 additions & 0 deletions label_studio/tests/jwt_auth/test_views.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,3 +91,47 @@ def test_create_token_after_blacklisting_previous():
response = client.post('/api/token/')
assert response.status_code == status.HTTP_201_CREATED
assert 'token' in response.data


@mock_feature_flag(flag_name='fflag__feature_develop__prompts__dia_1829_jwt_token_auth', value=True)
@pytest.mark.django_db
def test_rotate_token_success():
user = create_user_with_token_settings(api_tokens_enabled=True, legacy_api_tokens_enabled=False)
client = APIClient()
refresh = LSAPIToken()
client.credentials(HTTP_AUTHORIZATION=f'Bearer {refresh.access_token}')
client.force_authenticate(user)

# 1. Create first token
response = client.post('/api/token/')
assert response.status_code == status.HTTP_201_CREATED

# 2. Rotate the token
token = response.data['token']
response2 = client.post('/api/token/rotate/', data={'refresh': token}, format='json')
assert response2.status_code == status.HTTP_200_OK
assert 'refresh' in response2.data

# 3. The old refresh token should now be invalid
response3 = client.post('/api/token/rotate/', data={'refresh': token}, format='json')
assert response3.status_code == status.HTTP_400_BAD_REQUEST
assert 'detail' in response3.data or 'non_field_errors' in response3.data

# 4. The new refresh token should work for another rotation
new_token = response2.data['refresh']
response4 = client.post('/api/token/rotate/', data={'refresh': new_token}, format='json')
assert response4.status_code == status.HTTP_200_OK
assert 'refresh' in response4.data


@mock_feature_flag(flag_name='fflag__feature_develop__prompts__dia_1829_jwt_token_auth', value=True)
@pytest.mark.django_db
def test_rotate_token_requires_authentication():
user = create_user_with_token_settings(api_tokens_enabled=True, legacy_api_tokens_enabled=False)
refresh = LSAPIToken.for_user(user)
refresh_token = refresh.get_full_jwt()

client = APIClient()
# No credentials set
response = client.post('/api/token/rotate/', data={'refresh': refresh_token}, format='json')
assert response.status_code in (status.HTTP_401_UNAUTHORIZED, status.HTTP_403_FORBIDDEN)
32 changes: 3 additions & 29 deletions poetry.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -210,7 +210,7 @@ djangorestframework-simplejwt = {extras = ["crypto"], version = "^5.4.0"}
tldextract = ">=5.1.3"

# Humansignal repo dependencies
label-studio-sdk = {url = "https://github.com/HumanSignal/label-studio-sdk/archive/1a8ad39a63f2771db112fb488f131daf5052bad1.zip"}
label-studio-sdk = {url = "https://github.com/HumanSignal/label-studio-sdk/archive/1ef88e2f8afe50a738fc1b03385cb1f3d6b2dd9f.zip"}

[tool.poetry.group.test.dependencies]
pytest = "7.2.2"
Expand Down
Loading