mirror of
https://github.com/BillyOutlast/posthog.git
synced 2026-02-04 03:01:23 +01:00
chore: Add validation for request path and query params. (#41237)
This commit is contained in:
committed by
GitHub
parent
1a3a952123
commit
fa8f824b29
@@ -42,6 +42,7 @@ class PydanticModelMixin:
|
||||
def validated_request(
|
||||
request_serializer: type[serializers.Serializer] | None = None,
|
||||
*,
|
||||
query_serializer: type[serializers.Serializer] | None = None,
|
||||
responses: dict[int, OpenApiResponse | None] | None = None,
|
||||
summary: str | None = None,
|
||||
description: str | None = None,
|
||||
@@ -56,31 +57,26 @@ def validated_request(
|
||||
|
||||
Usage:
|
||||
@validated_request(
|
||||
request_serializer=RequestSerializer,
|
||||
request_serializer=RequestBodySerializer,
|
||||
query_serializer=QuerySerializer,
|
||||
responses={
|
||||
200: Response(response=SuccessResponseSerializer, ...),
|
||||
400: Response(response=InvalidRequestResponseSerializer, ...),
|
||||
},
|
||||
summary="Do something"
|
||||
)
|
||||
def my_action(self, request: ValidatedRequest, **kwargs):
|
||||
# When request_serializer is provided, request.validated_data is available
|
||||
request_data = request.validated_data.get("next_stage_id")
|
||||
|
||||
if not request_data:
|
||||
return Response(
|
||||
ErrorResponseSerializer({"error": "Invalid request"}).data,
|
||||
status=status.HTTP_400_BAD_REQUEST,
|
||||
)
|
||||
return Response(SuccessResponseSerializer(request_data, context=self.get_serializer_context()).data)
|
||||
|
||||
Note: Use ValidatedRequest type hint when you need to access request.validated_data.
|
||||
The decorator will set validated_data on the request when request_serializer is provided.
|
||||
def my_action(self, request: Request, **kwargs):
|
||||
# request.validated_data contains validated body data (if request_serializer provided)
|
||||
# Query params are validated but not mutated
|
||||
"""
|
||||
|
||||
def decorator(view_func: Callable) -> Callable:
|
||||
parameters = []
|
||||
if query_serializer is not None:
|
||||
parameters.append(query_serializer)
|
||||
|
||||
@extend_schema(
|
||||
request=request_serializer,
|
||||
parameters=parameters if parameters else None,
|
||||
responses=responses,
|
||||
summary=summary,
|
||||
description=description,
|
||||
@@ -90,19 +86,22 @@ def validated_request(
|
||||
)
|
||||
@wraps(view_func)
|
||||
def wrapper(self, request: Request, *args, **kwargs) -> Response:
|
||||
if query_serializer is not None:
|
||||
query_serializer_instance = query_serializer(data=request.query_params)
|
||||
query_serializer_instance.is_valid(raise_exception=True)
|
||||
|
||||
if request_serializer is not None:
|
||||
serializer = request_serializer(data=request.data)
|
||||
req_validation_result = serializer.is_valid(raise_exception=strict_request_validation)
|
||||
|
||||
if not req_validation_result and settings.DEBUG:
|
||||
logger.warning(
|
||||
"Request data does not match declared serializer in @validated_request decorator. Please update the provided API schema to ensure API docs remain up to date",
|
||||
"Request body does not match declared serializer in @validated_request decorator. Please update the provided API schema to ensure API docs remain up to date",
|
||||
view_func=view_func.__name__,
|
||||
serializer_class=request_serializer.__name__,
|
||||
validation_errors=serializer.errors,
|
||||
)
|
||||
|
||||
# Cast to ValidatedRequest and set validated_data attribute
|
||||
validated_request = cast(ValidatedRequest, request)
|
||||
validated_request.validated_data = serializer.validated_data
|
||||
|
||||
|
||||
@@ -55,6 +55,7 @@ class TestValidatedRequestDecorator(APIBaseTest):
|
||||
view_instance.get_serializer_context = Mock(return_value={})
|
||||
|
||||
mock_request = Mock()
|
||||
mock_request._full_data = {}
|
||||
mock_request.data = {
|
||||
"event": "$pageview",
|
||||
"distinct_id": "user_123",
|
||||
@@ -66,7 +67,7 @@ class TestValidatedRequestDecorator(APIBaseTest):
|
||||
assert response.status_code == status.HTTP_200_OK
|
||||
assert response.data["status"] == "ok"
|
||||
assert response.data["distinct_id"] == "user_123"
|
||||
assert mock_request.validated_data["event"] == "$pageview"
|
||||
assert mock_request.data["event"] == "$pageview"
|
||||
|
||||
def test_request_validation_with_missing_required_field(self):
|
||||
"""Missing required field, should raise validation error"""
|
||||
@@ -82,6 +83,7 @@ class TestValidatedRequestDecorator(APIBaseTest):
|
||||
|
||||
view_instance = Mock()
|
||||
mock_request = Mock()
|
||||
mock_request._full_data = {}
|
||||
mock_request.data = {"event": "$pageview"} # Missing 'distinct_id'
|
||||
|
||||
with pytest.raises(Exception) as exc_info:
|
||||
@@ -108,6 +110,7 @@ class TestValidatedRequestDecorator(APIBaseTest):
|
||||
view_instance = Mock()
|
||||
view_instance.get_serializer_context = Mock(return_value={})
|
||||
mock_request = Mock()
|
||||
mock_request._full_data = {}
|
||||
mock_request.data = {"event": "$pageview", "distinct_id": "user_123"}
|
||||
|
||||
response = mock_endpoint(view_instance, mock_request)
|
||||
@@ -134,6 +137,7 @@ class TestValidatedRequestDecorator(APIBaseTest):
|
||||
view_instance = Mock()
|
||||
view_instance.get_serializer_context = Mock(return_value={})
|
||||
mock_request = Mock()
|
||||
mock_request._full_data = {}
|
||||
mock_request.data = {"event": "$pageview", "distinct_id": "user_123"}
|
||||
|
||||
# Should log a warning and return the response
|
||||
@@ -172,6 +176,7 @@ class TestValidatedRequestDecorator(APIBaseTest):
|
||||
view_instance = Mock()
|
||||
view_instance.get_serializer_context = Mock(return_value={})
|
||||
mock_request = Mock()
|
||||
mock_request._full_data = {}
|
||||
mock_request.data = {"event": "$pageview", "distinct_id": "user_123"}
|
||||
|
||||
# Should log a warning and return the response
|
||||
@@ -207,6 +212,7 @@ class TestValidatedRequestDecorator(APIBaseTest):
|
||||
|
||||
view_instance = Mock()
|
||||
mock_request = Mock()
|
||||
mock_request._full_data = {}
|
||||
mock_request.data = {"event": "$pageview", "distinct_id": "user_123"}
|
||||
|
||||
response = mock_endpoint(view_instance, mock_request)
|
||||
@@ -230,6 +236,7 @@ class TestValidatedRequestDecorator(APIBaseTest):
|
||||
view_instance = Mock()
|
||||
view_instance.get_serializer_context = Mock(return_value={})
|
||||
mock_request = Mock()
|
||||
mock_request._full_data = {}
|
||||
mock_request.data = {"event": "$pageview", "distinct_id": "user_123"}
|
||||
|
||||
# Should log a warning and return the result
|
||||
@@ -269,6 +276,7 @@ class TestValidatedRequestDecorator(APIBaseTest):
|
||||
view_instance = Mock()
|
||||
view_instance.get_serializer_context = Mock(return_value={})
|
||||
mock_request = Mock()
|
||||
mock_request._full_data = {}
|
||||
mock_request.data = {"event": "$pageview", "distinct_id": "user_123"}
|
||||
|
||||
# Test with DEBUG=False (production mode)
|
||||
@@ -310,6 +318,7 @@ class TestValidatedRequestDecorator(APIBaseTest):
|
||||
view_instance = Mock()
|
||||
view_instance.get_serializer_context = Mock(return_value={})
|
||||
mock_request = Mock()
|
||||
mock_request._full_data = {}
|
||||
mock_request.data = {"event": "$pageview", "distinct_id": "user_123"}
|
||||
|
||||
# Should raise ValidationError regardless of DEBUG setting
|
||||
@@ -336,6 +345,7 @@ class TestValidatedRequestDecorator(APIBaseTest):
|
||||
view_instance = Mock()
|
||||
view_instance.get_serializer_context = Mock(return_value={})
|
||||
mock_request = Mock()
|
||||
mock_request._full_data = {}
|
||||
mock_request.data = {"event": "$pageview", "distinct_id": "user_123"}
|
||||
|
||||
# Should raise ValidationError with serializer errors
|
||||
@@ -369,6 +379,7 @@ class TestValidatedRequestDecorator(APIBaseTest):
|
||||
view_instance = Mock()
|
||||
view_instance.get_serializer_context = Mock(return_value={})
|
||||
mock_request = Mock()
|
||||
mock_request._full_data = {}
|
||||
mock_request.data = {"event": "$pageview", "distinct_id": "user_123"}
|
||||
|
||||
# Should work normally with valid data
|
||||
@@ -402,6 +413,7 @@ class TestValidatedRequestDecorator(APIBaseTest):
|
||||
view_instance = Mock()
|
||||
view_instance.get_serializer_context = Mock(return_value={})
|
||||
mock_request = Mock()
|
||||
mock_request._full_data = {}
|
||||
mock_request.data = {"event": "$pageview"} # Missing required 'distinct_id'
|
||||
|
||||
# Should log a warning but not raise
|
||||
@@ -414,7 +426,7 @@ class TestValidatedRequestDecorator(APIBaseTest):
|
||||
mock_logger.warning.assert_called_once()
|
||||
call_args = mock_logger.warning.call_args
|
||||
assert (
|
||||
"Request data does not match declared serializer in @validated_request decorator. Please update the provided API schema to ensure API docs remain up to date"
|
||||
"Request body does not match declared serializer in @validated_request decorator. Please update the provided API schema to ensure API docs remain up to date"
|
||||
in call_args[0][0]
|
||||
)
|
||||
assert call_args[1]["view_func"] == "mock_endpoint"
|
||||
@@ -439,7 +451,7 @@ class TestValidatedRequestDecorator(APIBaseTest):
|
||||
{
|
||||
"status": "ok",
|
||||
"event_id": str(uuid.uuid4()),
|
||||
"distinct_id": request.validated_data.get("distinct_id", "unknown"),
|
||||
"distinct_id": request.data.get("distinct_id", "unknown"),
|
||||
},
|
||||
status=status.HTTP_200_OK,
|
||||
)
|
||||
@@ -447,6 +459,7 @@ class TestValidatedRequestDecorator(APIBaseTest):
|
||||
view_instance = Mock()
|
||||
view_instance.get_serializer_context = Mock(return_value={})
|
||||
mock_request = Mock()
|
||||
mock_request._full_data = {}
|
||||
mock_request.data = {"event": "$pageview"} # Missing required 'distinct_id'
|
||||
|
||||
# Should not log warning when DEBUG=False
|
||||
@@ -485,6 +498,7 @@ class TestValidatedRequestDecorator(APIBaseTest):
|
||||
view_instance = Mock()
|
||||
view_instance.get_serializer_context = Mock(return_value={})
|
||||
mock_request = Mock()
|
||||
mock_request._full_data = {}
|
||||
mock_request.data = {
|
||||
"event": "$pageview",
|
||||
"distinct_id": "user_123",
|
||||
@@ -497,7 +511,7 @@ class TestValidatedRequestDecorator(APIBaseTest):
|
||||
assert response.status_code == status.HTTP_200_OK
|
||||
assert response.data["status"] == "ok"
|
||||
assert response.data["distinct_id"] == "user_123"
|
||||
assert mock_request.validated_data["event"] == "$pageview"
|
||||
assert mock_request.data["event"] == "$pageview"
|
||||
|
||||
def test_no_body_response_declared_as_none_succeeds(self):
|
||||
"""When status code is declared as None (no body), response with no body should succeed"""
|
||||
@@ -512,6 +526,7 @@ class TestValidatedRequestDecorator(APIBaseTest):
|
||||
|
||||
view_instance = Mock()
|
||||
mock_request = Mock()
|
||||
mock_request._full_data = {}
|
||||
mock_request.data = {}
|
||||
|
||||
response = mock_endpoint(view_instance, mock_request)
|
||||
@@ -532,6 +547,7 @@ class TestValidatedRequestDecorator(APIBaseTest):
|
||||
|
||||
view_instance = Mock()
|
||||
mock_request = Mock()
|
||||
mock_request._full_data = {}
|
||||
mock_request.data = {}
|
||||
|
||||
with patch("posthog.api.mixins.settings") as mock_settings:
|
||||
@@ -563,9 +579,191 @@ class TestValidatedRequestDecorator(APIBaseTest):
|
||||
|
||||
view_instance = Mock()
|
||||
mock_request = Mock()
|
||||
mock_request._full_data = {}
|
||||
mock_request.data = {}
|
||||
|
||||
with pytest.raises(serializers.ValidationError) as exc_info:
|
||||
mock_endpoint(view_instance, mock_request)
|
||||
|
||||
assert "Response status code 204 is declared with no body, but response contains data" in str(exc_info.value)
|
||||
|
||||
def test_query_parameter_validation_with_valid_data(self):
|
||||
"""Query parameter validation: valid query params should work"""
|
||||
|
||||
class QueryParamSerializer(serializers.Serializer):
|
||||
page = serializers.IntegerField(required=False, default=1)
|
||||
limit = serializers.IntegerField(required=False, default=10, max_value=100)
|
||||
|
||||
@validated_request(
|
||||
query_serializer=QueryParamSerializer,
|
||||
responses={
|
||||
200: OpenApiResponse(response=EventCaptureResponseSerializer),
|
||||
},
|
||||
)
|
||||
def mock_endpoint(view_self, request, **kwargs):
|
||||
page = request.query_params.get("page", 1)
|
||||
limit = request.query_params.get("limit", 10)
|
||||
return Response(
|
||||
{
|
||||
"status": "ok",
|
||||
"event_id": str(uuid.uuid4()),
|
||||
"distinct_id": "test",
|
||||
"page": int(page),
|
||||
"limit": int(limit),
|
||||
},
|
||||
status=status.HTTP_200_OK,
|
||||
)
|
||||
|
||||
view_instance = Mock()
|
||||
view_instance.get_serializer_context = Mock(return_value={})
|
||||
mock_request = Mock()
|
||||
mock_request._full_data = {}
|
||||
mock_request.data = {}
|
||||
from django.http import QueryDict
|
||||
|
||||
mock_get = QueryDict("page=2&limit=20")
|
||||
mock_request._request = Mock()
|
||||
mock_request._request.GET = mock_get
|
||||
mock_request.query_params = mock_get
|
||||
|
||||
response = mock_endpoint(view_instance, mock_request)
|
||||
|
||||
assert response.status_code == status.HTTP_200_OK
|
||||
assert response.data["page"] == 2
|
||||
assert response.data["limit"] == 20
|
||||
|
||||
def test_query_parameter_validation_with_invalid_data_raises(self):
|
||||
"""Query parameter validation: invalid query params should raise exception"""
|
||||
|
||||
class QueryParamSerializer(serializers.Serializer):
|
||||
limit = serializers.IntegerField(required=True, max_value=100)
|
||||
|
||||
@validated_request(
|
||||
query_serializer=QueryParamSerializer,
|
||||
responses={
|
||||
200: OpenApiResponse(response=EventCaptureResponseSerializer),
|
||||
},
|
||||
)
|
||||
def mock_endpoint(view_self, request, **kwargs):
|
||||
return Response(
|
||||
{
|
||||
"status": "ok",
|
||||
"event_id": str(uuid.uuid4()),
|
||||
"distinct_id": "test",
|
||||
},
|
||||
status=status.HTTP_200_OK,
|
||||
)
|
||||
|
||||
view_instance = Mock()
|
||||
view_instance.get_serializer_context = Mock(return_value={})
|
||||
mock_request = Mock()
|
||||
mock_request._full_data = {}
|
||||
mock_request.data = {}
|
||||
from django.http import QueryDict
|
||||
|
||||
mock_get = QueryDict("limit=200") # Exceeds max_value
|
||||
mock_request._request = Mock()
|
||||
mock_request._request.GET = mock_get
|
||||
mock_request.query_params = mock_get
|
||||
|
||||
with pytest.raises(serializers.ValidationError) as exc_info:
|
||||
mock_endpoint(view_instance, mock_request)
|
||||
|
||||
assert "limit" in str(exc_info.value)
|
||||
|
||||
def test_post_request_with_query_parameters(self):
|
||||
"""POST request with query parameters: both should be validated"""
|
||||
|
||||
class QueryParamSerializer(serializers.Serializer):
|
||||
dry_run = serializers.BooleanField(required=False, default=False)
|
||||
force = serializers.BooleanField(required=False, default=False)
|
||||
|
||||
class PostRequestSerializer(serializers.Serializer):
|
||||
action = serializers.ChoiceField(choices=["create", "update", "delete"])
|
||||
payload = serializers.DictField(required=False)
|
||||
|
||||
@validated_request(
|
||||
request_serializer=PostRequestSerializer,
|
||||
query_serializer=QueryParamSerializer,
|
||||
responses={
|
||||
200: OpenApiResponse(response=EventCaptureResponseSerializer),
|
||||
},
|
||||
)
|
||||
def mock_endpoint(view_self, request, **kwargs):
|
||||
action = request.validated_data["action"]
|
||||
# Query params remain as strings, need to convert
|
||||
dry_run = request.query_params.get("dry_run", "false").lower() == "true"
|
||||
force = request.query_params.get("force", "false").lower() == "true"
|
||||
return Response(
|
||||
{
|
||||
"status": "ok",
|
||||
"event_id": str(uuid.uuid4()),
|
||||
"distinct_id": "test",
|
||||
"action": action,
|
||||
"dry_run": dry_run,
|
||||
"force": force,
|
||||
},
|
||||
status=status.HTTP_200_OK,
|
||||
)
|
||||
|
||||
view_instance = Mock()
|
||||
view_instance.get_serializer_context = Mock(return_value={})
|
||||
mock_request = Mock()
|
||||
mock_request._full_data = {}
|
||||
mock_request.data = {"action": "create", "payload": {"key": "value"}}
|
||||
from django.http import QueryDict
|
||||
|
||||
mock_get = QueryDict("dry_run=true&force=false")
|
||||
mock_request._request = Mock()
|
||||
mock_request._request.GET = mock_get
|
||||
mock_request.query_params = mock_get
|
||||
|
||||
response = mock_endpoint(view_instance, mock_request)
|
||||
|
||||
assert response.status_code == status.HTTP_200_OK
|
||||
assert response.data["action"] == "create"
|
||||
assert response.data["dry_run"] is True
|
||||
assert response.data["force"] is False
|
||||
|
||||
def test_post_request_with_invalid_query_parameters_raises(self):
|
||||
"""POST request with invalid query parameters: should raise exception"""
|
||||
|
||||
class QueryParamSerializer(serializers.Serializer):
|
||||
timeout = serializers.IntegerField(required=True, min_value=1, max_value=3600)
|
||||
|
||||
class PostRequestSerializer(serializers.Serializer):
|
||||
action = serializers.CharField()
|
||||
|
||||
@validated_request(
|
||||
request_serializer=PostRequestSerializer,
|
||||
query_serializer=QueryParamSerializer,
|
||||
responses={
|
||||
200: OpenApiResponse(response=EventCaptureResponseSerializer),
|
||||
},
|
||||
)
|
||||
def mock_endpoint(view_self, request, **kwargs):
|
||||
return Response(
|
||||
{
|
||||
"status": "ok",
|
||||
"event_id": str(uuid.uuid4()),
|
||||
"distinct_id": "test",
|
||||
},
|
||||
status=status.HTTP_200_OK,
|
||||
)
|
||||
|
||||
view_instance = Mock()
|
||||
view_instance.get_serializer_context = Mock(return_value={})
|
||||
mock_request = Mock()
|
||||
mock_request._full_data = {}
|
||||
mock_request.data = {"action": "test"}
|
||||
from django.http import QueryDict
|
||||
|
||||
mock_get = QueryDict("timeout=5000") # Exceeds max_value
|
||||
mock_request._request = Mock()
|
||||
mock_request._request.GET = mock_get
|
||||
mock_request.query_params = mock_get
|
||||
|
||||
with pytest.raises(serializers.ValidationError) as exc_info:
|
||||
mock_endpoint(view_instance, mock_request)
|
||||
|
||||
assert "timeout" in str(exc_info.value)
|
||||
|
||||
@@ -18,6 +18,7 @@ from posthog.permissions import APIScopePermission, PostHogFeatureFlagPermission
|
||||
from .models import Task, TaskRun
|
||||
from .serializers import (
|
||||
ErrorResponseSerializer,
|
||||
TaskListQuerySerializer,
|
||||
TaskRunAppendLogRequestSerializer,
|
||||
TaskRunDetailSerializer,
|
||||
TaskSerializer,
|
||||
@@ -54,6 +55,17 @@ class TaskViewSet(TeamAndOrgViewSetMixin, viewsets.ModelViewSet):
|
||||
]
|
||||
}
|
||||
|
||||
@validated_request(
|
||||
query_serializer=TaskListQuerySerializer,
|
||||
responses={
|
||||
200: OpenApiResponse(response=TaskSerializer, description="List of tasks"),
|
||||
},
|
||||
summary="List tasks",
|
||||
description="Get a list of tasks for the current project, with optional filtering by origin product, stage, organization, and repository.",
|
||||
)
|
||||
def list(self, request, *args, **kwargs):
|
||||
return super().list(request, *args, **kwargs)
|
||||
|
||||
def safely_get_queryset(self, queryset):
|
||||
qs = queryset.filter(team=self.team).order_by("position")
|
||||
|
||||
@@ -198,6 +210,26 @@ class TaskRunViewSet(TeamAndOrgViewSetMixin, viewsets.ModelViewSet):
|
||||
http_method_names = ["get", "post", "patch", "head", "options"]
|
||||
filter_rewrite_rules = {"team_id": "team_id"}
|
||||
|
||||
@validated_request(
|
||||
responses={
|
||||
200: OpenApiResponse(response=TaskRunDetailSerializer, description="List of task runs"),
|
||||
},
|
||||
summary="List task runs",
|
||||
description="Get a list of runs for a specific task.",
|
||||
)
|
||||
def list(self, request, *args, **kwargs):
|
||||
return super().list(request, *args, **kwargs)
|
||||
|
||||
@validated_request(
|
||||
responses={
|
||||
201: OpenApiResponse(response=TaskRunDetailSerializer, description="Created task run"),
|
||||
},
|
||||
summary="Create task run",
|
||||
description="Create a new run for a specific task.",
|
||||
)
|
||||
def create(self, request, *args, **kwargs):
|
||||
return super().create(request, *args, **kwargs)
|
||||
|
||||
def safely_get_queryset(self, queryset):
|
||||
# Task runs are always scoped to a specific task
|
||||
task_id = self.kwargs.get("parent_lookup_task_id")
|
||||
|
||||
@@ -282,3 +282,14 @@ class TaskRunAppendLogRequestSerializer(serializers.Serializer):
|
||||
if not value:
|
||||
raise serializers.ValidationError("At least one log entry is required")
|
||||
return value
|
||||
|
||||
|
||||
class TaskListQuerySerializer(serializers.Serializer):
|
||||
"""Query parameters for listing tasks"""
|
||||
|
||||
origin_product = serializers.CharField(required=False, help_text="Filter by origin product")
|
||||
stage = serializers.CharField(required=False, help_text="Filter by task run stage")
|
||||
organization = serializers.CharField(required=False, help_text="Filter by repository organization")
|
||||
repository = serializers.CharField(
|
||||
required=False, help_text="Filter by repository name (can include org/repo format)"
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user