From caf34c5d804dc1204c740be7f588b45a8fbabfd5 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Pawe=C5=82=20Szczur?= Date: Thu, 14 Aug 2025 16:16:43 +0200 Subject: [PATCH] feat: add active livestream subs + increase incoming events queue (#36624) --- livestream/events/filter.go | 5 +- livestream/events/kafka.go | 4 +- livestream/handlers/handlers.go | 2 +- livestream/handlers/handlers_test.go | 162 +++++++++++++++++++++++++++ livestream/metrics/metrics.go | 4 + 5 files changed, 173 insertions(+), 4 deletions(-) create mode 100644 livestream/handlers/handlers_test.go diff --git a/livestream/events/filter.go b/livestream/events/filter.go index 49a72427d1..74bf4581ea 100644 --- a/livestream/events/filter.go +++ b/livestream/events/filter.go @@ -3,10 +3,11 @@ package events import ( "fmt" "log" + "slices" "sync/atomic" "github.com/gofrs/uuid/v5" - "slices" + "github.com/posthog/posthog/livestream/metrics" ) type Subscription struct { @@ -86,6 +87,7 @@ func uuidFromDistinctId(teamId int, distinctId string) string { func removeSubscription(subID uint64, subs []Subscription) []Subscription { for i, sub := range subs { if subID == sub.SubID { + metrics.SubTotal.Dec() return slices.Delete(subs, i, i+1) } } @@ -97,6 +99,7 @@ func (c *Filter) Run() { select { case newSub := <-c.SubChan: c.subs = append(c.subs, newSub) + metrics.SubTotal.Inc() case unSub := <-c.UnSubChan: c.subs = removeSubscription(unSub.SubID, c.subs) case event := <-c.inboundChan: diff --git a/livestream/events/kafka.go b/livestream/events/kafka.go index 678ac9103b..5dd01740b9 100644 --- a/livestream/events/kafka.go +++ b/livestream/events/kafka.go @@ -63,7 +63,7 @@ func NewPostHogKafkaConsumer( "security.protocol": securityProtocol, "fetch.message.max.bytes": 1_000_000_000, "fetch.max.bytes": 1_000_000_000, - "queued.max.messages.kbytes": 1_000_000, + "queued.max.messages.kbytes": 2_000_000, } consumer, err := kafka.NewConsumer(config) @@ -75,7 +75,7 @@ func NewPostHogKafkaConsumer( consumer: consumer, topic: topic, geolocator: geolocator, - incoming: make(chan []byte, (1+parallel)*10), + incoming: make(chan []byte, (1+parallel)*100), outgoingChan: outgoingChan, statsChan: statsChan, parallel: parallel, diff --git a/livestream/handlers/handlers.go b/livestream/handlers/handlers.go index d623c0e1c7..953a35800a 100644 --- a/livestream/handlers/handlers.go +++ b/livestream/handlers/handlers.go @@ -74,7 +74,7 @@ func StreamEventsHandler(log echo.Logger, subChan chan events.Subscription, filt ) teamID, token, err = auth.GetAuthClaims(c.Request().Header) - if err != nil { + if err != nil || token == "" || teamID == 0 { return echo.NewHTTPError(http.StatusUnauthorized, "wrong token") } diff --git a/livestream/handlers/handlers_test.go b/livestream/handlers/handlers_test.go new file mode 100644 index 0000000000..6da0fa2207 --- /dev/null +++ b/livestream/handlers/handlers_test.go @@ -0,0 +1,162 @@ +package handlers + +import ( + "context" + "net/http" + "net/http/httptest" + "testing" + "time" + + "github.com/golang-jwt/jwt/v5" + "github.com/labstack/echo/v4" + "github.com/posthog/posthog/livestream/auth" + "github.com/posthog/posthog/livestream/events" + "github.com/spf13/viper" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestStreamEventsHandler_AuthValidation(t *testing.T) { + logger := echo.New().Logger + subChan := make(chan events.Subscription, 10) + filter := &events.Filter{ + UnSubChan: make(chan events.Subscription, 10), + } + handler := StreamEventsHandler(logger, subChan, filter) + + tests := []struct { + name string + setupHeader func(*http.Request) + expectedStatus int + expectedError string + description string + }{ + { + name: "Missing authorization header returns unauthorized", + setupHeader: func(req *http.Request) { + }, + expectedStatus: http.StatusUnauthorized, + expectedError: "wrong token", + description: "When auth header is missing, GetAuthClaims returns error and handler should return 401", + }, + { + name: "Invalid auth header returns unauthorized", + setupHeader: func(req *http.Request) { + req.Header.Set("Authorization", "InvalidToken") + }, + expectedStatus: http.StatusUnauthorized, + expectedError: "wrong token", + description: "When auth header is invalid, GetAuthClaims returns error and handler should return 401", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + e := echo.New() + req := httptest.NewRequest(http.MethodGet, "/events", nil) + ctx, canc := context.WithTimeout(context.Background(), time.Millisecond) + defer canc() + req = req.WithContext(ctx) + tt.setupHeader(req) + rec := httptest.NewRecorder() + c := e.NewContext(req, rec) + + err := handler(c) + + require.Error(t, err, tt.description) + httpErr, ok := err.(*echo.HTTPError) + require.True(t, ok, "error should be an HTTPError") + assert.Equal(t, tt.expectedStatus, httpErr.Code) + assert.Equal(t, tt.expectedError, httpErr.Message) + }) + } +} + +func TestStreamEventsHandler_TokenAndTeamIDValidation(t *testing.T) { + viper.Set("jwt.secret", "test-secret-for-handlers") + + logger := echo.New().Logger + subChan := make(chan events.Subscription, 10) + filter := &events.Filter{ + UnSubChan: make(chan events.Subscription, 10), + } + handler := StreamEventsHandler(logger, subChan, filter) + + tests := []struct { + name string + claims jwt.MapClaims + expectError bool + errorMessage string + description string + }{ + { + name: "Empty api_token should return unauthorized", + claims: jwt.MapClaims{ + "team_id": 123, + "api_token": "", + }, + expectError: true, + errorMessage: "wrong token", + description: "New validation: empty token should be rejected even with valid JWT", + }, + { + name: "Team ID 0 should return unauthorized", + claims: jwt.MapClaims{ + "team_id": 0, + "api_token": "valid-token", + }, + expectError: true, + errorMessage: "wrong token", + description: "New validation: teamID=0 should be rejected even with valid JWT", + }, + { + name: "HappyPath", + claims: jwt.MapClaims{ + "team_id": 7, + "api_token": "valid-token", + }, + expectError: false, + description: "New validation: teamID=7 should be accepted even with valid JWT", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + token := createJWTToken(auth.ExpectedScope, tt.claims) + + e := echo.New() + req := httptest.NewRequest(http.MethodGet, "/events", nil) + req.Header.Set("Authorization", "Bearer "+token) + ctx, canc := context.WithTimeout(context.Background(), time.Millisecond) + defer canc() + req = req.WithContext(ctx) + rec := httptest.NewRecorder() + c := e.NewContext(req, rec) + + err := handler(c) + + if tt.expectError { + require.Error(t, err, tt.description) + httpErr, ok := err.(*echo.HTTPError) + require.True(t, ok, "error should be an HTTPError") + assert.Equal(t, http.StatusUnauthorized, httpErr.Code) + assert.Equal(t, tt.errorMessage, httpErr.Message) + } else { + assert.NoError(t, err, tt.description) + } + }) + } +} + +func createJWTToken(audience string, claims jwt.MapClaims) string { + newClaims := jwt.MapClaims{ + "aud": audience, + "exp": time.Now().Add(time.Hour).Unix(), + } + for k, v := range claims { + newClaims[k] = v + } + token := jwt.NewWithClaims(jwt.SigningMethodHS256, newClaims) + tokenString, _ := token.SignedString([]byte(viper.GetString("jwt.secret"))) + return tokenString +} diff --git a/livestream/metrics/metrics.go b/livestream/metrics/metrics.go index 68fa467406..286ab645ae 100644 --- a/livestream/metrics/metrics.go +++ b/livestream/metrics/metrics.go @@ -46,4 +46,8 @@ var ( Name: "livestream_unsub_queue_use_ratio", Help: "How much of unsub queue is used (disconnecting removes subscription)", }) + SubTotal = promauto.NewGauge(prometheus.GaugeOpts{ + Name: "livestream_active_event_subscriptions_total", + Help: "How many active event subscriptions we have", + }) )