chore: Add validation for request path and query params. (#41237)

This commit is contained in:
Vincent (Wen Yu) Ge
2025-11-13 14:02:29 -05:00
committed by GitHub
parent 1a3a952123
commit fa8f824b29
4 changed files with 261 additions and 21 deletions

View File

@@ -42,6 +42,7 @@ class PydanticModelMixin:
def validated_request(
request_serializer: type[serializers.Serializer] | None = None,
*,
query_serializer: type[serializers.Serializer] | None = None,
responses: dict[int, OpenApiResponse | None] | None = None,
summary: str | None = None,
description: str | None = None,
@@ -56,31 +57,26 @@ def validated_request(
Usage:
@validated_request(
request_serializer=RequestSerializer,
request_serializer=RequestBodySerializer,
query_serializer=QuerySerializer,
responses={
200: Response(response=SuccessResponseSerializer, ...),
400: Response(response=InvalidRequestResponseSerializer, ...),
},
summary="Do something"
)
def my_action(self, request: ValidatedRequest, **kwargs):
# When request_serializer is provided, request.validated_data is available
request_data = request.validated_data.get("next_stage_id")
if not request_data:
return Response(
ErrorResponseSerializer({"error": "Invalid request"}).data,
status=status.HTTP_400_BAD_REQUEST,
)
return Response(SuccessResponseSerializer(request_data, context=self.get_serializer_context()).data)
Note: Use ValidatedRequest type hint when you need to access request.validated_data.
The decorator will set validated_data on the request when request_serializer is provided.
def my_action(self, request: Request, **kwargs):
# request.validated_data contains validated body data (if request_serializer provided)
# Query params are validated but not mutated
"""
def decorator(view_func: Callable) -> Callable:
parameters = []
if query_serializer is not None:
parameters.append(query_serializer)
@extend_schema(
request=request_serializer,
parameters=parameters if parameters else None,
responses=responses,
summary=summary,
description=description,
@@ -90,19 +86,22 @@ def validated_request(
)
@wraps(view_func)
def wrapper(self, request: Request, *args, **kwargs) -> Response:
if query_serializer is not None:
query_serializer_instance = query_serializer(data=request.query_params)
query_serializer_instance.is_valid(raise_exception=True)
if request_serializer is not None:
serializer = request_serializer(data=request.data)
req_validation_result = serializer.is_valid(raise_exception=strict_request_validation)
if not req_validation_result and settings.DEBUG:
logger.warning(
"Request data does not match declared serializer in @validated_request decorator. Please update the provided API schema to ensure API docs remain up to date",
"Request body does not match declared serializer in @validated_request decorator. Please update the provided API schema to ensure API docs remain up to date",
view_func=view_func.__name__,
serializer_class=request_serializer.__name__,
validation_errors=serializer.errors,
)
# Cast to ValidatedRequest and set validated_data attribute
validated_request = cast(ValidatedRequest, request)
validated_request.validated_data = serializer.validated_data

View File

@@ -55,6 +55,7 @@ class TestValidatedRequestDecorator(APIBaseTest):
view_instance.get_serializer_context = Mock(return_value={})
mock_request = Mock()
mock_request._full_data = {}
mock_request.data = {
"event": "$pageview",
"distinct_id": "user_123",
@@ -66,7 +67,7 @@ class TestValidatedRequestDecorator(APIBaseTest):
assert response.status_code == status.HTTP_200_OK
assert response.data["status"] == "ok"
assert response.data["distinct_id"] == "user_123"
assert mock_request.validated_data["event"] == "$pageview"
assert mock_request.data["event"] == "$pageview"
def test_request_validation_with_missing_required_field(self):
"""Missing required field, should raise validation error"""
@@ -82,6 +83,7 @@ class TestValidatedRequestDecorator(APIBaseTest):
view_instance = Mock()
mock_request = Mock()
mock_request._full_data = {}
mock_request.data = {"event": "$pageview"} # Missing 'distinct_id'
with pytest.raises(Exception) as exc_info:
@@ -108,6 +110,7 @@ class TestValidatedRequestDecorator(APIBaseTest):
view_instance = Mock()
view_instance.get_serializer_context = Mock(return_value={})
mock_request = Mock()
mock_request._full_data = {}
mock_request.data = {"event": "$pageview", "distinct_id": "user_123"}
response = mock_endpoint(view_instance, mock_request)
@@ -134,6 +137,7 @@ class TestValidatedRequestDecorator(APIBaseTest):
view_instance = Mock()
view_instance.get_serializer_context = Mock(return_value={})
mock_request = Mock()
mock_request._full_data = {}
mock_request.data = {"event": "$pageview", "distinct_id": "user_123"}
# Should log a warning and return the response
@@ -172,6 +176,7 @@ class TestValidatedRequestDecorator(APIBaseTest):
view_instance = Mock()
view_instance.get_serializer_context = Mock(return_value={})
mock_request = Mock()
mock_request._full_data = {}
mock_request.data = {"event": "$pageview", "distinct_id": "user_123"}
# Should log a warning and return the response
@@ -207,6 +212,7 @@ class TestValidatedRequestDecorator(APIBaseTest):
view_instance = Mock()
mock_request = Mock()
mock_request._full_data = {}
mock_request.data = {"event": "$pageview", "distinct_id": "user_123"}
response = mock_endpoint(view_instance, mock_request)
@@ -230,6 +236,7 @@ class TestValidatedRequestDecorator(APIBaseTest):
view_instance = Mock()
view_instance.get_serializer_context = Mock(return_value={})
mock_request = Mock()
mock_request._full_data = {}
mock_request.data = {"event": "$pageview", "distinct_id": "user_123"}
# Should log a warning and return the result
@@ -269,6 +276,7 @@ class TestValidatedRequestDecorator(APIBaseTest):
view_instance = Mock()
view_instance.get_serializer_context = Mock(return_value={})
mock_request = Mock()
mock_request._full_data = {}
mock_request.data = {"event": "$pageview", "distinct_id": "user_123"}
# Test with DEBUG=False (production mode)
@@ -310,6 +318,7 @@ class TestValidatedRequestDecorator(APIBaseTest):
view_instance = Mock()
view_instance.get_serializer_context = Mock(return_value={})
mock_request = Mock()
mock_request._full_data = {}
mock_request.data = {"event": "$pageview", "distinct_id": "user_123"}
# Should raise ValidationError regardless of DEBUG setting
@@ -336,6 +345,7 @@ class TestValidatedRequestDecorator(APIBaseTest):
view_instance = Mock()
view_instance.get_serializer_context = Mock(return_value={})
mock_request = Mock()
mock_request._full_data = {}
mock_request.data = {"event": "$pageview", "distinct_id": "user_123"}
# Should raise ValidationError with serializer errors
@@ -369,6 +379,7 @@ class TestValidatedRequestDecorator(APIBaseTest):
view_instance = Mock()
view_instance.get_serializer_context = Mock(return_value={})
mock_request = Mock()
mock_request._full_data = {}
mock_request.data = {"event": "$pageview", "distinct_id": "user_123"}
# Should work normally with valid data
@@ -402,6 +413,7 @@ class TestValidatedRequestDecorator(APIBaseTest):
view_instance = Mock()
view_instance.get_serializer_context = Mock(return_value={})
mock_request = Mock()
mock_request._full_data = {}
mock_request.data = {"event": "$pageview"} # Missing required 'distinct_id'
# Should log a warning but not raise
@@ -414,7 +426,7 @@ class TestValidatedRequestDecorator(APIBaseTest):
mock_logger.warning.assert_called_once()
call_args = mock_logger.warning.call_args
assert (
"Request data does not match declared serializer in @validated_request decorator. Please update the provided API schema to ensure API docs remain up to date"
"Request body does not match declared serializer in @validated_request decorator. Please update the provided API schema to ensure API docs remain up to date"
in call_args[0][0]
)
assert call_args[1]["view_func"] == "mock_endpoint"
@@ -439,7 +451,7 @@ class TestValidatedRequestDecorator(APIBaseTest):
{
"status": "ok",
"event_id": str(uuid.uuid4()),
"distinct_id": request.validated_data.get("distinct_id", "unknown"),
"distinct_id": request.data.get("distinct_id", "unknown"),
},
status=status.HTTP_200_OK,
)
@@ -447,6 +459,7 @@ class TestValidatedRequestDecorator(APIBaseTest):
view_instance = Mock()
view_instance.get_serializer_context = Mock(return_value={})
mock_request = Mock()
mock_request._full_data = {}
mock_request.data = {"event": "$pageview"} # Missing required 'distinct_id'
# Should not log warning when DEBUG=False
@@ -485,6 +498,7 @@ class TestValidatedRequestDecorator(APIBaseTest):
view_instance = Mock()
view_instance.get_serializer_context = Mock(return_value={})
mock_request = Mock()
mock_request._full_data = {}
mock_request.data = {
"event": "$pageview",
"distinct_id": "user_123",
@@ -497,7 +511,7 @@ class TestValidatedRequestDecorator(APIBaseTest):
assert response.status_code == status.HTTP_200_OK
assert response.data["status"] == "ok"
assert response.data["distinct_id"] == "user_123"
assert mock_request.validated_data["event"] == "$pageview"
assert mock_request.data["event"] == "$pageview"
def test_no_body_response_declared_as_none_succeeds(self):
"""When status code is declared as None (no body), response with no body should succeed"""
@@ -512,6 +526,7 @@ class TestValidatedRequestDecorator(APIBaseTest):
view_instance = Mock()
mock_request = Mock()
mock_request._full_data = {}
mock_request.data = {}
response = mock_endpoint(view_instance, mock_request)
@@ -532,6 +547,7 @@ class TestValidatedRequestDecorator(APIBaseTest):
view_instance = Mock()
mock_request = Mock()
mock_request._full_data = {}
mock_request.data = {}
with patch("posthog.api.mixins.settings") as mock_settings:
@@ -563,9 +579,191 @@ class TestValidatedRequestDecorator(APIBaseTest):
view_instance = Mock()
mock_request = Mock()
mock_request._full_data = {}
mock_request.data = {}
with pytest.raises(serializers.ValidationError) as exc_info:
mock_endpoint(view_instance, mock_request)
assert "Response status code 204 is declared with no body, but response contains data" in str(exc_info.value)
def test_query_parameter_validation_with_valid_data(self):
"""Query parameter validation: valid query params should work"""
class QueryParamSerializer(serializers.Serializer):
page = serializers.IntegerField(required=False, default=1)
limit = serializers.IntegerField(required=False, default=10, max_value=100)
@validated_request(
query_serializer=QueryParamSerializer,
responses={
200: OpenApiResponse(response=EventCaptureResponseSerializer),
},
)
def mock_endpoint(view_self, request, **kwargs):
page = request.query_params.get("page", 1)
limit = request.query_params.get("limit", 10)
return Response(
{
"status": "ok",
"event_id": str(uuid.uuid4()),
"distinct_id": "test",
"page": int(page),
"limit": int(limit),
},
status=status.HTTP_200_OK,
)
view_instance = Mock()
view_instance.get_serializer_context = Mock(return_value={})
mock_request = Mock()
mock_request._full_data = {}
mock_request.data = {}
from django.http import QueryDict
mock_get = QueryDict("page=2&limit=20")
mock_request._request = Mock()
mock_request._request.GET = mock_get
mock_request.query_params = mock_get
response = mock_endpoint(view_instance, mock_request)
assert response.status_code == status.HTTP_200_OK
assert response.data["page"] == 2
assert response.data["limit"] == 20
def test_query_parameter_validation_with_invalid_data_raises(self):
"""Query parameter validation: invalid query params should raise exception"""
class QueryParamSerializer(serializers.Serializer):
limit = serializers.IntegerField(required=True, max_value=100)
@validated_request(
query_serializer=QueryParamSerializer,
responses={
200: OpenApiResponse(response=EventCaptureResponseSerializer),
},
)
def mock_endpoint(view_self, request, **kwargs):
return Response(
{
"status": "ok",
"event_id": str(uuid.uuid4()),
"distinct_id": "test",
},
status=status.HTTP_200_OK,
)
view_instance = Mock()
view_instance.get_serializer_context = Mock(return_value={})
mock_request = Mock()
mock_request._full_data = {}
mock_request.data = {}
from django.http import QueryDict
mock_get = QueryDict("limit=200") # Exceeds max_value
mock_request._request = Mock()
mock_request._request.GET = mock_get
mock_request.query_params = mock_get
with pytest.raises(serializers.ValidationError) as exc_info:
mock_endpoint(view_instance, mock_request)
assert "limit" in str(exc_info.value)
def test_post_request_with_query_parameters(self):
"""POST request with query parameters: both should be validated"""
class QueryParamSerializer(serializers.Serializer):
dry_run = serializers.BooleanField(required=False, default=False)
force = serializers.BooleanField(required=False, default=False)
class PostRequestSerializer(serializers.Serializer):
action = serializers.ChoiceField(choices=["create", "update", "delete"])
payload = serializers.DictField(required=False)
@validated_request(
request_serializer=PostRequestSerializer,
query_serializer=QueryParamSerializer,
responses={
200: OpenApiResponse(response=EventCaptureResponseSerializer),
},
)
def mock_endpoint(view_self, request, **kwargs):
action = request.validated_data["action"]
# Query params remain as strings, need to convert
dry_run = request.query_params.get("dry_run", "false").lower() == "true"
force = request.query_params.get("force", "false").lower() == "true"
return Response(
{
"status": "ok",
"event_id": str(uuid.uuid4()),
"distinct_id": "test",
"action": action,
"dry_run": dry_run,
"force": force,
},
status=status.HTTP_200_OK,
)
view_instance = Mock()
view_instance.get_serializer_context = Mock(return_value={})
mock_request = Mock()
mock_request._full_data = {}
mock_request.data = {"action": "create", "payload": {"key": "value"}}
from django.http import QueryDict
mock_get = QueryDict("dry_run=true&force=false")
mock_request._request = Mock()
mock_request._request.GET = mock_get
mock_request.query_params = mock_get
response = mock_endpoint(view_instance, mock_request)
assert response.status_code == status.HTTP_200_OK
assert response.data["action"] == "create"
assert response.data["dry_run"] is True
assert response.data["force"] is False
def test_post_request_with_invalid_query_parameters_raises(self):
"""POST request with invalid query parameters: should raise exception"""
class QueryParamSerializer(serializers.Serializer):
timeout = serializers.IntegerField(required=True, min_value=1, max_value=3600)
class PostRequestSerializer(serializers.Serializer):
action = serializers.CharField()
@validated_request(
request_serializer=PostRequestSerializer,
query_serializer=QueryParamSerializer,
responses={
200: OpenApiResponse(response=EventCaptureResponseSerializer),
},
)
def mock_endpoint(view_self, request, **kwargs):
return Response(
{
"status": "ok",
"event_id": str(uuid.uuid4()),
"distinct_id": "test",
},
status=status.HTTP_200_OK,
)
view_instance = Mock()
view_instance.get_serializer_context = Mock(return_value={})
mock_request = Mock()
mock_request._full_data = {}
mock_request.data = {"action": "test"}
from django.http import QueryDict
mock_get = QueryDict("timeout=5000") # Exceeds max_value
mock_request._request = Mock()
mock_request._request.GET = mock_get
mock_request.query_params = mock_get
with pytest.raises(serializers.ValidationError) as exc_info:
mock_endpoint(view_instance, mock_request)
assert "timeout" in str(exc_info.value)

View File

@@ -18,6 +18,7 @@ from posthog.permissions import APIScopePermission, PostHogFeatureFlagPermission
from .models import Task, TaskRun
from .serializers import (
ErrorResponseSerializer,
TaskListQuerySerializer,
TaskRunAppendLogRequestSerializer,
TaskRunDetailSerializer,
TaskSerializer,
@@ -54,6 +55,17 @@ class TaskViewSet(TeamAndOrgViewSetMixin, viewsets.ModelViewSet):
]
}
@validated_request(
query_serializer=TaskListQuerySerializer,
responses={
200: OpenApiResponse(response=TaskSerializer, description="List of tasks"),
},
summary="List tasks",
description="Get a list of tasks for the current project, with optional filtering by origin product, stage, organization, and repository.",
)
def list(self, request, *args, **kwargs):
return super().list(request, *args, **kwargs)
def safely_get_queryset(self, queryset):
qs = queryset.filter(team=self.team).order_by("position")
@@ -198,6 +210,26 @@ class TaskRunViewSet(TeamAndOrgViewSetMixin, viewsets.ModelViewSet):
http_method_names = ["get", "post", "patch", "head", "options"]
filter_rewrite_rules = {"team_id": "team_id"}
@validated_request(
responses={
200: OpenApiResponse(response=TaskRunDetailSerializer, description="List of task runs"),
},
summary="List task runs",
description="Get a list of runs for a specific task.",
)
def list(self, request, *args, **kwargs):
return super().list(request, *args, **kwargs)
@validated_request(
responses={
201: OpenApiResponse(response=TaskRunDetailSerializer, description="Created task run"),
},
summary="Create task run",
description="Create a new run for a specific task.",
)
def create(self, request, *args, **kwargs):
return super().create(request, *args, **kwargs)
def safely_get_queryset(self, queryset):
# Task runs are always scoped to a specific task
task_id = self.kwargs.get("parent_lookup_task_id")

View File

@@ -282,3 +282,14 @@ class TaskRunAppendLogRequestSerializer(serializers.Serializer):
if not value:
raise serializers.ValidationError("At least one log entry is required")
return value
class TaskListQuerySerializer(serializers.Serializer):
"""Query parameters for listing tasks"""
origin_product = serializers.CharField(required=False, help_text="Filter by origin product")
stage = serializers.CharField(required=False, help_text="Filter by task run stage")
organization = serializers.CharField(required=False, help_text="Filter by repository organization")
repository = serializers.CharField(
required=False, help_text="Filter by repository name (can include org/repo format)"
)