feat: add active livestream subs + increase incoming events queue (#36624)

This commit is contained in:
Paweł Szczur
2025-08-14 16:16:43 +02:00
committed by GitHub
parent 3de358b47e
commit caf34c5d80
5 changed files with 173 additions and 4 deletions

View File

@@ -3,10 +3,11 @@ package events
import (
"fmt"
"log"
"slices"
"sync/atomic"
"github.com/gofrs/uuid/v5"
"slices"
"github.com/posthog/posthog/livestream/metrics"
)
type Subscription struct {
@@ -86,6 +87,7 @@ func uuidFromDistinctId(teamId int, distinctId string) string {
func removeSubscription(subID uint64, subs []Subscription) []Subscription {
for i, sub := range subs {
if subID == sub.SubID {
metrics.SubTotal.Dec()
return slices.Delete(subs, i, i+1)
}
}
@@ -97,6 +99,7 @@ func (c *Filter) Run() {
select {
case newSub := <-c.SubChan:
c.subs = append(c.subs, newSub)
metrics.SubTotal.Inc()
case unSub := <-c.UnSubChan:
c.subs = removeSubscription(unSub.SubID, c.subs)
case event := <-c.inboundChan:

View File

@@ -63,7 +63,7 @@ func NewPostHogKafkaConsumer(
"security.protocol": securityProtocol,
"fetch.message.max.bytes": 1_000_000_000,
"fetch.max.bytes": 1_000_000_000,
"queued.max.messages.kbytes": 1_000_000,
"queued.max.messages.kbytes": 2_000_000,
}
consumer, err := kafka.NewConsumer(config)
@@ -75,7 +75,7 @@ func NewPostHogKafkaConsumer(
consumer: consumer,
topic: topic,
geolocator: geolocator,
incoming: make(chan []byte, (1+parallel)*10),
incoming: make(chan []byte, (1+parallel)*100),
outgoingChan: outgoingChan,
statsChan: statsChan,
parallel: parallel,

View File

@@ -74,7 +74,7 @@ func StreamEventsHandler(log echo.Logger, subChan chan events.Subscription, filt
)
teamID, token, err = auth.GetAuthClaims(c.Request().Header)
if err != nil {
if err != nil || token == "" || teamID == 0 {
return echo.NewHTTPError(http.StatusUnauthorized, "wrong token")
}

View File

@@ -0,0 +1,162 @@
package handlers
import (
"context"
"net/http"
"net/http/httptest"
"testing"
"time"
"github.com/golang-jwt/jwt/v5"
"github.com/labstack/echo/v4"
"github.com/posthog/posthog/livestream/auth"
"github.com/posthog/posthog/livestream/events"
"github.com/spf13/viper"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func TestStreamEventsHandler_AuthValidation(t *testing.T) {
logger := echo.New().Logger
subChan := make(chan events.Subscription, 10)
filter := &events.Filter{
UnSubChan: make(chan events.Subscription, 10),
}
handler := StreamEventsHandler(logger, subChan, filter)
tests := []struct {
name string
setupHeader func(*http.Request)
expectedStatus int
expectedError string
description string
}{
{
name: "Missing authorization header returns unauthorized",
setupHeader: func(req *http.Request) {
},
expectedStatus: http.StatusUnauthorized,
expectedError: "wrong token",
description: "When auth header is missing, GetAuthClaims returns error and handler should return 401",
},
{
name: "Invalid auth header returns unauthorized",
setupHeader: func(req *http.Request) {
req.Header.Set("Authorization", "InvalidToken")
},
expectedStatus: http.StatusUnauthorized,
expectedError: "wrong token",
description: "When auth header is invalid, GetAuthClaims returns error and handler should return 401",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
e := echo.New()
req := httptest.NewRequest(http.MethodGet, "/events", nil)
ctx, canc := context.WithTimeout(context.Background(), time.Millisecond)
defer canc()
req = req.WithContext(ctx)
tt.setupHeader(req)
rec := httptest.NewRecorder()
c := e.NewContext(req, rec)
err := handler(c)
require.Error(t, err, tt.description)
httpErr, ok := err.(*echo.HTTPError)
require.True(t, ok, "error should be an HTTPError")
assert.Equal(t, tt.expectedStatus, httpErr.Code)
assert.Equal(t, tt.expectedError, httpErr.Message)
})
}
}
func TestStreamEventsHandler_TokenAndTeamIDValidation(t *testing.T) {
viper.Set("jwt.secret", "test-secret-for-handlers")
logger := echo.New().Logger
subChan := make(chan events.Subscription, 10)
filter := &events.Filter{
UnSubChan: make(chan events.Subscription, 10),
}
handler := StreamEventsHandler(logger, subChan, filter)
tests := []struct {
name string
claims jwt.MapClaims
expectError bool
errorMessage string
description string
}{
{
name: "Empty api_token should return unauthorized",
claims: jwt.MapClaims{
"team_id": 123,
"api_token": "",
},
expectError: true,
errorMessage: "wrong token",
description: "New validation: empty token should be rejected even with valid JWT",
},
{
name: "Team ID 0 should return unauthorized",
claims: jwt.MapClaims{
"team_id": 0,
"api_token": "valid-token",
},
expectError: true,
errorMessage: "wrong token",
description: "New validation: teamID=0 should be rejected even with valid JWT",
},
{
name: "HappyPath",
claims: jwt.MapClaims{
"team_id": 7,
"api_token": "valid-token",
},
expectError: false,
description: "New validation: teamID=7 should be accepted even with valid JWT",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
token := createJWTToken(auth.ExpectedScope, tt.claims)
e := echo.New()
req := httptest.NewRequest(http.MethodGet, "/events", nil)
req.Header.Set("Authorization", "Bearer "+token)
ctx, canc := context.WithTimeout(context.Background(), time.Millisecond)
defer canc()
req = req.WithContext(ctx)
rec := httptest.NewRecorder()
c := e.NewContext(req, rec)
err := handler(c)
if tt.expectError {
require.Error(t, err, tt.description)
httpErr, ok := err.(*echo.HTTPError)
require.True(t, ok, "error should be an HTTPError")
assert.Equal(t, http.StatusUnauthorized, httpErr.Code)
assert.Equal(t, tt.errorMessage, httpErr.Message)
} else {
assert.NoError(t, err, tt.description)
}
})
}
}
func createJWTToken(audience string, claims jwt.MapClaims) string {
newClaims := jwt.MapClaims{
"aud": audience,
"exp": time.Now().Add(time.Hour).Unix(),
}
for k, v := range claims {
newClaims[k] = v
}
token := jwt.NewWithClaims(jwt.SigningMethodHS256, newClaims)
tokenString, _ := token.SignedString([]byte(viper.GetString("jwt.secret")))
return tokenString
}

View File

@@ -46,4 +46,8 @@ var (
Name: "livestream_unsub_queue_use_ratio",
Help: "How much of unsub queue is used (disconnecting removes subscription)",
})
SubTotal = promauto.NewGauge(prometheus.GaugeOpts{
Name: "livestream_active_event_subscriptions_total",
Help: "How many active event subscriptions we have",
})
)