merge history

This commit is contained in:
Neil Kakkar
2024-06-10 14:00:00 +01:00
98 changed files with 17084 additions and 0 deletions

5
rust/.dockerignore Normal file
View File

@@ -0,0 +1,5 @@
.env
.git
.github
docker
target

1
rust/.env Normal file
View File

@@ -0,0 +1 @@
DATABASE_URL=postgres://posthog:posthog@localhost:15432/test_database

79
rust/.github/workflows/docker-build.yml vendored Normal file
View File

@@ -0,0 +1,79 @@
name: Build container images
on:
workflow_dispatch:
push:
branches:
- 'main'
jobs:
build:
name: Build and publish container image
strategy:
matrix:
image:
- capture
- hook-api
- hook-janitor
- hook-worker
runs-on: depot-ubuntu-22.04-4
permissions:
id-token: write # allow issuing OIDC tokens for this workflow run
contents: read # allow reading the repo contents
packages: write # allow push to ghcr.io
steps:
- name: Check Out Repo
uses: actions/checkout@v3
- name: Set up Depot CLI
uses: depot/setup-action@v1
- name: Login to DockerHub
uses: docker/login-action@v2
with:
username: posthog
password: ${{ secrets.DOCKERHUB_TOKEN }}
- name: Login to ghcr.io
uses: docker/login-action@v2
with:
registry: ghcr.io
username: ${{ github.actor }}
password: ${{ secrets.GITHUB_TOKEN }}
- name: Set up QEMU
uses: docker/setup-qemu-action@v2
- name: Docker meta
id: meta
uses: docker/metadata-action@v4
with:
images: ghcr.io/posthog/hog-rs/${{ matrix.image }}
tags: |
type=ref,event=pr
type=ref,event=branch
type=semver,pattern={{version}}
type=semver,pattern={{major}}.{{minor}}
type=sha
- name: Set up Docker Buildx
id: buildx
uses: docker/setup-buildx-action@v2
- name: Build and push image
id: docker_build
uses: depot/build-push-action@v1
with:
context: ./
file: ./Dockerfile
push: true
tags: ${{ steps.meta.outputs.tags }}
labels: ${{ steps.meta.outputs.labels }}
platforms: linux/arm64
cache-from: type=gha
cache-to: type=gha,mode=max
build-args: BIN=${{ matrix.image }}
- name: Container image digest
run: echo ${{ steps.docker_build.outputs.digest }}

View File

@@ -0,0 +1,74 @@
name: Build hook-migrator docker image
on:
workflow_dispatch:
push:
branches:
- 'main'
permissions:
packages: write
jobs:
build:
name: build and publish hook-migrator image
runs-on: depot-ubuntu-22.04-4
permissions:
id-token: write # allow issuing OIDC tokens for this workflow run
contents: read # allow reading the repo contents
packages: write # allow push to ghcr.io
steps:
- name: Check Out Repo
uses: actions/checkout@v3
- name: Set up Depot CLI
uses: depot/setup-action@v1
- name: Login to DockerHub
uses: docker/login-action@v2
with:
username: posthog
password: ${{ secrets.DOCKERHUB_TOKEN }}
- name: Login to ghcr.io
uses: docker/login-action@v2
with:
registry: ghcr.io
username: ${{ github.actor }}
password: ${{ secrets.GITHUB_TOKEN }}
- name: Set up QEMU
uses: docker/setup-qemu-action@v2
- name: Docker meta
id: meta
uses: docker/metadata-action@v4
with:
images: ghcr.io/posthog/hog-rs/hook-migrator
tags: |
type=ref,event=pr
type=ref,event=branch
type=semver,pattern={{version}}
type=semver,pattern={{major}}.{{minor}}
type=sha
- name: Set up Docker Buildx
id: buildx
uses: docker/setup-buildx-action@v2
- name: Build and push migrator
id: docker_build_hook_migrator
uses: depot/build-push-action@v1
with:
context: ./
file: ./Dockerfile.migrate
push: true
tags: ${{ steps.meta.outputs.tags }}
labels: ${{ steps.meta.outputs.labels }}
platforms: linux/arm64
cache-from: type=gha
cache-to: type=gha,mode=max
- name: Hook-migrator image digest
run: echo ${{ steps.docker_build_hook_migrator.outputs.digest }}

105
rust/.github/workflows/rust.yml vendored Normal file
View File

@@ -0,0 +1,105 @@
name: Rust
on:
workflow_dispatch:
push:
branches: [main]
pull_request:
env:
CARGO_TERM_COLOR: always
jobs:
build:
runs-on: depot-ubuntu-22.04-4
steps:
- uses: actions/checkout@v3
- name: Install rust
uses: dtolnay/rust-toolchain@1.77
- uses: actions/cache@v3
with:
path: |
~/.cargo/registry
~/.cargo/git
target
key: ${{ runner.os }}-cargo-release-${{ hashFiles('**/Cargo.lock') }}
- name: Run cargo build
run: cargo build --all --locked --release && find target/release/ -maxdepth 1 -executable -type f | xargs strip
test:
runs-on: depot-ubuntu-22.04-4
timeout-minutes: 10
steps:
- uses: actions/checkout@v3
- name: Login to DockerHub
uses: docker/login-action@v2
with:
username: posthog
password: ${{ secrets.DOCKERHUB_TOKEN }}
- name: Setup dependencies
run: |
docker compose up kafka redis db echo_server -d --wait
docker compose up setup_test_db
echo "127.0.0.1 kafka" | sudo tee -a /etc/hosts
- name: Install rust
uses: dtolnay/rust-toolchain@1.77
- uses: actions/cache@v3
with:
path: |
~/.cargo/registry
~/.cargo/git
target
key: ${ runner.os }-cargo-debug-${{ hashFiles('**/Cargo.lock') }}
- name: Run cargo test
run: cargo test --all-features
linting:
runs-on: depot-ubuntu-22.04-4
steps:
- uses: actions/checkout@v3
- name: Install rust
uses: dtolnay/rust-toolchain@1.77
with:
components: clippy,rustfmt
- uses: actions/cache@v3
with:
path: |
~/.cargo/registry
~/.cargo/git
target
key: ${{ runner.os }}-cargo-debug-${{ hashFiles('**/Cargo.lock') }}
- name: Check format
run: cargo fmt -- --check
- name: Run clippy
run: cargo clippy -- -D warnings
- name: Run cargo check
run: cargo check --all-features
shear:
runs-on: depot-ubuntu-22.04-4
steps:
- uses: actions/checkout@v3
- name: Install cargo-binstall
uses: cargo-bins/cargo-binstall@main
- name: Install cargo-shear
run: cargo binstall --no-confirm cargo-shear
- run: cargo shear

1
rust/.gitignore vendored Normal file
View File

@@ -0,0 +1 @@
/target

3822
rust/Cargo.lock generated Normal file

File diff suppressed because it is too large Load Diff

81
rust/Cargo.toml Normal file
View File

@@ -0,0 +1,81 @@
[workspace]
resolver = "2"
members = [
"capture",
"common/health",
"feature-flags",
"hook-api",
"hook-common",
"hook-janitor",
"hook-worker",
]
[workspace.lints.rust]
# See https://doc.rust-lang.org/stable/rustc/lints/listing/allowed-by-default.html
unsafe_code = "forbid" # forbid cannot be ignored with an annotation
unstable_features = "forbid"
macro_use_extern_crate = "forbid"
let_underscore_drop = "deny"
non_ascii_idents = "deny"
trivial_casts = "deny"
trivial_numeric_casts = "deny"
unit_bindings = "deny"
[workspace.lints.clippy]
# See https://rust-lang.github.io/rust-clippy/, we might want to add more
enum_glob_use = "deny"
[workspace.dependencies]
anyhow = "1.0"
assert-json-diff = "2.0.2"
async-trait = "0.1.74"
axum = { version = "0.7.5", features = ["http2", "macros", "matched-path"] }
axum-client-ip = "0.6.0"
base64 = "0.22.0"
bytes = "1"
chrono = { version = "0.4" }
envconfig = "0.10.0"
eyre = "0.6.9"
flate2 = "1.0"
futures = { version = "0.3.29" }
governor = { version = "0.5.1", features = ["dashmap"] }
http = { version = "1.1.0" }
http-body-util = "0.1.0"
metrics = "0.22.0"
metrics-exporter-prometheus = "0.14.0"
once_cell = "1.18.0"
opentelemetry = { version = "0.22.0", features = ["trace"]}
opentelemetry-otlp = "0.15.0"
opentelemetry_sdk = { version = "0.22.1", features = ["trace", "rt-tokio"] }
rand = "0.8.5"
rdkafka = { version = "0.36.0", features = ["cmake-build", "ssl", "tracing"] }
reqwest = { version = "0.12.3", features = ["json", "stream"] }
serde = { version = "1.0", features = ["derive"] }
serde_derive = { version = "1.0" }
serde_json = { version = "1.0" }
serde_urlencoded = "0.7.1"
sqlx = { version = "0.7", features = [
"chrono",
"json",
"migrate",
"postgres",
"runtime-tokio",
"tls-native-tls",
"uuid",
] }
time = { version = "0.3.20", features = [
"formatting",
"macros",
"parsing",
"serde",
] }
thiserror = { version = "1.0" }
tokio = { version = "1.34.0", features = ["full"] }
tower = "0.4.13"
tower-http = { version = "0.5.2", features = ["cors", "limit", "trace"] }
tracing = "0.1.40"
tracing-opentelemetry = "0.23.0"
tracing-subscriber = { version="0.3.18", features = ["env-filter"] }
url = { version = "2.5.0 " }
uuid = { version = "1.6.1", features = ["v7", "serde"] }

38
rust/Dockerfile Normal file
View File

@@ -0,0 +1,38 @@
FROM docker.io/lukemathwalker/cargo-chef:latest-rust-1.77-bookworm AS chef
ARG BIN
WORKDIR /app
FROM chef AS planner
ARG BIN
COPY . .
RUN cargo chef prepare --recipe-path recipe.json --bin $BIN
FROM chef AS builder
ARG BIN
# Ensure working C compile setup (not installed by default in arm64 images)
RUN apt update && apt install build-essential cmake -y
COPY --from=planner /app/recipe.json recipe.json
RUN cargo chef cook --release --recipe-path recipe.json
COPY . .
RUN cargo build --release --bin $BIN
FROM debian:bookworm-slim AS runtime
RUN apt-get update && \
apt-get install -y --no-install-recommends \
"ca-certificates" \
&& \
rm -rf /var/lib/apt/lists/*
ARG BIN
ENV BIN=$BIN
WORKDIR /app
USER nobody
COPY --from=builder /app/target/release/$BIN /usr/local/bin
ENTRYPOINT ["/bin/sh", "-c", "/usr/local/bin/$BIN"]

16
rust/Dockerfile.migrate Normal file
View File

@@ -0,0 +1,16 @@
FROM docker.io/library/rust:1.74.0-buster as builder
RUN apt update && apt install build-essential cmake -y
RUN cargo install sqlx-cli@0.7.3 --no-default-features --features native-tls,postgres --root /app/target/release/
FROM debian:bullseye-20230320-slim AS runtime
WORKDIR /sqlx
ADD bin /sqlx/bin/
ADD migrations /sqlx/migrations/
COPY --from=builder /app/target/release/bin/sqlx /usr/local/bin
RUN chmod +x ./bin/migrate
CMD ["./bin/migrate"]

21
rust/LICENSE Normal file
View File

@@ -0,0 +1,21 @@
MIT License
Copyright (c) 2023 PostHog
Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:
The above copyright notice and this permission notice shall be included in all
copies or substantial portions of the Software.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
SOFTWARE.

40
rust/README.md Normal file
View File

@@ -0,0 +1,40 @@
# hog-rs
PostHog Rust service monorepo. This is *not* the Rust client library for PostHog.
## capture
This is a rewrite of [capture.py](https://github.com/PostHog/posthog/blob/master/posthog/api/capture.py), in Rust.
### Why?
Capture is very simple. It takes some JSON, checks a key in Redis, and then pushes onto Kafka. It's mostly IO bound.
We currently use far too much compute to run this service, and it could be more efficient. This effort should not take too long to complete, but should massively reduce our CPU usage - and therefore spend.
### How?
I'm trying to ensure the rewrite at least vaguely resembles the Python version. This will both minimize accidental regressions, but also serve as a "rosetta stone" for engineers at PostHog who have not written Rust before.
## rusty-hook
A reliable and performant webhook system for PostHog
### Requirements
1. [Rust](https://www.rust-lang.org/tools/install).
2. [Docker](https://docs.docker.com/engine/install/), or [podman](https://podman.io/docs/installation) and [podman-compose](https://github.com/containers/podman-compose#installation): To setup development stack.
### Testing
1. Start development stack:
```bash
docker compose -f docker-compose.yml up -d --wait
```
2. Test:
```bash
# Note that tests require a DATABASE_URL environment variable to be set, e.g.:
# export DATABASE_URL=postgres://posthog:posthog@localhost:15432/test_database
# But there is an .env file in the project root that should be used automatically.
cargo test
```

4
rust/bin/migrate Executable file
View File

@@ -0,0 +1,4 @@
#!/bin/sh
sqlx database create
sqlx migrate run

53
rust/capture/Cargo.toml Normal file
View File

@@ -0,0 +1,53 @@
[package]
name = "capture"
version = "0.1.0"
edition = "2021"
[lints]
workspace = true
[dependencies]
anyhow = { workspace = true }
async-trait = { workspace = true }
axum = { workspace = true }
axum-client-ip = { workspace = true }
base64 = { workspace = true }
bytes = { workspace = true }
envconfig = { workspace = true }
flate2 = { workspace = true }
governor = { workspace = true }
health = { path = "../common/health" }
metrics = { workspace = true }
metrics-exporter-prometheus = { workspace = true }
opentelemetry = { workspace = true }
opentelemetry-otlp = { workspace = true }
opentelemetry_sdk = { workspace = true }
rand = { workspace = true }
rdkafka = { workspace = true }
redis = { version = "0.23.3", features = [
"tokio-comp",
"cluster",
"cluster-async",
] }
serde = { workspace = true }
serde_json = { workspace = true }
serde_urlencoded = { workspace = true }
thiserror = { workspace = true }
time = { workspace = true }
tokio = { workspace = true }
tower-http = { workspace = true }
tracing = { workspace = true }
tracing-opentelemetry = { workspace = true }
tracing-subscriber = { workspace = true }
uuid = { workspace = true }
[dev-dependencies]
assert-json-diff = { workspace = true }
axum-test-helper = { git = "https://github.com/posthog/axum-test-helper.git" } # TODO: remove, directly use reqwest like capture-server tests
anyhow = { workspace = true }
futures = { workspace = true }
once_cell = { workspace = true }
rand = { workspace = true }
rdkafka = { workspace = true }
reqwest = { workspace = true }
serde_json = { workspace = true }

109
rust/capture/src/api.rs Normal file
View File

@@ -0,0 +1,109 @@
use axum::http::StatusCode;
use axum::response::{IntoResponse, Response};
use serde::{Deserialize, Serialize};
use thiserror::Error;
use time::OffsetDateTime;
use uuid::Uuid;
use crate::token::InvalidTokenReason;
#[derive(Debug, PartialEq, Eq, Deserialize, Serialize)]
pub enum CaptureResponseCode {
Ok = 1,
}
#[derive(Debug, PartialEq, Eq, Deserialize, Serialize)]
pub struct CaptureResponse {
pub status: CaptureResponseCode,
}
#[derive(Error, Debug)]
pub enum CaptureError {
#[error("failed to decode request: {0}")]
RequestDecodingError(String),
#[error("failed to parse request: {0}")]
RequestParsingError(#[from] serde_json::Error),
#[error("request holds no event")]
EmptyBatch,
#[error("event submitted with an empty event name")]
MissingEventName,
#[error("event submitted with an empty distinct_id")]
EmptyDistinctId,
#[error("event submitted without a distinct_id")]
MissingDistinctId,
#[error("event submitted without an api_key")]
NoTokenError,
#[error("batch submitted with inconsistent api_key values")]
MultipleTokensError,
#[error("API key is not valid: {0}")]
TokenValidationError(#[from] InvalidTokenReason),
#[error("transient error, please retry")]
RetryableSinkError,
#[error("maximum event size exceeded")]
EventTooBig,
#[error("invalid event could not be processed")]
NonRetryableSinkError,
#[error("billing limit reached")]
BillingLimit,
#[error("rate limited")]
RateLimited,
}
impl IntoResponse for CaptureError {
fn into_response(self) -> Response {
match self {
CaptureError::RequestDecodingError(_)
| CaptureError::RequestParsingError(_)
| CaptureError::EmptyBatch
| CaptureError::MissingEventName
| CaptureError::EmptyDistinctId
| CaptureError::MissingDistinctId
| CaptureError::EventTooBig
| CaptureError::NonRetryableSinkError => (StatusCode::BAD_REQUEST, self.to_string()),
CaptureError::NoTokenError
| CaptureError::MultipleTokensError
| CaptureError::TokenValidationError(_) => (StatusCode::UNAUTHORIZED, self.to_string()),
CaptureError::RetryableSinkError => (StatusCode::SERVICE_UNAVAILABLE, self.to_string()),
CaptureError::BillingLimit | CaptureError::RateLimited => {
(StatusCode::TOO_MANY_REQUESTS, self.to_string())
}
}
.into_response()
}
}
#[derive(Debug, Copy, Clone, Eq, PartialEq)]
pub enum DataType {
AnalyticsMain,
AnalyticsHistorical,
}
#[derive(Clone, Debug, Serialize, Eq, PartialEq)]
pub struct ProcessedEvent {
#[serde(skip_serializing)]
pub data_type: DataType,
pub uuid: Uuid,
pub distinct_id: String,
pub ip: String,
pub data: String,
pub now: String,
#[serde(
with = "time::serde::rfc3339::option",
skip_serializing_if = "Option::is_none"
)]
pub sent_at: Option<OffsetDateTime>,
pub token: String,
}
impl ProcessedEvent {
pub fn key(&self) -> String {
format!("{}:{}", self.token, self.distinct_id)
}
}

View File

@@ -0,0 +1,57 @@
use std::{net::SocketAddr, num::NonZeroU32};
use envconfig::Envconfig;
#[derive(Envconfig, Clone)]
pub struct Config {
#[envconfig(default = "false")]
pub print_sink: bool,
#[envconfig(default = "127.0.0.1:3000")]
pub address: SocketAddr,
pub redis_url: String,
pub otel_url: Option<String>,
#[envconfig(default = "false")]
pub overflow_enabled: bool,
#[envconfig(default = "100")]
pub overflow_per_second_limit: NonZeroU32,
#[envconfig(default = "1000")]
pub overflow_burst_limit: NonZeroU32,
pub overflow_forced_keys: Option<String>, // Coma-delimited keys
#[envconfig(nested = true)]
pub kafka: KafkaConfig,
#[envconfig(default = "1.0")]
pub otel_sampling_rate: f64,
#[envconfig(default = "capture")]
pub otel_service_name: String,
#[envconfig(default = "true")]
pub export_prometheus: bool,
}
#[derive(Envconfig, Clone)]
pub struct KafkaConfig {
#[envconfig(default = "20")]
pub kafka_producer_linger_ms: u32, // Maximum time between producer batches during low traffic
#[envconfig(default = "400")]
pub kafka_producer_queue_mib: u32, // Size of the in-memory producer queue in mebibytes
#[envconfig(default = "20000")]
pub kafka_message_timeout_ms: u32, // Time before we stop retrying producing a message: 20 seconds
#[envconfig(default = "none")]
pub kafka_compression_codec: String, // none, gzip, snappy, lz4, zstd
pub kafka_hosts: String,
#[envconfig(default = "events_plugin_ingestion")]
pub kafka_topic: String,
#[envconfig(default = "events_plugin_ingestion_historical")]
pub kafka_historical_topic: String,
#[envconfig(default = "false")]
pub kafka_tls: bool,
}

13
rust/capture/src/lib.rs Normal file
View File

@@ -0,0 +1,13 @@
pub mod api;
pub mod config;
pub mod limiters;
pub mod prometheus;
pub mod redis;
pub mod router;
pub mod server;
pub mod sinks;
pub mod time;
pub mod token;
pub mod utils;
pub mod v0_endpoint;
pub mod v0_request;

View File

@@ -0,0 +1,196 @@
use std::{collections::HashSet, ops::Sub, sync::Arc};
use crate::redis::Client;
/// Limit accounts by team ID if they hit a billing limit
///
/// We have an async celery worker that regularly checks on accounts + assesses if they are beyond
/// a billing limit. If this is the case, a key is set in redis.
///
/// Requirements
///
/// 1. Updates from the celery worker should be reflected in capture within a short period of time
/// 2. Capture should cope with redis being _totally down_, and fail open
/// 3. We should not hit redis for every single request
///
/// The solution here is to read from the cache until a time interval is hit, and then fetch new
/// data. The write requires taking a lock that stalls all readers, though so long as redis reads
/// stay fast we're ok.
///
/// Some small delay between an account being limited and the limit taking effect is acceptable.
/// However, ideally we should not allow requests from some pods but 429 from others.
use thiserror::Error;
use time::{Duration, OffsetDateTime};
use tokio::sync::RwLock;
use tracing::instrument;
// todo: fetch from env
const QUOTA_LIMITER_CACHE_KEY: &str = "@posthog/quota-limits/";
#[derive(Debug)]
pub enum QuotaResource {
Events,
Recordings,
}
impl QuotaResource {
fn as_str(&self) -> &'static str {
match self {
Self::Events => "events",
Self::Recordings => "recordings",
}
}
}
#[derive(Error, Debug)]
pub enum LimiterError {
#[error("updater already running - there can only be one")]
UpdaterRunning,
}
#[derive(Clone)]
pub struct BillingLimiter {
limited: Arc<RwLock<HashSet<String>>>,
redis: Arc<dyn Client + Send + Sync>,
interval: Duration,
updated: Arc<RwLock<OffsetDateTime>>,
}
impl BillingLimiter {
/// Create a new BillingLimiter.
///
/// This connects to a redis cluster - pass in a vec of addresses for the initial nodes.
///
/// You can also initialize the limiter with a set of tokens to limit from the very beginning.
/// This may be overridden by Redis, if the sets differ,
///
/// Pass an empty redis node list to only use this initial set.
pub fn new(
interval: Duration,
redis: Arc<dyn Client + Send + Sync>,
) -> anyhow::Result<BillingLimiter> {
let limited = Arc::new(RwLock::new(HashSet::new()));
// Force an update immediately if we have any reasonable interval
let updated = OffsetDateTime::from_unix_timestamp(0)?;
let updated = Arc::new(RwLock::new(updated));
Ok(BillingLimiter {
interval,
limited,
updated,
redis,
})
}
#[instrument(skip_all)]
async fn fetch_limited(
client: &Arc<dyn Client + Send + Sync>,
resource: QuotaResource,
) -> anyhow::Result<Vec<String>> {
let now = time::OffsetDateTime::now_utc().unix_timestamp();
client
.zrangebyscore(
format!("{QUOTA_LIMITER_CACHE_KEY}{}", resource.as_str()),
now.to_string(),
String::from("+Inf"),
)
.await
}
#[instrument(skip_all, fields(key = key))]
pub async fn is_limited(&self, key: &str, resource: QuotaResource) -> bool {
// hold the read lock to clone it, very briefly. clone is ok because it's very small 🤏
// rwlock can have many readers, but one writer. the writer will wait in a queue with all
// the readers, so we want to hold read locks for the smallest time possible to avoid
// writers waiting for too long. and vice versa.
let updated = {
let updated = self.updated.read().await;
*updated
};
let now = OffsetDateTime::now_utc();
let since_update = now.sub(updated);
// If an update is due, fetch the set from redis + cache it until the next update is due.
// Otherwise, return a value from the cache
//
// This update will block readers! Keep it fast.
if since_update > self.interval {
// open the update lock to change the update, and prevent anyone else from doing so
let mut updated = self.updated.write().await;
*updated = OffsetDateTime::now_utc();
let span = tracing::debug_span!("updating billing cache from redis");
let _span = span.enter();
// a few requests might end up in here concurrently, but I don't think a few extra will
// be a big problem. If it is, we can rework the concurrency a bit.
// On prod atm we call this around 15 times per second at peak times, and it usually
// completes in <1ms.
let set = Self::fetch_limited(&self.redis, resource).await;
tracing::debug!("fetched set from redis, caching");
if let Ok(set) = set {
let set = HashSet::from_iter(set.iter().cloned());
let mut limited = self.limited.write().await;
*limited = set;
tracing::debug!("updated cache from redis");
limited.contains(key)
} else {
tracing::error!("failed to fetch from redis in time, failing open");
// If we fail to fetch the set, something really wrong is happening. To avoid
// dropping events that we don't mean to drop, fail open and accept data. Better
// than angry customers :)
//
// TODO: Consider backing off our redis checks
false
}
} else {
let l = self.limited.read().await;
l.contains(key)
}
}
}
#[cfg(test)]
mod tests {
use std::sync::Arc;
use time::Duration;
use crate::{
limiters::billing::{BillingLimiter, QuotaResource},
redis::MockRedisClient,
};
#[tokio::test]
async fn test_dynamic_limited() {
let client = MockRedisClient::new().zrangebyscore_ret(vec![String::from("banana")]);
let client = Arc::new(client);
let limiter = BillingLimiter::new(Duration::microseconds(1), client)
.expect("Failed to create billing limiter");
assert_eq!(
limiter
.is_limited("idk it doesn't matter", QuotaResource::Events)
.await,
false
);
assert_eq!(
limiter
.is_limited("some_org_hit_limits", QuotaResource::Events)
.await,
false
);
assert!(limiter.is_limited("banana", QuotaResource::Events).await);
}
}

View File

@@ -0,0 +1,2 @@
pub mod billing;
pub mod overflow;

View File

@@ -0,0 +1,127 @@
/// The analytics ingestion pipeline provides ordering guarantees for events of the same
/// token and distinct_id. We currently achieve this through a locality constraint on the
/// Kafka partition (consistent partition hashing through a computed key).
///
/// Volume spikes to a given key can create lag on the destination partition and induce
/// ingestion lag. To protect the downstream systems, capture can relax this locality
/// constraint when bursts are detected. When that happens, the excess traffic will be
/// spread across all partitions and be processed by the overflow consumer, without
/// strict ordering guarantees.
use std::collections::HashSet;
use std::num::NonZeroU32;
use std::sync::Arc;
use governor::{clock, state::keyed::DefaultKeyedStateStore, Quota, RateLimiter};
use metrics::gauge;
use rand::Rng;
// See: https://docs.rs/governor/latest/governor/_guide/index.html#usage-in-multiple-threads
#[derive(Clone)]
pub struct OverflowLimiter {
limiter: Arc<RateLimiter<String, DefaultKeyedStateStore<String>, clock::DefaultClock>>,
forced_keys: HashSet<String>,
}
impl OverflowLimiter {
pub fn new(per_second: NonZeroU32, burst: NonZeroU32, forced_keys: Option<String>) -> Self {
let quota = Quota::per_second(per_second).allow_burst(burst);
let limiter = Arc::new(governor::RateLimiter::dashmap(quota));
let forced_keys: HashSet<String> = match forced_keys {
None => HashSet::new(),
Some(values) => values.split(',').map(String::from).collect(),
};
OverflowLimiter {
limiter,
forced_keys,
}
}
pub fn is_limited(&self, key: &String) -> bool {
self.forced_keys.contains(key) || self.limiter.check_key(key).is_err()
}
/// Reports the number of tracked keys to prometheus every 10 seconds,
/// needs to be spawned in a separate task.
pub async fn report_metrics(&self) {
let mut interval = tokio::time::interval(tokio::time::Duration::from_secs(10));
loop {
interval.tick().await;
gauge!("partition_limits_key_count").set(self.limiter.len() as f64);
}
}
/// Clean up the rate limiter state, once per minute. Ensure we don't use more memory than
/// necessary.
pub async fn clean_state(&self) {
// Give a small amount of randomness to the interval to ensure we don't have all replicas
// locking at the same time. The lock isn't going to be held for long, but this will reduce
// impact regardless.
let interval_secs = rand::thread_rng().gen_range(60..70);
let mut interval = tokio::time::interval(tokio::time::Duration::from_secs(interval_secs));
loop {
interval.tick().await;
self.limiter.retain_recent();
self.limiter.shrink_to_fit();
}
}
}
#[cfg(test)]
mod tests {
use crate::limiters::overflow::OverflowLimiter;
use std::num::NonZeroU32;
#[tokio::test]
async fn low_limits() {
let limiter = OverflowLimiter::new(
NonZeroU32::new(1).unwrap(),
NonZeroU32::new(1).unwrap(),
None,
);
let token = String::from("test");
assert!(!limiter.is_limited(&token));
assert!(limiter.is_limited(&token));
}
#[tokio::test]
async fn bursting() {
let limiter = OverflowLimiter::new(
NonZeroU32::new(1).unwrap(),
NonZeroU32::new(3).unwrap(),
None,
);
let token = String::from("test");
assert!(!limiter.is_limited(&token));
assert!(!limiter.is_limited(&token));
assert!(!limiter.is_limited(&token));
assert!(limiter.is_limited(&token));
}
#[tokio::test]
async fn forced_key() {
let key_one = String::from("one");
let key_two = String::from("two");
let key_three = String::from("three");
let forced_keys = Some(String::from("one,three"));
let limiter = OverflowLimiter::new(
NonZeroU32::new(1).unwrap(),
NonZeroU32::new(1).unwrap(),
forced_keys,
);
// One and three are limited from the start, two is not
assert!(limiter.is_limited(&key_one));
assert!(!limiter.is_limited(&key_two));
assert!(limiter.is_limited(&key_three));
// Two is limited on the second event
assert!(limiter.is_limited(&key_two));
}
}

88
rust/capture/src/main.rs Normal file
View File

@@ -0,0 +1,88 @@
use std::time::Duration;
use envconfig::Envconfig;
use opentelemetry::{KeyValue, Value};
use opentelemetry_otlp::WithExportConfig;
use opentelemetry_sdk::trace::{BatchConfig, RandomIdGenerator, Sampler, Tracer};
use opentelemetry_sdk::{runtime, Resource};
use tokio::signal;
use tracing::level_filters::LevelFilter;
use tracing::Level;
use tracing_opentelemetry::OpenTelemetryLayer;
use tracing_subscriber::layer::SubscriberExt;
use tracing_subscriber::util::SubscriberInitExt;
use tracing_subscriber::{EnvFilter, Layer};
use capture::config::Config;
use capture::server::serve;
async fn shutdown() {
let mut term = signal::unix::signal(signal::unix::SignalKind::terminate())
.expect("failed to register SIGTERM handler");
let mut interrupt = signal::unix::signal(signal::unix::SignalKind::interrupt())
.expect("failed to register SIGINT handler");
tokio::select! {
_ = term.recv() => {},
_ = interrupt.recv() => {},
};
tracing::info!("Shutting down gracefully...");
}
fn init_tracer(sink_url: &str, sampling_rate: f64, service_name: &str) -> Tracer {
opentelemetry_otlp::new_pipeline()
.tracing()
.with_trace_config(
opentelemetry_sdk::trace::Config::default()
.with_sampler(Sampler::ParentBased(Box::new(Sampler::TraceIdRatioBased(
sampling_rate,
))))
.with_id_generator(RandomIdGenerator::default())
.with_resource(Resource::new(vec![KeyValue::new(
"service.name",
Value::from(service_name.to_string()),
)])),
)
.with_batch_config(BatchConfig::default())
.with_exporter(
opentelemetry_otlp::new_exporter()
.tonic()
.with_endpoint(sink_url)
.with_timeout(Duration::from_secs(3)),
)
.install_batch(runtime::Tokio)
.unwrap()
}
#[tokio::main]
async fn main() {
let config = Config::init_from_env().expect("Invalid configuration:");
// Instantiate tracing outputs:
// - stdout with a level configured by the RUST_LOG envvar (default=ERROR)
// - OpenTelemetry if enabled, for levels INFO and higher
let log_layer = tracing_subscriber::fmt::layer().with_filter(EnvFilter::from_default_env());
let otel_layer = config
.otel_url
.clone()
.map(|url| {
OpenTelemetryLayer::new(init_tracer(
&url,
config.otel_sampling_rate,
&config.otel_service_name,
))
})
.with_filter(LevelFilter::from_level(Level::INFO));
tracing_subscriber::registry()
.with(log_layer)
.with(otel_layer)
.init();
// Open the TCP port and start the server
let listener = tokio::net::TcpListener::bind(config.address)
.await
.expect("could not bind port");
serve(config, listener, shutdown()).await
}

View File

@@ -0,0 +1,70 @@
// Middleware + prometheus exporter setup
use std::time::Instant;
use axum::body::Body;
use axum::{extract::MatchedPath, http::Request, middleware::Next, response::IntoResponse};
use metrics::counter;
use metrics_exporter_prometheus::{Matcher, PrometheusBuilder, PrometheusHandle};
pub fn report_dropped_events(cause: &'static str, quantity: u64) {
counter!("capture_events_dropped_total", "cause" => cause).increment(quantity);
}
pub fn report_overflow_partition(quantity: u64) {
counter!("capture_partition_key_capacity_exceeded_total").increment(quantity);
}
pub fn setup_metrics_recorder() -> PrometheusHandle {
// Ok I broke it at the end, but the limit on our ingress is 60 and that's a nicer way of reaching it
const EXPONENTIAL_SECONDS: &[f64] = &[
0.005, 0.01, 0.025, 0.05, 0.1, 0.25, 0.5, 1.0, 2.5, 5.0, 10.0, 30.0, 60.0,
];
const BATCH_SIZES: &[f64] = &[
1.0, 10.0, 25.0, 50.0, 75.0, 100.0, 250.0, 500.0, 750.0, 1000.0,
];
PrometheusBuilder::new()
.set_buckets_for_metric(
Matcher::Full("http_requests_duration_seconds".to_string()),
EXPONENTIAL_SECONDS,
)
.unwrap()
.set_buckets_for_metric(Matcher::Suffix("_batch_size".to_string()), BATCH_SIZES)
.unwrap()
.install_recorder()
.unwrap()
}
/// Middleware to record some common HTTP metrics
/// Generic over B to allow for arbitrary body types (eg Vec<u8>, Streams, a deserialized thing, etc)
/// Someday tower-http might provide a metrics middleware: https://github.com/tower-rs/tower-http/issues/57
pub async fn track_metrics(req: Request<Body>, next: Next) -> impl IntoResponse {
let start = Instant::now();
let path = if let Some(matched_path) = req.extensions().get::<MatchedPath>() {
matched_path.as_str().to_owned()
} else {
req.uri().path().to_owned()
};
let method = req.method().clone();
// Run the rest of the request handling first, so we can measure it and get response
// codes.
let response = next.run(req).await;
let latency = start.elapsed().as_secs_f64();
let status = response.status().as_u16().to_string();
let labels = [
("method", method.to_string()),
("path", path),
("status", status),
];
metrics::counter!("http_requests_total", &labels).increment(1);
metrics::histogram!("http_requests_duration_seconds", &labels).record(latency);
response
}

80
rust/capture/src/redis.rs Normal file
View File

@@ -0,0 +1,80 @@
use std::time::Duration;
use anyhow::Result;
use async_trait::async_trait;
use redis::AsyncCommands;
use tokio::time::timeout;
// average for all commands is <10ms, check grafana
const REDIS_TIMEOUT_MILLISECS: u64 = 10;
/// A simple redis wrapper
/// I'm currently just exposing the commands we use, for ease of implementation
/// Allows for testing + injecting failures
/// We can also swap it out for alternative implementations in the future
/// I tried using redis-rs Connection/ConnectionLike traits but honestly things just got really
/// awkward to work with.
#[async_trait]
pub trait Client {
// A very simplified wrapper, but works for our usage
async fn zrangebyscore(&self, k: String, min: String, max: String) -> Result<Vec<String>>;
}
pub struct RedisClient {
client: redis::Client,
}
impl RedisClient {
pub fn new(addr: String) -> Result<RedisClient> {
let client = redis::Client::open(addr)?;
Ok(RedisClient { client })
}
}
#[async_trait]
impl Client for RedisClient {
async fn zrangebyscore(&self, k: String, min: String, max: String) -> Result<Vec<String>> {
let mut conn = self.client.get_async_connection().await?;
let results = conn.zrangebyscore(k, min, max);
let fut = timeout(Duration::from_secs(REDIS_TIMEOUT_MILLISECS), results).await?;
Ok(fut?)
}
}
// mockall got really annoying with async and results so I'm just gonna do my own
#[derive(Clone)]
pub struct MockRedisClient {
zrangebyscore_ret: Vec<String>,
}
impl MockRedisClient {
pub fn new() -> MockRedisClient {
MockRedisClient {
zrangebyscore_ret: Vec::new(),
}
}
pub fn zrangebyscore_ret(&mut self, ret: Vec<String>) -> Self {
self.zrangebyscore_ret = ret;
self.clone()
}
}
impl Default for MockRedisClient {
fn default() -> Self {
Self::new()
}
}
#[async_trait]
impl Client for MockRedisClient {
// A very simplified wrapper, but works for our usage
async fn zrangebyscore(&self, _k: String, _min: String, _max: String) -> Result<Vec<String>> {
Ok(self.zrangebyscore_ret.clone())
}
}

113
rust/capture/src/router.rs Normal file
View File

@@ -0,0 +1,113 @@
use std::future::ready;
use std::sync::Arc;
use axum::http::Method;
use axum::{
routing::{get, post},
Router,
};
use health::HealthRegistry;
use tower_http::cors::{AllowHeaders, AllowOrigin, CorsLayer};
use tower_http::trace::TraceLayer;
use crate::{
limiters::billing::BillingLimiter, redis::Client, sinks, time::TimeSource, v0_endpoint,
};
use crate::prometheus::{setup_metrics_recorder, track_metrics};
#[derive(Clone)]
pub struct State {
pub sink: Arc<dyn sinks::Event + Send + Sync>,
pub timesource: Arc<dyn TimeSource + Send + Sync>,
pub redis: Arc<dyn Client + Send + Sync>,
pub billing: BillingLimiter,
}
async fn index() -> &'static str {
"capture"
}
pub fn router<
TZ: TimeSource + Send + Sync + 'static,
S: sinks::Event + Send + Sync + 'static,
R: Client + Send + Sync + 'static,
>(
timesource: TZ,
liveness: HealthRegistry,
sink: S,
redis: Arc<R>,
billing: BillingLimiter,
metrics: bool,
) -> Router {
let state = State {
sink: Arc::new(sink),
timesource: Arc::new(timesource),
redis,
billing,
};
// Very permissive CORS policy, as old SDK versions
// and reverse proxies might send funky headers.
let cors = CorsLayer::new()
.allow_methods([Method::GET, Method::POST, Method::OPTIONS])
.allow_headers(AllowHeaders::mirror_request())
.allow_credentials(true)
.allow_origin(AllowOrigin::mirror_request());
let router = Router::new()
// TODO: use NormalizePathLayer::trim_trailing_slash
.route("/", get(index))
.route("/_readiness", get(index))
.route("/_liveness", get(move || ready(liveness.get_status())))
.route(
"/e",
post(v0_endpoint::event)
.get(v0_endpoint::event)
.options(v0_endpoint::options),
)
.route(
"/e/",
post(v0_endpoint::event)
.get(v0_endpoint::event)
.options(v0_endpoint::options),
)
.route(
"/batch",
post(v0_endpoint::event)
.get(v0_endpoint::event)
.options(v0_endpoint::options),
)
.route(
"/batch/",
post(v0_endpoint::event)
.get(v0_endpoint::event)
.options(v0_endpoint::options),
)
.route(
"/i/v0/e",
post(v0_endpoint::event)
.get(v0_endpoint::event)
.options(v0_endpoint::options),
)
.route(
"/i/v0/e/",
post(v0_endpoint::event)
.get(v0_endpoint::event)
.options(v0_endpoint::options),
)
.layer(TraceLayer::new_for_http())
.layer(cors)
.layer(axum::middleware::from_fn(track_metrics))
.with_state(state);
// Don't install metrics unless asked to
// Installing a global recorder when capture is used as a library (during tests etc)
// does not work well.
if metrics {
let recorder_handle = setup_metrics_recorder();
router.route("/metrics", get(move || ready(recorder_handle.render())))
} else {
router
}
}

View File

@@ -0,0 +1,98 @@
use std::future::Future;
use std::net::SocketAddr;
use std::sync::Arc;
use health::{ComponentStatus, HealthRegistry};
use time::Duration;
use tokio::net::TcpListener;
use crate::config::Config;
use crate::limiters::billing::BillingLimiter;
use crate::limiters::overflow::OverflowLimiter;
use crate::redis::RedisClient;
use crate::router;
use crate::sinks::kafka::KafkaSink;
use crate::sinks::print::PrintSink;
pub async fn serve<F>(config: Config, listener: TcpListener, shutdown: F)
where
F: Future<Output = ()> + Send + 'static,
{
let liveness = HealthRegistry::new("liveness");
let redis_client =
Arc::new(RedisClient::new(config.redis_url).expect("failed to create redis client"));
let billing = BillingLimiter::new(Duration::seconds(5), redis_client.clone())
.expect("failed to create billing limiter");
let app = if config.print_sink {
// Print sink is only used for local debug, don't allow a container with it to run on prod
liveness
.register("print_sink".to_string(), Duration::seconds(30))
.await
.report_status(ComponentStatus::Unhealthy)
.await;
router::router(
crate::time::SystemTime {},
liveness,
PrintSink {},
redis_client,
billing,
config.export_prometheus,
)
} else {
let sink_liveness = liveness
.register("rdkafka".to_string(), Duration::seconds(30))
.await;
let partition = match config.overflow_enabled {
false => None,
true => {
let partition = OverflowLimiter::new(
config.overflow_per_second_limit,
config.overflow_burst_limit,
config.overflow_forced_keys,
);
if config.export_prometheus {
let partition = partition.clone();
tokio::spawn(async move {
partition.report_metrics().await;
});
}
{
// Ensure that the rate limiter state does not grow unbounded
let partition = partition.clone();
tokio::spawn(async move {
partition.clean_state().await;
});
}
Some(partition)
}
};
let sink = KafkaSink::new(config.kafka, sink_liveness, partition)
.expect("failed to start Kafka sink");
router::router(
crate::time::SystemTime {},
liveness,
sink,
redis_client,
billing,
config.export_prometheus,
)
};
// run our app with hyper
// `axum::Server` is a re-export of `hyper::Server`
tracing::info!("listening on {:?}", listener.local_addr().unwrap());
axum::serve(
listener,
app.into_make_service_with_connect_info::<SocketAddr>(),
)
.with_graceful_shutdown(shutdown)
.await
.unwrap()
}

View File

@@ -0,0 +1,435 @@
use std::time::Duration;
use async_trait::async_trait;
use health::HealthHandle;
use metrics::{counter, gauge, histogram};
use rdkafka::error::{KafkaError, RDKafkaErrorCode};
use rdkafka::producer::{DeliveryFuture, FutureProducer, FutureRecord, Producer};
use rdkafka::util::Timeout;
use rdkafka::ClientConfig;
use tokio::task::JoinSet;
use tracing::log::{debug, error, info};
use tracing::{info_span, instrument, Instrument};
use crate::api::{CaptureError, DataType, ProcessedEvent};
use crate::config::KafkaConfig;
use crate::limiters::overflow::OverflowLimiter;
use crate::prometheus::report_dropped_events;
use crate::sinks::Event;
struct KafkaContext {
liveness: HealthHandle,
}
impl rdkafka::ClientContext for KafkaContext {
fn stats(&self, stats: rdkafka::Statistics) {
// Signal liveness, as the main rdkafka loop is running and calling us
self.liveness.report_healthy_blocking();
// Update exported metrics
gauge!("capture_kafka_callback_queue_depth",).set(stats.replyq as f64);
gauge!("capture_kafka_producer_queue_depth",).set(stats.msg_cnt as f64);
gauge!("capture_kafka_producer_queue_depth_limit",).set(stats.msg_max as f64);
gauge!("capture_kafka_producer_queue_bytes",).set(stats.msg_max as f64);
gauge!("capture_kafka_producer_queue_bytes_limit",).set(stats.msg_size_max as f64);
for (topic, stats) in stats.topics {
gauge!(
"capture_kafka_produce_avg_batch_size_bytes",
"topic" => topic.clone()
)
.set(stats.batchsize.avg as f64);
gauge!(
"capture_kafka_produce_avg_batch_size_events",
"topic" => topic
)
.set(stats.batchcnt.avg as f64);
}
for (_, stats) in stats.brokers {
let id_string = format!("{}", stats.nodeid);
if let Some(rtt) = stats.rtt {
gauge!(
"capture_kafka_produce_rtt_latency_us",
"quantile" => "p50",
"broker" => id_string.clone()
)
.set(rtt.p50 as f64);
gauge!(
"capture_kafka_produce_rtt_latency_us",
"quantile" => "p90",
"broker" => id_string.clone()
)
.set(rtt.p90 as f64);
gauge!(
"capture_kafka_produce_rtt_latency_us",
"quantile" => "p95",
"broker" => id_string.clone()
)
.set(rtt.p95 as f64);
gauge!(
"capture_kafka_produce_rtt_latency_us",
"quantile" => "p99",
"broker" => id_string.clone()
)
.set(rtt.p99 as f64);
}
gauge!(
"capture_kafka_broker_requests_pending",
"broker" => id_string.clone()
)
.set(stats.outbuf_cnt as f64);
gauge!(
"capture_kafka_broker_responses_awaiting",
"broker" => id_string.clone()
)
.set(stats.waitresp_cnt as f64);
counter!(
"capture_kafka_broker_tx_errors_total",
"broker" => id_string.clone()
)
.absolute(stats.txerrs);
counter!(
"capture_kafka_broker_rx_errors_total",
"broker" => id_string.clone()
)
.absolute(stats.rxerrs);
counter!(
"capture_kafka_broker_request_timeouts",
"broker" => id_string
)
.absolute(stats.req_timeouts);
}
}
}
#[derive(Clone)]
pub struct KafkaSink {
producer: FutureProducer<KafkaContext>,
partition: Option<OverflowLimiter>,
main_topic: String,
historical_topic: String,
}
impl KafkaSink {
pub fn new(
config: KafkaConfig,
liveness: HealthHandle,
partition: Option<OverflowLimiter>,
) -> anyhow::Result<KafkaSink> {
info!("connecting to Kafka brokers at {}...", config.kafka_hosts);
let mut client_config = ClientConfig::new();
client_config
.set("bootstrap.servers", &config.kafka_hosts)
.set("statistics.interval.ms", "10000")
.set("partitioner", "murmur2_random") // Compatibility with python-kafka
.set("linger.ms", config.kafka_producer_linger_ms.to_string())
.set(
"message.timeout.ms",
config.kafka_message_timeout_ms.to_string(),
)
.set("compression.codec", config.kafka_compression_codec)
.set(
"queue.buffering.max.kbytes",
(config.kafka_producer_queue_mib * 1024).to_string(),
);
if config.kafka_tls {
client_config
.set("security.protocol", "ssl")
.set("enable.ssl.certificate.verification", "false");
};
debug!("rdkafka configuration: {:?}", client_config);
let producer: FutureProducer<KafkaContext> =
client_config.create_with_context(KafkaContext { liveness })?;
// Ping the cluster to make sure we can reach brokers, fail after 10 seconds
drop(producer.client().fetch_metadata(
Some("__consumer_offsets"),
Timeout::After(Duration::new(10, 0)),
)?);
info!("connected to Kafka brokers");
Ok(KafkaSink {
producer,
partition,
main_topic: config.kafka_topic,
historical_topic: config.kafka_historical_topic,
})
}
pub fn flush(&self) -> Result<(), KafkaError> {
// TODO: hook it up on shutdown
self.producer.flush(Duration::new(30, 0))
}
async fn kafka_send(&self, event: ProcessedEvent) -> Result<DeliveryFuture, CaptureError> {
let payload = serde_json::to_string(&event).map_err(|e| {
error!("failed to serialize event: {}", e);
CaptureError::NonRetryableSinkError
})?;
let event_key = event.key();
let (topic, partition_key): (&str, Option<&str>) = match &event.data_type {
DataType::AnalyticsHistorical => (&self.historical_topic, Some(event_key.as_str())), // We never trigger overflow on historical events
DataType::AnalyticsMain => {
// TODO: deprecate capture-led overflow or move logic in handler
let is_limited = match &self.partition {
None => false,
Some(partition) => partition.is_limited(&event_key),
};
if is_limited {
(&self.main_topic, None) // Analytics overflow goes to the main topic without locality
} else {
(&self.main_topic, Some(event_key.as_str()))
}
}
};
match self.producer.send_result(FutureRecord {
topic,
payload: Some(&payload),
partition: None,
key: partition_key,
timestamp: None,
headers: None,
}) {
Ok(ack) => Ok(ack),
Err((e, _)) => match e.rdkafka_error_code() {
Some(RDKafkaErrorCode::MessageSizeTooLarge) => {
report_dropped_events("kafka_message_size", 1);
Err(CaptureError::EventTooBig)
}
_ => {
// TODO(maybe someday): Don't drop them but write them somewhere and try again
report_dropped_events("kafka_write_error", 1);
error!("failed to produce event: {}", e);
Err(CaptureError::RetryableSinkError)
}
},
}
}
async fn process_ack(delivery: DeliveryFuture) -> Result<(), CaptureError> {
match delivery.await {
Err(_) => {
// Cancelled due to timeout while retrying
counter!("capture_kafka_produce_errors_total").increment(1);
error!("failed to produce to Kafka before write timeout");
Err(CaptureError::RetryableSinkError)
}
Ok(Err((KafkaError::MessageProduction(RDKafkaErrorCode::MessageSizeTooLarge), _))) => {
// Rejected by broker due to message size
report_dropped_events("kafka_message_size", 1);
Err(CaptureError::EventTooBig)
}
Ok(Err((err, _))) => {
// Unretriable produce error
counter!("capture_kafka_produce_errors_total").increment(1);
error!("failed to produce to Kafka: {}", err);
Err(CaptureError::RetryableSinkError)
}
Ok(Ok(_)) => {
counter!("capture_events_ingested_total").increment(1);
Ok(())
}
}
}
}
#[async_trait]
impl Event for KafkaSink {
#[instrument(skip_all)]
async fn send(&self, event: ProcessedEvent) -> Result<(), CaptureError> {
let ack = self.kafka_send(event).await?;
histogram!("capture_event_batch_size").record(1.0);
Self::process_ack(ack)
.instrument(info_span!("ack_wait_one"))
.await
}
#[instrument(skip_all)]
async fn send_batch(&self, events: Vec<ProcessedEvent>) -> Result<(), CaptureError> {
let mut set = JoinSet::new();
let batch_size = events.len();
for event in events {
// We await kafka_send to get events in the producer queue sequentially
let ack = self.kafka_send(event).await?;
// Then stash the returned DeliveryFuture, waiting concurrently for the write ACKs from brokers.
set.spawn(Self::process_ack(ack));
}
// Await on all the produce promises, fail batch on first failure
async move {
while let Some(res) = set.join_next().await {
match res {
Ok(Ok(_)) => {}
Ok(Err(err)) => {
set.abort_all();
return Err(err);
}
Err(err) => {
set.abort_all();
error!("join error while waiting on Kafka ACK: {:?}", err);
return Err(CaptureError::RetryableSinkError);
}
}
}
Ok(())
}
.instrument(info_span!("ack_wait_many"))
.await?;
histogram!("capture_event_batch_size").record(batch_size as f64);
Ok(())
}
}
#[cfg(test)]
mod tests {
use crate::api::{CaptureError, DataType, ProcessedEvent};
use crate::config;
use crate::limiters::overflow::OverflowLimiter;
use crate::sinks::kafka::KafkaSink;
use crate::sinks::Event;
use crate::utils::uuid_v7;
use health::HealthRegistry;
use rand::distributions::Alphanumeric;
use rand::Rng;
use rdkafka::mocking::MockCluster;
use rdkafka::producer::DefaultProducerContext;
use rdkafka::types::{RDKafkaApiKey, RDKafkaRespErr};
use std::num::NonZeroU32;
use time::Duration;
async fn start_on_mocked_sink() -> (MockCluster<'static, DefaultProducerContext>, KafkaSink) {
let registry = HealthRegistry::new("liveness");
let handle = registry
.register("one".to_string(), Duration::seconds(30))
.await;
let limiter = Some(OverflowLimiter::new(
NonZeroU32::new(10).unwrap(),
NonZeroU32::new(10).unwrap(),
None,
));
let cluster = MockCluster::new(1).expect("failed to create mock brokers");
let config = config::KafkaConfig {
kafka_producer_linger_ms: 0,
kafka_producer_queue_mib: 50,
kafka_message_timeout_ms: 500,
kafka_compression_codec: "none".to_string(),
kafka_hosts: cluster.bootstrap_servers(),
kafka_topic: "events_plugin_ingestion".to_string(),
kafka_historical_topic: "events_plugin_ingestion_historical".to_string(),
kafka_tls: false,
};
let sink = KafkaSink::new(config, handle, limiter).expect("failed to create sink");
(cluster, sink)
}
#[tokio::test]
async fn kafka_sink_error_handling() {
// Uses a mocked Kafka broker that allows injecting write errors, to check error handling.
// We test different cases in a single test to amortize the startup cost of the producer.
let (cluster, sink) = start_on_mocked_sink().await;
let event: ProcessedEvent = ProcessedEvent {
data_type: DataType::AnalyticsMain,
uuid: uuid_v7(),
distinct_id: "id1".to_string(),
ip: "".to_string(),
data: "".to_string(),
now: "".to_string(),
sent_at: None,
token: "token1".to_string(),
};
// Wait for producer to be healthy, to keep kafka_message_timeout_ms short and tests faster
for _ in 0..20 {
if sink.send(event.clone()).await.is_ok() {
break;
}
}
// Send events to confirm happy path
sink.send(event.clone())
.await
.expect("failed to send one initial event");
sink.send_batch(vec![event.clone(), event.clone()])
.await
.expect("failed to send initial event batch");
// Producer should reject a 2MB message, twice the default `message.max.bytes`
let big_data = rand::thread_rng()
.sample_iter(Alphanumeric)
.take(2_000_000)
.map(char::from)
.collect();
let big_event: ProcessedEvent = ProcessedEvent {
data_type: DataType::AnalyticsMain,
uuid: uuid_v7(),
distinct_id: "id1".to_string(),
ip: "".to_string(),
data: big_data,
now: "".to_string(),
sent_at: None,
token: "token1".to_string(),
};
match sink.send(big_event).await {
Err(CaptureError::EventTooBig) => {} // Expected
Err(err) => panic!("wrong error code {}", err),
Ok(()) => panic!("should have errored"),
};
// Simulate unretriable errors
cluster.clear_request_errors(RDKafkaApiKey::Produce);
let err = [RDKafkaRespErr::RD_KAFKA_RESP_ERR_MSG_SIZE_TOO_LARGE; 1];
cluster.request_errors(RDKafkaApiKey::Produce, &err);
match sink.send(event.clone()).await {
Err(CaptureError::EventTooBig) => {} // Expected
Err(err) => panic!("wrong error code {}", err),
Ok(()) => panic!("should have errored"),
};
cluster.clear_request_errors(RDKafkaApiKey::Produce);
let err = [RDKafkaRespErr::RD_KAFKA_RESP_ERR_INVALID_PARTITIONS; 1];
cluster.request_errors(RDKafkaApiKey::Produce, &err);
match sink.send_batch(vec![event.clone(), event.clone()]).await {
Err(CaptureError::RetryableSinkError) => {} // Expected
Err(err) => panic!("wrong error code {}", err),
Ok(()) => panic!("should have errored"),
};
// Simulate transient errors, messages should go through OK
cluster.clear_request_errors(RDKafkaApiKey::Produce);
let err = [RDKafkaRespErr::RD_KAFKA_RESP_ERR_BROKER_NOT_AVAILABLE; 2];
cluster.request_errors(RDKafkaApiKey::Produce, &err);
sink.send(event.clone())
.await
.expect("failed to send one event after recovery");
cluster.clear_request_errors(RDKafkaApiKey::Produce);
let err = [RDKafkaRespErr::RD_KAFKA_RESP_ERR_BROKER_NOT_AVAILABLE; 2];
cluster.request_errors(RDKafkaApiKey::Produce, &err);
sink.send_batch(vec![event.clone(), event.clone()])
.await
.expect("failed to send event batch after recovery");
// Timeout on a sustained transient error
cluster.clear_request_errors(RDKafkaApiKey::Produce);
let err = [RDKafkaRespErr::RD_KAFKA_RESP_ERR_BROKER_NOT_AVAILABLE; 50];
cluster.request_errors(RDKafkaApiKey::Produce, &err);
match sink.send(event.clone()).await {
Err(CaptureError::RetryableSinkError) => {} // Expected
Err(err) => panic!("wrong error code {}", err),
Ok(()) => panic!("should have errored"),
};
match sink.send_batch(vec![event.clone(), event.clone()]).await {
Err(CaptureError::RetryableSinkError) => {} // Expected
Err(err) => panic!("wrong error code {}", err),
Ok(()) => panic!("should have errored"),
};
}
}

View File

@@ -0,0 +1,12 @@
use async_trait::async_trait;
use crate::api::{CaptureError, ProcessedEvent};
pub mod kafka;
pub mod print;
#[async_trait]
pub trait Event {
async fn send(&self, event: ProcessedEvent) -> Result<(), CaptureError>;
async fn send_batch(&self, events: Vec<ProcessedEvent>) -> Result<(), CaptureError>;
}

View File

@@ -0,0 +1,30 @@
use async_trait::async_trait;
use metrics::{counter, histogram};
use tracing::log::info;
use crate::api::{CaptureError, ProcessedEvent};
use crate::sinks::Event;
pub struct PrintSink {}
#[async_trait]
impl Event for PrintSink {
async fn send(&self, event: ProcessedEvent) -> Result<(), CaptureError> {
info!("single event: {:?}", event);
counter!("capture_events_ingested_total").increment(1);
Ok(())
}
async fn send_batch(&self, events: Vec<ProcessedEvent>) -> Result<(), CaptureError> {
let span = tracing::span!(tracing::Level::INFO, "batch of events");
let _enter = span.enter();
histogram!("capture_event_batch_size").record(events.len() as f64);
counter!("capture_events_ingested_total").increment(events.len() as u64);
for event in events {
info!("event: {:?}", event);
}
Ok(())
}
}

16
rust/capture/src/time.rs Normal file
View File

@@ -0,0 +1,16 @@
pub trait TimeSource {
// Return an ISO timestamp
fn current_time(&self) -> String;
}
#[derive(Clone)]
pub struct SystemTime {}
impl TimeSource for SystemTime {
fn current_time(&self) -> String {
let time = time::OffsetDateTime::now_utc();
time.format(&time::format_description::well_known::Rfc3339)
.expect("failed to format timestamp")
}
}

99
rust/capture/src/token.rs Normal file
View File

@@ -0,0 +1,99 @@
use std::error::Error;
use std::fmt::Display;
/// Validate that a token is the correct shape
#[derive(Debug, PartialEq)]
pub enum InvalidTokenReason {
Empty,
// ignoring for now, as serde and the type system save us but we need to error properly
// IsNotString,
TooLong,
NotAscii,
PersonalApiKey,
}
impl InvalidTokenReason {
pub fn reason(&self) -> &str {
match *self {
Self::Empty => "empty",
Self::NotAscii => "not_ascii",
// Self::IsNotString => "not_string",
Self::TooLong => "too_long",
Self::PersonalApiKey => "personal_api_key",
}
}
}
impl Display for InvalidTokenReason {
fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
write!(f, "{}", self.reason())
}
}
impl Error for InvalidTokenReason {
fn description(&self) -> &str {
self.reason()
}
}
/// Check if a token is the right shape. It may not actually be a valid token! We don't validate
/// these at the edge yet.
pub fn validate_token(token: &str) -> Result<(), InvalidTokenReason> {
if token.is_empty() {
return Err(InvalidTokenReason::Empty);
}
if token.len() > 64 {
return Err(InvalidTokenReason::TooLong);
}
if !token.is_ascii() {
return Err(InvalidTokenReason::NotAscii);
}
if token.starts_with("phx_") {
return Err(InvalidTokenReason::PersonalApiKey);
}
Ok(())
}
#[cfg(test)]
mod tests {
use crate::token::{validate_token, InvalidTokenReason};
#[test]
fn blocks_empty_tokens() {
let valid = validate_token("");
assert!(valid.is_err());
assert_eq!(valid.unwrap_err(), InvalidTokenReason::Empty);
}
#[test]
fn blocks_too_long_tokens() {
let valid =
validate_token("xxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxx");
assert!(valid.is_err());
assert_eq!(valid.unwrap_err(), InvalidTokenReason::TooLong);
}
#[test]
fn blocks_invalid_ascii() {
let valid = validate_token("🦀");
assert!(valid.is_err());
assert_eq!(valid.unwrap_err(), InvalidTokenReason::NotAscii);
}
#[test]
fn blocks_personal_api_key() {
let valid = validate_token("phx_hellothere");
assert!(valid.is_err());
assert_eq!(valid.unwrap_err(), InvalidTokenReason::PersonalApiKey);
}
}

38
rust/capture/src/utils.rs Normal file
View File

@@ -0,0 +1,38 @@
use rand::RngCore;
use uuid::Uuid;
pub fn random_bytes<const N: usize>() -> [u8; N] {
let mut ret = [0u8; N];
rand::thread_rng().fill_bytes(&mut ret);
ret
}
// basically just ripped from the uuid crate. they have it as unstable, but we can use it fine.
const fn encode_unix_timestamp_millis(millis: u64, random_bytes: &[u8; 10]) -> Uuid {
let millis_high = ((millis >> 16) & 0xFFFF_FFFF) as u32;
let millis_low = (millis & 0xFFFF) as u16;
let random_and_version =
(random_bytes[0] as u16 | ((random_bytes[1] as u16) << 8) & 0x0FFF) | (0x7 << 12);
let mut d4 = [0; 8];
d4[0] = (random_bytes[2] & 0x3F) | 0x80;
d4[1] = random_bytes[3];
d4[2] = random_bytes[4];
d4[3] = random_bytes[5];
d4[4] = random_bytes[6];
d4[5] = random_bytes[7];
d4[6] = random_bytes[8];
d4[7] = random_bytes[9];
Uuid::from_fields(millis_high, millis_low, random_and_version, &d4)
}
pub fn uuid_v7() -> Uuid {
let bytes = random_bytes();
let now = time::OffsetDateTime::now_utc();
let now_millis: u64 = now.unix_timestamp() as u64 * 1_000 + now.millisecond() as u64;
encode_unix_timestamp_millis(now_millis, &bytes)
}

View File

@@ -0,0 +1,225 @@
use std::ops::Deref;
use std::sync::Arc;
use axum::{debug_handler, Json};
use bytes::Bytes;
// TODO: stream this instead
use axum::extract::{MatchedPath, Query, State};
use axum::http::{HeaderMap, Method};
use axum_client_ip::InsecureClientIp;
use base64::Engine;
use metrics::counter;
use tracing::instrument;
use crate::limiters::billing::QuotaResource;
use crate::prometheus::report_dropped_events;
use crate::v0_request::{Compression, ProcessingContext, RawRequest};
use crate::{
api::{CaptureError, CaptureResponse, CaptureResponseCode, DataType, ProcessedEvent},
router, sinks,
utils::uuid_v7,
v0_request::{EventFormData, EventQuery, RawEvent},
};
/// Flexible endpoint that targets wide compatibility with the wide range of requests
/// currently processed by posthog-events (analytics events capture). Replay is out
/// of scope and should be processed on a separate endpoint.
///
/// Because it must accommodate several shapes, it is inefficient in places. A v1
/// endpoint should be created, that only accepts the BatchedRequest payload shape.
#[instrument(
skip_all,
fields(
path,
token,
batch_size,
user_agent,
content_encoding,
content_type,
version,
compression,
historical_migration
)
)]
#[debug_handler]
pub async fn event(
state: State<router::State>,
InsecureClientIp(ip): InsecureClientIp,
meta: Query<EventQuery>,
headers: HeaderMap,
method: Method,
path: MatchedPath,
body: Bytes,
) -> Result<Json<CaptureResponse>, CaptureError> {
let user_agent = headers
.get("user-agent")
.map_or("unknown", |v| v.to_str().unwrap_or("unknown"));
let content_encoding = headers
.get("content-encoding")
.map_or("unknown", |v| v.to_str().unwrap_or("unknown"));
let comp = match meta.compression {
None => String::from("unknown"),
Some(Compression::Gzip) => String::from("gzip"),
Some(Compression::Unsupported) => String::from("unsupported"),
};
tracing::Span::current().record("user_agent", user_agent);
tracing::Span::current().record("content_encoding", content_encoding);
tracing::Span::current().record("version", meta.lib_version.clone());
tracing::Span::current().record("compression", comp.as_str());
tracing::Span::current().record("method", method.as_str());
tracing::Span::current().record("path", path.as_str().trim_end_matches('/'));
let request = match headers
.get("content-type")
.map_or("", |v| v.to_str().unwrap_or(""))
{
"application/x-www-form-urlencoded" => {
tracing::Span::current().record("content_type", "application/x-www-form-urlencoded");
let input: EventFormData = serde_urlencoded::from_bytes(body.deref()).map_err(|e| {
tracing::error!("failed to decode body: {}", e);
CaptureError::RequestDecodingError(String::from("invalid form data"))
})?;
let payload = base64::engine::general_purpose::STANDARD
.decode(input.data)
.map_err(|e| {
tracing::error!("failed to decode form data: {}", e);
CaptureError::RequestDecodingError(String::from("missing data field"))
})?;
RawRequest::from_bytes(payload.into())
}
ct => {
tracing::Span::current().record("content_type", ct);
RawRequest::from_bytes(body)
}
}?;
let sent_at = request.sent_at().or(meta.sent_at());
let token = match request.extract_and_verify_token() {
Ok(token) => token,
Err(err) => {
report_dropped_events("token_shape_invalid", request.events().len() as u64);
return Err(err);
}
};
let historical_migration = request.historical_migration();
let events = request.events(); // Takes ownership of request
tracing::Span::current().record("token", &token);
tracing::Span::current().record("historical_migration", historical_migration);
tracing::Span::current().record("batch_size", events.len());
if events.is_empty() {
return Err(CaptureError::EmptyBatch);
}
counter!("capture_events_received_total").increment(events.len() as u64);
let context = ProcessingContext {
lib_version: meta.lib_version.clone(),
sent_at,
token,
now: state.timesource.current_time(),
client_ip: ip.to_string(),
historical_migration,
};
let billing_limited = state
.billing
.is_limited(context.token.as_str(), QuotaResource::Events)
.await;
if billing_limited {
report_dropped_events("over_quota", events.len() as u64);
// for v0 we want to just return ok 🙃
// this is because the clients are pretty dumb and will just retry over and over and
// over...
//
// for v1, we'll return a meaningful error code and error, so that the clients can do
// something meaningful with that error
return Ok(Json(CaptureResponse {
status: CaptureResponseCode::Ok,
}));
}
tracing::debug!(context=?context, events=?events, "decoded request");
if let Err(err) = process_events(state.sink.clone(), &events, &context).await {
let cause = match err {
// TODO: automate this with a macro
CaptureError::EmptyDistinctId => "empty_distinct_id",
CaptureError::MissingDistinctId => "missing_distinct_id",
CaptureError::MissingEventName => "missing_event_name",
_ => "process_events_error",
};
report_dropped_events(cause, events.len() as u64);
tracing::log::warn!("rejected invalid payload: {}", err);
return Err(err);
}
Ok(Json(CaptureResponse {
status: CaptureResponseCode::Ok,
}))
}
pub async fn options() -> Result<Json<CaptureResponse>, CaptureError> {
Ok(Json(CaptureResponse {
status: CaptureResponseCode::Ok,
}))
}
#[instrument(skip_all)]
pub fn process_single_event(
event: &RawEvent,
context: &ProcessingContext,
) -> Result<ProcessedEvent, CaptureError> {
if event.event.is_empty() {
return Err(CaptureError::MissingEventName);
}
let data_type = match context.historical_migration {
true => DataType::AnalyticsHistorical,
false => DataType::AnalyticsMain,
};
let data = serde_json::to_string(&event).map_err(|e| {
tracing::error!("failed to encode data field: {}", e);
CaptureError::NonRetryableSinkError
})?;
Ok(ProcessedEvent {
data_type,
uuid: event.uuid.unwrap_or_else(uuid_v7),
distinct_id: event.extract_distinct_id()?,
ip: context.client_ip.clone(),
data,
now: context.now.clone(),
sent_at: context.sent_at,
token: context.token.clone(),
})
}
#[instrument(skip_all, fields(events = events.len()))]
pub async fn process_events<'a>(
sink: Arc<dyn sinks::Event + Send + Sync>,
events: &'a [RawEvent],
context: &'a ProcessingContext,
) -> Result<(), CaptureError> {
let events: Vec<ProcessedEvent> = events
.iter()
.map(|e| process_single_event(e, context))
.collect::<Result<Vec<ProcessedEvent>, CaptureError>>()?;
tracing::debug!(events=?events, "processed {} events", events.len());
if events.len() == 1 {
sink.send(events[0].clone()).await
} else {
sink.send_batch(events).await
}
}

View File

@@ -0,0 +1,434 @@
use std::collections::{HashMap, HashSet};
use std::io::prelude::*;
use bytes::{Buf, Bytes};
use flate2::read::GzDecoder;
use serde::{Deserialize, Serialize};
use serde_json::Value;
use time::format_description::well_known::Iso8601;
use time::OffsetDateTime;
use tracing::instrument;
use uuid::Uuid;
use crate::api::CaptureError;
use crate::token::validate_token;
#[derive(Deserialize, Default)]
pub enum Compression {
#[default]
Unsupported,
#[serde(rename = "gzip", alias = "gzip-js")]
Gzip,
}
#[derive(Deserialize, Default)]
pub struct EventQuery {
pub compression: Option<Compression>,
#[serde(alias = "ver")]
pub lib_version: Option<String>,
#[serde(alias = "_")]
sent_at: Option<i64>,
}
impl EventQuery {
/// Returns the parsed value of the sent_at timestamp if present in the query params.
/// We only support the format sent by recent posthog-js versions, in milliseconds integer.
/// Values in seconds integer (older SDKs will be ignored).
pub fn sent_at(&self) -> Option<OffsetDateTime> {
if let Some(value) = self.sent_at {
let value_nanos: i128 = i128::from(value) * 1_000_000; // Assuming the value is in milliseconds, latest posthog-js releases
if let Ok(sent_at) = OffsetDateTime::from_unix_timestamp_nanos(value_nanos) {
if sent_at.year() > 2020 {
// Could be lower if the input is in seconds
return Some(sent_at);
}
}
}
None
}
}
#[derive(Debug, Deserialize)]
pub struct EventFormData {
pub data: String,
}
#[derive(Default, Debug, Deserialize, Serialize)]
pub struct RawEvent {
#[serde(
alias = "$token",
alias = "api_key",
skip_serializing_if = "Option::is_none"
)]
pub token: Option<String>,
#[serde(alias = "$distinct_id", skip_serializing_if = "Option::is_none")]
pub distinct_id: Option<Value>, // posthog-js accepts arbitrary values as distinct_id
pub uuid: Option<Uuid>,
pub event: String,
#[serde(default)]
pub properties: HashMap<String, Value>,
#[serde(skip_serializing_if = "Option::is_none")]
pub timestamp: Option<String>, // Passed through if provided, parsed by ingestion
#[serde(skip_serializing_if = "Option::is_none")]
pub offset: Option<i64>, // Passed through if provided, parsed by ingestion
#[serde(rename = "$set", skip_serializing_if = "Option::is_none")]
pub set: Option<HashMap<String, Value>>,
#[serde(rename = "$set_once", skip_serializing_if = "Option::is_none")]
pub set_once: Option<HashMap<String, Value>>,
}
static GZIP_MAGIC_NUMBERS: [u8; 3] = [0x1f, 0x8b, 8];
#[derive(Deserialize)]
#[serde(untagged)]
pub enum RawRequest {
/// Array of events (posthog-js)
Array(Vec<RawEvent>),
/// Batched events (/batch)
Batch(BatchedRequest),
/// Single event (/capture)
One(Box<RawEvent>),
}
#[derive(Deserialize)]
pub struct BatchedRequest {
#[serde(alias = "api_key")]
pub token: String,
pub historical_migration: Option<bool>,
pub sent_at: Option<String>,
pub batch: Vec<RawEvent>,
}
impl RawRequest {
/// Takes a request payload and tries to decompress and unmarshall it.
/// While posthog-js sends a compression query param, a sizable portion of requests
/// fail due to it being missing when the body is compressed.
/// Instead of trusting the parameter, we peek at the payload's first three bytes to
/// detect gzip, fallback to uncompressed utf8 otherwise.
#[instrument(skip_all)]
pub fn from_bytes(bytes: Bytes) -> Result<RawRequest, CaptureError> {
tracing::debug!(len = bytes.len(), "decoding new event");
let payload = if bytes.starts_with(&GZIP_MAGIC_NUMBERS) {
let mut d = GzDecoder::new(bytes.reader());
let mut s = String::new();
d.read_to_string(&mut s).map_err(|e| {
tracing::error!("failed to decode gzip: {}", e);
CaptureError::RequestDecodingError(String::from("invalid gzip data"))
})?;
s
} else {
String::from_utf8(bytes.into()).map_err(|e| {
tracing::error!("failed to decode body: {}", e);
CaptureError::RequestDecodingError(String::from("invalid body encoding"))
})?
};
tracing::debug!(json = payload, "decoded event data");
Ok(serde_json::from_str::<RawRequest>(&payload)?)
}
pub fn events(self) -> Vec<RawEvent> {
match self {
RawRequest::Array(events) => events,
RawRequest::One(event) => vec![*event],
RawRequest::Batch(req) => req.batch,
}
}
pub fn extract_and_verify_token(&self) -> Result<String, CaptureError> {
let token = match self {
RawRequest::Batch(req) => req.token.to_string(),
RawRequest::One(event) => event.extract_token().ok_or(CaptureError::NoTokenError)?,
RawRequest::Array(events) => extract_token(events)?,
};
validate_token(&token)?;
Ok(token)
}
pub fn historical_migration(&self) -> bool {
match self {
RawRequest::Batch(req) => req.historical_migration.unwrap_or_default(),
_ => false,
}
}
pub fn sent_at(&self) -> Option<OffsetDateTime> {
if let RawRequest::Batch(req) = &self {
if let Some(value) = &req.sent_at {
if let Ok(parsed) = OffsetDateTime::parse(value, &Iso8601::DEFAULT) {
return Some(parsed);
}
}
}
None
}
}
#[instrument(skip_all, fields(events = events.len()))]
pub fn extract_token(events: &[RawEvent]) -> Result<String, CaptureError> {
let distinct_tokens: HashSet<Option<String>> = HashSet::from_iter(
events
.iter()
.map(RawEvent::extract_token)
.filter(Option::is_some),
);
return match distinct_tokens.len() {
0 => Err(CaptureError::NoTokenError),
1 => match distinct_tokens.iter().last() {
Some(Some(token)) => Ok(token.clone()),
_ => Err(CaptureError::NoTokenError),
},
_ => Err(CaptureError::MultipleTokensError),
};
}
impl RawEvent {
pub fn extract_token(&self) -> Option<String> {
match &self.token {
Some(value) => Some(value.clone()),
None => self
.properties
.get("token")
.and_then(Value::as_str)
.map(String::from),
}
}
/// Extracts, stringifies and trims the distinct_id to a 200 chars String.
/// SDKs send the distinct_id either in the root field or as a property,
/// and can send string, number, array, or map values. We try to best-effort
/// stringify complex values, and make sure it's not longer than 200 chars.
pub fn extract_distinct_id(&self) -> Result<String, CaptureError> {
// Breaking change compared to capture-py: None / Null is not allowed.
let value = match &self.distinct_id {
None | Some(Value::Null) => match self.properties.get("distinct_id") {
None | Some(Value::Null) => return Err(CaptureError::MissingDistinctId),
Some(id) => id,
},
Some(id) => id,
};
let distinct_id = value
.as_str()
.map(|s| s.to_owned())
.unwrap_or_else(|| value.to_string());
match distinct_id.len() {
0 => Err(CaptureError::EmptyDistinctId),
1..=200 => Ok(distinct_id),
_ => Ok(distinct_id.chars().take(200).collect()),
}
}
}
#[derive(Debug)]
pub struct ProcessingContext {
pub lib_version: Option<String>,
pub sent_at: Option<OffsetDateTime>,
pub token: String,
pub now: String,
pub client_ip: String,
pub historical_migration: bool,
}
#[cfg(test)]
mod tests {
use crate::token::InvalidTokenReason;
use base64::Engine as _;
use bytes::Bytes;
use rand::distributions::Alphanumeric;
use rand::Rng;
use serde_json::json;
use super::CaptureError;
use super::RawRequest;
#[test]
fn decode_uncompressed_raw_event() {
let base64_payload = "ewogICAgImRpc3RpbmN0X2lkIjogIm15X2lkMSIsCiAgICAiZXZlbnQiOiAibXlfZXZlbnQxIiwKICAgICJwcm9wZXJ0aWVzIjogewogICAgICAgICIkZGV2aWNlX3R5cGUiOiAiRGVza3RvcCIKICAgIH0sCiAgICAiYXBpX2tleSI6ICJteV90b2tlbjEiCn0K";
let compressed_bytes = Bytes::from(
base64::engine::general_purpose::STANDARD
.decode(base64_payload)
.expect("payload is not base64"),
);
let events = RawRequest::from_bytes(compressed_bytes)
.expect("failed to parse")
.events();
assert_eq!(1, events.len());
assert_eq!(Some("my_token1".to_string()), events[0].extract_token());
assert_eq!("my_event1".to_string(), events[0].event);
assert_eq!(
"my_id1".to_string(),
events[0]
.extract_distinct_id()
.expect("cannot find distinct_id")
);
}
#[test]
fn decode_gzipped_raw_event() {
let base64_payload = "H4sIADQSbmUCAz2MsQqAMAxE936FBEcnR2f/o4i9IRTb0AahiP9urcVMx3t3ucxQjxxn5bCrZUfLQEepYabpkzgRtOOWfyMpCpIyctVXY42PDifvsFoE73BF9hqFWuPu403YepT+WKNHmMnc5gENoFu2kwAAAA==";
let compressed_bytes = Bytes::from(
base64::engine::general_purpose::STANDARD
.decode(base64_payload)
.expect("payload is not base64"),
);
let events = RawRequest::from_bytes(compressed_bytes)
.expect("failed to parse")
.events();
assert_eq!(1, events.len());
assert_eq!(Some("my_token2".to_string()), events[0].extract_token());
assert_eq!("my_event2".to_string(), events[0].event);
assert_eq!(
"my_id2".to_string(),
events[0]
.extract_distinct_id()
.expect("cannot find distinct_id")
);
}
#[test]
fn extract_distinct_id() {
let parse_and_extract = |input: &'static str| -> Result<String, CaptureError> {
let parsed = RawRequest::from_bytes(input.into())
.expect("failed to parse")
.events();
parsed[0].extract_distinct_id()
};
// Return MissingDistinctId if not found
assert!(matches!(
parse_and_extract(r#"{"event": "e"}"#),
Err(CaptureError::MissingDistinctId)
));
// Return MissingDistinctId if null
assert!(matches!(
parse_and_extract(r#"{"event": "e", "distinct_id": null}"#),
Err(CaptureError::MissingDistinctId)
));
// Return EmptyDistinctId if empty string
assert!(matches!(
parse_and_extract(r#"{"event": "e", "distinct_id": ""}"#),
Err(CaptureError::EmptyDistinctId)
));
let assert_extracted_id = |input: &'static str, expected: &str| {
let id = parse_and_extract(input).expect("failed to extract");
assert_eq!(id, expected);
};
// Happy path: toplevel field present
assert_extracted_id(r#"{"event": "e", "distinct_id": "myid"}"#, "myid");
assert_extracted_id(r#"{"event": "e", "$distinct_id": "23"}"#, "23");
// Sourced from properties if not present in toplevel field, but toplevel wins if both present
assert_extracted_id(
r#"{"event": "e", "properties":{"distinct_id": "myid"}}"#,
"myid",
);
assert_extracted_id(
r#"{"event": "e", "distinct_id": 23, "properties":{"distinct_id": "myid"}}"#,
"23",
);
// Numbers are stringified
assert_extracted_id(r#"{"event": "e", "distinct_id": 23}"#, "23");
assert_extracted_id(r#"{"event": "e", "distinct_id": 23.4}"#, "23.4");
// Containers are stringified
assert_extracted_id(
r#"{"event": "e", "distinct_id": ["a", "b"]}"#,
r#"["a","b"]"#,
);
assert_extracted_id(
r#"{"event": "e", "distinct_id": {"string": "a", "number": 3}}"#,
r#"{"number":3,"string":"a"}"#,
);
}
#[test]
fn extract_distinct_id_trims_to_200_chars() {
let distinct_id: String = rand::thread_rng()
.sample_iter(Alphanumeric)
.take(222)
.map(char::from)
.collect();
let (expected_distinct_id, _) = distinct_id.split_at(200); // works because ascii chars only
let input = json!([{
"token": "mytoken",
"event": "myevent",
"distinct_id": distinct_id
}]);
let parsed = RawRequest::from_bytes(input.to_string().into())
.expect("failed to parse")
.events();
assert_eq!(
parsed[0].extract_distinct_id().expect("failed to extract"),
expected_distinct_id
);
}
#[test]
fn extract_and_verify_token() {
let parse_and_extract = |input: &'static str| -> Result<String, CaptureError> {
RawRequest::from_bytes(input.into())
.expect("failed to parse")
.extract_and_verify_token()
};
let assert_extracted_token = |input: &'static str, expected: &str| {
let id = parse_and_extract(input).expect("failed to extract");
assert_eq!(id, expected);
};
// Return NoTokenError if not found
assert!(matches!(
parse_and_extract(r#"{"event": "e"}"#),
Err(CaptureError::NoTokenError)
));
// Return TokenValidationError if token empty
assert!(matches!(
parse_and_extract(r#"{"api_key": "", "batch":[{"event": "e"}]}"#),
Err(CaptureError::TokenValidationError(
InvalidTokenReason::Empty
))
));
// Return TokenValidationError if personal apikey
assert!(matches!(
parse_and_extract(r#"[{"event": "e", "token": "phx_hellothere"}]"#),
Err(CaptureError::TokenValidationError(
InvalidTokenReason::PersonalApiKey
))
));
// Return MultipleTokensError if tokens don't match in array
assert!(matches!(
parse_and_extract(
r#"[{"event": "e", "token": "token1"},{"event": "e", "token": "token2"}]"#
),
Err(CaptureError::MultipleTokensError)
));
// Return token from array if consistent
assert_extracted_token(
r#"[{"event":"e","token":"token1"},{"event":"e","token":"token1"}]"#,
"token1",
);
// Return token from batch if present
assert_extracted_token(
r#"{"batch":[{"event":"e","token":"token1"}],"api_key":"batched"}"#,
"batched",
);
// Return token from single event if present
assert_extracted_token(r#"{"event":"e","$token":"single_token"}"#, "single_token");
assert_extracted_token(r#"{"event":"e","api_key":"single_token"}"#, "single_token");
}
}

View File

@@ -0,0 +1,216 @@
#![allow(dead_code)]
use std::default::Default;
use std::net::SocketAddr;
use std::num::NonZeroU32;
use std::str::FromStr;
use std::string::ToString;
use std::sync::{Arc, Once};
use std::time::Duration;
use anyhow::bail;
use once_cell::sync::Lazy;
use rand::distributions::Alphanumeric;
use rand::Rng;
use rdkafka::admin::{AdminClient, AdminOptions, NewTopic, TopicReplication};
use rdkafka::config::{ClientConfig, FromClientConfig};
use rdkafka::consumer::{BaseConsumer, Consumer};
use rdkafka::util::Timeout;
use rdkafka::{Message, TopicPartitionList};
use tokio::net::TcpListener;
use tokio::sync::Notify;
use tokio::time::timeout;
use tracing::{debug, warn};
use capture::config::{Config, KafkaConfig};
use capture::server::serve;
pub static DEFAULT_CONFIG: Lazy<Config> = Lazy::new(|| Config {
print_sink: false,
address: SocketAddr::from_str("127.0.0.1:0").unwrap(),
redis_url: "redis://localhost:6379/".to_string(),
overflow_enabled: false,
overflow_burst_limit: NonZeroU32::new(5).unwrap(),
overflow_per_second_limit: NonZeroU32::new(10).unwrap(),
overflow_forced_keys: None,
kafka: KafkaConfig {
kafka_producer_linger_ms: 0, // Send messages as soon as possible
kafka_producer_queue_mib: 10,
kafka_message_timeout_ms: 10000, // 10s, ACKs can be slow on low volumes, should be tuned
kafka_compression_codec: "none".to_string(),
kafka_hosts: "kafka:9092".to_string(),
kafka_topic: "events_plugin_ingestion".to_string(),
kafka_historical_topic: "events_plugin_ingestion_historical".to_string(),
kafka_tls: false,
},
otel_url: None,
otel_sampling_rate: 0.0,
otel_service_name: "capture-testing".to_string(),
export_prometheus: false,
});
static TRACING_INIT: Once = Once::new();
pub fn setup_tracing() {
TRACING_INIT.call_once(|| {
tracing_subscriber::fmt()
.with_writer(tracing_subscriber::fmt::TestWriter::new())
.init()
});
}
pub struct ServerHandle {
pub addr: SocketAddr,
shutdown: Arc<Notify>,
}
impl ServerHandle {
pub async fn for_topics(main: &EphemeralTopic, historical: &EphemeralTopic) -> Self {
let mut config = DEFAULT_CONFIG.clone();
config.kafka.kafka_topic = main.topic_name().to_string();
config.kafka.kafka_historical_topic = historical.topic_name().to_string();
Self::for_config(config).await
}
pub async fn for_config(config: Config) -> Self {
let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
let addr = listener.local_addr().unwrap();
let notify = Arc::new(Notify::new());
let shutdown = notify.clone();
tokio::spawn(async move {
serve(config, listener, async move { notify.notified().await }).await
});
Self { addr, shutdown }
}
pub async fn capture_events<T: Into<reqwest::Body>>(&self, body: T) -> reqwest::Response {
let client = reqwest::Client::new();
client
.post(format!("http://{:?}/i/v0/e", self.addr))
.body(body)
.send()
.await
.expect("failed to send request")
}
}
impl Drop for ServerHandle {
fn drop(&mut self) {
self.shutdown.notify_one()
}
}
pub struct EphemeralTopic {
consumer: BaseConsumer,
read_timeout: Timeout,
topic_name: String,
}
impl EphemeralTopic {
pub async fn new() -> Self {
let mut config = ClientConfig::new();
config.set("group.id", "capture_integration_tests");
config.set(
"bootstrap.servers",
DEFAULT_CONFIG.kafka.kafka_hosts.clone(),
);
config.set("debug", "all");
// TODO: check for name collision?
let topic_name = random_string("events_", 16);
let admin = AdminClient::from_config(&config).expect("failed to create admin client");
admin
.create_topics(
&[NewTopic {
name: &topic_name,
num_partitions: 1,
replication: TopicReplication::Fixed(1),
config: vec![],
}],
&AdminOptions::default(),
)
.await
.expect("failed to create topic");
let consumer: BaseConsumer = config.create().expect("failed to create consumer");
let mut assignment = TopicPartitionList::new();
assignment.add_partition(&topic_name, 0);
consumer
.assign(&assignment)
.expect("failed to assign topic");
Self {
consumer,
read_timeout: Timeout::After(Duration::from_secs(5)),
topic_name,
}
}
pub fn next_event(&self) -> anyhow::Result<serde_json::Value> {
match self.consumer.poll(self.read_timeout) {
Some(Ok(message)) => {
let body = message.payload().expect("empty kafka message");
let event = serde_json::from_slice(body)?;
Ok(event)
}
Some(Err(err)) => bail!("kafka read error: {}", err),
None => bail!("kafka read timeout"),
}
}
pub fn next_message_key(&self) -> anyhow::Result<Option<String>> {
match self.consumer.poll(self.read_timeout) {
Some(Ok(message)) => {
let key = message.key();
if let Some(key) = key {
let key = std::str::from_utf8(key)?;
let key = String::from_str(key)?;
Ok(Some(key))
} else {
Ok(None)
}
}
Some(Err(err)) => bail!("kafka read error: {}", err),
None => bail!("kafka read timeout"),
}
}
pub fn topic_name(&self) -> &str {
&self.topic_name
}
}
impl Drop for EphemeralTopic {
fn drop(&mut self) {
debug!("dropping EphemeralTopic {}...", self.topic_name);
self.consumer.unsubscribe();
match futures::executor::block_on(timeout(
Duration::from_secs(10),
delete_topic(self.topic_name.clone()),
)) {
Ok(_) => debug!("dropped topic"),
Err(err) => warn!("failed to drop topic: {}", err),
}
}
}
async fn delete_topic(topic: String) {
let mut config = ClientConfig::new();
config.set(
"bootstrap.servers",
DEFAULT_CONFIG.kafka.kafka_hosts.clone(),
);
let admin = AdminClient::from_config(&config).expect("failed to create admin client");
admin
.delete_topics(&[&topic], &AdminOptions::default())
.await
.expect("failed to delete topic");
}
pub fn random_string(prefix: &str, length: usize) -> String {
let suffix: String = rand::thread_rng()
.sample_iter(Alphanumeric)
.take(length)
.map(char::from)
.collect();
format!("{}_{}", prefix, suffix)
}

View File

@@ -0,0 +1,227 @@
use assert_json_diff::assert_json_matches_no_panic;
use async_trait::async_trait;
use axum::http::StatusCode;
use axum_test_helper::TestClient;
use base64::engine::general_purpose;
use base64::Engine;
use capture::api::{CaptureError, CaptureResponse, CaptureResponseCode, DataType, ProcessedEvent};
use capture::limiters::billing::BillingLimiter;
use capture::redis::MockRedisClient;
use capture::router::router;
use capture::sinks::Event;
use capture::time::TimeSource;
use health::HealthRegistry;
use serde::Deserialize;
use serde_json::{json, Value};
use std::fs::File;
use std::io::{BufRead, BufReader};
use std::sync::{Arc, Mutex};
use time::format_description::well_known::{Iso8601, Rfc3339};
use time::{Duration, OffsetDateTime};
#[derive(Debug, Deserialize)]
struct RequestDump {
path: String,
method: String,
content_encoding: String,
content_type: String,
ip: String,
now: String,
body: String,
output: Vec<Value>,
#[serde(default)] // default = false
historical_migration: bool,
}
static REQUESTS_DUMP_FILE_NAME: &str = "tests/requests_dump.jsonl";
#[derive(Clone)]
pub struct FixedTime {
pub time: String,
}
impl TimeSource for FixedTime {
fn current_time(&self) -> String {
self.time.to_string()
}
}
#[derive(Clone, Default)]
struct MemorySink {
events: Arc<Mutex<Vec<ProcessedEvent>>>,
}
impl MemorySink {
fn len(&self) -> usize {
self.events.lock().unwrap().len()
}
fn events(&self) -> Vec<ProcessedEvent> {
self.events.lock().unwrap().clone()
}
}
#[async_trait]
impl Event for MemorySink {
async fn send(&self, event: ProcessedEvent) -> Result<(), CaptureError> {
self.events.lock().unwrap().push(event);
Ok(())
}
async fn send_batch(&self, events: Vec<ProcessedEvent>) -> Result<(), CaptureError> {
self.events.lock().unwrap().extend_from_slice(&events);
Ok(())
}
}
#[tokio::test]
async fn it_matches_django_capture_behaviour() -> anyhow::Result<()> {
let file = File::open(REQUESTS_DUMP_FILE_NAME)?;
let reader = BufReader::new(file);
let liveness = HealthRegistry::new("dummy");
let mut mismatches = 0;
for (line_number, line_contents) in reader.lines().enumerate() {
let line_contents = line_contents?;
if line_contents.starts_with('#') {
// Skip comment lines
continue;
}
let case: RequestDump = serde_json::from_str(&line_contents)?;
let raw_body = general_purpose::STANDARD.decode(&case.body)?;
assert_eq!(
case.method, "POST",
"update code to handle method {}",
case.method
);
let sink = MemorySink::default();
let timesource = FixedTime { time: case.now };
let redis = Arc::new(MockRedisClient::new());
let billing = BillingLimiter::new(Duration::weeks(1), redis.clone())
.expect("failed to create billing limiter");
let app = router(
timesource,
liveness.clone(),
sink.clone(),
redis,
billing,
false,
);
let client = TestClient::new(app);
let mut req = client.post(&case.path).body(raw_body);
if !case.content_encoding.is_empty() {
req = req.header("Content-encoding", case.content_encoding);
}
if !case.content_type.is_empty() {
req = req.header("Content-type", case.content_type);
}
if !case.ip.is_empty() {
req = req.header("X-Forwarded-For", case.ip);
}
let res = req.send().await;
assert_eq!(
res.status(),
StatusCode::OK,
"line {} rejected: {}",
line_number,
res.text().await
);
assert_eq!(
Some(CaptureResponse {
status: CaptureResponseCode::Ok
}),
res.json().await
);
assert_eq!(
sink.len(),
case.output.len(),
"event count mismatch on line {}",
line_number
);
for (event_number, (message, expected)) in
sink.events().iter().zip(case.output.iter()).enumerate()
{
// Ensure the data type matches
if case.historical_migration {
assert_eq!(DataType::AnalyticsHistorical, message.data_type);
} else {
assert_eq!(DataType::AnalyticsMain, message.data_type);
}
// Normalizing the expected event to align with known django->rust inconsistencies
let mut expected = expected.clone();
if let Some(value) = expected.get_mut("sent_at") {
// Default ISO format is different between python and rust, both are valid
// Parse and re-print the value before comparison
let raw_value = value.as_str().expect("sent_at field is not a string");
if raw_value.is_empty() {
*value = Value::Null
} else {
let sent_at =
OffsetDateTime::parse(value.as_str().expect("empty"), &Iso8601::DEFAULT)
.expect("failed to parse expected sent_at");
*value = Value::String(sent_at.format(&Rfc3339)?)
}
}
if let Some(expected_data) = expected.get_mut("data") {
// Data is a serialized JSON map. Unmarshall both and compare them,
// instead of expecting the serialized bytes to be equal
let mut expected_props: Value =
serde_json::from_str(expected_data.as_str().expect("not str"))?;
if let Some(object) = expected_props.as_object_mut() {
// toplevel fields added by posthog-node that plugin-server will ignore anyway
object.remove("type");
object.remove("library");
object.remove("library_version");
}
let found_props: Value = serde_json::from_str(&message.data)?;
let match_config =
assert_json_diff::Config::new(assert_json_diff::CompareMode::Strict);
if let Err(e) =
assert_json_matches_no_panic(&expected_props, &found_props, match_config)
{
println!(
"data field mismatch at line {}, event {}: {}",
line_number, event_number, e
);
mismatches += 1;
} else {
*expected_data = json!(&message.data)
}
}
if let Some(object) = expected.as_object_mut() {
// site_url is unused in the pipeline now, let's drop it
object.remove("site_url");
// Remove sent_at field if empty: Rust will skip marshalling it
if let Some(None) = object.get("sent_at").map(|v| v.as_str()) {
object.remove("sent_at");
}
}
let match_config = assert_json_diff::Config::new(assert_json_diff::CompareMode::Strict);
if let Err(e) =
assert_json_matches_no_panic(&json!(expected), &json!(message), match_config)
{
println!(
"record mismatch at line {}, event {}: {}",
line_number + 1,
event_number,
e
);
mismatches += 1;
}
}
}
assert_eq!(0, mismatches, "some events didn't match");
Ok(())
}

View File

@@ -0,0 +1,351 @@
use std::num::NonZeroU32;
use anyhow::Result;
use assert_json_diff::assert_json_include;
use reqwest::StatusCode;
use serde_json::json;
use crate::common::*;
mod common;
#[tokio::test]
async fn it_captures_one_event() -> Result<()> {
setup_tracing();
let token = random_string("token", 16);
let distinct_id = random_string("id", 16);
let main_topic = EphemeralTopic::new().await;
let histo_topic = EphemeralTopic::new().await;
let server = ServerHandle::for_topics(&main_topic, &histo_topic).await;
let event = json!({
"token": token,
"event": "testing",
"distinct_id": distinct_id
});
let res = server.capture_events(event.to_string()).await;
assert_eq!(StatusCode::OK, res.status());
let event = main_topic.next_event()?;
assert_json_include!(
actual: event,
expected: json!({
"token": token,
"distinct_id": distinct_id
})
);
Ok(())
}
#[tokio::test]
async fn it_captures_a_posthogjs_array() -> Result<()> {
setup_tracing();
let token = random_string("token", 16);
let distinct_id1 = random_string("id", 16);
let distinct_id2 = random_string("id", 16);
let main_topic = EphemeralTopic::new().await;
let histo_topic = EphemeralTopic::new().await;
let server = ServerHandle::for_topics(&main_topic, &histo_topic).await;
let event = json!([{
"token": token,
"event": "event1",
"distinct_id": distinct_id1
},{
"token": token,
"event": "event2",
"distinct_id": distinct_id2
}]);
let res = server.capture_events(event.to_string()).await;
assert_eq!(StatusCode::OK, res.status());
assert_json_include!(
actual: main_topic.next_event()?,
expected: json!({
"token": token,
"distinct_id": distinct_id1
})
);
assert_json_include!(
actual: main_topic.next_event()?,
expected: json!({
"token": token,
"distinct_id": distinct_id2
})
);
Ok(())
}
#[tokio::test]
async fn it_captures_a_batch() -> Result<()> {
setup_tracing();
let token = random_string("token", 16);
let distinct_id1 = random_string("id", 16);
let distinct_id2 = random_string("id", 16);
let main_topic = EphemeralTopic::new().await;
let histo_topic = EphemeralTopic::new().await;
let server = ServerHandle::for_topics(&main_topic, &histo_topic).await;
let event = json!({
"token": token,
"batch": [{
"event": "event1",
"distinct_id": distinct_id1
},{
"event": "event2",
"distinct_id": distinct_id2
}]
});
let res = server.capture_events(event.to_string()).await;
assert_eq!(StatusCode::OK, res.status());
assert_json_include!(
actual: main_topic.next_event()?,
expected: json!({
"token": token,
"distinct_id": distinct_id1
})
);
assert_json_include!(
actual: main_topic.next_event()?,
expected: json!({
"token": token,
"distinct_id": distinct_id2
})
);
Ok(())
}
#[tokio::test]
async fn it_captures_a_historical_batch() -> Result<()> {
setup_tracing();
let token = random_string("token", 16);
let distinct_id1 = random_string("id", 16);
let distinct_id2 = random_string("id", 16);
let main_topic = EphemeralTopic::new().await;
let histo_topic = EphemeralTopic::new().await;
let server = ServerHandle::for_topics(&main_topic, &histo_topic).await;
let event = json!({
"token": token,
"historical_migration": true,
"batch": [{
"event": "event1",
"distinct_id": distinct_id1
},{
"event": "event2",
"distinct_id": distinct_id2
}]
});
let res = server.capture_events(event.to_string()).await;
assert_eq!(StatusCode::OK, res.status());
assert_json_include!(
actual: histo_topic.next_event()?,
expected: json!({
"token": token,
"distinct_id": distinct_id1
})
);
assert_json_include!(
actual: histo_topic.next_event()?,
expected: json!({
"token": token,
"distinct_id": distinct_id2
})
);
Ok(())
}
#[tokio::test]
async fn it_overflows_events_on_burst() -> Result<()> {
setup_tracing();
let token = random_string("token", 16);
let distinct_id = random_string("id", 16);
let topic = EphemeralTopic::new().await;
let mut config = DEFAULT_CONFIG.clone();
config.kafka.kafka_topic = topic.topic_name().to_string();
config.overflow_enabled = true;
config.overflow_burst_limit = NonZeroU32::new(2).unwrap();
config.overflow_per_second_limit = NonZeroU32::new(1).unwrap();
let server = ServerHandle::for_config(config).await;
let event = json!([{
"token": token,
"event": "event1",
"distinct_id": distinct_id
},{
"token": token,
"event": "event2",
"distinct_id": distinct_id
},{
"token": token,
"event": "event3",
"distinct_id": distinct_id
}]);
let res = server.capture_events(event.to_string()).await;
assert_eq!(StatusCode::OK, res.status());
assert_eq!(
topic.next_message_key()?.unwrap(),
format!("{}:{}", token, distinct_id)
);
assert_eq!(
topic.next_message_key()?.unwrap(),
format!("{}:{}", token, distinct_id)
);
assert_eq!(topic.next_message_key()?, None);
Ok(())
}
#[tokio::test]
async fn it_does_not_overflow_team_with_different_ids() -> Result<()> {
setup_tracing();
let token = random_string("token", 16);
let distinct_id = random_string("id", 16);
let distinct_id2 = random_string("id", 16);
let topic = EphemeralTopic::new().await;
let mut config = DEFAULT_CONFIG.clone();
config.kafka.kafka_topic = topic.topic_name().to_string();
config.overflow_enabled = true;
config.overflow_burst_limit = NonZeroU32::new(1).unwrap();
config.overflow_per_second_limit = NonZeroU32::new(1).unwrap();
let server = ServerHandle::for_config(config).await;
let event = json!([{
"token": token,
"event": "event1",
"distinct_id": distinct_id
},{
"token": token,
"event": "event2",
"distinct_id": distinct_id2
}]);
let res = server.capture_events(event.to_string()).await;
assert_eq!(StatusCode::OK, res.status());
assert_eq!(
topic.next_message_key()?.unwrap(),
format!("{}:{}", token, distinct_id)
);
assert_eq!(
topic.next_message_key()?.unwrap(),
format!("{}:{}", token, distinct_id2)
);
Ok(())
}
#[tokio::test]
async fn it_skips_overflows_when_disabled() -> Result<()> {
setup_tracing();
let token = random_string("token", 16);
let distinct_id = random_string("id", 16);
let topic = EphemeralTopic::new().await;
let mut config = DEFAULT_CONFIG.clone();
config.kafka.kafka_topic = topic.topic_name().to_string();
config.overflow_enabled = false;
config.overflow_burst_limit = NonZeroU32::new(2).unwrap();
config.overflow_per_second_limit = NonZeroU32::new(1).unwrap();
let server = ServerHandle::for_config(config).await;
let event = json!([{
"token": token,
"event": "event1",
"distinct_id": distinct_id
},{
"token": token,
"event": "event2",
"distinct_id": distinct_id
},{
"token": token,
"event": "event3",
"distinct_id": distinct_id
}]);
let res = server.capture_events(event.to_string()).await;
assert_eq!(StatusCode::OK, res.status());
assert_eq!(
topic.next_message_key()?.unwrap(),
format!("{}:{}", token, distinct_id)
);
assert_eq!(
topic.next_message_key()?.unwrap(),
format!("{}:{}", token, distinct_id)
);
// Should have triggered overflow, but has not
assert_eq!(
topic.next_message_key()?.unwrap(),
format!("{}:{}", token, distinct_id)
);
Ok(())
}
#[tokio::test]
async fn it_trims_distinct_id() -> Result<()> {
setup_tracing();
let token = random_string("token", 16);
let distinct_id1 = random_string("id", 200 - 3);
let distinct_id2 = random_string("id", 222);
let (trimmed_distinct_id2, _) = distinct_id2.split_at(200); // works because ascii chars
let main_topic = EphemeralTopic::new().await;
let histo_topic = EphemeralTopic::new().await;
let server = ServerHandle::for_topics(&main_topic, &histo_topic).await;
let event = json!([{
"token": token,
"event": "event1",
"distinct_id": distinct_id1
},{
"token": token,
"event": "event2",
"distinct_id": distinct_id2
}]);
let res = server.capture_events(event.to_string()).await;
assert_eq!(StatusCode::OK, res.status());
assert_json_include!(
actual: main_topic.next_event()?,
expected: json!({
"token": token,
"distinct_id": distinct_id1
})
);
assert_json_include!(
actual: main_topic.next_event()?,
expected: json!({
"token": token,
"distinct_id": trimmed_distinct_id2
})
);
Ok(())
}

File diff suppressed because one or more lines are too long

8
rust/common/README.md Normal file
View File

@@ -0,0 +1,8 @@
# Common crates for the hog-rs services
This folder holds internal crates for code reuse between services in the monorepo. To keep maintenance costs low,
these crates should ideally:
- Cover a small feature scope and use as little dependencies as possible
- Only use `{ workspace = true }` dependencies, instead of pinning versions that could diverge from the workspace
- Have adequate test coverage and documentation

View File

@@ -0,0 +1,13 @@
[package]
name = "health"
version = "0.1.0"
edition = "2021"
[lints]
workspace = true
[dependencies]
axum = { workspace = true }
time = { workspace = true }
tokio = { workspace = true }
tracing = { workspace = true }

View File

@@ -0,0 +1,344 @@
use std::collections::HashMap;
use std::ops::Add;
use std::sync::{Arc, RwLock};
use axum::http::StatusCode;
use axum::response::{IntoResponse, Response};
use time::Duration;
use tokio::sync::mpsc;
use tracing::{info, warn};
/// Health reporting for components of the service.
///
/// The capture server contains several asynchronous loops, and
/// the process can only be trusted with user data if all the
/// loops are properly running and reporting.
///
/// HealthRegistry allows an arbitrary number of components to
/// be registered and report their health. The process' health
/// status is the combination of these individual health status:
/// - if any component is unhealthy, the process is unhealthy
/// - if all components recently reported healthy, the process is healthy
/// - if a component failed to report healthy for its defined deadline,
/// it is considered unhealthy, and the check fails.
///
/// Trying to merge the k8s concepts of liveness and readiness in
/// a single state is full of foot-guns, so HealthRegistry does not
/// try to do it. Each probe should have its separate instance of
/// the registry to avoid confusions.
#[derive(Default, Debug)]
pub struct HealthStatus {
/// The overall status: true of all components are healthy
pub healthy: bool,
/// Current status of each registered component, for display
pub components: HashMap<String, ComponentStatus>,
}
impl IntoResponse for HealthStatus {
/// Computes the axum status code based on the overall health status,
/// and prints each component status in the body for debugging.
fn into_response(self) -> Response {
let body = format!("{:?}", self);
match self.healthy {
true => (StatusCode::OK, body),
false => (StatusCode::INTERNAL_SERVER_ERROR, body),
}
.into_response()
}
}
#[derive(Debug, Clone, Eq, PartialEq)]
pub enum ComponentStatus {
/// Automatically set when a component is newly registered
Starting,
/// Recently reported healthy, will need to report again before the date
HealthyUntil(time::OffsetDateTime),
/// Reported unhealthy
Unhealthy,
/// Automatically set when the HealthyUntil deadline is reached
Stalled,
}
struct HealthMessage {
component: String,
status: ComponentStatus,
}
pub struct HealthHandle {
component: String,
deadline: Duration,
sender: mpsc::Sender<HealthMessage>,
}
impl HealthHandle {
/// Asynchronously report healthy, returns when the message is queued.
/// Must be called more frequently than the configured deadline.
pub async fn report_healthy(&self) {
self.report_status(ComponentStatus::HealthyUntil(
time::OffsetDateTime::now_utc().add(self.deadline),
))
.await
}
/// Asynchronously report component status, returns when the message is queued.
pub async fn report_status(&self, status: ComponentStatus) {
let message = HealthMessage {
component: self.component.clone(),
status,
};
if let Err(err) = self.sender.send(message).await {
warn!("failed to report heath status: {}", err)
}
}
/// Synchronously report as healthy, returns when the message is queued.
/// Must be called more frequently than the configured deadline.
pub fn report_healthy_blocking(&self) {
self.report_status_blocking(ComponentStatus::HealthyUntil(
time::OffsetDateTime::now_utc().add(self.deadline),
))
}
/// Synchronously report component status, returns when the message is queued.
pub fn report_status_blocking(&self, status: ComponentStatus) {
let message = HealthMessage {
component: self.component.clone(),
status,
};
if let Err(err) = self.sender.blocking_send(message) {
warn!("failed to report heath status: {}", err)
}
}
}
#[derive(Clone)]
pub struct HealthRegistry {
name: String,
components: Arc<RwLock<HashMap<String, ComponentStatus>>>,
sender: mpsc::Sender<HealthMessage>,
}
impl HealthRegistry {
pub fn new(name: &str) -> Self {
let (tx, mut rx) = mpsc::channel::<HealthMessage>(16);
let registry = Self {
name: name.to_owned(),
components: Default::default(),
sender: tx,
};
let components = registry.components.clone();
tokio::spawn(async move {
while let Some(message) = rx.recv().await {
if let Ok(mut map) = components.write() {
_ = map.insert(message.component, message.status);
} else {
// Poisoned mutex: Just warn, the probes will fail and the process restart
warn!("poisoned HeathRegistry mutex")
}
}
});
registry
}
/// Registers a new component in the registry. The returned handle should be passed
/// to the component, to allow it to frequently report its health status.
pub async fn register(&self, component: String, deadline: Duration) -> HealthHandle {
let handle = HealthHandle {
component,
deadline,
sender: self.sender.clone(),
};
handle.report_status(ComponentStatus::Starting).await;
handle
}
/// Returns the overall process status, computed from the status of all the components
/// currently registered. Can be used as an axum handler.
pub fn get_status(&self) -> HealthStatus {
let components = self
.components
.read()
.expect("poisoned HeathRegistry mutex");
let result = HealthStatus {
healthy: !components.is_empty(), // unhealthy if no component has registered yet
components: Default::default(),
};
let now = time::OffsetDateTime::now_utc();
let result = components
.iter()
.fold(result, |mut result, (name, status)| {
match status {
ComponentStatus::HealthyUntil(until) => {
if until.gt(&now) {
_ = result.components.insert(name.clone(), status.clone())
} else {
result.healthy = false;
_ = result
.components
.insert(name.clone(), ComponentStatus::Stalled)
}
}
_ => {
result.healthy = false;
_ = result.components.insert(name.clone(), status.clone())
}
}
result
});
match result.healthy {
true => info!("{} health check ok", self.name),
false => warn!("{} health check failed: {:?}", self.name, result.components),
}
result
}
}
#[cfg(test)]
mod tests {
use crate::{ComponentStatus, HealthRegistry, HealthStatus};
use axum::http::StatusCode;
use axum::response::IntoResponse;
use std::ops::{Add, Sub};
use time::{Duration, OffsetDateTime};
async fn assert_or_retry<F>(check: F)
where
F: Fn() -> bool,
{
assert_or_retry_for_duration(check, Duration::seconds(5)).await
}
async fn assert_or_retry_for_duration<F>(check: F, timeout: Duration)
where
F: Fn() -> bool,
{
let deadline = OffsetDateTime::now_utc().add(timeout);
while !check() && OffsetDateTime::now_utc().lt(&deadline) {
tokio::time::sleep(tokio::time::Duration::from_millis(100)).await;
}
assert!(check())
}
#[tokio::test]
async fn defaults_to_unhealthy() {
let registry = HealthRegistry::new("liveness");
assert!(!registry.get_status().healthy);
}
#[tokio::test]
async fn one_component() {
let registry = HealthRegistry::new("liveness");
// New components are registered in Starting
let handle = registry
.register("one".to_string(), Duration::seconds(30))
.await;
assert_or_retry(|| registry.get_status().components.len() == 1).await;
let mut status = registry.get_status();
assert!(!status.healthy);
assert_eq!(
status.components.get("one"),
Some(&ComponentStatus::Starting)
);
// Status goes healthy once the component reports
handle.report_healthy().await;
assert_or_retry(|| registry.get_status().healthy).await;
status = registry.get_status();
assert_eq!(status.components.len(), 1);
// Status goes unhealthy if the components says so
handle.report_status(ComponentStatus::Unhealthy).await;
assert_or_retry(|| !registry.get_status().healthy).await;
status = registry.get_status();
assert_eq!(status.components.len(), 1);
assert_eq!(
status.components.get("one"),
Some(&ComponentStatus::Unhealthy)
);
}
#[tokio::test]
async fn staleness_check() {
let registry = HealthRegistry::new("liveness");
let handle = registry
.register("one".to_string(), Duration::seconds(30))
.await;
// Status goes healthy once the component reports
handle.report_healthy().await;
assert_or_retry(|| registry.get_status().healthy).await;
let mut status = registry.get_status();
assert_eq!(status.components.len(), 1);
// If the component's ping is too old, it is considered stalled and the healthcheck fails
// FIXME: we should mock the time instead
handle
.report_status(ComponentStatus::HealthyUntil(
OffsetDateTime::now_utc().sub(Duration::seconds(1)),
))
.await;
assert_or_retry(|| !registry.get_status().healthy).await;
status = registry.get_status();
assert_eq!(status.components.len(), 1);
assert_eq!(
status.components.get("one"),
Some(&ComponentStatus::Stalled)
);
}
#[tokio::test]
async fn several_components() {
let registry = HealthRegistry::new("liveness");
let handle1 = registry
.register("one".to_string(), Duration::seconds(30))
.await;
let handle2 = registry
.register("two".to_string(), Duration::seconds(30))
.await;
assert_or_retry(|| registry.get_status().components.len() == 2).await;
// First component going healthy is not enough
handle1.report_healthy().await;
assert_or_retry(|| {
registry.get_status().components.get("one").unwrap() != &ComponentStatus::Starting
})
.await;
assert!(!registry.get_status().healthy);
// Second component going healthy brings the health to green
handle2.report_healthy().await;
assert_or_retry(|| {
registry.get_status().components.get("two").unwrap() != &ComponentStatus::Starting
})
.await;
assert!(registry.get_status().healthy);
// First component going unhealthy takes down the health to red
handle1.report_status(ComponentStatus::Unhealthy).await;
assert_or_retry(|| !registry.get_status().healthy).await;
// First component recovering returns the health to green
handle1.report_healthy().await;
assert_or_retry(|| registry.get_status().healthy).await;
// Second component going unhealthy takes down the health to red
handle2.report_status(ComponentStatus::Unhealthy).await;
assert_or_retry(|| !registry.get_status().healthy).await;
}
#[tokio::test]
async fn into_response() {
let nok = HealthStatus::default().into_response();
assert_eq!(nok.status(), StatusCode::INTERNAL_SERVER_ERROR);
let ok = HealthStatus {
healthy: true,
components: Default::default(),
}
.into_response();
assert_eq!(ok.status(), StatusCode::OK);
}
}

1
rust/depot.json Normal file
View File

@@ -0,0 +1 @@
{ "id": "zcszdgwzsw" }

89
rust/docker-compose.yml Normal file
View File

@@ -0,0 +1,89 @@
version: '3'
services:
zookeeper:
image: zookeeper:3.7.0
restart: on-failure
kafka:
image: ghcr.io/posthog/kafka-container:v2.8.2
restart: on-failure
depends_on:
- zookeeper
environment:
KAFKA_BROKER_ID: 1001
KAFKA_CFG_RESERVED_BROKER_MAX_ID: 1001
KAFKA_CFG_LISTENERS: PLAINTEXT://:9092
KAFKA_CFG_ADVERTISED_LISTENERS: PLAINTEXT://kafka:9092
KAFKA_CFG_ZOOKEEPER_CONNECT: zookeeper:2181
ALLOW_PLAINTEXT_LISTENER: 'true'
ports:
- '9092:9092'
healthcheck:
test: kafka-cluster.sh cluster-id --bootstrap-server localhost:9092 || exit 1
interval: 3s
timeout: 10s
retries: 10
redis:
image: redis:6.2.7-alpine
restart: on-failure
command: redis-server --maxmemory-policy allkeys-lru --maxmemory 200mb
ports:
- '6379:6379'
healthcheck:
test: ['CMD', 'redis-cli', 'ping']
interval: 3s
timeout: 10s
retries: 10
kafka-ui:
image: provectuslabs/kafka-ui:latest
profiles: ['ui']
ports:
- '8080:8080'
depends_on:
- zookeeper
- kafka
environment:
KAFKA_CLUSTERS_0_NAME: local
KAFKA_CLUSTERS_0_BOOTSTRAPSERVERS: kafka:9092
KAFKA_CLUSTERS_0_ZOOKEEPER: zookeeper:2181
db:
container_name: db
image: docker.io/library/postgres:16-alpine
restart: on-failure
environment:
POSTGRES_USER: posthog
POSTGRES_DB: posthog
POSTGRES_PASSWORD: posthog
healthcheck:
test: ['CMD-SHELL', 'pg_isready -U posthog']
interval: 5s
timeout: 5s
ports:
- '15432:5432'
command: postgres -c max_connections=1000 -c idle_in_transaction_session_timeout=300000
setup_test_db:
container_name: setup-test-db
build:
context: .
dockerfile: Dockerfile.migrate
restart: on-failure
depends_on:
db:
condition: service_healthy
restart: true
environment:
DATABASE_URL: postgres://posthog:posthog@db:5432/test_database
echo_server:
image: docker.io/library/caddy:2
container_name: echo-server
restart: on-failure
ports:
- '18081:8081'
volumes:
- ./docker/echo-server/Caddyfile:/etc/caddy/Caddyfile

View File

@@ -0,0 +1,17 @@
{
auto_https off
}
:8081
route /echo {
respond `{http.request.body}` 200 {
close
}
}
route /fail {
respond `{http.request.body}` 400 {
close
}
}

View File

@@ -0,0 +1,38 @@
[package]
name = "feature-flags"
version = "0.1.0"
edition = "2021"
# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html
[dependencies]
anyhow = { workspace = true }
async-trait = { workspace = true }
axum = { workspace = true }
axum-client-ip = { workspace = true }
envconfig = { workspace = true }
tokio = { workspace = true }
tracing = { workspace = true }
tracing-subscriber = { workspace = true, features = ["env-filter"] }
bytes = { workspace = true }
rand = { workspace = true }
redis = { version = "0.23.3", features = [
"tokio-comp",
"cluster",
"cluster-async",
] }
serde = { workspace = true }
serde_json = { workspace = true }
thiserror = { workspace = true }
serde-pickle = { version = "1.1.1"}
sha1 = "0.10.6"
regex = "1.10.4"
[lints]
workspace = true
[dev-dependencies]
assert-json-diff = { workspace = true }
once_cell = "1.18.0"
reqwest = { workspace = true }

View File

@@ -0,0 +1,36 @@
# Testing
```
cargo test --package feature-flags
```
### To watch changes
```
brew install cargo-watch
```
and then run:
```
cargo watch -x test --package feature-flags
```
To run a specific test:
```
cargo watch -x "test --package feature-flags --lib -- property_matching::tests::test_match_properties_math_operators --exact --show-output"
```
# Running
```
RUST_LOG=debug cargo run --bin feature-flags
```
# Format code
```
cargo fmt --package feature-flags
```

View File

@@ -0,0 +1,67 @@
use std::collections::HashMap;
use axum::http::StatusCode;
use axum::response::{IntoResponse, Response};
use serde::{Deserialize, Serialize};
use thiserror::Error;
#[derive(Debug, PartialEq, Eq, Deserialize, Serialize)]
pub enum FlagsResponseCode {
Ok = 1,
}
#[derive(Debug, PartialEq, Eq, Deserialize, Serialize)]
#[serde(rename_all = "camelCase")]
pub struct FlagsResponse {
pub error_while_computing_flags: bool,
// TODO: better typing here, support bool responses
pub feature_flags: HashMap<String, String>,
}
#[derive(Error, Debug)]
pub enum FlagError {
#[error("failed to decode request: {0}")]
RequestDecodingError(String),
#[error("failed to parse request: {0}")]
RequestParsingError(#[from] serde_json::Error),
#[error("Empty distinct_id in request")]
EmptyDistinctId,
#[error("No distinct_id in request")]
MissingDistinctId,
#[error("No api_key in request")]
NoTokenError,
#[error("API key is not valid")]
TokenValidationError,
#[error("rate limited")]
RateLimited,
#[error("failed to parse redis cache data")]
DataParsingError,
#[error("redis unavailable")]
RedisUnavailable,
}
impl IntoResponse for FlagError {
fn into_response(self) -> Response {
match self {
FlagError::RequestDecodingError(_)
| FlagError::RequestParsingError(_)
| FlagError::EmptyDistinctId
| FlagError::MissingDistinctId => (StatusCode::BAD_REQUEST, self.to_string()),
FlagError::NoTokenError | FlagError::TokenValidationError => {
(StatusCode::UNAUTHORIZED, self.to_string())
}
FlagError::RateLimited => (StatusCode::TOO_MANY_REQUESTS, self.to_string()),
FlagError::DataParsingError | FlagError::RedisUnavailable => {
(StatusCode::SERVICE_UNAVAILABLE, self.to_string())
}
}
.into_response()
}
}

View File

@@ -0,0 +1,24 @@
use std::net::SocketAddr;
use envconfig::Envconfig;
#[derive(Envconfig, Clone)]
pub struct Config {
#[envconfig(default = "127.0.0.1:3001")]
pub address: SocketAddr,
#[envconfig(default = "postgres://posthog:posthog@localhost:15432/test_database")]
pub write_database_url: String,
#[envconfig(default = "postgres://posthog:posthog@localhost:15432/test_database")]
pub read_database_url: String,
#[envconfig(default = "1024")]
pub max_concurrent_jobs: usize,
#[envconfig(default = "100")]
pub max_pg_connections: u32,
#[envconfig(default = "redis://localhost:6379/")]
pub redis_url: String,
}

View File

@@ -0,0 +1,214 @@
use serde::Deserialize;
use std::sync::Arc;
use tracing::instrument;
use crate::{
api::FlagError,
redis::{Client, CustomRedisError},
};
// TRICKY: This cache data is coming from django-redis. If it ever goes out of sync, we'll bork.
// TODO: Add integration tests across repos to ensure this doesn't happen.
pub const TEAM_FLAGS_CACHE_PREFIX: &str = "posthog:1:team_feature_flags_";
// TODO: Hmm, revisit when dealing with groups, but seems like
// ideal to just treat it as a u8 and do our own validation on top
#[derive(Debug, Deserialize)]
pub enum GroupTypeIndex {}
#[derive(Debug, Clone, PartialEq, Eq, Deserialize)]
#[serde(rename_all = "snake_case")]
pub enum OperatorType {
Exact,
IsNot,
Icontains,
NotIcontains,
Regex,
NotRegex,
Gt,
Lt,
Gte,
Lte,
IsSet,
IsNotSet,
IsDateExact,
IsDateAfter,
IsDateBefore,
}
#[derive(Debug, Clone, Deserialize)]
pub struct PropertyFilter {
pub key: String,
// TODO: Probably need a default for value?
// incase operators like is_set, is_not_set are used
// not guaranteed to have a value, if say created via api
pub value: serde_json::Value,
pub operator: Option<OperatorType>,
#[serde(rename = "type")]
pub prop_type: String,
pub group_type_index: Option<u8>,
}
#[derive(Debug, Clone, Deserialize)]
pub struct FlagGroupType {
pub properties: Option<Vec<PropertyFilter>>,
pub rollout_percentage: Option<f64>,
pub variant: Option<String>,
}
#[derive(Debug, Clone, Deserialize)]
pub struct MultivariateFlagVariant {
pub key: String,
pub name: Option<String>,
pub rollout_percentage: f64,
}
#[derive(Debug, Clone, Deserialize)]
pub struct MultivariateFlagOptions {
pub variants: Vec<MultivariateFlagVariant>,
}
// TODO: test name with https://www.fileformat.info/info/charset/UTF-16/list.htm values, like '𝖕𝖗𝖔𝖕𝖊𝖗𝖙𝖞': `𝓿𝓪𝓵𝓾𝓮`
#[derive(Debug, Clone, Deserialize)]
pub struct FlagFilters {
pub groups: Vec<FlagGroupType>,
pub multivariate: Option<MultivariateFlagOptions>,
pub aggregation_group_type_index: Option<u8>,
pub payloads: Option<serde_json::Value>,
pub super_groups: Option<Vec<FlagGroupType>>,
}
#[derive(Debug, Clone, Deserialize)]
pub struct FeatureFlag {
pub id: i64,
pub team_id: i64,
pub name: Option<String>,
pub key: String,
pub filters: FlagFilters,
#[serde(default)]
pub deleted: bool,
#[serde(default)]
pub active: bool,
#[serde(default)]
pub ensure_experience_continuity: bool,
}
impl FeatureFlag {
pub fn get_group_type_index(&self) -> Option<u8> {
self.filters.aggregation_group_type_index
}
pub fn get_conditions(&self) -> &Vec<FlagGroupType> {
&self.filters.groups
}
pub fn get_variants(&self) -> Vec<MultivariateFlagVariant> {
self.filters
.multivariate
.clone()
.map_or(vec![], |m| m.variants)
}
}
#[derive(Debug, Deserialize)]
pub struct FeatureFlagList {
pub flags: Vec<FeatureFlag>,
}
impl FeatureFlagList {
/// Returns feature flags from redis given a team_id
#[instrument(skip_all)]
pub async fn from_redis(
client: Arc<dyn Client + Send + Sync>,
team_id: i64,
) -> Result<FeatureFlagList, FlagError> {
// TODO: Instead of failing here, i.e. if not in redis, fallback to pg
let serialized_flags = client
.get(format!("{TEAM_FLAGS_CACHE_PREFIX}{}", team_id))
.await
.map_err(|e| match e {
CustomRedisError::NotFound => FlagError::TokenValidationError,
CustomRedisError::PickleError(_) => {
// TODO: Implement From trait for FlagError so we don't need to map
// CustomRedisError ourselves
tracing::error!("failed to fetch data: {}", e);
println!("failed to fetch data: {}", e);
FlagError::DataParsingError
}
_ => {
tracing::error!("Unknown redis error: {}", e);
FlagError::RedisUnavailable
}
})?;
let flags_list: Vec<FeatureFlag> =
serde_json::from_str(&serialized_flags).map_err(|e| {
tracing::error!("failed to parse data to flags list: {}", e);
println!("failed to parse data: {}", e);
FlagError::DataParsingError
})?;
Ok(FeatureFlagList { flags: flags_list })
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::test_utils::{
insert_flags_for_team_in_redis, insert_new_team_in_redis, setup_redis_client,
};
#[tokio::test]
async fn test_fetch_flags_from_redis() {
let client = setup_redis_client(None);
let team = insert_new_team_in_redis(client.clone())
.await
.expect("Failed to insert team");
insert_flags_for_team_in_redis(client.clone(), team.id, None)
.await
.expect("Failed to insert flags");
let flags_from_redis = FeatureFlagList::from_redis(client.clone(), team.id)
.await
.expect("Failed to fetch flags from redis");
assert_eq!(flags_from_redis.flags.len(), 1);
let flag = flags_from_redis.flags.get(0).expect("Empty flags in redis");
assert_eq!(flag.key, "flag1");
assert_eq!(flag.team_id, team.id);
assert_eq!(flag.filters.groups.len(), 1);
assert_eq!(
flag.filters.groups[0]
.properties
.as_ref()
.expect("Properties don't exist on flag")
.len(),
1
);
}
#[tokio::test]
async fn test_fetch_invalid_team_from_redis() {
let client = setup_redis_client(None);
match FeatureFlagList::from_redis(client.clone(), 1234).await {
Err(FlagError::TokenValidationError) => (),
_ => panic!("Expected TokenValidationError"),
};
}
#[tokio::test]
async fn test_cant_connect_to_redis_error_is_not_token_validation_error() {
let client = setup_redis_client(Some("redis://localhost:1111/".to_string()));
match FeatureFlagList::from_redis(client.clone(), 1234).await {
Err(FlagError::RedisUnavailable) => (),
_ => panic!("Expected RedisUnavailable"),
};
}
}

View File

@@ -0,0 +1,160 @@
use crate::flag_definitions::{FeatureFlag, FlagGroupType};
use sha1::{Digest, Sha1};
use std::fmt::Write;
#[derive(Debug, PartialEq, Eq)]
pub struct FeatureFlagMatch {
pub matches: bool,
pub variant: Option<String>,
//reason
//condition_index
//payload
}
// TODO: Rework FeatureFlagMatcher - python has a pretty awkward interface, where we pass in all flags, and then again
// the flag to match. I don't think there's any reason anymore to store the flags in the matcher, since we can just
// pass the flag to match directly to the get_match method. This will also make the matcher more stateless.
// Potentially, we could also make the matcher a long-lived object, with caching for group keys and such.
// It just takes in the flag and distinct_id and returns the match...
// Or, make this fully stateless
// and have a separate cache struct for caching group keys, cohort definitions, etc. - and check size, if we can keep it in memory
// for all teams. If not, we can have a LRU cache, or a cache that stores only the most recent N keys.
// But, this can be a future refactor, for now just focusing on getting the basic matcher working, write lots and lots of tests
// and then we can easily refactor stuff around.
#[derive(Debug)]
pub struct FeatureFlagMatcher {
// pub flags: Vec<FeatureFlag>,
pub distinct_id: String,
}
const LONG_SCALE: u64 = 0xfffffffffffffff;
impl FeatureFlagMatcher {
pub fn new(distinct_id: String) -> Self {
FeatureFlagMatcher {
// flags,
distinct_id,
}
}
pub fn get_match(&self, feature_flag: &FeatureFlag) -> FeatureFlagMatch {
if self.hashed_identifier(feature_flag).is_none() {
return FeatureFlagMatch {
matches: false,
variant: None,
};
}
// TODO: super groups for early access
// TODO: Variant overrides condition sort
for (index, condition) in feature_flag.get_conditions().iter().enumerate() {
let (is_match, _evaluation_reason) =
self.is_condition_match(feature_flag, condition, index);
if is_match {
// TODO: This is a bit awkward, we should handle overrides only when variants exist.
let variant = match condition.variant.clone() {
Some(variant_override) => {
if feature_flag
.get_variants()
.iter()
.any(|v| v.key == variant_override)
{
Some(variant_override)
} else {
self.get_matching_variant(feature_flag)
}
}
None => self.get_matching_variant(feature_flag),
};
// let payload = self.get_matching_payload(is_match, variant, feature_flag);
return FeatureFlagMatch {
matches: true,
variant,
};
}
}
FeatureFlagMatch {
matches: false,
variant: None,
}
}
pub fn is_condition_match(
&self,
feature_flag: &FeatureFlag,
condition: &FlagGroupType,
_index: usize,
) -> (bool, String) {
let rollout_percentage = condition.rollout_percentage.unwrap_or(100.0);
let mut condition_match = true;
if condition.properties.is_some() {
// TODO: Handle matching conditions
if !condition.properties.as_ref().unwrap().is_empty() {
condition_match = false;
}
}
if !condition_match {
return (false, "NO_CONDITION_MATCH".to_string());
} else if rollout_percentage == 100.0 {
// TODO: Check floating point schenanigans if any
return (true, "CONDITION_MATCH".to_string());
}
if self.get_hash(feature_flag, "") > (rollout_percentage / 100.0) {
return (false, "OUT_OF_ROLLOUT_BOUND".to_string());
}
(true, "CONDITION_MATCH".to_string())
}
pub fn hashed_identifier(&self, feature_flag: &FeatureFlag) -> Option<String> {
if feature_flag.get_group_type_index().is_none() {
// TODO: Use hash key overrides for experience continuity
Some(self.distinct_id.clone())
} else {
// TODO: Handle getting group key
Some("".to_string())
}
}
/// This function takes a identifier and a feature flag key and returns a float between 0 and 1.
/// Given the same identifier and key, it'll always return the same float. These floats are
/// uniformly distributed between 0 and 1, so if we want to show this feature to 20% of traffic
/// we can do _hash(key, identifier) < 0.2
pub fn get_hash(&self, feature_flag: &FeatureFlag, salt: &str) -> f64 {
// check if hashed_identifier is None
let hashed_identifier = self
.hashed_identifier(feature_flag)
.expect("hashed_identifier is None when computing hash");
let hash_key = format!("{}.{}{}", feature_flag.key, hashed_identifier, salt);
let mut hasher = Sha1::new();
hasher.update(hash_key.as_bytes());
let result = hasher.finalize();
// :TRICKY: Convert the first 15 characters of the digest to a hexadecimal string
let hex_str: String = result.iter().fold(String::new(), |mut acc, byte| {
let _ = write!(acc, "{:02x}", byte);
acc
})[..15]
.to_string();
let hash_val = u64::from_str_radix(&hex_str, 16).unwrap();
hash_val as f64 / LONG_SCALE as f64
}
pub fn get_matching_variant(&self, feature_flag: &FeatureFlag) -> Option<String> {
let hash = self.get_hash(feature_flag, "variant");
let mut total_percentage = 0.0;
for variant in feature_flag.get_variants() {
total_percentage += variant.rollout_percentage / 100.0;
if hash < total_percentage {
return Some(variant.key.clone());
}
}
None
}
}

View File

@@ -0,0 +1,19 @@
pub mod api;
pub mod config;
pub mod flag_definitions;
pub mod flag_matching;
pub mod property_matching;
pub mod redis;
pub mod router;
pub mod server;
pub mod team;
pub mod v0_endpoint;
pub mod v0_request;
// Test modules don't need to be compiled with main binary
// #[cfg(test)]
// TODO: To use in integration tests, we need to compile with binary
// or make it a separate feature using cfg(feature = "integration-tests")
// and then use this feature only in tests.
// For now, ok to just include in binary
pub mod test_utils;

View File

@@ -0,0 +1,39 @@
use envconfig::Envconfig;
use tokio::signal;
use tracing_subscriber::layer::SubscriberExt;
use tracing_subscriber::util::SubscriberInitExt;
use tracing_subscriber::{EnvFilter, Layer};
use feature_flags::config::Config;
use feature_flags::server::serve;
async fn shutdown() {
let mut term = signal::unix::signal(signal::unix::SignalKind::terminate())
.expect("failed to register SIGTERM handler");
let mut interrupt = signal::unix::signal(signal::unix::SignalKind::interrupt())
.expect("failed to register SIGINT handler");
tokio::select! {
_ = term.recv() => {},
_ = interrupt.recv() => {},
};
tracing::info!("Shutting down gracefully...");
}
#[tokio::main]
async fn main() {
let config = Config::init_from_env().expect("Invalid configuration:");
// Basic logging for now:
// - stdout with a level configured by the RUST_LOG envvar (default=ERROR)
let log_layer = tracing_subscriber::fmt::layer().with_filter(EnvFilter::from_default_env());
tracing_subscriber::registry().with(log_layer).init();
// Open the TCP port and start the server
let listener = tokio::net::TcpListener::bind(config.address)
.await
.expect("could not bind port");
serve(config, listener, shutdown()).await
}

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,96 @@
use std::time::Duration;
use anyhow::Result;
use async_trait::async_trait;
use redis::{AsyncCommands, RedisError};
use thiserror::Error;
use tokio::time::timeout;
// average for all commands is <10ms, check grafana
const REDIS_TIMEOUT_MILLISECS: u64 = 10;
#[derive(Error, Debug)]
pub enum CustomRedisError {
#[error("Not found in redis")]
NotFound,
#[error("Pickle error: {0}")]
PickleError(#[from] serde_pickle::Error),
#[error("Redis error: {0}")]
Other(#[from] RedisError),
#[error("Timeout error")]
Timeout(#[from] tokio::time::error::Elapsed),
}
/// A simple redis wrapper
/// Copied from capture/src/redis.rs.
/// TODO: Modify this to support hincrby
#[async_trait]
pub trait Client {
// A very simplified wrapper, but works for our usage
async fn zrangebyscore(&self, k: String, min: String, max: String) -> Result<Vec<String>>;
async fn get(&self, k: String) -> Result<String, CustomRedisError>;
async fn set(&self, k: String, v: String) -> Result<()>;
}
pub struct RedisClient {
client: redis::Client,
}
impl RedisClient {
pub fn new(addr: String) -> Result<RedisClient> {
let client = redis::Client::open(addr)?;
Ok(RedisClient { client })
}
}
#[async_trait]
impl Client for RedisClient {
async fn zrangebyscore(&self, k: String, min: String, max: String) -> Result<Vec<String>> {
let mut conn = self.client.get_async_connection().await?;
let results = conn.zrangebyscore(k, min, max);
let fut = timeout(Duration::from_secs(REDIS_TIMEOUT_MILLISECS), results).await?;
Ok(fut?)
}
async fn get(&self, k: String) -> Result<String, CustomRedisError> {
let mut conn = self.client.get_async_connection().await?;
let results = conn.get(k);
let fut: Result<Vec<u8>, RedisError> =
timeout(Duration::from_secs(REDIS_TIMEOUT_MILLISECS), results).await?;
// return NotFound error when empty or not found
if match &fut {
Ok(v) => v.is_empty(),
Err(_) => false,
} {
return Err(CustomRedisError::NotFound);
}
// TRICKY: We serialise data to json, then django pickles it.
// Here we deserialize the bytes using serde_pickle, to get the json string.
let string_response: String = serde_pickle::from_slice(&fut?, Default::default())?;
Ok(string_response)
}
async fn set(&self, k: String, v: String) -> Result<()> {
// TRICKY: We serialise data to json, then django pickles it.
// Here we serialize the json string to bytes using serde_pickle.
let bytes = serde_pickle::to_vec(&v, Default::default())?;
let mut conn = self.client.get_async_connection().await?;
let results = conn.set(k, bytes);
let fut = timeout(Duration::from_secs(REDIS_TIMEOUT_MILLISECS), results).await?;
Ok(fut?)
}
}

View File

@@ -0,0 +1,19 @@
use std::sync::Arc;
use axum::{routing::post, Router};
use crate::{redis::Client, v0_endpoint};
#[derive(Clone)]
pub struct State {
pub redis: Arc<dyn Client + Send + Sync>,
// TODO: Add pgClient when ready
}
pub fn router<R: Client + Send + Sync + 'static>(redis: Arc<R>) -> Router {
let state = State { redis };
Router::new()
.route("/flags", post(v0_endpoint::flags).get(v0_endpoint::flags))
.with_state(state)
}

View File

@@ -0,0 +1,31 @@
use std::future::Future;
use std::net::SocketAddr;
use std::sync::Arc;
use tokio::net::TcpListener;
use crate::config::Config;
use crate::redis::RedisClient;
use crate::router;
pub async fn serve<F>(config: Config, listener: TcpListener, shutdown: F)
where
F: Future<Output = ()> + Send + 'static,
{
let redis_client =
Arc::new(RedisClient::new(config.redis_url).expect("failed to create redis client"));
let app = router::router(redis_client);
// run our app with hyper
// `axum::Server` is a re-export of `hyper::Server`
tracing::info!("listening on {:?}", listener.local_addr().unwrap());
axum::serve(
listener,
app.into_make_service_with_connect_info::<SocketAddr>(),
)
.with_graceful_shutdown(shutdown)
.await
.unwrap()
}

View File

@@ -0,0 +1,140 @@
use serde::{Deserialize, Serialize};
use std::sync::Arc;
use tracing::instrument;
use crate::{
api::FlagError,
redis::{Client, CustomRedisError},
};
// TRICKY: This cache data is coming from django-redis. If it ever goes out of sync, we'll bork.
// TODO: Add integration tests across repos to ensure this doesn't happen.
pub const TEAM_TOKEN_CACHE_PREFIX: &str = "posthog:1:team_token:";
#[derive(Debug, Deserialize, Serialize)]
pub struct Team {
pub id: i64,
pub name: String,
pub api_token: String,
}
impl Team {
/// Validates a token, and returns a team if it exists.
#[instrument(skip_all)]
pub async fn from_redis(
client: Arc<dyn Client + Send + Sync>,
token: String,
) -> Result<Team, FlagError> {
// TODO: Instead of failing here, i.e. if not in redis, fallback to pg
let serialized_team = client
.get(format!("{TEAM_TOKEN_CACHE_PREFIX}{}", token))
.await
.map_err(|e| match e {
CustomRedisError::NotFound => FlagError::TokenValidationError,
CustomRedisError::PickleError(_) => {
tracing::error!("failed to fetch data: {}", e);
FlagError::DataParsingError
}
_ => {
tracing::error!("Unknown redis error: {}", e);
FlagError::RedisUnavailable
}
})?;
// TODO: Consider an LRU cache for teams as well, with small TTL to skip redis/pg lookups
let team: Team = serde_json::from_str(&serialized_team).map_err(|e| {
tracing::error!("failed to parse data to team: {}", e);
FlagError::DataParsingError
})?;
Ok(team)
}
}
#[cfg(test)]
mod tests {
use rand::Rng;
use redis::AsyncCommands;
use super::*;
use crate::{
team,
test_utils::{insert_new_team_in_redis, random_string, setup_redis_client},
};
#[tokio::test]
async fn test_fetch_team_from_redis() {
let client = setup_redis_client(None);
let team = insert_new_team_in_redis(client.clone()).await.unwrap();
let target_token = team.api_token;
let team_from_redis = Team::from_redis(client.clone(), target_token.clone())
.await
.unwrap();
assert_eq!(team_from_redis.api_token, target_token);
assert_eq!(team_from_redis.id, team.id);
}
#[tokio::test]
async fn test_fetch_invalid_team_from_redis() {
let client = setup_redis_client(None);
match Team::from_redis(client.clone(), "banana".to_string()).await {
Err(FlagError::TokenValidationError) => (),
_ => panic!("Expected TokenValidationError"),
};
}
#[tokio::test]
async fn test_cant_connect_to_redis_error_is_not_token_validation_error() {
let client = setup_redis_client(Some("redis://localhost:1111/".to_string()));
match Team::from_redis(client.clone(), "banana".to_string()).await {
Err(FlagError::RedisUnavailable) => (),
_ => panic!("Expected RedisUnavailable"),
};
}
#[tokio::test]
async fn test_corrupted_data_in_redis_is_handled() {
// TODO: Extend this test with fallback to pg
let id = rand::thread_rng().gen_range(0..10_000_000);
let token = random_string("phc_", 12);
let team = Team {
id,
name: "team".to_string(),
api_token: token,
};
let serialized_team = serde_json::to_string(&team).expect("Failed to serialise team");
// manually insert non-pickled data in redis
let client =
redis::Client::open("redis://localhost:6379/").expect("Failed to create redis client");
let mut conn = client
.get_async_connection()
.await
.expect("Failed to get redis connection");
conn.set::<String, String, ()>(
format!(
"{}{}",
team::TEAM_TOKEN_CACHE_PREFIX,
team.api_token.clone()
),
serialized_team,
)
.await
.expect("Failed to write data to redis");
// now get client connection for data
let client = setup_redis_client(None);
match Team::from_redis(client.clone(), team.api_token.clone()).await {
Err(FlagError::DataParsingError) => (),
Err(other) => panic!("Expected DataParsingError, got {:?}", other),
Ok(_) => panic!("Expected DataParsingError"),
};
}
}

View File

@@ -0,0 +1,126 @@
use anyhow::Error;
use serde_json::json;
use std::sync::Arc;
use crate::{
flag_definitions::{self, FeatureFlag},
redis::{Client, RedisClient},
team::{self, Team},
};
use rand::{distributions::Alphanumeric, Rng};
pub fn random_string(prefix: &str, length: usize) -> String {
let suffix: String = rand::thread_rng()
.sample_iter(Alphanumeric)
.take(length)
.map(char::from)
.collect();
format!("{}{}", prefix, suffix)
}
pub async fn insert_new_team_in_redis(client: Arc<RedisClient>) -> Result<Team, Error> {
let id = rand::thread_rng().gen_range(0..10_000_000);
let token = random_string("phc_", 12);
let team = Team {
id,
name: "team".to_string(),
api_token: token,
};
let serialized_team = serde_json::to_string(&team)?;
client
.set(
format!(
"{}{}",
team::TEAM_TOKEN_CACHE_PREFIX,
team.api_token.clone()
),
serialized_team,
)
.await?;
Ok(team)
}
pub async fn insert_flags_for_team_in_redis(
client: Arc<RedisClient>,
team_id: i64,
json_value: Option<String>,
) -> Result<(), Error> {
let payload = match json_value {
Some(value) => value,
None => json!([{
"id": 1,
"key": "flag1",
"name": "flag1 description",
"active": true,
"deleted": false,
"team_id": team_id,
"filters": {
"groups": [
{
"properties": [
{
"key": "email",
"value": "a@b.com",
"type": "person",
},
]
},
],
},
}])
.to_string(),
};
client
.set(
format!("{}{}", flag_definitions::TEAM_FLAGS_CACHE_PREFIX, team_id),
payload,
)
.await?;
Ok(())
}
pub fn setup_redis_client(url: Option<String>) -> Arc<RedisClient> {
let redis_url = match url {
Some(value) => value,
None => "redis://localhost:6379/".to_string(),
};
let client = RedisClient::new(redis_url).expect("Failed to create redis client");
Arc::new(client)
}
pub fn create_flag_from_json(json_value: Option<String>) -> Vec<FeatureFlag> {
let payload = match json_value {
Some(value) => value,
None => json!([{
"id": 1,
"key": "flag1",
"name": "flag1 description",
"active": true,
"deleted": false,
"team_id": 1,
"filters": {
"groups": [
{
"properties": [
{
"key": "email",
"value": "a@b.com",
"type": "person",
},
],
"rollout_percentage": 50,
},
],
},
}])
.to_string(),
};
let flags: Vec<FeatureFlag> =
serde_json::from_str(&payload).expect("Failed to parse data to flags list");
flags
}

View File

@@ -0,0 +1,94 @@
use std::collections::HashMap;
use axum::{debug_handler, Json};
use bytes::Bytes;
// TODO: stream this instead
use axum::extract::{MatchedPath, Query, State};
use axum::http::{HeaderMap, Method};
use axum_client_ip::InsecureClientIp;
use tracing::instrument;
use crate::{
api::{FlagError, FlagsResponse},
router,
v0_request::{FlagRequest, FlagsQueryParams},
};
/// Feature flag evaluation endpoint.
/// Only supports a specific shape of data, and rejects any malformed data.
#[instrument(
skip_all,
fields(
path,
token,
batch_size,
user_agent,
content_encoding,
content_type,
version,
compression,
historical_migration
)
)]
#[debug_handler]
pub async fn flags(
state: State<router::State>,
InsecureClientIp(ip): InsecureClientIp,
meta: Query<FlagsQueryParams>,
headers: HeaderMap,
method: Method,
path: MatchedPath,
body: Bytes,
) -> Result<Json<FlagsResponse>, FlagError> {
let user_agent = headers
.get("user-agent")
.map_or("unknown", |v| v.to_str().unwrap_or("unknown"));
let content_encoding = headers
.get("content-encoding")
.map_or("unknown", |v| v.to_str().unwrap_or("unknown"));
tracing::Span::current().record("user_agent", user_agent);
tracing::Span::current().record("content_encoding", content_encoding);
tracing::Span::current().record("version", meta.version.clone());
tracing::Span::current().record("method", method.as_str());
tracing::Span::current().record("path", path.as_str().trim_end_matches('/'));
tracing::Span::current().record("ip", ip.to_string());
let request = match headers
.get("content-type")
.map_or("", |v| v.to_str().unwrap_or(""))
{
"application/json" => {
tracing::Span::current().record("content_type", "application/json");
FlagRequest::from_bytes(body)
}
ct => {
return Err(FlagError::RequestDecodingError(format!(
"unsupported content type: {}",
ct
)));
}
}?;
let token = request
.extract_and_verify_token(state.redis.clone())
.await?;
let distinct_id = request.extract_distinct_id()?;
tracing::Span::current().record("token", &token);
tracing::Span::current().record("distinct_id", &distinct_id);
tracing::debug!("request: {:?}", request);
// TODO: Some actual processing for evaluating the feature flag
Ok(Json(FlagsResponse {
error_while_computing_flags: false,
feature_flags: HashMap::from([
("beta-feature".to_string(), "variant-1".to_string()),
("rollout-flag".to_string(), true.to_string()),
]),
}))
}

View File

@@ -0,0 +1,138 @@
use std::{collections::HashMap, sync::Arc};
use bytes::Bytes;
use serde::{Deserialize, Serialize};
use serde_json::Value;
use tracing::instrument;
use crate::{api::FlagError, redis::Client, team::Team};
#[derive(Deserialize, Default)]
pub struct FlagsQueryParams {
#[serde(alias = "v")]
pub version: Option<String>,
}
#[derive(Default, Debug, Deserialize, Serialize)]
pub struct FlagRequest {
#[serde(
alias = "$token",
alias = "api_key",
skip_serializing_if = "Option::is_none"
)]
pub token: Option<String>,
#[serde(alias = "$distinct_id", skip_serializing_if = "Option::is_none")]
pub distinct_id: Option<String>,
pub geoip_disable: Option<bool>,
#[serde(default)]
pub person_properties: Option<HashMap<String, Value>>,
#[serde(default)]
pub groups: Option<HashMap<String, Value>>,
// TODO: better type this since we know its going to be a nested json
#[serde(default)]
pub group_properties: Option<HashMap<String, Value>>,
#[serde(alias = "$anon_distinct_id", skip_serializing_if = "Option::is_none")]
pub anon_distinct_id: Option<String>,
}
impl FlagRequest {
/// Takes a request payload and tries to read it.
/// Only supports base64 encoded payloads or uncompressed utf-8 as json.
#[instrument(skip_all)]
pub fn from_bytes(bytes: Bytes) -> Result<FlagRequest, FlagError> {
tracing::debug!(len = bytes.len(), "decoding new request");
// TODO: Add base64 decoding
let payload = String::from_utf8(bytes.into()).map_err(|e| {
tracing::error!("failed to decode body: {}", e);
FlagError::RequestDecodingError(String::from("invalid body encoding"))
})?;
tracing::debug!(json = payload, "decoded event data");
Ok(serde_json::from_str::<FlagRequest>(&payload)?)
}
pub async fn extract_and_verify_token(
&self,
redis_client: Arc<dyn Client + Send + Sync>,
) -> Result<String, FlagError> {
let token = match self {
FlagRequest {
token: Some(token), ..
} => token.to_string(),
_ => return Err(FlagError::NoTokenError),
};
// validate token
Team::from_redis(redis_client, token.clone()).await?;
// TODO: fallback when token not found in redis
Ok(token)
}
pub fn extract_distinct_id(&self) -> Result<String, FlagError> {
let distinct_id = match &self.distinct_id {
None => return Err(FlagError::MissingDistinctId),
Some(id) => id,
};
match distinct_id.len() {
0 => Err(FlagError::EmptyDistinctId),
1..=200 => Ok(distinct_id.to_owned()),
_ => Ok(distinct_id.chars().take(200).collect()),
}
}
}
#[cfg(test)]
mod tests {
use crate::api::FlagError;
use crate::v0_request::FlagRequest;
use bytes::Bytes;
use serde_json::json;
#[test]
fn empty_distinct_id_not_accepted() {
let json = json!({
"distinct_id": "",
"token": "my_token1",
});
let bytes = Bytes::from(json.to_string());
let flag_payload = FlagRequest::from_bytes(bytes).expect("failed to parse request");
match flag_payload.extract_distinct_id() {
Err(FlagError::EmptyDistinctId) => (),
_ => panic!("expected empty distinct id error"),
};
}
#[test]
fn too_large_distinct_id_is_truncated() {
let json = json!({
"distinct_id": std::iter::repeat("a").take(210).collect::<String>(),
"token": "my_token1",
});
let bytes = Bytes::from(json.to_string());
let flag_payload = FlagRequest::from_bytes(bytes).expect("failed to parse request");
assert_eq!(flag_payload.extract_distinct_id().unwrap().len(), 200);
}
#[test]
fn distinct_id_is_returned_correctly() {
let json = json!({
"$distinct_id": "alakazam",
"token": "my_token1",
});
let bytes = Bytes::from(json.to_string());
let flag_payload = FlagRequest::from_bytes(bytes).expect("failed to parse request");
match flag_payload.extract_distinct_id() {
Ok(id) => assert_eq!(id, "alakazam"),
_ => panic!("expected distinct id"),
};
}
}

View File

@@ -0,0 +1,71 @@
use std::net::SocketAddr;
use std::str::FromStr;
use std::string::ToString;
use std::sync::Arc;
use once_cell::sync::Lazy;
use reqwest::header::CONTENT_TYPE;
use tokio::net::TcpListener;
use tokio::sync::Notify;
use feature_flags::config::Config;
use feature_flags::server::serve;
pub static DEFAULT_CONFIG: Lazy<Config> = Lazy::new(|| Config {
address: SocketAddr::from_str("127.0.0.1:0").unwrap(),
redis_url: "redis://localhost:6379/".to_string(),
write_database_url: "postgres://posthog:posthog@localhost:15432/test_database".to_string(),
read_database_url: "postgres://posthog:posthog@localhost:15432/test_database".to_string(),
max_concurrent_jobs: 1024,
max_pg_connections: 100,
});
pub struct ServerHandle {
pub addr: SocketAddr,
shutdown: Arc<Notify>,
}
impl ServerHandle {
pub async fn for_config(config: Config) -> ServerHandle {
let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
let addr = listener.local_addr().unwrap();
let notify = Arc::new(Notify::new());
let shutdown = notify.clone();
tokio::spawn(async move {
serve(config, listener, async move { notify.notified().await }).await
});
ServerHandle { addr, shutdown }
}
pub async fn send_flags_request<T: Into<reqwest::Body>>(&self, body: T) -> reqwest::Response {
let client = reqwest::Client::new();
client
.post(format!("http://{:?}/flags", self.addr))
.body(body)
.header(CONTENT_TYPE, "application/json")
.send()
.await
.expect("failed to send request")
}
pub async fn send_invalid_header_for_flags_request<T: Into<reqwest::Body>>(
&self,
body: T,
) -> reqwest::Response {
let client = reqwest::Client::new();
client
.post(format!("http://{:?}/flags", self.addr))
.body(body)
.header(CONTENT_TYPE, "xyz")
.send()
.await
.expect("failed to send request")
}
}
impl Drop for ServerHandle {
fn drop(&mut self) {
self.shutdown.notify_one()
}
}

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,83 @@
use anyhow::Result;
use assert_json_diff::assert_json_include;
use reqwest::StatusCode;
use serde_json::{json, Value};
use crate::common::*;
use feature_flags::test_utils::{insert_new_team_in_redis, setup_redis_client};
pub mod common;
#[tokio::test]
async fn it_sends_flag_request() -> Result<()> {
let config = DEFAULT_CONFIG.clone();
let distinct_id = "user_distinct_id".to_string();
let client = setup_redis_client(Some(config.redis_url.clone()));
let team = insert_new_team_in_redis(client.clone()).await.unwrap();
let token = team.api_token;
let server = ServerHandle::for_config(config).await;
let payload = json!({
"token": token,
"distinct_id": distinct_id,
"groups": {"group1": "group1"}
});
let res = server.send_flags_request(payload.to_string()).await;
assert_eq!(StatusCode::OK, res.status());
// We don't want to deserialize the data into a flagResponse struct here,
// because we want to assert the shape of the raw json data.
let json_data = res.json::<Value>().await?;
assert_json_include!(
actual: json_data,
expected: json!({
"errorWhileComputingFlags": false,
"featureFlags": {
"beta-feature": "variant-1",
"rollout-flag": "true",
}
})
);
Ok(())
}
#[tokio::test]
async fn it_rejects_invalid_headers_flag_request() -> Result<()> {
let config = DEFAULT_CONFIG.clone();
let distinct_id = "user_distinct_id".to_string();
let client = setup_redis_client(Some(config.redis_url.clone()));
let team = insert_new_team_in_redis(client.clone()).await.unwrap();
let token = team.api_token;
let server = ServerHandle::for_config(config).await;
let payload = json!({
"token": token,
"distinct_id": distinct_id,
"groups": {"group1": "group1"}
});
let res = server
.send_invalid_header_for_flags_request(payload.to_string())
.await;
assert_eq!(StatusCode::BAD_REQUEST, res.status());
// We don't want to deserialize the data into a flagResponse struct here,
// because we want to assert the shape of the raw json data.
let response_text = res.text().await?;
assert_eq!(
response_text,
"failed to decode request: unsupported content type: xyz"
);
Ok(())
}

25
rust/hook-api/Cargo.toml Normal file
View File

@@ -0,0 +1,25 @@
[package]
name = "hook-api"
version = "0.1.0"
edition = "2021"
[lints]
workspace = true
[dependencies]
axum = { workspace = true }
envconfig = { workspace = true }
eyre = { workspace = true }
hook-common = { path = "../hook-common" }
http-body-util = { workspace = true }
metrics = { workspace = true }
serde = { workspace = true }
serde_derive = { workspace = true }
serde_json = { workspace = true }
sqlx = { workspace = true }
tokio = { workspace = true }
tower = { workspace = true }
tower-http = { workspace = true }
tracing = { workspace = true }
tracing-subscriber = { workspace = true }
url = { workspace = true }

View File

@@ -0,0 +1,28 @@
use envconfig::Envconfig;
#[derive(Envconfig)]
pub struct Config {
#[envconfig(from = "BIND_HOST", default = "0.0.0.0")]
pub host: String,
#[envconfig(from = "BIND_PORT", default = "3300")]
pub port: u16,
#[envconfig(default = "postgres://posthog:posthog@localhost:15432/test_database")]
pub database_url: String,
#[envconfig(default = "default")]
pub queue_name: String,
#[envconfig(default = "100")]
pub max_pg_connections: u32,
#[envconfig(default = "5000000")]
pub max_body_size: usize,
}
impl Config {
pub fn bind(&self) -> String {
format!("{}:{}", self.host, self.port)
}
}

View File

@@ -0,0 +1,53 @@
use axum::{routing, Router};
use tower_http::limit::RequestBodyLimitLayer;
use hook_common::pgqueue::PgQueue;
use super::webhook;
pub fn add_routes(router: Router, pg_pool: PgQueue, max_body_size: usize) -> Router {
router
.route("/", routing::get(index))
.route("/_readiness", routing::get(index))
.route("/_liveness", routing::get(index)) // No async loop for now, just check axum health
.route(
"/webhook",
routing::post(webhook::post)
.with_state(pg_pool)
.layer(RequestBodyLimitLayer::new(max_body_size)),
)
}
pub async fn index() -> &'static str {
"rusty-hook api"
}
#[cfg(test)]
mod tests {
use super::*;
use axum::{
body::Body,
http::{Request, StatusCode},
};
use hook_common::pgqueue::PgQueue;
use http_body_util::BodyExt; // for `collect`
use sqlx::PgPool;
use tower::ServiceExt; // for `call`, `oneshot`, and `ready`
#[sqlx::test(migrations = "../migrations")]
async fn index(db: PgPool) {
let pg_queue = PgQueue::new_from_pool("test_index", db).await;
let app = add_routes(Router::new(), pg_queue, 1_000_000);
let response = app
.oneshot(Request::builder().uri("/").body(Body::empty()).unwrap())
.await
.unwrap();
assert_eq!(response.status(), StatusCode::OK);
let body = response.into_body().collect().await.unwrap().to_bytes();
assert_eq!(&body[..], b"rusty-hook api");
}
}

View File

@@ -0,0 +1,4 @@
mod app;
mod webhook;
pub use app::add_routes;

View File

@@ -0,0 +1,279 @@
use std::time::Instant;
use axum::{extract::State, http::StatusCode, Json};
use hook_common::webhook::{WebhookJobMetadata, WebhookJobParameters};
use serde_derive::Deserialize;
use url::Url;
use hook_common::pgqueue::{NewJob, PgQueue};
use serde::Serialize;
use tracing::{debug, error};
#[derive(Serialize, Deserialize)]
pub struct WebhookPostResponse {
#[serde(skip_serializing_if = "Option::is_none")]
error: Option<String>,
}
/// The body of a request made to create a webhook Job.
#[derive(Deserialize, Serialize, Debug, PartialEq, Clone)]
pub struct WebhookPostRequestBody {
parameters: WebhookJobParameters,
metadata: WebhookJobMetadata,
#[serde(default = "default_max_attempts")]
max_attempts: u32,
}
fn default_max_attempts() -> u32 {
3
}
pub async fn post(
State(pg_queue): State<PgQueue>,
Json(payload): Json<WebhookPostRequestBody>,
) -> Result<Json<WebhookPostResponse>, (StatusCode, Json<WebhookPostResponse>)> {
debug!("received payload: {:?}", payload);
let url_hostname = get_hostname(&payload.parameters.url)?;
// We could cast to i32, but this ensures we are not wrapping.
let max_attempts = i32::try_from(payload.max_attempts).map_err(|_| {
(
StatusCode::BAD_REQUEST,
Json(WebhookPostResponse {
error: Some("invalid number of max attempts".to_owned()),
}),
)
})?;
let job = NewJob::new(
max_attempts,
payload.metadata,
payload.parameters,
url_hostname.as_str(),
);
let start_time = Instant::now();
pg_queue.enqueue(job).await.map_err(internal_error)?;
let elapsed_time = start_time.elapsed().as_secs_f64();
metrics::histogram!("webhook_api_enqueue").record(elapsed_time);
Ok(Json(WebhookPostResponse { error: None }))
}
fn internal_error<E>(err: E) -> (StatusCode, Json<WebhookPostResponse>)
where
E: std::error::Error,
{
error!("internal error: {}", err);
(
StatusCode::INTERNAL_SERVER_ERROR,
Json(WebhookPostResponse {
error: Some(err.to_string()),
}),
)
}
fn get_hostname(url_str: &str) -> Result<String, (StatusCode, Json<WebhookPostResponse>)> {
let url = Url::parse(url_str).map_err(|_| {
(
StatusCode::BAD_REQUEST,
Json(WebhookPostResponse {
error: Some("could not parse url".to_owned()),
}),
)
})?;
match url.host_str() {
Some(hostname) => Ok(hostname.to_owned()),
None => Err((
StatusCode::BAD_REQUEST,
Json(WebhookPostResponse {
error: Some("couldn't extract hostname from url".to_owned()),
}),
)),
}
}
#[cfg(test)]
mod tests {
use super::*;
use axum::{
body::Body,
http::{self, Request, StatusCode},
Router,
};
use hook_common::pgqueue::PgQueue;
use hook_common::webhook::{HttpMethod, WebhookJobParameters};
use http_body_util::BodyExt;
use sqlx::PgPool; // for `collect`
use std::collections;
use tower::ServiceExt; // for `call`, `oneshot`, and `ready`
use crate::handlers::app::add_routes;
const MAX_BODY_SIZE: usize = 1_000_000;
#[sqlx::test(migrations = "../migrations")]
async fn webhook_success(db: PgPool) {
let pg_queue = PgQueue::new_from_pool("test_index", db).await;
let app = add_routes(Router::new(), pg_queue, MAX_BODY_SIZE);
let mut headers = collections::HashMap::new();
headers.insert("Content-Type".to_owned(), "application/json".to_owned());
let response = app
.oneshot(
Request::builder()
.method(http::Method::POST)
.uri("/webhook")
.header(http::header::CONTENT_TYPE, "application/json")
.body(Body::from(
serde_json::to_string(&WebhookPostRequestBody {
parameters: WebhookJobParameters {
headers,
method: HttpMethod::POST,
url: "http://example.com/".to_owned(),
body: r#"{"a": "b"}"#.to_owned(),
},
metadata: WebhookJobMetadata {
team_id: 1,
plugin_id: 2,
plugin_config_id: 3,
},
max_attempts: 1,
})
.unwrap(),
))
.unwrap(),
)
.await
.unwrap();
assert_eq!(response.status(), StatusCode::OK);
let body = response.into_body().collect().await.unwrap().to_bytes();
assert_eq!(&body[..], b"{}");
}
#[sqlx::test(migrations = "../migrations")]
async fn webhook_bad_url(db: PgPool) {
let pg_queue = PgQueue::new_from_pool("test_index", db).await;
let app = add_routes(Router::new(), pg_queue, MAX_BODY_SIZE);
let response = app
.oneshot(
Request::builder()
.method(http::Method::POST)
.uri("/webhook")
.header(http::header::CONTENT_TYPE, "application/json")
.body(Body::from(
serde_json::to_string(&WebhookPostRequestBody {
parameters: WebhookJobParameters {
headers: collections::HashMap::new(),
method: HttpMethod::POST,
url: "invalid".to_owned(),
body: r#"{"a": "b"}"#.to_owned(),
},
metadata: WebhookJobMetadata {
team_id: 1,
plugin_id: 2,
plugin_config_id: 3,
},
max_attempts: 1,
})
.unwrap(),
))
.unwrap(),
)
.await
.unwrap();
assert_eq!(response.status(), StatusCode::BAD_REQUEST);
}
#[sqlx::test(migrations = "../migrations")]
async fn webhook_payload_missing_fields(db: PgPool) {
let pg_queue = PgQueue::new_from_pool("test_index", db).await;
let app = add_routes(Router::new(), pg_queue, MAX_BODY_SIZE);
let response = app
.oneshot(
Request::builder()
.method(http::Method::POST)
.uri("/webhook")
.header(http::header::CONTENT_TYPE, "application/json")
.body("{}".to_owned())
.unwrap(),
)
.await
.unwrap();
assert_eq!(response.status(), StatusCode::UNPROCESSABLE_ENTITY);
}
#[sqlx::test(migrations = "../migrations")]
async fn webhook_payload_not_json(db: PgPool) {
let pg_queue = PgQueue::new_from_pool("test_index", db).await;
let app = add_routes(Router::new(), pg_queue, MAX_BODY_SIZE);
let response = app
.oneshot(
Request::builder()
.method(http::Method::POST)
.uri("/webhook")
.header(http::header::CONTENT_TYPE, "application/json")
.body("x".to_owned())
.unwrap(),
)
.await
.unwrap();
assert_eq!(response.status(), StatusCode::BAD_REQUEST);
}
#[sqlx::test(migrations = "../migrations")]
async fn webhook_payload_body_too_large(db: PgPool) {
let pg_queue = PgQueue::new_from_pool("test_index", db).await;
let app = add_routes(Router::new(), pg_queue, MAX_BODY_SIZE);
let bytes: Vec<u8> = vec![b'a'; MAX_BODY_SIZE + 1];
let long_string = String::from_utf8_lossy(&bytes);
let response = app
.oneshot(
Request::builder()
.method(http::Method::POST)
.uri("/webhook")
.header(http::header::CONTENT_TYPE, "application/json")
.body(Body::from(
serde_json::to_string(&WebhookPostRequestBody {
parameters: WebhookJobParameters {
headers: collections::HashMap::new(),
method: HttpMethod::POST,
url: "http://example.com".to_owned(),
body: long_string.to_string(),
},
metadata: WebhookJobMetadata {
team_id: 1,
plugin_id: 2,
plugin_config_id: 3,
},
max_attempts: 1,
})
.unwrap(),
))
.unwrap(),
)
.await
.unwrap();
assert_eq!(response.status(), StatusCode::PAYLOAD_TOO_LARGE);
}
}

44
rust/hook-api/src/main.rs Normal file
View File

@@ -0,0 +1,44 @@
use axum::Router;
use config::Config;
use envconfig::Envconfig;
use eyre::Result;
use hook_common::metrics::setup_metrics_routes;
use hook_common::pgqueue::PgQueue;
mod config;
mod handlers;
async fn listen(app: Router, bind: String) -> Result<()> {
let listener = tokio::net::TcpListener::bind(bind).await?;
axum::serve(listener, app).await?;
Ok(())
}
#[tokio::main]
async fn main() {
tracing_subscriber::fmt::init();
let config = Config::init_from_env().expect("failed to load configuration from env");
let pg_queue = PgQueue::new(
// TODO: Coupling the queue name to the PgQueue object doesn't seem ideal from the api
// side, but we don't need more than one queue for now.
&config.queue_name,
&config.database_url,
config.max_pg_connections,
"hook-api",
)
.await
.expect("failed to initialize queue");
let app = handlers::add_routes(Router::new(), pg_queue, config.max_body_size);
let app = setup_metrics_routes(app);
match listen(app, config.bind()).await {
Ok(_) => {}
Err(e) => tracing::error!("failed to start hook-api http server, {}", e),
}
}

View File

@@ -0,0 +1,27 @@
[package]
name = "hook-common"
version = "0.1.0"
edition = "2021"
[lints]
workspace = true
[dependencies]
async-trait = { workspace = true }
axum = { workspace = true, features = ["http2"] }
chrono = { workspace = true }
http = { workspace = true }
metrics = { workspace = true }
metrics-exporter-prometheus = { workspace = true }
reqwest = { workspace = true }
serde = { workspace = true }
serde_json = { workspace = true }
sqlx = { workspace = true }
thiserror = { workspace = true }
time = { workspace = true }
tokio = { workspace = true }
tracing = { workspace = true }
uuid = { workspace = true }
[dev-dependencies]
tokio = { workspace = true } # We need a runtime for async tests

View File

@@ -0,0 +1,2 @@
# hook-common
Library of common utilities used by rusty-hook.

View File

@@ -0,0 +1,208 @@
use chrono::{DateTime, Utc};
use serde::{Deserialize, Deserializer, Serialize, Serializer};
use uuid::Uuid;
use super::{deserialize_datetime, serialize_datetime};
#[derive(Deserialize, Serialize, Debug, PartialEq, Clone)]
pub enum AppMetricCategory {
ProcessEvent,
OnEvent,
ScheduledTask,
Webhook,
ComposeWebhook,
}
// NOTE: These are stored in Postgres and deserialized by the cleanup/janitor process, so these
// names need to remain stable, or new variants need to be deployed to the cleanup/janitor
// process before they are used.
#[derive(Deserialize, Serialize, Debug, PartialEq, Clone)]
pub enum ErrorType {
TimeoutError,
ConnectionError,
BadHttpStatus(u16),
ParseError,
}
// NOTE: This is stored in Postgres and deserialized by the cleanup/janitor process, so this
// shouldn't change. It is intended to replicate the shape of `error_details` used in the
// plugin-server and by the frontend.
#[derive(Deserialize, Serialize, Debug, PartialEq, Clone)]
pub struct ErrorDetails {
pub error: Error,
}
#[derive(Deserialize, Serialize, Debug, PartialEq, Clone)]
pub struct Error {
pub name: String,
#[serde(skip_serializing_if = "Option::is_none")]
pub message: Option<String>,
// This field will only be useful if we start running plugins in Rust (via a WASM runtime or
// something) and want to provide the user with stack traces like we do for TypeScript plugins.
#[serde(skip_serializing_if = "Option::is_none")]
pub stack: Option<String>,
}
#[derive(Deserialize, Serialize, Debug, PartialEq, Clone)]
pub struct AppMetric {
#[serde(
serialize_with = "serialize_datetime",
deserialize_with = "deserialize_datetime"
)]
pub timestamp: DateTime<Utc>,
pub team_id: u32,
pub plugin_config_id: i32,
#[serde(skip_serializing_if = "Option::is_none")]
pub job_id: Option<String>,
#[serde(
serialize_with = "serialize_category",
deserialize_with = "deserialize_category"
)]
pub category: AppMetricCategory,
pub successes: u32,
pub successes_on_retry: u32,
pub failures: u32,
#[serde(skip_serializing_if = "Option::is_none")]
pub error_uuid: Option<Uuid>,
#[serde(
serialize_with = "serialize_error_type",
deserialize_with = "deserialize_error_type",
default,
skip_serializing_if = "Option::is_none"
)]
pub error_type: Option<ErrorType>,
#[serde(skip_serializing_if = "Option::is_none")]
pub error_details: Option<ErrorDetails>,
}
fn serialize_category<S>(category: &AppMetricCategory, serializer: S) -> Result<S::Ok, S::Error>
where
S: Serializer,
{
let category_str = match category {
AppMetricCategory::ProcessEvent => "processEvent",
AppMetricCategory::OnEvent => "onEvent",
AppMetricCategory::ScheduledTask => "scheduledTask",
AppMetricCategory::Webhook => "webhook",
AppMetricCategory::ComposeWebhook => "composeWebhook",
};
serializer.serialize_str(category_str)
}
fn deserialize_category<'de, D>(deserializer: D) -> Result<AppMetricCategory, D::Error>
where
D: Deserializer<'de>,
{
let s: String = Deserialize::deserialize(deserializer)?;
let category = match &s[..] {
"processEvent" => AppMetricCategory::ProcessEvent,
"onEvent" => AppMetricCategory::OnEvent,
"scheduledTask" => AppMetricCategory::ScheduledTask,
"webhook" => AppMetricCategory::Webhook,
"composeWebhook" => AppMetricCategory::ComposeWebhook,
_ => {
return Err(serde::de::Error::unknown_variant(
&s,
&[
"processEvent",
"onEvent",
"scheduledTask",
"webhook",
"composeWebhook",
],
))
}
};
Ok(category)
}
fn serialize_error_type<S>(error_type: &Option<ErrorType>, serializer: S) -> Result<S::Ok, S::Error>
where
S: Serializer,
{
let error_type = match error_type {
Some(error_type) => error_type,
None => return serializer.serialize_none(),
};
let error_type = match error_type {
ErrorType::ConnectionError => "Connection Error".to_owned(),
ErrorType::TimeoutError => "Timeout Error".to_owned(),
ErrorType::BadHttpStatus(s) => format!("Bad HTTP Status: {}", s),
ErrorType::ParseError => "Parse Error".to_owned(),
};
serializer.serialize_str(&error_type)
}
fn deserialize_error_type<'de, D>(deserializer: D) -> Result<Option<ErrorType>, D::Error>
where
D: Deserializer<'de>,
{
let opt = Option::<String>::deserialize(deserializer)?;
let error_type = match opt {
Some(s) => {
let error_type = match &s[..] {
"Connection Error" => ErrorType::ConnectionError,
"Timeout Error" => ErrorType::TimeoutError,
_ if s.starts_with("Bad HTTP Status:") => {
let status = &s["Bad HTTP Status:".len()..];
ErrorType::BadHttpStatus(status.parse().map_err(serde::de::Error::custom)?)
}
"Parse Error" => ErrorType::ParseError,
_ => {
return Err(serde::de::Error::unknown_variant(
&s,
&[
"Connection Error",
"Timeout Error",
"Bad HTTP Status: <status>",
"Parse Error",
],
))
}
};
Some(error_type)
}
None => None,
};
Ok(error_type)
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_app_metric_serialization() {
use chrono::prelude::*;
let app_metric = AppMetric {
timestamp: Utc.with_ymd_and_hms(2023, 12, 14, 12, 2, 0).unwrap(),
team_id: 123,
plugin_config_id: 456,
job_id: None,
category: AppMetricCategory::Webhook,
successes: 10,
successes_on_retry: 0,
failures: 2,
error_uuid: Some(Uuid::parse_str("550e8400-e29b-41d4-a716-446655447777").unwrap()),
error_type: Some(ErrorType::ConnectionError),
error_details: Some(ErrorDetails {
error: Error {
name: "FooError".to_owned(),
message: Some("Error Message".to_owned()),
stack: None,
},
}),
};
let serialized_json = serde_json::to_string(&app_metric).unwrap();
let expected_json = r#"{"timestamp":"2023-12-14 12:02:00","team_id":123,"plugin_config_id":456,"category":"webhook","successes":10,"successes_on_retry":0,"failures":2,"error_uuid":"550e8400-e29b-41d4-a716-446655447777","error_type":"Connection Error","error_details":{"error":{"name":"FooError","message":"Error Message"}}}"#;
assert_eq!(serialized_json, expected_json);
}
}

View File

@@ -0,0 +1,25 @@
pub mod app_metrics;
pub mod plugin_logs;
use chrono::{DateTime, NaiveDateTime, Utc};
use serde::{Deserialize, Deserializer, Serializer};
pub fn serialize_datetime<S>(datetime: &DateTime<Utc>, serializer: S) -> Result<S::Ok, S::Error>
where
S: Serializer,
{
serializer.serialize_str(&datetime.format("%Y-%m-%d %H:%M:%S").to_string())
}
pub fn deserialize_datetime<'de, D>(deserializer: D) -> Result<DateTime<Utc>, D::Error>
where
D: Deserializer<'de>,
{
let formatted: String = Deserialize::deserialize(deserializer)?;
let datetime = match NaiveDateTime::parse_from_str(&formatted, "%Y-%m-%d %H:%M:%S") {
Ok(d) => d.and_utc(),
Err(_) => return Err(serde::de::Error::custom("Invalid datetime format")),
};
Ok(datetime)
}

View File

@@ -0,0 +1,126 @@
use chrono::{DateTime, Utc};
use serde::{Serialize, Serializer};
use uuid::Uuid;
use super::serialize_datetime;
#[derive(Serialize)]
pub enum PluginLogEntrySource {
System,
Plugin,
Console,
}
#[derive(Serialize)]
pub enum PluginLogEntryType {
Debug,
Log,
Info,
Warn,
Error,
}
#[derive(Serialize)]
pub struct PluginLogEntry {
#[serde(serialize_with = "serialize_source")]
pub source: PluginLogEntrySource,
#[serde(rename = "type", serialize_with = "serialize_type")]
pub type_: PluginLogEntryType,
pub id: Uuid,
pub team_id: u32,
pub plugin_id: i32,
pub plugin_config_id: i32,
#[serde(serialize_with = "serialize_datetime")]
pub timestamp: DateTime<Utc>,
#[serde(serialize_with = "serialize_message")]
pub message: String,
pub instance_id: Uuid,
}
fn serialize_source<S>(source: &PluginLogEntrySource, serializer: S) -> Result<S::Ok, S::Error>
where
S: Serializer,
{
let source_str = match source {
PluginLogEntrySource::System => "SYSTEM",
PluginLogEntrySource::Plugin => "PLUGIN",
PluginLogEntrySource::Console => "CONSOLE",
};
serializer.serialize_str(source_str)
}
fn serialize_type<S>(type_: &PluginLogEntryType, serializer: S) -> Result<S::Ok, S::Error>
where
S: Serializer,
{
let type_str = match type_ {
PluginLogEntryType::Debug => "DEBUG",
PluginLogEntryType::Log => "LOG",
PluginLogEntryType::Info => "INFO",
PluginLogEntryType::Warn => "WARN",
PluginLogEntryType::Error => "ERROR",
};
serializer.serialize_str(type_str)
}
fn serialize_message<S>(msg: &str, serializer: S) -> Result<S::Ok, S::Error>
where
S: Serializer,
{
if msg.len() > 50_000 {
return Err(serde::ser::Error::custom(
"Message is too long for ClickHouse",
));
}
serializer.serialize_str(msg)
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_plugin_log_entry_serialization() {
use chrono::prelude::*;
let log_entry = PluginLogEntry {
source: PluginLogEntrySource::Plugin,
type_: PluginLogEntryType::Warn,
id: Uuid::parse_str("550e8400-e29b-41d4-a716-446655440000").unwrap(),
team_id: 4,
plugin_id: 5,
plugin_config_id: 6,
timestamp: Utc.with_ymd_and_hms(2023, 12, 14, 12, 2, 0).unwrap(),
message: "My message!".to_string(),
instance_id: Uuid::parse_str("00000000-0000-0000-0000-000000000000").unwrap(),
};
let serialized_json = serde_json::to_string(&log_entry).unwrap();
assert_eq!(
serialized_json,
r#"{"source":"PLUGIN","type":"WARN","id":"550e8400-e29b-41d4-a716-446655440000","team_id":4,"plugin_id":5,"plugin_config_id":6,"timestamp":"2023-12-14 12:02:00","message":"My message!","instance_id":"00000000-0000-0000-0000-000000000000"}"#
);
}
#[test]
fn test_plugin_log_entry_message_too_long() {
use chrono::prelude::*;
let log_entry = PluginLogEntry {
source: PluginLogEntrySource::Plugin,
type_: PluginLogEntryType::Warn,
id: Uuid::parse_str("550e8400-e29b-41d4-a716-446655440000").unwrap(),
team_id: 4,
plugin_id: 5,
plugin_config_id: 6,
timestamp: Utc.with_ymd_and_hms(2023, 12, 14, 12, 2, 0).unwrap(),
message: "My message!".repeat(10_000).to_string(),
instance_id: Uuid::parse_str("00000000-0000-0000-0000-000000000000").unwrap(),
};
let err = serde_json::to_string(&log_entry).unwrap_err();
assert_eq!(err.to_string(), "Message is too long for ClickHouse");
}
}

View File

@@ -0,0 +1,5 @@
pub mod kafka_messages;
pub mod metrics;
pub mod pgqueue;
pub mod retry;
pub mod webhook;

View File

@@ -0,0 +1,82 @@
use std::time::{Instant, SystemTime};
use axum::{
body::Body, extract::MatchedPath, http::Request, middleware::Next, response::IntoResponse,
routing::get, Router,
};
use metrics_exporter_prometheus::{PrometheusBuilder, PrometheusHandle};
/// Bind a `TcpListener` on the provided bind address to serve a `Router` on it.
/// This function is intended to take a Router as returned by `setup_metrics_router`, potentially with more routes added by the caller.
pub async fn serve(router: Router, bind: &str) -> Result<(), std::io::Error> {
let listener = tokio::net::TcpListener::bind(bind).await?;
axum::serve(listener, router).await?;
Ok(())
}
/// Add the prometheus endpoint and middleware to a router, should be called last.
pub fn setup_metrics_routes(router: Router) -> Router {
let recorder_handle = setup_metrics_recorder();
router
.route(
"/metrics",
get(move || std::future::ready(recorder_handle.render())),
)
.layer(axum::middleware::from_fn(track_metrics))
}
pub fn setup_metrics_recorder() -> PrometheusHandle {
const EXPONENTIAL_SECONDS: &[f64] = &[
0.005, 0.01, 0.025, 0.05, 0.1, 0.25, 0.5, 1.0, 2.5, 5.0, 10.0,
];
PrometheusBuilder::new()
.set_buckets(EXPONENTIAL_SECONDS)
.unwrap()
.install_recorder()
.unwrap()
}
/// Middleware to record some common HTTP metrics
/// Someday tower-http might provide a metrics middleware: https://github.com/tower-rs/tower-http/issues/57
pub async fn track_metrics(req: Request<Body>, next: Next) -> impl IntoResponse {
let start = Instant::now();
let path = if let Some(matched_path) = req.extensions().get::<MatchedPath>() {
matched_path.as_str().to_owned()
} else {
req.uri().path().to_owned()
};
let method = req.method().clone();
// Run the rest of the request handling first, so we can measure it and get response
// codes.
let response = next.run(req).await;
let latency = start.elapsed().as_secs_f64();
let status = response.status().as_u16().to_string();
let labels = [
("method", method.to_string()),
("path", path),
("status", status),
];
metrics::counter!("http_requests_total", &labels).increment(1);
metrics::histogram!("http_requests_duration_seconds", &labels).record(latency);
response
}
/// Returns the number of seconds since the Unix epoch, to use in prom gauges.
/// Saturates to zero if the system time is set before epoch.
pub fn get_current_timestamp_seconds() -> f64 {
SystemTime::now()
.duration_since(SystemTime::UNIX_EPOCH)
.unwrap_or_default()
.as_secs() as f64
}

View File

@@ -0,0 +1,957 @@
//! # PgQueue
//!
//! A job queue implementation backed by a PostgreSQL table.
use std::time;
use std::{str::FromStr, sync::Arc};
use async_trait::async_trait;
use chrono;
use serde;
use sqlx::postgres::any::AnyConnectionBackend;
use sqlx::postgres::{PgConnectOptions, PgPool, PgPoolOptions};
use thiserror::Error;
use tokio::sync::Mutex;
use tracing::error;
/// Enumeration of parsing errors in PgQueue.
#[derive(Error, Debug)]
pub enum ParseError {
#[error("{0} is not a valid JobStatus")]
ParseJobStatusError(String),
#[error("{0} is not a valid HttpMethod")]
ParseHttpMethodError(String),
#[error("transaction was already closed")]
TransactionAlreadyClosedError,
}
/// Enumeration of database-related errors in PgQueue.
/// Errors that can originate from sqlx and are wrapped by us to provide additional context.
#[derive(Error, Debug)]
pub enum DatabaseError {
#[error("pool creation failed with: {error}")]
PoolCreationError { error: sqlx::Error },
#[error("connection failed with: {error}")]
ConnectionError { error: sqlx::Error },
#[error("{command} query failed with: {error}")]
QueryError { command: String, error: sqlx::Error },
#[error("transaction {command} failed with: {error}")]
TransactionError { command: String, error: sqlx::Error },
#[error("transaction was already closed")]
TransactionAlreadyClosedError,
}
/// An error that occurs when a job cannot be retried.
/// Returns the underlying job so that a client can fail it.
#[derive(Error, Debug)]
#[error("retry is an invalid state for this job: {error}")]
pub struct RetryInvalidError<T> {
pub job: T,
pub error: String,
}
/// Enumeration of errors that can occur when retrying a job.
/// They are in a separate enum a failed retry could be returning the underlying job.
#[derive(Error, Debug)]
pub enum RetryError<T> {
#[error(transparent)]
DatabaseError(#[from] DatabaseError),
#[error(transparent)]
RetryInvalidError(#[from] RetryInvalidError<T>),
}
/// Enumeration of possible statuses for a Job.
#[derive(Debug, PartialEq, sqlx::Type)]
#[sqlx(type_name = "job_status")]
#[sqlx(rename_all = "lowercase")]
pub enum JobStatus {
/// A job that is waiting in the queue to be picked up by a worker.
Available,
/// A job that was cancelled by a worker.
Cancelled,
/// A job that was successfully completed by a worker.
Completed,
/// A job that has
Discarded,
/// A job that was unsuccessfully completed by a worker.
Failed,
}
/// Allow casting JobStatus from strings.
impl FromStr for JobStatus {
type Err = ParseError;
fn from_str(s: &str) -> Result<Self, Self::Err> {
match s {
"available" => Ok(JobStatus::Available),
"completed" => Ok(JobStatus::Completed),
"failed" => Ok(JobStatus::Failed),
invalid => Err(ParseError::ParseJobStatusError(invalid.to_owned())),
}
}
}
/// JobParameters are stored and read to and from a JSONB field, so we accept anything that fits `sqlx::types::Json`.
pub type JobParameters<J> = sqlx::types::Json<J>;
/// JobMetadata are stored and read to and from a JSONB field, so we accept anything that fits `sqlx::types::Json`.
pub type JobMetadata<M> = sqlx::types::Json<M>;
/// A Job to be executed by a worker dequeueing a PgQueue.
#[derive(sqlx::FromRow, Debug)]
pub struct Job<J, M> {
/// A unique id identifying a job.
pub id: i64,
/// A number corresponding to the current job attempt.
pub attempt: i32,
/// A datetime corresponding to when the job was attempted.
pub attempted_at: chrono::DateTime<chrono::offset::Utc>,
/// A vector of identifiers that have attempted this job. E.g. thread ids, pod names, etc...
pub attempted_by: Vec<String>,
/// A datetime corresponding to when the job was created.
pub created_at: chrono::DateTime<chrono::offset::Utc>,
/// The current job's number of max attempts.
pub max_attempts: i32,
/// Arbitrary job metadata stored as JSON.
pub metadata: JobMetadata<M>,
/// Arbitrary job parameters stored as JSON.
pub parameters: JobParameters<J>,
/// The queue this job belongs to.
pub queue: String,
/// The current status of the job.
pub status: JobStatus,
/// The target of the job. E.g. an endpoint or service we are trying to reach.
pub target: String,
}
impl<J, M> Job<J, M> {
/// Return true if this job attempt is greater or equal to the maximum number of possible attempts.
pub fn is_gte_max_attempts(&self) -> bool {
self.attempt >= self.max_attempts
}
/// Consume `Job` to transition it to a `RetryableJob`, i.e. a `Job` that may be retried.
fn retryable(self) -> RetryableJob {
RetryableJob {
id: self.id,
attempt: self.attempt,
queue: self.queue,
retry_queue: None,
}
}
/// Consume `Job` to complete it.
/// A `CompletedJob` is finalized and cannot be used further; it is returned for reporting or inspection.
///
/// # Arguments
///
/// * `executor`: Any sqlx::Executor that can execute the UPDATE query required to mark this `Job` as completed.
async fn complete<'c, E>(self, executor: E) -> Result<CompletedJob, sqlx::Error>
where
E: sqlx::Executor<'c, Database = sqlx::Postgres>,
{
let base_query = r#"
UPDATE
job_queue
SET
last_attempt_finished_at = NOW(),
status = 'completed'::job_status
WHERE
queue = $1
AND id = $2
RETURNING
job_queue.*
"#;
sqlx::query(base_query)
.bind(&self.queue)
.bind(self.id)
.execute(executor)
.await?;
Ok(CompletedJob {
id: self.id,
queue: self.queue,
})
}
/// Consume `Job` to fail it.
/// A `FailedJob` is finalized and cannot be used further; it is returned for reporting or inspection.
///
/// # Arguments
///
/// * `error`: Any JSON-serializable value to be stored as an error.
/// * `executor`: Any sqlx::Executor that can execute the UPDATE query required to mark this `Job` as failed.
async fn fail<'c, E, S>(self, error: S, executor: E) -> Result<FailedJob<S>, sqlx::Error>
where
S: serde::Serialize + std::marker::Sync + std::marker::Send,
E: sqlx::Executor<'c, Database = sqlx::Postgres>,
{
let json_error = sqlx::types::Json(error);
let base_query = r#"
UPDATE
job_queue
SET
last_attempt_finished_at = NOW(),
status = 'failed'::job_status,
errors = array_append(errors, $3)
WHERE
queue = $1
AND id = $2
RETURNING
job_queue.*
"#;
sqlx::query(base_query)
.bind(&self.queue)
.bind(self.id)
.bind(&json_error)
.execute(executor)
.await?;
Ok(FailedJob {
id: self.id,
error: json_error,
queue: self.queue,
})
}
}
#[async_trait]
pub trait PgQueueJob {
async fn complete(mut self) -> Result<CompletedJob, DatabaseError>;
async fn fail<E: serde::Serialize + std::marker::Sync + std::marker::Send>(
mut self,
error: E,
) -> Result<FailedJob<E>, DatabaseError>;
async fn retry<E: serde::Serialize + std::marker::Sync + std::marker::Send>(
mut self,
error: E,
retry_interval: time::Duration,
queue: &str,
) -> Result<RetriedJob, RetryError<Box<Self>>>;
}
/// A Job within an open PostgreSQL transaction.
/// This implementation allows 'hiding' the job from any other workers running SKIP LOCKED queries.
#[derive(Debug)]
pub struct PgTransactionJob<'c, J, M> {
pub job: Job<J, M>,
/// The open transaction this job came from. If multiple jobs were queried at once, then this
/// transaction will be shared between them (across async tasks and threads as necessary). See
/// below for more information.
shared_txn: Arc<Mutex<Option<sqlx::Transaction<'c, sqlx::postgres::Postgres>>>>,
}
// Container struct for a batch of PgTransactionJob. Includes a reference to the shared transaction
// for committing the work when all of the jobs are finished.
pub struct PgTransactionBatch<'c, J, M> {
pub jobs: Vec<PgTransactionJob<'c, J, M>>,
/// The open transaction the jobs in the Vec came from. This should be used to commit or
/// rollback when all of the work is finished.
shared_txn: Arc<Mutex<Option<sqlx::Transaction<'c, sqlx::postgres::Postgres>>>>,
}
impl<'c, J, M> PgTransactionBatch<'_, J, M> {
pub async fn commit(self) -> PgQueueResult<()> {
let mut txn_guard = self.shared_txn.lock().await;
txn_guard
.as_deref_mut()
.ok_or(DatabaseError::TransactionAlreadyClosedError)?
.commit()
.await
.map_err(|e| DatabaseError::QueryError {
command: "COMMIT".to_owned(),
error: e,
})?;
Ok(())
}
}
#[async_trait]
impl<'c, J: std::marker::Send, M: std::marker::Send> PgQueueJob for PgTransactionJob<'c, J, M> {
async fn complete(mut self) -> Result<CompletedJob, DatabaseError> {
let mut txn_guard = self.shared_txn.lock().await;
let txn_ref = txn_guard
.as_deref_mut()
.ok_or(DatabaseError::TransactionAlreadyClosedError)?;
let completed_job =
self.job
.complete(txn_ref)
.await
.map_err(|error| DatabaseError::QueryError {
command: "UPDATE".to_owned(),
error,
})?;
Ok(completed_job)
}
async fn fail<S: serde::Serialize + std::marker::Sync + std::marker::Send>(
mut self,
error: S,
) -> Result<FailedJob<S>, DatabaseError> {
let mut txn_guard = self.shared_txn.lock().await;
let txn_ref = txn_guard
.as_deref_mut()
.ok_or(DatabaseError::TransactionAlreadyClosedError)?;
let failed_job =
self.job
.fail(error, txn_ref)
.await
.map_err(|error| DatabaseError::QueryError {
command: "UPDATE".to_owned(),
error,
})?;
Ok(failed_job)
}
async fn retry<E: serde::Serialize + std::marker::Sync + std::marker::Send>(
mut self,
error: E,
retry_interval: time::Duration,
queue: &str,
) -> Result<RetriedJob, RetryError<Box<PgTransactionJob<'c, J, M>>>> {
// Ideally, the transition to RetryableJob should be fallible.
// But taking ownership of self when we return this error makes things difficult.
if self.job.is_gte_max_attempts() {
return Err(RetryError::from(RetryInvalidError {
job: Box::new(self),
error: "Maximum attempts reached".to_owned(),
}));
}
let mut txn_guard = self.shared_txn.lock().await;
let txn_ref = txn_guard
.as_deref_mut()
.ok_or(DatabaseError::TransactionAlreadyClosedError)?;
let retried_job = self
.job
.retryable()
.queue(queue)
.retry(error, retry_interval, txn_ref)
.await
.map_err(|error| DatabaseError::QueryError {
command: "UPDATE".to_owned(),
error,
})?;
Ok(retried_job)
}
}
/// A Job that has failed but can still be enqueued into a PgQueue to be retried at a later point.
/// The time until retry will depend on the PgQueue's RetryPolicy.
pub struct RetryableJob {
/// A unique id identifying a job.
pub id: i64,
/// A number corresponding to the current job attempt.
pub attempt: i32,
/// A unique id identifying a job queue.
queue: String,
/// An optional separate queue where to enqueue this job when retrying.
retry_queue: Option<String>,
}
impl RetryableJob {
/// Set the queue for a `RetryableJob`.
/// If not set, `Job` will be retried to its original queue on calling `retry`.
fn queue(mut self, queue: &str) -> Self {
self.retry_queue = Some(queue.to_owned());
self
}
/// Return the queue that a `Job` is to be retried into.
fn retry_queue(&self) -> &str {
self.retry_queue.as_ref().unwrap_or(&self.queue)
}
/// Consume `Job` to retry it.
/// A `RetriedJob` cannot be used further; it is returned for reporting or inspection.
///
/// # Arguments
///
/// * `error`: Any JSON-serializable value to be stored as an error.
/// * `retry_interval`: The duration until the `Job` is to be retried again. Used to set `scheduled_at`.
/// * `executor`: Any sqlx::Executor that can execute the UPDATE query required to mark this `Job` as completed.
async fn retry<'c, S, E>(
self,
error: S,
retry_interval: time::Duration,
executor: E,
) -> Result<RetriedJob, sqlx::Error>
where
S: serde::Serialize + std::marker::Sync + std::marker::Send,
E: sqlx::Executor<'c, Database = sqlx::Postgres>,
{
let json_error = sqlx::types::Json(error);
let base_query = r#"
UPDATE
job_queue
SET
last_attempt_finished_at = NOW(),
status = 'available'::job_status,
scheduled_at = NOW() + $3,
errors = array_append(errors, $4),
queue = $5
WHERE
queue = $1
AND id = $2
RETURNING
job_queue.*
"#;
sqlx::query(base_query)
.bind(&self.queue)
.bind(self.id)
.bind(retry_interval)
.bind(&json_error)
.bind(self.retry_queue())
.execute(executor)
.await?;
Ok(RetriedJob {
id: self.id,
queue: self.queue,
retry_queue: self.retry_queue.to_owned(),
})
}
}
/// State a `Job` is transitioned to after successfully completing.
#[derive(Debug)]
pub struct CompletedJob {
/// A unique id identifying a job.
pub id: i64,
/// A unique id identifying a job queue.
pub queue: String,
}
/// State a `Job` is transitioned to after it has been enqueued for retrying.
#[derive(Debug)]
pub struct RetriedJob {
/// A unique id identifying a job.
pub id: i64,
/// A unique id identifying a job queue.
pub queue: String,
pub retry_queue: Option<String>,
}
/// State a `Job` is transitioned to after exhausting all of their attempts.
#[derive(Debug)]
pub struct FailedJob<J> {
/// A unique id identifying a job.
pub id: i64,
/// Any JSON-serializable value to be stored as an error.
pub error: sqlx::types::Json<J>,
/// A unique id identifying a job queue.
pub queue: String,
}
/// This struct represents a new job being created to be enqueued into a `PgQueue`.
#[derive(Debug)]
pub struct NewJob<J, M> {
/// The maximum amount of attempts this NewJob has to complete.
pub max_attempts: i32,
/// The JSON-deserializable parameters for this NewJob.
pub metadata: JobMetadata<M>,
/// The JSON-deserializable parameters for this NewJob.
pub parameters: JobParameters<J>,
/// The target of the NewJob. E.g. an endpoint or service we are trying to reach.
pub target: String,
}
impl<J, M> NewJob<J, M> {
pub fn new(max_attempts: i32, metadata: M, parameters: J, target: &str) -> Self {
Self {
max_attempts,
metadata: sqlx::types::Json(metadata),
parameters: sqlx::types::Json(parameters),
target: target.to_owned(),
}
}
}
/// A queue implemented on top of a PostgreSQL table.
#[derive(Clone)]
pub struct PgQueue {
/// A name to identify this PgQueue as multiple may share a table.
name: String,
/// A connection pool used to connect to the PostgreSQL database.
pool: PgPool,
}
pub type PgQueueResult<T> = std::result::Result<T, DatabaseError>;
impl PgQueue {
/// Initialize a new PgQueue backed by table in PostgreSQL by intializing a connection pool to the database in `url`.
///
/// # Arguments
///
/// * `queue_name`: A name for the queue we are going to initialize.
/// * `url`: A URL pointing to where the PostgreSQL database is hosted.
pub async fn new(
queue_name: &str,
url: &str,
max_connections: u32,
app_name: &'static str,
) -> PgQueueResult<Self> {
let name = queue_name.to_owned();
let options = PgConnectOptions::from_str(url)
.map_err(|error| DatabaseError::PoolCreationError { error })?
.application_name(app_name);
let pool = PgPoolOptions::new()
.max_connections(max_connections)
.connect_lazy_with(options);
Ok(Self { name, pool })
}
/// Initialize a new PgQueue backed by table in PostgreSQL from a provided connection pool.
///
/// # Arguments
///
/// * `queue_name`: A name for the queue we are going to initialize.
/// * `pool`: A database connection pool to be used by this queue.
pub async fn new_from_pool(queue_name: &str, pool: PgPool) -> PgQueue {
let name = queue_name.to_owned();
Self { name, pool }
}
/// Dequeue up to `limit` `Job`s from this `PgQueue` and hold the transaction.
/// Any other `dequeue_tx` calls will skip rows locked, so by holding a transaction we ensure only one
/// worker can dequeue a job. Holding a transaction open can have performance implications, but
/// it means no `'running'` state is required.
pub async fn dequeue_tx<
'a,
J: for<'d> serde::Deserialize<'d> + std::marker::Send + std::marker::Unpin + 'static,
M: for<'d> serde::Deserialize<'d> + std::marker::Send + std::marker::Unpin + 'static,
>(
&self,
attempted_by: &str,
limit: u32,
) -> PgQueueResult<Option<PgTransactionBatch<'a, J, M>>> {
let mut tx = self
.pool
.begin()
.await
.map_err(|error| DatabaseError::ConnectionError { error })?;
// The query that follows uses a FOR UPDATE SKIP LOCKED clause.
// For more details on this see: 2ndquadrant.com/en/blog/what-is-select-skip-locked-for-in-postgresql-9-5.
let base_query = r#"
WITH available_in_queue AS (
SELECT
id
FROM
job_queue
WHERE
status = 'available'
AND scheduled_at <= NOW()
AND queue = $1
ORDER BY
attempt,
scheduled_at
LIMIT $2
FOR UPDATE SKIP LOCKED
)
UPDATE
job_queue
SET
attempted_at = NOW(),
attempt = attempt + 1,
attempted_by = array_append(attempted_by, $3::text)
FROM
available_in_queue
WHERE
job_queue.id = available_in_queue.id
RETURNING
job_queue.*
"#;
let query_result: Result<Vec<Job<J, M>>, sqlx::Error> = sqlx::query_as(base_query)
.bind(&self.name)
.bind(limit as i64)
.bind(attempted_by)
.fetch_all(&mut *tx)
.await;
match query_result {
Ok(jobs) => {
if jobs.is_empty() {
return Ok(None);
}
let shared_txn = Arc::new(Mutex::new(Some(tx)));
let pg_jobs: Vec<PgTransactionJob<J, M>> = jobs
.into_iter()
.map(|job| PgTransactionJob {
job,
shared_txn: shared_txn.clone(),
})
.collect();
Ok(Some(PgTransactionBatch {
jobs: pg_jobs,
shared_txn: shared_txn.clone(),
}))
}
// Transaction is rolled back on drop.
Err(sqlx::Error::RowNotFound) => Ok(None),
Err(e) => Err(DatabaseError::QueryError {
command: "UPDATE".to_owned(),
error: e,
}),
}
}
/// Enqueue a `NewJob` into this PgQueue.
/// We take ownership of `NewJob` to enforce a specific `NewJob` is only enqueued once.
pub async fn enqueue<
J: serde::Serialize + std::marker::Sync,
M: serde::Serialize + std::marker::Sync,
>(
&self,
job: NewJob<J, M>,
) -> PgQueueResult<()> {
// TODO: Escaping. I think sqlx doesn't support identifiers.
let base_query = r#"
INSERT INTO job_queue
(attempt, created_at, scheduled_at, max_attempts, metadata, parameters, queue, status, target)
VALUES
(0, NOW(), NOW(), $1, $2, $3, $4, 'available'::job_status, $5)
"#;
sqlx::query(base_query)
.bind(job.max_attempts)
.bind(&job.metadata)
.bind(&job.parameters)
.bind(&self.name)
.bind(&job.target)
.execute(&self.pool)
.await
.map_err(|error| DatabaseError::QueryError {
command: "INSERT".to_owned(),
error,
})?;
Ok(())
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::retry::RetryPolicy;
#[derive(serde::Serialize, serde::Deserialize, PartialEq, Debug, Clone)]
struct JobMetadata {
team_id: u32,
plugin_config_id: i32,
plugin_id: i32,
}
impl Default for JobMetadata {
fn default() -> Self {
Self {
team_id: 0,
plugin_config_id: 1,
plugin_id: 2,
}
}
}
#[derive(serde::Serialize, serde::Deserialize, PartialEq, Debug, Clone)]
struct JobParameters {
method: String,
body: String,
url: String,
}
impl Default for JobParameters {
fn default() -> Self {
Self {
method: "POST".to_string(),
body: "{\"event\":\"event-name\"}".to_string(),
url: "https://localhost".to_string(),
}
}
}
/// Use process id as a worker id for tests.
fn worker_id() -> String {
std::process::id().to_string()
}
/// Hardcoded test value for job target.
fn job_target() -> String {
"https://myhost/endpoint".to_owned()
}
#[sqlx::test(migrations = "../migrations")]
async fn test_can_dequeue_tx_job(db: PgPool) {
let job_target = job_target();
let job_metadata = JobMetadata::default();
let job_parameters = JobParameters::default();
let worker_id = worker_id();
let queue = PgQueue::new_from_pool("test_can_dequeue_tx_job", db).await;
let new_job = NewJob::new(1, job_metadata, job_parameters, &job_target);
queue.enqueue(new_job).await.expect("failed to enqueue job");
let mut batch: PgTransactionBatch<'_, JobParameters, JobMetadata> = queue
.dequeue_tx(&worker_id, 1)
.await
.expect("failed to dequeue jobs")
.expect("didn't find any jobs to dequeue");
let tx_job = batch.jobs.pop().unwrap();
assert_eq!(tx_job.job.attempt, 1);
assert!(tx_job.job.attempted_by.contains(&worker_id));
assert_eq!(tx_job.job.attempted_by.len(), 1);
assert_eq!(tx_job.job.max_attempts, 1);
assert_eq!(*tx_job.job.metadata.as_ref(), JobMetadata::default());
assert_eq!(*tx_job.job.parameters.as_ref(), JobParameters::default());
assert_eq!(tx_job.job.target, job_target);
// Transactional jobs must be completed, failed or retried before being dropped. This is
// to prevent logic bugs when using the shared txn.
tx_job.complete().await.expect("failed to complete job");
batch.commit().await.expect("failed to commit transaction");
}
#[sqlx::test(migrations = "../migrations")]
async fn test_can_dequeue_multiple_tx_jobs(db: PgPool) {
let job_target = job_target();
let job_metadata = JobMetadata::default();
let job_parameters = JobParameters::default();
let worker_id = worker_id();
let queue = PgQueue::new_from_pool("test_can_dequeue_multiple_tx_jobs", db).await;
for _ in 0..5 {
queue
.enqueue(NewJob::new(
1,
job_metadata.clone(),
job_parameters.clone(),
&job_target,
))
.await
.expect("failed to enqueue job");
}
// Only get 4 jobs, leaving one in the queue.
let limit = 4;
let mut batch: PgTransactionBatch<'_, JobParameters, JobMetadata> = queue
.dequeue_tx(&worker_id, limit)
.await
.expect("failed to dequeue job")
.expect("didn't find a job to dequeue");
assert_eq!(batch.jobs.len(), limit as usize);
// Complete those 4 and commit.
for job in std::mem::take(&mut batch.jobs) {
job.complete().await.expect("failed to complete job");
}
batch.commit().await.expect("failed to commit transaction");
// Try to get up to 4 jobs, but only 1 remains.
let mut batch: PgTransactionBatch<'_, JobParameters, JobMetadata> = queue
.dequeue_tx(&worker_id, limit)
.await
.expect("failed to dequeue job")
.expect("didn't find a job to dequeue");
assert_eq!(batch.jobs.len(), 1); // Only one job should have been left in the queue.
for job in std::mem::take(&mut batch.jobs) {
job.complete().await.expect("failed to complete job");
}
batch.commit().await.expect("failed to commit transaction");
}
#[sqlx::test(migrations = "../migrations")]
async fn test_dequeue_tx_returns_none_on_no_jobs(db: PgPool) {
let worker_id = worker_id();
let queue = PgQueue::new_from_pool("test_dequeue_tx_returns_none_on_no_jobs", db).await;
let batch: Option<PgTransactionBatch<'_, JobParameters, JobMetadata>> = queue
.dequeue_tx(&worker_id, 1)
.await
.expect("failed to dequeue job");
assert!(batch.is_none());
}
#[sqlx::test(migrations = "../migrations")]
async fn test_can_retry_job_with_remaining_attempts(db: PgPool) {
let job_target = job_target();
let job_parameters = JobParameters::default();
let job_metadata = JobMetadata::default();
let worker_id = worker_id();
let new_job = NewJob::new(2, job_metadata, job_parameters, &job_target);
let queue_name = "test_can_retry_job_with_remaining_attempts".to_owned();
let retry_policy = RetryPolicy::build(0, time::Duration::from_secs(0))
.queue(&queue_name)
.provide();
let queue = PgQueue::new_from_pool(&queue_name, db).await;
queue.enqueue(new_job).await.expect("failed to enqueue job");
let mut batch: PgTransactionBatch<'_, JobParameters, JobMetadata> = queue
.dequeue_tx(&worker_id, 1)
.await
.expect("failed to dequeue job")
.expect("didn't find a job to dequeue");
let job = batch.jobs.pop().unwrap();
let retry_interval = retry_policy.retry_interval(job.job.attempt as u32, None);
let retry_queue = retry_policy.retry_queue(&job.job.queue).to_owned();
drop(
job.retry(
"a very reasonable failure reason",
retry_interval,
&retry_queue,
)
.await
.expect("failed to retry job"),
);
batch.commit().await.expect("failed to commit transaction");
let retried_job: PgTransactionJob<JobParameters, JobMetadata> = queue
.dequeue_tx(&worker_id, 1)
.await
.expect("failed to dequeue job")
.expect("didn't find retried job to dequeue")
.jobs
.pop()
.unwrap();
assert_eq!(retried_job.job.attempt, 2);
assert!(retried_job.job.attempted_by.contains(&worker_id));
assert_eq!(retried_job.job.attempted_by.len(), 2);
assert_eq!(retried_job.job.max_attempts, 2);
assert_eq!(
*retried_job.job.parameters.as_ref(),
JobParameters::default()
);
assert_eq!(retried_job.job.target, job_target);
}
#[sqlx::test(migrations = "../migrations")]
async fn test_can_retry_job_to_different_queue(db: PgPool) {
let job_target = job_target();
let job_parameters = JobParameters::default();
let job_metadata = JobMetadata::default();
let worker_id = worker_id();
let new_job = NewJob::new(2, job_metadata, job_parameters, &job_target);
let queue_name = "test_can_retry_job_to_different_queue".to_owned();
let retry_queue_name = "test_can_retry_job_to_different_queue_retry".to_owned();
let retry_policy = RetryPolicy::build(0, time::Duration::from_secs(0))
.queue(&retry_queue_name)
.provide();
let queue = PgQueue::new_from_pool(&queue_name, db.clone()).await;
queue.enqueue(new_job).await.expect("failed to enqueue job");
let mut batch: PgTransactionBatch<JobParameters, JobMetadata> = queue
.dequeue_tx(&worker_id, 1)
.await
.expect("failed to dequeue job")
.expect("didn't find a job to dequeue");
let job = batch.jobs.pop().unwrap();
let retry_interval = retry_policy.retry_interval(job.job.attempt as u32, None);
let retry_queue = retry_policy.retry_queue(&job.job.queue).to_owned();
drop(
job.retry(
"a very reasonable failure reason",
retry_interval,
&retry_queue,
)
.await
.expect("failed to retry job"),
);
batch.commit().await.expect("failed to commit transaction");
let retried_job_not_found: Option<PgTransactionBatch<JobParameters, JobMetadata>> = queue
.dequeue_tx(&worker_id, 1)
.await
.expect("failed to dequeue job");
assert!(retried_job_not_found.is_none());
let queue = PgQueue::new_from_pool(&retry_queue_name, db).await;
let retried_job: PgTransactionJob<JobParameters, JobMetadata> = queue
.dequeue_tx(&worker_id, 1)
.await
.expect("failed to dequeue job")
.expect("job not found in retry queue")
.jobs
.pop()
.unwrap();
assert_eq!(retried_job.job.attempt, 2);
assert!(retried_job.job.attempted_by.contains(&worker_id));
assert_eq!(retried_job.job.attempted_by.len(), 2);
assert_eq!(retried_job.job.max_attempts, 2);
assert_eq!(
*retried_job.job.parameters.as_ref(),
JobParameters::default()
);
assert_eq!(retried_job.job.target, job_target);
}
#[sqlx::test(migrations = "../migrations")]
#[should_panic(expected = "failed to retry job")]
async fn test_cannot_retry_job_without_remaining_attempts(db: PgPool) {
let job_target = job_target();
let job_parameters = JobParameters::default();
let job_metadata = JobMetadata::default();
let worker_id = worker_id();
let new_job = NewJob::new(1, job_metadata, job_parameters, &job_target);
let retry_policy = RetryPolicy::build(0, time::Duration::from_secs(0)).provide();
let queue =
PgQueue::new_from_pool("test_cannot_retry_job_without_remaining_attempts", db).await;
queue.enqueue(new_job).await.expect("failed to enqueue job");
let job: PgTransactionJob<JobParameters, JobMetadata> = queue
.dequeue_tx(&worker_id, 1)
.await
.expect("failed to dequeue job")
.expect("didn't find a job to dequeue")
.jobs
.pop()
.unwrap();
let retry_interval = retry_policy.retry_interval(job.job.attempt as u32, None);
job.retry("a very reasonable failure reason", retry_interval, "any")
.await
.expect("failed to retry job");
}
}

View File

@@ -0,0 +1,225 @@
//! # Retry
//!
//! Module providing a `RetryPolicy` struct to configure job retrying.
use std::time;
#[derive(Clone, Debug)]
/// A retry policy to determine retry parameters for a job.
pub struct RetryPolicy {
/// Coefficient to multiply initial_interval with for every past attempt.
pub backoff_coefficient: u32,
/// The backoff interval for the first retry.
pub initial_interval: time::Duration,
/// The maximum possible backoff between retries.
pub maximum_interval: Option<time::Duration>,
/// An optional queue to send WebhookJob retries to.
pub queue: Option<String>,
}
impl RetryPolicy {
/// Initialize a `RetryPolicyBuilder`.
pub fn build(backoff_coefficient: u32, initial_interval: time::Duration) -> RetryPolicyBuilder {
RetryPolicyBuilder::new(backoff_coefficient, initial_interval)
}
/// Determine interval for retrying at a given attempt number.
/// If not `None`, this method will respect `preferred_retry_interval` as long as it falls within `candidate_interval <= preferred_retry_interval <= maximum_interval`.
pub fn retry_interval(
&self,
attempt: u32,
preferred_retry_interval: Option<time::Duration>,
) -> time::Duration {
let candidate_interval =
self.initial_interval * self.backoff_coefficient.pow(attempt.saturating_sub(1));
match (preferred_retry_interval, self.maximum_interval) {
(Some(duration), Some(max_interval)) => {
let min_interval_allowed = std::cmp::min(candidate_interval, max_interval);
if min_interval_allowed <= duration && duration <= max_interval {
duration
} else {
min_interval_allowed
}
}
(Some(duration), None) => std::cmp::max(candidate_interval, duration),
(None, Some(max_interval)) => std::cmp::min(candidate_interval, max_interval),
(None, None) => candidate_interval,
}
}
/// Determine the queue to be used for retrying.
/// Only whether a queue is configured in this RetryPolicy is used to determine which queue to use for retrying.
/// This may be extended in the future to support more decision parameters.
pub fn retry_queue<'s>(&'s self, current_queue: &'s str) -> &'s str {
if let Some(new_queue) = &self.queue {
new_queue
} else {
current_queue
}
}
}
impl Default for RetryPolicy {
fn default() -> Self {
RetryPolicyBuilder::default().provide()
}
}
/// Builder pattern struct to provide a `RetryPolicy`.
pub struct RetryPolicyBuilder {
/// Coefficient to multiply initial_interval with for every past attempt.
pub backoff_coefficient: u32,
/// The backoff interval for the first retry.
pub initial_interval: time::Duration,
/// The maximum possible backoff between retries.
pub maximum_interval: Option<time::Duration>,
/// An optional queue to send WebhookJob retries to.
pub queue: Option<String>,
}
impl Default for RetryPolicyBuilder {
fn default() -> Self {
Self {
backoff_coefficient: 2,
initial_interval: time::Duration::from_secs(1),
maximum_interval: None,
queue: None,
}
}
}
impl RetryPolicyBuilder {
pub fn new(backoff_coefficient: u32, initial_interval: time::Duration) -> Self {
Self {
backoff_coefficient,
initial_interval,
..RetryPolicyBuilder::default()
}
}
pub fn maximum_interval(mut self, interval: time::Duration) -> RetryPolicyBuilder {
self.maximum_interval = Some(interval);
self
}
pub fn queue(mut self, queue: &str) -> RetryPolicyBuilder {
self.queue = Some(queue.to_owned());
self
}
/// Provide a `RetryPolicy` according to build parameters provided thus far.
pub fn provide(&self) -> RetryPolicy {
RetryPolicy {
backoff_coefficient: self.backoff_coefficient,
initial_interval: self.initial_interval,
maximum_interval: self.maximum_interval,
queue: self.queue.clone(),
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_constant_retry_interval() {
let retry_policy = RetryPolicy::build(1, time::Duration::from_secs(2)).provide();
let first_interval = retry_policy.retry_interval(1, None);
let second_interval = retry_policy.retry_interval(2, None);
let third_interval = retry_policy.retry_interval(3, None);
assert_eq!(first_interval, time::Duration::from_secs(2));
assert_eq!(second_interval, time::Duration::from_secs(2));
assert_eq!(third_interval, time::Duration::from_secs(2));
}
#[test]
fn test_retry_interval_never_exceeds_maximum() {
let retry_policy = RetryPolicy::build(2, time::Duration::from_secs(2))
.maximum_interval(time::Duration::from_secs(4))
.provide();
let first_interval = retry_policy.retry_interval(1, None);
let second_interval = retry_policy.retry_interval(2, None);
let third_interval = retry_policy.retry_interval(3, None);
let fourth_interval = retry_policy.retry_interval(4, None);
assert_eq!(first_interval, time::Duration::from_secs(2));
assert_eq!(second_interval, time::Duration::from_secs(4));
assert_eq!(third_interval, time::Duration::from_secs(4));
assert_eq!(fourth_interval, time::Duration::from_secs(4));
}
#[test]
fn test_retry_interval_increases_with_coefficient() {
let retry_policy = RetryPolicy::build(2, time::Duration::from_secs(2)).provide();
let first_interval = retry_policy.retry_interval(1, None);
let second_interval = retry_policy.retry_interval(2, None);
let third_interval = retry_policy.retry_interval(3, None);
assert_eq!(first_interval, time::Duration::from_secs(2));
assert_eq!(second_interval, time::Duration::from_secs(4));
assert_eq!(third_interval, time::Duration::from_secs(8));
}
#[test]
fn test_retry_interval_respects_preferred() {
let retry_policy = RetryPolicy::build(1, time::Duration::from_secs(2)).provide();
let preferred = time::Duration::from_secs(999);
let first_interval = retry_policy.retry_interval(1, Some(preferred));
let second_interval = retry_policy.retry_interval(2, Some(preferred));
let third_interval = retry_policy.retry_interval(3, Some(preferred));
assert_eq!(first_interval, preferred);
assert_eq!(second_interval, preferred);
assert_eq!(third_interval, preferred);
}
#[test]
fn test_retry_interval_ignores_small_preferred() {
let retry_policy = RetryPolicy::build(1, time::Duration::from_secs(5)).provide();
let preferred = time::Duration::from_secs(2);
let first_interval = retry_policy.retry_interval(1, Some(preferred));
let second_interval = retry_policy.retry_interval(2, Some(preferred));
let third_interval = retry_policy.retry_interval(3, Some(preferred));
assert_eq!(first_interval, time::Duration::from_secs(5));
assert_eq!(second_interval, time::Duration::from_secs(5));
assert_eq!(third_interval, time::Duration::from_secs(5));
}
#[test]
fn test_retry_interval_ignores_large_preferred() {
let retry_policy = RetryPolicy::build(2, time::Duration::from_secs(2))
.maximum_interval(time::Duration::from_secs(4))
.provide();
let preferred = time::Duration::from_secs(10);
let first_interval = retry_policy.retry_interval(1, Some(preferred));
let second_interval = retry_policy.retry_interval(2, Some(preferred));
let third_interval = retry_policy.retry_interval(3, Some(preferred));
assert_eq!(first_interval, time::Duration::from_secs(2));
assert_eq!(second_interval, time::Duration::from_secs(4));
assert_eq!(third_interval, time::Duration::from_secs(4));
}
#[test]
fn test_returns_retry_queue_if_set() {
let retry_queue_name = "retry_queue".to_owned();
let retry_policy = RetryPolicy::build(0, time::Duration::from_secs(0))
.queue(&retry_queue_name)
.provide();
let current_queue = "queue".to_owned();
assert_eq!(retry_policy.retry_queue(&current_queue), retry_queue_name);
}
#[test]
fn test_returns_queue_if_retry_queue_not_set() {
let retry_policy = RetryPolicy::build(0, time::Duration::from_secs(0)).provide();
let current_queue = "queue".to_owned();
assert_eq!(retry_policy.retry_queue(&current_queue), current_queue);
}
}

View File

@@ -0,0 +1,225 @@
use std::collections;
use std::convert::From;
use std::fmt;
use std::str::FromStr;
use serde::{de::Visitor, Deserialize, Serialize};
use crate::kafka_messages::app_metrics;
use crate::pgqueue::ParseError;
/// Supported HTTP methods for webhooks.
#[derive(Debug, PartialEq, Clone, Copy)]
pub enum HttpMethod {
DELETE,
GET,
PATCH,
POST,
PUT,
}
/// Allow casting `HttpMethod` from strings.
impl FromStr for HttpMethod {
type Err = ParseError;
fn from_str(s: &str) -> Result<Self, Self::Err> {
match s.to_ascii_uppercase().as_ref() {
"DELETE" => Ok(HttpMethod::DELETE),
"GET" => Ok(HttpMethod::GET),
"PATCH" => Ok(HttpMethod::PATCH),
"POST" => Ok(HttpMethod::POST),
"PUT" => Ok(HttpMethod::PUT),
invalid => Err(ParseError::ParseHttpMethodError(invalid.to_owned())),
}
}
}
/// Implement `std::fmt::Display` to convert HttpMethod to string.
impl fmt::Display for HttpMethod {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
match self {
HttpMethod::DELETE => write!(f, "DELETE"),
HttpMethod::GET => write!(f, "GET"),
HttpMethod::PATCH => write!(f, "PATCH"),
HttpMethod::POST => write!(f, "POST"),
HttpMethod::PUT => write!(f, "PUT"),
}
}
}
struct HttpMethodVisitor;
impl<'de> Visitor<'de> for HttpMethodVisitor {
type Value = HttpMethod;
fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result {
write!(formatter, "the string representation of HttpMethod")
}
fn visit_str<E>(self, s: &str) -> Result<Self::Value, E>
where
E: serde::de::Error,
{
match HttpMethod::from_str(s) {
Ok(method) => Ok(method),
Err(_) => Err(serde::de::Error::invalid_value(
serde::de::Unexpected::Str(s),
&self,
)),
}
}
}
/// Deserialize required to read `HttpMethod` from database.
impl<'de> Deserialize<'de> for HttpMethod {
fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
where
D: serde::Deserializer<'de>,
{
deserializer.deserialize_str(HttpMethodVisitor)
}
}
/// Serialize required to write `HttpMethod` to database.
impl Serialize for HttpMethod {
fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
where
S: serde::Serializer,
{
serializer.serialize_str(&self.to_string())
}
}
/// Convenience to cast `HttpMethod` to `http::Method`.
/// Not all `http::Method` variants are valid `HttpMethod` variants, hence why we
/// can't just use the former or implement `From<HttpMethod>`.
impl From<HttpMethod> for http::Method {
fn from(val: HttpMethod) -> Self {
match val {
HttpMethod::DELETE => http::Method::DELETE,
HttpMethod::GET => http::Method::GET,
HttpMethod::PATCH => http::Method::PATCH,
HttpMethod::POST => http::Method::POST,
HttpMethod::PUT => http::Method::PUT,
}
}
}
impl From<&HttpMethod> for http::Method {
fn from(val: &HttpMethod) -> Self {
match val {
HttpMethod::DELETE => http::Method::DELETE,
HttpMethod::GET => http::Method::GET,
HttpMethod::PATCH => http::Method::PATCH,
HttpMethod::POST => http::Method::POST,
HttpMethod::PUT => http::Method::PUT,
}
}
}
/// `JobParameters` required for the `WebhookWorker` to execute a webhook.
/// These parameters should match the exported Webhook interface that PostHog plugins.
/// implement. See: https://github.com/PostHog/plugin-scaffold/blob/main/src/types.ts#L15.
#[derive(Deserialize, Serialize, Debug, PartialEq, Clone)]
pub struct WebhookJobParameters {
pub body: String,
pub headers: collections::HashMap<String, String>,
pub method: HttpMethod,
pub url: String,
}
/// `JobMetadata` required for the `WebhookWorker` to execute a webhook.
/// These should be set if the Webhook is associated with a plugin `composeWebhook` invocation.
#[derive(Deserialize, Serialize, Debug, PartialEq, Clone)]
pub struct WebhookJobMetadata {
pub team_id: u32,
pub plugin_id: i32,
pub plugin_config_id: i32,
}
/// An error originating during a Webhook Job invocation.
/// This is to be serialized to be stored as an error whenever retrying or failing a webhook job.
#[derive(Deserialize, Serialize, Debug)]
pub struct WebhookJobError {
pub r#type: app_metrics::ErrorType,
pub details: app_metrics::ErrorDetails,
}
/// Webhook jobs boil down to an HTTP request, so it's useful to have a way to convert from &reqwest::Error.
/// For the convertion we check all possible error types with the associated is_* methods provided by reqwest.
/// Some precision may be lost as our app_metrics::ErrorType does not support the same number of variants.
impl From<&reqwest::Error> for WebhookJobError {
fn from(error: &reqwest::Error) -> Self {
if error.is_timeout() {
WebhookJobError::new_timeout(&error.to_string())
} else if error.is_status() {
WebhookJobError::new_http_status(
error.status().expect("status code is defined").into(),
&error.to_string(),
)
} else {
// Catch all other errors as `app_metrics::ErrorType::Connection` errors.
// Not all of `reqwest::Error` may strictly be connection errors, so our supported error types may need an extension
// depending on how strict error reporting has to be.
WebhookJobError::new_connection(&error.to_string())
}
}
}
impl WebhookJobError {
pub fn new_timeout(message: &str) -> Self {
let error_details = app_metrics::Error {
name: "Timeout Error".to_owned(),
message: Some(message.to_owned()),
stack: None,
};
Self {
r#type: app_metrics::ErrorType::TimeoutError,
details: app_metrics::ErrorDetails {
error: error_details,
},
}
}
pub fn new_connection(message: &str) -> Self {
let error_details = app_metrics::Error {
name: "Connection Error".to_owned(),
message: Some(message.to_owned()),
stack: None,
};
Self {
r#type: app_metrics::ErrorType::ConnectionError,
details: app_metrics::ErrorDetails {
error: error_details,
},
}
}
pub fn new_http_status(status_code: u16, message: &str) -> Self {
let error_details = app_metrics::Error {
name: "Bad Http Status".to_owned(),
message: Some(message.to_owned()),
stack: None,
};
Self {
r#type: app_metrics::ErrorType::BadHttpStatus(status_code),
details: app_metrics::ErrorDetails {
error: error_details,
},
}
}
pub fn new_parse(message: &str) -> Self {
let error_details = app_metrics::Error {
name: "Parse Error".to_owned(),
message: Some(message.to_owned()),
stack: None,
};
Self {
r#type: app_metrics::ErrorType::ParseError,
details: app_metrics::ErrorDetails {
error: error_details,
},
}
}
}

View File

@@ -0,0 +1,25 @@
[package]
name = "hook-janitor"
version = "0.1.0"
edition = "2021"
[lints]
workspace = true
[dependencies]
async-trait = { workspace = true }
axum = { workspace = true }
envconfig = { workspace = true }
eyre = { workspace = true }
futures = { workspace = true }
health = { path = "../common/health" }
hook-common = { path = "../hook-common" }
metrics = { workspace = true }
rdkafka = { workspace = true }
serde_json = { workspace = true }
sqlx = { workspace = true }
thiserror = { workspace = true }
time = { workspace = true }
tokio = { workspace = true }
tracing = { workspace = true }
tracing-subscriber = { workspace = true }

View File

@@ -0,0 +1,34 @@
use async_trait::async_trait;
use std::result::Result;
use std::str::FromStr;
use thiserror::Error;
#[derive(Error, Debug)]
pub enum CleanerError {
#[error("invalid cleaner mode")]
InvalidCleanerMode,
}
// Mode names, used by config/environment parsing to verify the mode is supported.
#[derive(Debug)]
pub enum CleanerModeName {
Webhooks,
}
impl FromStr for CleanerModeName {
type Err = CleanerError;
fn from_str(s: &str) -> Result<Self, CleanerError> {
match s {
"webhooks" => Ok(CleanerModeName::Webhooks),
_ => Err(CleanerError::InvalidCleanerMode),
}
}
}
// Right now, all this trait does is allow us to call `cleanup` in a loop in `main.rs`. There may
// be other benefits as we build this out, or we could remove it if it doesn't end up being useful.
#[async_trait]
pub trait Cleaner {
async fn cleanup(&self);
}

View File

@@ -0,0 +1,57 @@
use envconfig::Envconfig;
#[derive(Envconfig)]
pub struct Config {
#[envconfig(from = "BIND_HOST", default = "0.0.0.0")]
pub host: String,
#[envconfig(from = "BIND_PORT", default = "3302")]
pub port: u16,
#[envconfig(default = "postgres://posthog:posthog@localhost:15432/test_database")]
pub database_url: String,
#[envconfig(default = "30")]
pub cleanup_interval_secs: u64,
// The cleanup task needs to have special knowledge of the queue it's cleaning up. This is so it
// can do things like flush the proper app_metrics or plugin_log_entries, and so it knows what
// to expect in the job's payload JSONB column.
#[envconfig(default = "webhooks")]
pub mode: String,
#[envconfig(nested = true)]
pub kafka: KafkaConfig,
}
#[derive(Envconfig, Clone)]
pub struct KafkaConfig {
#[envconfig(default = "20")]
pub kafka_producer_linger_ms: u32, // Maximum time between producer batches during low traffic
#[envconfig(default = "400")]
pub kafka_producer_queue_mib: u32, // Size of the in-memory producer queue in mebibytes
#[envconfig(default = "20000")]
pub kafka_message_timeout_ms: u32, // Time before we stop retrying producing a message: 20 seconds
#[envconfig(default = "none")]
pub kafka_compression_codec: String, // none, gzip, snappy, lz4, zstd
#[envconfig(default = "false")]
pub kafka_tls: bool,
#[envconfig(default = "clickhouse_app_metrics")]
pub app_metrics_topic: String,
#[envconfig(default = "plugin_log_entries")]
pub plugin_log_entries_topic: String,
pub kafka_hosts: String,
}
impl Config {
pub fn bind(&self) -> String {
format!("{}:{}", self.host, self.port)
}
}

View File

@@ -0,0 +1,166 @@
INSERT INTO
job_queue (
errors,
metadata,
attempted_at,
last_attempt_finished_at,
parameters,
queue,
status,
target
)
VALUES
-- team:1, plugin_config:2, completed in hour 20
(
NULL,
'{"team_id": 1, "plugin_id": 99, "plugin_config_id": 2}',
'2023-12-19 20:01:18.799371+00',
'2023-12-19 20:01:18.799371+00',
'{}',
'webhooks',
'completed',
'https://myhost/endpoint'
),
-- team:1, plugin_config:2, completed in hour 20 (purposeful duplicate)
(
NULL,
'{"team_id": 1, "plugin_id": 99, "plugin_config_id": 2}',
'2023-12-19 20:01:18.799371+00',
'2023-12-19 20:01:18.799371+00',
'{}',
'webhooks',
'completed',
'https://myhost/endpoint'
),
-- team:1, plugin_config:2, completed in hour 21 (different hour)
(
NULL,
'{"team_id": 1, "plugin_id": 99, "plugin_config_id": 2}',
'2023-12-19 21:01:18.799371+00',
'2023-12-19 21:01:18.799371+00',
'{}',
'webhooks',
'completed',
'https://myhost/endpoint'
),
-- team:1, plugin_config:3, completed in hour 20 (different plugin_config)
(
NULL,
'{"team_id": 1, "plugin_id": 99, "plugin_config_id": 3}',
'2023-12-19 20:01:18.80335+00',
'2023-12-19 20:01:18.80335+00',
'{}',
'webhooks',
'completed',
'https://myhost/endpoint'
),
-- team:1, plugin_config:2, completed but in a different queue
(
NULL,
'{"team_id": 1, "plugin_id": 99, "plugin_config_id": 2}',
'2023-12-19 20:01:18.799371+00',
'2023-12-19 20:01:18.799371+00',
'{}',
'not-webhooks',
'completed',
'https://myhost/endpoint'
),
-- team:2, plugin_config:4, completed in hour 20 (different team)
(
NULL,
'{"team_id": 2, "plugin_id": 99, "plugin_config_id": 4}',
'2023-12-19 20:01:18.799371+00',
'2023-12-19 20:01:18.799371+00',
'{}',
'webhooks',
'completed',
'https://myhost/endpoint'
),
-- team:1, plugin_config:2, failed in hour 20
(
ARRAY ['{"type":"TimeoutError","details":{"error":{"name":"Timeout"}}}'::jsonb],
'{"team_id": 1, "plugin_id": 99, "plugin_config_id": 2}',
'2023-12-19 20:01:18.799371+00',
'2023-12-19 20:01:18.799371+00',
'{}',
'webhooks',
'failed',
'https://myhost/endpoint'
),
-- team:1, plugin_config:2, failed in hour 20 (purposeful duplicate)
(
ARRAY ['{"type":"TimeoutError","details":{"error":{"name":"Timeout"}}}'::jsonb],
'{"team_id": 1, "plugin_id": 99, "plugin_config_id": 2}',
'2023-12-19 20:01:18.799371+00',
'2023-12-19 20:01:18.799371+00',
'{}',
'webhooks',
'failed',
'https://myhost/endpoint'
),
-- team:1, plugin_config:2, failed in hour 20 (different error)
(
ARRAY ['{"type":"ConnectionError","details":{"error":{"name":"Connection Error"}}}'::jsonb],
'{"team_id": 1, "plugin_id": 99, "plugin_config_id": 2}',
'2023-12-19 20:01:18.799371+00',
'2023-12-19 20:01:18.799371+00',
'{}',
'webhooks',
'failed',
'https://myhost/endpoint'
),
-- team:1, plugin_config:2, failed in hour 21 (different hour)
(
ARRAY ['{"type":"TimeoutError","details":{"error":{"name":"Timeout"}}}'::jsonb],
'{"team_id": 1, "plugin_id": 99, "plugin_config_id": 2}',
'2023-12-19 21:01:18.799371+00',
'2023-12-19 21:01:18.799371+00',
'{}',
'webhooks',
'failed',
'https://myhost/endpoint'
),
-- team:1, plugin_config:3, failed in hour 20 (different plugin_config)
(
ARRAY ['{"type":"TimeoutError","details":{"error":{"name":"Timeout"}}}'::jsonb],
'{"team_id": 1, "plugin_id": 99, "plugin_config_id": 3}',
'2023-12-19 20:01:18.799371+00',
'2023-12-19 20:01:18.799371+00',
'{}',
'webhooks',
'failed',
'https://myhost/endpoint'
),
-- team:1, plugin_config:2, failed but in a different queue
(
ARRAY ['{"type":"TimeoutError","details":{"error":{"name":"Timeout"}}}'::jsonb],
'{"team_id": 1, "plugin_id": 99, "plugin_config_id": 2}',
'2023-12-19 20:01:18.799371+00',
'2023-12-19 20:01:18.799371+00',
'{}',
'not-webhooks',
'failed',
'https://myhost/endpoint'
),
-- team:2, plugin_config:4, failed in hour 20 (purposeful duplicate)
(
ARRAY ['{"type":"TimeoutError","details":{"error":{"name":"Timeout"}}}'::jsonb],
'{"team_id": 2, "plugin_id": 99, "plugin_config_id": 4}',
'2023-12-19 20:01:18.799371+00',
'2023-12-19 20:01:18.799371+00',
'{}',
'webhooks',
'failed',
'https://myhost/endpoint'
),
-- team:1, plugin_config:2, available
(
NULL,
'{"team_id": 1, "plugin_id": 99, "plugin_config_id": 2}',
'2023-12-19 20:01:18.799371+00',
'2023-12-19 20:01:18.799371+00',
'{"body": "hello world", "headers": {}, "method": "POST", "url": "https://myhost/endpoint"}',
'webhooks',
'available',
'https://myhost/endpoint'
);

View File

@@ -0,0 +1,14 @@
use axum::{routing::get, Router};
use health::HealthRegistry;
use std::future::ready;
pub fn app(liveness: HealthRegistry) -> Router {
Router::new()
.route("/", get(index))
.route("/_readiness", get(index))
.route("/_liveness", get(move || ready(liveness.get_status())))
}
pub async fn index() -> &'static str {
"rusty-hook janitor"
}

View File

@@ -0,0 +1,3 @@
mod app;
pub use app::app;

View File

@@ -0,0 +1,57 @@
use crate::config::KafkaConfig;
use health::HealthHandle;
use rdkafka::error::KafkaError;
use rdkafka::producer::FutureProducer;
use rdkafka::ClientConfig;
use tracing::debug;
pub struct KafkaContext {
liveness: HealthHandle,
}
impl rdkafka::ClientContext for KafkaContext {
fn stats(&self, _: rdkafka::Statistics) {
// Signal liveness, as the main rdkafka loop is running and calling us
self.liveness.report_healthy_blocking();
// TODO: Take stats recording pieces that we want from `capture-rs`.
}
}
pub async fn create_kafka_producer(
config: &KafkaConfig,
liveness: HealthHandle,
) -> Result<FutureProducer<KafkaContext>, KafkaError> {
let mut client_config = ClientConfig::new();
client_config
.set("bootstrap.servers", &config.kafka_hosts)
.set("statistics.interval.ms", "10000")
.set("linger.ms", config.kafka_producer_linger_ms.to_string())
.set(
"message.timeout.ms",
config.kafka_message_timeout_ms.to_string(),
)
.set(
"compression.codec",
config.kafka_compression_codec.to_owned(),
)
.set(
"queue.buffering.max.kbytes",
(config.kafka_producer_queue_mib * 1024).to_string(),
);
if config.kafka_tls {
client_config
.set("security.protocol", "ssl")
.set("enable.ssl.certificate.verification", "false");
};
debug!("rdkafka configuration: {:?}", client_config);
let api: FutureProducer<KafkaContext> =
client_config.create_with_context(KafkaContext { liveness })?;
// TODO: ping the kafka brokers to confirm configuration is OK (copy capture)
Ok(api)
}

View File

@@ -0,0 +1,97 @@
use axum::Router;
use cleanup::{Cleaner, CleanerModeName};
use config::Config;
use envconfig::Envconfig;
use eyre::Result;
use futures::future::{select, Either};
use health::{HealthHandle, HealthRegistry};
use kafka_producer::create_kafka_producer;
use std::{str::FromStr, time::Duration};
use tokio::sync::Semaphore;
use webhooks::WebhookCleaner;
use hook_common::metrics::setup_metrics_routes;
mod cleanup;
mod config;
mod handlers;
mod kafka_producer;
mod webhooks;
async fn listen(app: Router, bind: String) -> Result<()> {
let listener = tokio::net::TcpListener::bind(bind).await?;
axum::serve(listener, app).await?;
Ok(())
}
async fn cleanup_loop(cleaner: Box<dyn Cleaner>, interval_secs: u64, liveness: HealthHandle) {
let semaphore = Semaphore::new(1);
let mut interval = tokio::time::interval(Duration::from_secs(interval_secs));
loop {
let _permit = semaphore.acquire().await;
interval.tick().await;
liveness.report_healthy().await;
cleaner.cleanup().await;
drop(_permit);
}
}
#[tokio::main]
async fn main() {
tracing_subscriber::fmt::init();
let config = Config::init_from_env().expect("failed to load configuration from env");
let mode_name = CleanerModeName::from_str(&config.mode)
.unwrap_or_else(|_| panic!("invalid cleaner mode: {}", config.mode));
let liveness = HealthRegistry::new("liveness");
let cleaner = match mode_name {
CleanerModeName::Webhooks => {
let kafka_liveness = liveness
.register("rdkafka".to_string(), time::Duration::seconds(30))
.await;
let kafka_producer = create_kafka_producer(&config.kafka, kafka_liveness)
.await
.expect("failed to create kafka producer");
Box::new(
WebhookCleaner::new(
&config.database_url,
kafka_producer,
config.kafka.app_metrics_topic.to_owned(),
)
.expect("unable to create webhook cleaner"),
)
}
};
let cleanup_liveness = liveness
.register(
"cleanup_loop".to_string(),
time::Duration::seconds(config.cleanup_interval_secs as i64 * 2),
)
.await;
let cleanup_loop = Box::pin(cleanup_loop(
cleaner,
config.cleanup_interval_secs,
cleanup_liveness,
));
let app = setup_metrics_routes(handlers::app(liveness));
let http_server = Box::pin(listen(app, config.bind()));
match select(http_server, cleanup_loop).await {
Either::Left((listen_result, _)) => match listen_result {
Ok(_) => {}
Err(e) => tracing::error!("failed to start hook-janitor http server, {}", e),
},
Either::Right((_, _)) => {
tracing::error!("hook-janitor cleanup task exited")
}
};
}

View File

@@ -0,0 +1,899 @@
use std::str::FromStr;
use std::time::{Duration, Instant};
use async_trait::async_trait;
use chrono::{DateTime, Utc};
use futures::future::join_all;
use hook_common::webhook::WebhookJobError;
use rdkafka::error::KafkaError;
use rdkafka::producer::{FutureProducer, FutureRecord};
use serde_json::error::Error as SerdeError;
use sqlx::postgres::{PgConnectOptions, PgPool, PgPoolOptions, Postgres};
use sqlx::types::{chrono, Uuid};
use sqlx::{Row, Transaction};
use thiserror::Error;
use tracing::{debug, error, info};
use crate::cleanup::Cleaner;
use crate::kafka_producer::KafkaContext;
use hook_common::kafka_messages::app_metrics::{AppMetric, AppMetricCategory};
use hook_common::metrics::get_current_timestamp_seconds;
#[derive(Error, Debug)]
pub enum WebhookCleanerError {
#[error("failed to create postgres pool: {error}")]
PoolCreationError { error: sqlx::Error },
#[error("failed to acquire conn: {error}")]
AcquireConnError { error: sqlx::Error },
#[error("failed to acquire conn and start txn: {error}")]
StartTxnError { error: sqlx::Error },
#[error("failed to get queue depth: {error}")]
GetQueueDepthError { error: sqlx::Error },
#[error("failed to get row count: {error}")]
GetRowCountError { error: sqlx::Error },
#[error("failed to get completed rows: {error}")]
GetCompletedRowsError { error: sqlx::Error },
#[error("failed to get failed rows: {error}")]
GetFailedRowsError { error: sqlx::Error },
#[error("failed to serialize rows: {error}")]
SerializeRowsError { error: SerdeError },
#[error("failed to produce to kafka: {error}")]
KafkaProduceError { error: KafkaError },
#[error("failed to produce to kafka (timeout)")]
KafkaProduceCanceled,
#[error("failed to delete rows: {error}")]
DeleteRowsError { error: sqlx::Error },
#[error("attempted to delete a different number of rows than expected")]
DeleteConsistencyError,
#[error("failed to rollback txn: {error}")]
RollbackTxnError { error: sqlx::Error },
#[error("failed to commit txn: {error}")]
CommitTxnError { error: sqlx::Error },
}
type Result<T, E = WebhookCleanerError> = std::result::Result<T, E>;
pub struct WebhookCleaner {
pg_pool: PgPool,
kafka_producer: FutureProducer<KafkaContext>,
app_metrics_topic: String,
}
#[derive(sqlx::FromRow, Debug)]
struct CompletedRow {
// App Metrics truncates/aggregates rows on the hour, so we take advantage of that to GROUP BY
// and aggregate to select fewer rows.
hour: DateTime<Utc>,
// A note about the `try_from`s: Postgres returns all of those types as `bigint` (i64), but
// we know their true sizes, and so we can convert them to the correct types here. If this
// ever fails then something has gone wrong.
#[sqlx(try_from = "i64")]
team_id: u32,
#[sqlx(try_from = "i64")]
plugin_config_id: i32,
#[sqlx(try_from = "i64")]
successes: u32,
}
impl From<CompletedRow> for AppMetric {
fn from(row: CompletedRow) -> Self {
AppMetric {
timestamp: row.hour,
team_id: row.team_id,
plugin_config_id: row.plugin_config_id,
job_id: None,
category: AppMetricCategory::Webhook,
successes: row.successes,
successes_on_retry: 0,
failures: 0,
error_uuid: None,
error_type: None,
error_details: None,
}
}
}
#[derive(sqlx::FromRow, Debug)]
struct FailedRow {
// App Metrics truncates/aggregates rows on the hour, so we take advantage of that to GROUP BY
// and aggregate to select fewer rows.
hour: DateTime<Utc>,
// A note about the `try_from`s: Postgres returns all of those types as `bigint` (i64), but
// we know their true sizes, and so we can convert them to the correct types here. If this
// ever fails then something has gone wrong.
#[sqlx(try_from = "i64")]
team_id: u32,
#[sqlx(try_from = "i64")]
plugin_config_id: i32,
#[sqlx(json)]
last_error: WebhookJobError,
#[sqlx(try_from = "i64")]
failures: u32,
}
#[derive(sqlx::FromRow, Debug)]
struct QueueDepth {
oldest_scheduled_at_untried: DateTime<Utc>,
count_untried: i64,
oldest_scheduled_at_retries: DateTime<Utc>,
count_retries: i64,
}
impl From<FailedRow> for AppMetric {
fn from(row: FailedRow) -> Self {
AppMetric {
timestamp: row.hour,
team_id: row.team_id,
plugin_config_id: row.plugin_config_id,
job_id: None,
category: AppMetricCategory::Webhook,
successes: 0,
successes_on_retry: 0,
failures: row.failures,
error_uuid: Some(Uuid::now_v7()),
error_type: Some(row.last_error.r#type),
error_details: Some(row.last_error.details),
}
}
}
// A simple wrapper type that ensures we don't use any old Transaction object when we need one
// that has set the isolation level to serializable.
struct SerializableTxn<'a>(Transaction<'a, Postgres>);
struct CleanupStats {
rows_processed: u64,
completed_row_count: u64,
completed_agg_row_count: u64,
failed_row_count: u64,
failed_agg_row_count: u64,
}
impl WebhookCleaner {
pub fn new(
database_url: &str,
kafka_producer: FutureProducer<KafkaContext>,
app_metrics_topic: String,
) -> Result<Self> {
let options = PgConnectOptions::from_str(database_url)
.map_err(|error| WebhookCleanerError::PoolCreationError { error })?
.application_name("hook-janitor");
let pg_pool = PgPoolOptions::new()
.acquire_timeout(Duration::from_secs(10))
.connect_lazy_with(options);
Ok(Self {
pg_pool,
kafka_producer,
app_metrics_topic,
})
}
#[allow(dead_code)] // This is used in tests.
pub fn new_from_pool(
pg_pool: PgPool,
kafka_producer: FutureProducer<KafkaContext>,
app_metrics_topic: String,
) -> Result<Self> {
Ok(Self {
pg_pool,
kafka_producer,
app_metrics_topic,
})
}
async fn get_queue_depth(&self) -> Result<QueueDepth> {
let mut conn = self
.pg_pool
.acquire()
.await
.map_err(|e| WebhookCleanerError::AcquireConnError { error: e })?;
let base_query = r#"
SELECT
COALESCE(MIN(CASE WHEN attempt = 0 THEN scheduled_at END), now()) AS oldest_scheduled_at_untried,
COALESCE(SUM(CASE WHEN attempt = 0 THEN 1 ELSE 0 END), 0) AS count_untried,
COALESCE(MIN(CASE WHEN attempt > 0 THEN scheduled_at END), now()) AS oldest_scheduled_at_retries,
COALESCE(SUM(CASE WHEN attempt > 0 THEN 1 ELSE 0 END), 0) AS count_retries
FROM job_queue
WHERE status = 'available';
"#;
let row = sqlx::query_as::<_, QueueDepth>(base_query)
.fetch_one(&mut *conn)
.await
.map_err(|e| WebhookCleanerError::GetQueueDepthError { error: e })?;
Ok(row)
}
async fn start_serializable_txn(&self) -> Result<SerializableTxn> {
let mut tx = self
.pg_pool
.begin()
.await
.map_err(|e| WebhookCleanerError::StartTxnError { error: e })?;
// We use serializable isolation so that we observe a snapshot of the DB at the time we
// start the cleanup process. This prevents us from accidentally deleting rows that are
// added (or become 'completed' or 'failed') after we start the cleanup process.
//
// If we find that this has a significant performance impact, we could instead move
// rows to a temporary table for processing and then deletion.
sqlx::query("SET TRANSACTION ISOLATION LEVEL SERIALIZABLE")
.execute(&mut *tx)
.await
.map_err(|e| WebhookCleanerError::StartTxnError { error: e })?;
Ok(SerializableTxn(tx))
}
async fn get_row_count_for_status(
&self,
tx: &mut SerializableTxn<'_>,
status: &str,
) -> Result<u64> {
let base_query = r#"
SELECT count(*) FROM job_queue
WHERE status = $1::job_status;
"#;
let count: i64 = sqlx::query(base_query)
.bind(status)
.fetch_one(&mut *tx.0)
.await
.map_err(|e| WebhookCleanerError::GetRowCountError { error: e })?
.get(0);
Ok(count as u64)
}
async fn get_completed_agg_rows(
&self,
tx: &mut SerializableTxn<'_>,
) -> Result<Vec<CompletedRow>> {
let base_query = r#"
SELECT DATE_TRUNC('hour', last_attempt_finished_at) AS hour,
(metadata->>'team_id')::bigint AS team_id,
(metadata->>'plugin_config_id')::bigint AS plugin_config_id,
count(*) as successes
FROM job_queue
WHERE status = 'completed'
GROUP BY hour, team_id, plugin_config_id
ORDER BY hour, team_id, plugin_config_id;
"#;
let rows = sqlx::query_as::<_, CompletedRow>(base_query)
.fetch_all(&mut *tx.0)
.await
.map_err(|e| WebhookCleanerError::GetCompletedRowsError { error: e })?;
Ok(rows)
}
async fn get_failed_agg_rows(&self, tx: &mut SerializableTxn<'_>) -> Result<Vec<FailedRow>> {
let base_query = r#"
SELECT DATE_TRUNC('hour', last_attempt_finished_at) AS hour,
(metadata->>'team_id')::bigint AS team_id,
(metadata->>'plugin_config_id')::bigint AS plugin_config_id,
errors[array_upper(errors, 1)] AS last_error,
count(*) as failures
FROM job_queue
WHERE status = 'failed'
GROUP BY hour, team_id, plugin_config_id, last_error
ORDER BY hour, team_id, plugin_config_id, last_error;
"#;
let rows = sqlx::query_as::<_, FailedRow>(base_query)
.fetch_all(&mut *tx.0)
.await
.map_err(|e| WebhookCleanerError::GetFailedRowsError { error: e })?;
Ok(rows)
}
async fn send_metrics_to_kafka(&self, metrics: Vec<AppMetric>) -> Result<()> {
if metrics.is_empty() {
return Ok(());
}
let payloads: Vec<String> = metrics
.into_iter()
.map(|metric| serde_json::to_string(&metric))
.collect::<Result<Vec<String>, SerdeError>>()
.map_err(|e| WebhookCleanerError::SerializeRowsError { error: e })?;
let mut delivery_futures = Vec::new();
for payload in payloads {
match self.kafka_producer.send_result(FutureRecord {
topic: self.app_metrics_topic.as_str(),
payload: Some(&payload),
partition: None,
key: None::<&str>,
timestamp: None,
headers: None,
}) {
Ok(future) => delivery_futures.push(future),
Err((error, _)) => return Err(WebhookCleanerError::KafkaProduceError { error }),
}
}
for result in join_all(delivery_futures).await {
match result {
Ok(Ok(_)) => {}
Ok(Err((error, _))) => {
return Err(WebhookCleanerError::KafkaProduceError { error })
}
Err(_) => {
// Cancelled due to timeout while retrying
return Err(WebhookCleanerError::KafkaProduceCanceled);
}
}
}
Ok(())
}
async fn delete_observed_rows(&self, tx: &mut SerializableTxn<'_>) -> Result<u64> {
// This DELETE is only safe because we are in serializable isolation mode, see the note
// in `start_serializable_txn`.
let base_query = r#"
DELETE FROM job_queue
WHERE status IN ('failed', 'completed')
"#;
let result = sqlx::query(base_query)
.execute(&mut *tx.0)
.await
.map_err(|e| WebhookCleanerError::DeleteRowsError { error: e })?;
Ok(result.rows_affected())
}
async fn rollback_txn(&self, tx: SerializableTxn<'_>) -> Result<()> {
tx.0.rollback()
.await
.map_err(|e| WebhookCleanerError::RollbackTxnError { error: e })?;
Ok(())
}
async fn commit_txn(&self, tx: SerializableTxn<'_>) -> Result<()> {
tx.0.commit()
.await
.map_err(|e| WebhookCleanerError::CommitTxnError { error: e })?;
Ok(())
}
async fn cleanup_impl(&self) -> Result<CleanupStats> {
debug!("WebhookCleaner starting cleanup");
// Note that we select all completed and failed rows without any pagination at the moment.
// We aggregrate as much as possible with GROUP BY, truncating the timestamp down to the
// hour just like App Metrics does. A completed row is 24 bytes (and aggregates an entire
// hour per `plugin_config_id`), and a failed row is 104 bytes + the error message length
// (and aggregates an entire hour per `plugin_config_id` per `error`), so we can fit a lot
// of rows in memory. It seems unlikely we'll need to paginate, but that can be added in the
// future if necessary.
let untried_status = [("status", "untried")];
let retries_status = [("status", "retries")];
let queue_depth = self.get_queue_depth().await?;
metrics::gauge!("queue_depth_oldest_scheduled", &untried_status)
.set(queue_depth.oldest_scheduled_at_untried.timestamp() as f64);
metrics::gauge!("queue_depth", &untried_status).set(queue_depth.count_untried as f64);
metrics::gauge!("queue_depth_oldest_scheduled", &retries_status)
.set(queue_depth.oldest_scheduled_at_retries.timestamp() as f64);
metrics::gauge!("queue_depth", &retries_status).set(queue_depth.count_retries as f64);
let mut tx = self.start_serializable_txn().await?;
let (completed_row_count, completed_agg_row_count) = {
let completed_row_count = self.get_row_count_for_status(&mut tx, "completed").await?;
let completed_agg_rows = self.get_completed_agg_rows(&mut tx).await?;
let agg_row_count = completed_agg_rows.len() as u64;
let completed_app_metrics: Vec<AppMetric> =
completed_agg_rows.into_iter().map(Into::into).collect();
self.send_metrics_to_kafka(completed_app_metrics).await?;
(completed_row_count, agg_row_count)
};
let (failed_row_count, failed_agg_row_count) = {
let failed_row_count = self.get_row_count_for_status(&mut tx, "failed").await?;
let failed_agg_rows = self.get_failed_agg_rows(&mut tx).await?;
let agg_row_count = failed_agg_rows.len() as u64;
let failed_app_metrics: Vec<AppMetric> =
failed_agg_rows.into_iter().map(Into::into).collect();
self.send_metrics_to_kafka(failed_app_metrics).await?;
(failed_row_count, agg_row_count)
};
let mut rows_deleted = 0;
if completed_agg_row_count + failed_agg_row_count != 0 {
rows_deleted = self.delete_observed_rows(&mut tx).await?;
if rows_deleted != completed_row_count + failed_row_count {
// This should never happen, but if it does, we want to know about it (and abort the
// txn).
error!(
attempted_rows_deleted = rows_deleted,
completed_row_count = completed_row_count,
failed_row_count = failed_row_count,
"WebhookCleaner::cleanup attempted to delete a different number of rows than expected"
);
self.rollback_txn(tx).await?;
return Err(WebhookCleanerError::DeleteConsistencyError);
}
self.commit_txn(tx).await?;
}
Ok(CleanupStats {
rows_processed: rows_deleted,
completed_row_count,
completed_agg_row_count,
failed_row_count,
failed_agg_row_count,
})
}
}
#[async_trait]
impl Cleaner for WebhookCleaner {
async fn cleanup(&self) {
let start_time = Instant::now();
metrics::counter!("webhook_cleanup_attempts",).increment(1);
match self.cleanup_impl().await {
Ok(stats) => {
metrics::counter!("webhook_cleanup_success",).increment(1);
metrics::gauge!("webhook_cleanup_last_success_timestamp",)
.set(get_current_timestamp_seconds());
if stats.rows_processed > 0 {
let elapsed_time = start_time.elapsed().as_secs_f64();
metrics::histogram!("webhook_cleanup_duration").record(elapsed_time);
metrics::counter!("webhook_cleanup_rows_processed",)
.increment(stats.rows_processed);
metrics::counter!("webhook_cleanup_completed_row_count",)
.increment(stats.completed_row_count);
metrics::counter!("webhook_cleanup_completed_agg_row_count",)
.increment(stats.completed_agg_row_count);
metrics::counter!("webhook_cleanup_failed_row_count",)
.increment(stats.failed_row_count);
metrics::counter!("webhook_cleanup_failed_agg_row_count",)
.increment(stats.failed_agg_row_count);
info!(
rows_processed = stats.rows_processed,
completed_row_count = stats.completed_row_count,
completed_agg_row_count = stats.completed_agg_row_count,
failed_row_count = stats.failed_row_count,
failed_agg_row_count = stats.failed_agg_row_count,
"WebhookCleaner::cleanup finished"
);
} else {
debug!("WebhookCleaner finished cleanup, there were no rows to process");
}
}
Err(error) => {
metrics::counter!("webhook_cleanup_failures",).increment(1);
error!(error = ?error, "WebhookCleaner::cleanup failed");
}
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::config;
use crate::kafka_producer::{create_kafka_producer, KafkaContext};
use health::HealthRegistry;
use hook_common::kafka_messages::app_metrics::{
Error as WebhookError, ErrorDetails, ErrorType,
};
use hook_common::pgqueue::PgQueueJob;
use hook_common::pgqueue::{NewJob, PgQueue, PgTransactionBatch};
use hook_common::webhook::{HttpMethod, WebhookJobMetadata, WebhookJobParameters};
use rdkafka::consumer::{Consumer, StreamConsumer};
use rdkafka::mocking::MockCluster;
use rdkafka::producer::{DefaultProducerContext, FutureProducer};
use rdkafka::types::{RDKafkaApiKey, RDKafkaRespErr};
use rdkafka::{ClientConfig, Message};
use sqlx::{PgPool, Row};
use std::collections::HashMap;
use std::str::FromStr;
const APP_METRICS_TOPIC: &str = "app_metrics";
async fn create_mock_kafka() -> (
MockCluster<'static, DefaultProducerContext>,
FutureProducer<KafkaContext>,
) {
let registry = HealthRegistry::new("liveness");
let handle = registry
.register("one".to_string(), time::Duration::seconds(30))
.await;
let cluster = MockCluster::new(1).expect("failed to create mock brokers");
let config = config::KafkaConfig {
kafka_producer_linger_ms: 0,
kafka_producer_queue_mib: 50,
kafka_message_timeout_ms: 5000,
kafka_compression_codec: "none".to_string(),
kafka_hosts: cluster.bootstrap_servers(),
app_metrics_topic: APP_METRICS_TOPIC.to_string(),
plugin_log_entries_topic: "plugin_log_entries".to_string(),
kafka_tls: false,
};
(
cluster,
create_kafka_producer(&config, handle)
.await
.expect("failed to create mocked kafka producer"),
)
}
fn check_app_metric_vector_equality(v1: &[AppMetric], v2: &[AppMetric]) {
// Ignores `error_uuid`s.
assert_eq!(v1.len(), v2.len());
for (item1, item2) in v1.iter().zip(v2) {
let mut item1 = item1.clone();
item1.error_uuid = None;
let mut item2 = item2.clone();
item2.error_uuid = None;
assert_eq!(item1, item2);
}
}
#[sqlx::test(migrations = "../migrations", fixtures("webhook_cleanup"))]
async fn test_cleanup_impl(db: PgPool) {
let (mock_cluster, mock_producer) = create_mock_kafka().await;
mock_cluster
.create_topic(APP_METRICS_TOPIC, 1, 1)
.expect("failed to create mock app_metrics topic");
let consumer: StreamConsumer = ClientConfig::new()
.set("bootstrap.servers", mock_cluster.bootstrap_servers())
.set("group.id", "mock")
.set("auto.offset.reset", "earliest")
.create()
.expect("failed to create mock consumer");
consumer.subscribe(&[APP_METRICS_TOPIC]).unwrap();
let webhook_cleaner =
WebhookCleaner::new_from_pool(db, mock_producer, APP_METRICS_TOPIC.to_owned())
.expect("unable to create webhook cleaner");
let cleanup_stats = webhook_cleaner
.cleanup_impl()
.await
.expect("webbook cleanup_impl failed");
// Rows that are not 'completed' or 'failed' should not be processed.
assert_eq!(cleanup_stats.rows_processed, 13);
let mut received_app_metrics = Vec::new();
for _ in 0..(cleanup_stats.completed_agg_row_count + cleanup_stats.failed_agg_row_count) {
let kafka_msg = consumer.recv().await.unwrap();
let payload_str = String::from_utf8(kafka_msg.payload().unwrap().to_vec()).unwrap();
let app_metric: AppMetric = serde_json::from_str(&payload_str).unwrap();
received_app_metrics.push(app_metric);
}
let expected_app_metrics = vec![
AppMetric {
timestamp: DateTime::<Utc>::from_str("2023-12-19T20:00:00Z").unwrap(),
team_id: 1,
plugin_config_id: 2,
job_id: None,
category: AppMetricCategory::Webhook,
successes: 3,
successes_on_retry: 0,
failures: 0,
error_uuid: None,
error_type: None,
error_details: None,
},
AppMetric {
timestamp: DateTime::<Utc>::from_str("2023-12-19T20:00:00Z").unwrap(),
team_id: 1,
plugin_config_id: 3,
job_id: None,
category: AppMetricCategory::Webhook,
successes: 1,
successes_on_retry: 0,
failures: 0,
error_uuid: None,
error_type: None,
error_details: None,
},
AppMetric {
timestamp: DateTime::<Utc>::from_str("2023-12-19T20:00:00Z").unwrap(),
team_id: 2,
plugin_config_id: 4,
job_id: None,
category: AppMetricCategory::Webhook,
successes: 1,
successes_on_retry: 0,
failures: 0,
error_uuid: None,
error_type: None,
error_details: None,
},
AppMetric {
timestamp: DateTime::<Utc>::from_str("2023-12-19T21:00:00Z").unwrap(),
team_id: 1,
plugin_config_id: 2,
job_id: None,
category: AppMetricCategory::Webhook,
successes: 1,
successes_on_retry: 0,
failures: 0,
error_uuid: None,
error_type: None,
error_details: None,
},
AppMetric {
timestamp: DateTime::<Utc>::from_str("2023-12-19T20:00:00Z").unwrap(),
team_id: 1,
plugin_config_id: 2,
job_id: None,
category: AppMetricCategory::Webhook,
successes: 0,
successes_on_retry: 0,
failures: 1,
error_uuid: Some(Uuid::parse_str("018c8935-d038-714a-957c-0df43d42e377").unwrap()),
error_type: Some(ErrorType::ConnectionError),
error_details: Some(ErrorDetails {
error: WebhookError {
name: "Connection Error".to_owned(),
message: None,
stack: None,
},
}),
},
AppMetric {
timestamp: DateTime::<Utc>::from_str("2023-12-19T20:00:00Z").unwrap(),
team_id: 1,
plugin_config_id: 2,
job_id: None,
category: AppMetricCategory::Webhook,
successes: 0,
successes_on_retry: 0,
failures: 3,
error_uuid: Some(Uuid::parse_str("018c8935-d038-714a-957c-0df43d42e377").unwrap()),
error_type: Some(ErrorType::TimeoutError),
error_details: Some(ErrorDetails {
error: WebhookError {
name: "Timeout".to_owned(),
message: None,
stack: None,
},
}),
},
AppMetric {
timestamp: DateTime::<Utc>::from_str("2023-12-19T20:00:00Z").unwrap(),
team_id: 1,
plugin_config_id: 3,
job_id: None,
category: AppMetricCategory::Webhook,
successes: 0,
successes_on_retry: 0,
failures: 1,
error_uuid: Some(Uuid::parse_str("018c8935-d038-714a-957c-0df43d42e377").unwrap()),
error_type: Some(ErrorType::TimeoutError),
error_details: Some(ErrorDetails {
error: WebhookError {
name: "Timeout".to_owned(),
message: None,
stack: None,
},
}),
},
AppMetric {
timestamp: DateTime::<Utc>::from_str("2023-12-19T20:00:00Z").unwrap(),
team_id: 2,
plugin_config_id: 4,
job_id: None,
category: AppMetricCategory::Webhook,
successes: 0,
successes_on_retry: 0,
failures: 1,
error_uuid: Some(Uuid::parse_str("018c8935-d038-714a-957c-0df43d42e377").unwrap()),
error_type: Some(ErrorType::TimeoutError),
error_details: Some(ErrorDetails {
error: WebhookError {
name: "Timeout".to_owned(),
message: None,
stack: None,
},
}),
},
AppMetric {
timestamp: DateTime::<Utc>::from_str("2023-12-19T21:00:00Z").unwrap(),
team_id: 1,
plugin_config_id: 2,
job_id: None,
category: AppMetricCategory::Webhook,
successes: 0,
successes_on_retry: 0,
failures: 1,
error_uuid: Some(Uuid::parse_str("018c8935-d038-714a-957c-0df43d42e377").unwrap()),
error_type: Some(ErrorType::TimeoutError),
error_details: Some(ErrorDetails {
error: WebhookError {
name: "Timeout".to_owned(),
message: None,
stack: None,
},
}),
},
];
check_app_metric_vector_equality(&expected_app_metrics, &received_app_metrics);
}
#[sqlx::test(migrations = "../migrations")]
async fn test_cleanup_impl_empty_queue(db: PgPool) {
let (mock_cluster, mock_producer) = create_mock_kafka().await;
mock_cluster
.create_topic(APP_METRICS_TOPIC, 1, 1)
.expect("failed to create mock app_metrics topic");
// No payload should be produced to kafka as the queue is empty.
// Set a non-retriable produce error that would bubble-up when cleanup_impl is called.
let err = [RDKafkaRespErr::RD_KAFKA_RESP_ERR_MSG_SIZE_TOO_LARGE; 1];
mock_cluster.request_errors(RDKafkaApiKey::Produce, &err);
let consumer: StreamConsumer = ClientConfig::new()
.set("bootstrap.servers", mock_cluster.bootstrap_servers())
.set("group.id", "mock")
.set("auto.offset.reset", "earliest")
.create()
.expect("failed to create mock consumer");
consumer.subscribe(&[APP_METRICS_TOPIC]).unwrap();
let webhook_cleaner =
WebhookCleaner::new_from_pool(db, mock_producer, APP_METRICS_TOPIC.to_owned())
.expect("unable to create webhook cleaner");
let cleanup_stats = webhook_cleaner
.cleanup_impl()
.await
.expect("webbook cleanup_impl failed");
// Reported metrics are all zeroes
assert_eq!(cleanup_stats.rows_processed, 0);
assert_eq!(cleanup_stats.completed_row_count, 0);
assert_eq!(cleanup_stats.completed_agg_row_count, 0);
assert_eq!(cleanup_stats.failed_row_count, 0);
assert_eq!(cleanup_stats.failed_agg_row_count, 0);
}
#[sqlx::test(migrations = "../migrations", fixtures("webhook_cleanup"))]
async fn test_serializable_isolation(db: PgPool) {
let (_, mock_producer) = create_mock_kafka().await;
let webhook_cleaner =
WebhookCleaner::new_from_pool(db.clone(), mock_producer, APP_METRICS_TOPIC.to_owned())
.expect("unable to create webhook cleaner");
let queue = PgQueue::new_from_pool("webhooks", db.clone()).await;
async fn get_count_from_new_conn(db: &PgPool, status: &str) -> i64 {
let mut conn = db.acquire().await.unwrap();
let count: i64 =
sqlx::query("SELECT count(*) FROM job_queue WHERE status = $1::job_status")
.bind(&status)
.fetch_one(&mut *conn)
.await
.unwrap()
.get(0);
count
}
// Important! Serializable txn is started here.
let mut tx = webhook_cleaner.start_serializable_txn().await.unwrap();
webhook_cleaner
.get_completed_agg_rows(&mut tx)
.await
.unwrap();
webhook_cleaner.get_failed_agg_rows(&mut tx).await.unwrap();
// All 15 rows in the DB are visible from outside the txn.
// The 13 the cleaner will process, plus 1 available and 1 running.
assert_eq!(get_count_from_new_conn(&db, "completed").await, 6);
assert_eq!(get_count_from_new_conn(&db, "failed").await, 7);
assert_eq!(get_count_from_new_conn(&db, "available").await, 1);
{
// The fixtures include an available job, so let's complete it while the txn is open.
let mut batch: PgTransactionBatch<'_, WebhookJobParameters, WebhookJobMetadata> = queue
.dequeue_tx(&"worker_id", 1)
.await
.expect("failed to dequeue job")
.expect("didn't find a job to dequeue");
let webhook_job = batch.jobs.pop().unwrap();
webhook_job
.complete()
.await
.expect("failed to complete job");
batch.commit().await.expect("failed to commit batch");
}
{
// Enqueue and complete another job while the txn is open.
let job_parameters = WebhookJobParameters {
body: "foo".to_owned(),
headers: HashMap::new(),
method: HttpMethod::POST,
url: "http://example.com".to_owned(),
};
let job_metadata = WebhookJobMetadata {
team_id: 1,
plugin_id: 2,
plugin_config_id: 3,
};
let new_job = NewJob::new(1, job_metadata, job_parameters, &"target");
queue.enqueue(new_job).await.expect("failed to enqueue job");
let mut batch: PgTransactionBatch<'_, WebhookJobParameters, WebhookJobMetadata> = queue
.dequeue_tx(&"worker_id", 1)
.await
.expect("failed to dequeue job")
.expect("didn't find a job to dequeue");
let webhook_job = batch.jobs.pop().unwrap();
webhook_job
.complete()
.await
.expect("failed to complete job");
batch.commit().await.expect("failed to commit batch");
}
{
// Enqueue another available job while the txn is open.
let job_parameters = WebhookJobParameters {
body: "foo".to_owned(),
headers: HashMap::new(),
method: HttpMethod::POST,
url: "http://example.com".to_owned(),
};
let job_metadata = WebhookJobMetadata {
team_id: 1,
plugin_id: 2,
plugin_config_id: 3,
};
let new_job = NewJob::new(1, job_metadata, job_parameters, &"target");
queue.enqueue(new_job).await.expect("failed to enqueue job");
}
// There are now 2 more completed rows (jobs added above) than before, visible from outside the txn.
assert_eq!(get_count_from_new_conn(&db, "completed").await, 8);
assert_eq!(get_count_from_new_conn(&db, "available").await, 1);
let rows_processed = webhook_cleaner.delete_observed_rows(&mut tx).await.unwrap();
// The 13 rows in the DB when the txn started should be deleted.
assert_eq!(rows_processed, 13);
// We haven't committed, so the rows are still visible from outside the txn.
assert_eq!(get_count_from_new_conn(&db, "completed").await, 8);
assert_eq!(get_count_from_new_conn(&db, "available").await, 1);
webhook_cleaner.commit_txn(tx).await.unwrap();
// We have committed, what remains are:
// * The 1 available job we completed while the txn was open.
// * The 2 brand new jobs we added while the txn was open.
// * The 1 running job that didn't change.
assert_eq!(get_count_from_new_conn(&db, "completed").await, 2);
assert_eq!(get_count_from_new_conn(&db, "failed").await, 0);
assert_eq!(get_count_from_new_conn(&db, "available").await, 1);
}
}

View File

@@ -0,0 +1,25 @@
[package]
name = "hook-worker"
version = "0.1.0"
edition = "2021"
[lints]
workspace = true
[dependencies]
axum = { workspace = true }
chrono = { workspace = true }
envconfig = { workspace = true }
futures = "0.3"
health = { path = "../common/health" }
hook-common = { path = "../hook-common" }
http = { workspace = true }
metrics = { workspace = true }
reqwest = { workspace = true }
sqlx = { workspace = true }
thiserror = { workspace = true }
time = { workspace = true }
tokio = { workspace = true }
tracing = { workspace = true }
tracing-subscriber = { workspace = true }
url = { version = "2.2" }

View File

@@ -0,0 +1,2 @@
# hook-worker
Consume and process webhook jobs

View File

@@ -0,0 +1,104 @@
use std::str::FromStr;
use std::time;
use envconfig::Envconfig;
#[derive(Envconfig, Clone)]
pub struct Config {
#[envconfig(from = "BIND_HOST", default = "0.0.0.0")]
pub host: String,
#[envconfig(from = "BIND_PORT", default = "3301")]
pub port: u16,
#[envconfig(default = "postgres://posthog:posthog@localhost:15432/test_database")]
pub database_url: String,
#[envconfig(default = "worker")]
pub worker_name: String,
#[envconfig(default = "default")]
pub queue_name: NonEmptyString,
#[envconfig(default = "100")]
pub poll_interval: EnvMsDuration,
#[envconfig(default = "5000")]
pub request_timeout: EnvMsDuration,
#[envconfig(default = "1024")]
pub max_concurrent_jobs: usize,
#[envconfig(default = "100")]
pub max_pg_connections: u32,
#[envconfig(nested = true)]
pub retry_policy: RetryPolicyConfig,
#[envconfig(default = "1")]
pub dequeue_batch_size: u32,
#[envconfig(default = "false")]
pub allow_internal_ips: bool,
}
impl Config {
/// Produce a host:port address for binding a TcpListener.
pub fn bind(&self) -> String {
format!("{}:{}", self.host, self.port)
}
}
#[derive(Debug, Clone, Copy)]
pub struct EnvMsDuration(pub time::Duration);
#[derive(Debug, PartialEq, Eq)]
pub struct ParseEnvMsDurationError;
impl FromStr for EnvMsDuration {
type Err = ParseEnvMsDurationError;
fn from_str(s: &str) -> Result<Self, Self::Err> {
let ms = s.parse::<u64>().map_err(|_| ParseEnvMsDurationError)?;
Ok(EnvMsDuration(time::Duration::from_millis(ms)))
}
}
#[derive(Envconfig, Clone)]
pub struct RetryPolicyConfig {
#[envconfig(default = "2")]
pub backoff_coefficient: u32,
#[envconfig(default = "1000")]
pub initial_interval: EnvMsDuration,
#[envconfig(default = "100000")]
pub maximum_interval: EnvMsDuration,
pub retry_queue_name: Option<NonEmptyString>,
}
#[derive(Debug, Clone)]
pub struct NonEmptyString(pub String);
impl NonEmptyString {
pub fn as_str(&self) -> &str {
&self.0
}
}
#[derive(Debug, PartialEq, Eq)]
pub struct StringIsEmptyError;
impl FromStr for NonEmptyString {
type Err = StringIsEmptyError;
fn from_str(s: &str) -> Result<Self, Self::Err> {
if s.is_empty() {
Err(StringIsEmptyError)
} else {
Ok(NonEmptyString(s.to_owned()))
}
}
}

140
rust/hook-worker/src/dns.rs Normal file
View File

@@ -0,0 +1,140 @@
use std::error::Error as StdError;
use std::net::{IpAddr, SocketAddr, ToSocketAddrs};
use std::{fmt, io};
use futures::FutureExt;
use reqwest::dns::{Addrs, Name, Resolve, Resolving};
use tokio::task::spawn_blocking;
pub struct NoPublicIPv4Error;
impl std::error::Error for NoPublicIPv4Error {}
impl fmt::Display for NoPublicIPv4Error {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
write!(f, "No public IPv4 found for specified host")
}
}
impl fmt::Debug for NoPublicIPv4Error {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
write!(f, "No public IPv4 found for specified host")
}
}
/// Internal reqwest type, copied here as part of Resolving
pub(crate) type BoxError = Box<dyn StdError + Send + Sync>;
/// Returns [`true`] if the address appears to be a globally reachable IPv4.
///
/// Trimmed down version of the unstable IpAddr::is_global, move to it when it's stable.
fn is_global_ipv4(addr: &SocketAddr) -> bool {
match addr.ip() {
IpAddr::V4(ip) => {
!(ip.octets()[0] == 0 // "This network"
|| ip.is_private()
|| ip.is_loopback()
|| ip.is_link_local()
|| ip.is_broadcast())
}
IpAddr::V6(_) => false, // Our network does not currently support ipv6, let's ignore for now
}
}
/// DNS resolver using the stdlib resolver, but filtering results to only pass public IPv4 results.
///
/// Private and broadcast addresses are filtered out, so are IPv6 results for now (as our infra
/// does not currently support IPv6 routing anyway).
/// This is adapted from the GaiResolver in hyper and reqwest.
pub struct PublicIPv4Resolver {}
impl Resolve for PublicIPv4Resolver {
fn resolve(&self, name: Name) -> Resolving {
// Closure to call the system's resolver (blocking call) through the ToSocketAddrs trait.
let resolve_host = move || (name.as_str(), 0).to_socket_addrs();
// Execute the blocking call in a separate worker thread then process its result asynchronously.
// spawn_blocking returns a JoinHandle that implements Future<Result<(closure result), JoinError>>.
let future_result = spawn_blocking(resolve_host).map(|result| match result {
Ok(Ok(all_addrs)) => {
// Resolution succeeded, filter the results
let filtered_addr: Vec<SocketAddr> = all_addrs.filter(is_global_ipv4).collect();
if filtered_addr.is_empty() {
// No public IPs found, error out with PermissionDenied
let err: BoxError = Box::new(NoPublicIPv4Error);
Err(err)
} else {
// Pass remaining IPs in a boxed iterator for request to use.
let addrs: Addrs = Box::new(filtered_addr.into_iter());
Ok(addrs)
}
}
Ok(Err(err)) => {
// Resolution failed, pass error through in a Box
let err: BoxError = Box::new(err);
Err(err)
}
Err(join_err) => {
// The tokio task failed, pass as io::Error in a Box
let err: BoxError = Box::new(io::Error::from(join_err));
Err(err)
}
});
// Box the Future to satisfy the Resolving interface.
Box::pin(future_result)
}
}
#[cfg(test)]
mod tests {
use crate::dns::{NoPublicIPv4Error, PublicIPv4Resolver};
use reqwest::dns::{Name, Resolve};
use std::str::FromStr;
#[tokio::test]
async fn it_resolves_google_com() {
let resolver: PublicIPv4Resolver = PublicIPv4Resolver {};
let addrs = resolver
.resolve(Name::from_str("google.com").unwrap())
.await
.expect("lookup has failed");
assert!(addrs.count() > 0, "empty address list")
}
#[tokio::test]
async fn it_denies_ipv6_google_com() {
let resolver: PublicIPv4Resolver = PublicIPv4Resolver {};
match resolver
.resolve(Name::from_str("ipv6.google.com").unwrap())
.await
{
Ok(_) => panic!("should have failed"),
Err(err) => assert!(err.is::<NoPublicIPv4Error>()),
}
}
#[tokio::test]
async fn it_denies_localhost() {
let resolver: PublicIPv4Resolver = PublicIPv4Resolver {};
match resolver.resolve(Name::from_str("localhost").unwrap()).await {
Ok(_) => panic!("should have failed"),
Err(err) => assert!(err.is::<NoPublicIPv4Error>()),
}
}
#[tokio::test]
async fn it_bubbles_up_resolution_error() {
let resolver: PublicIPv4Resolver = PublicIPv4Resolver {};
match resolver
.resolve(Name::from_str("invalid.domain.unknown").unwrap())
.await
{
Ok(_) => panic!("should have failed"),
Err(err) => {
assert!(!err.is::<NoPublicIPv4Error>());
assert!(err
.to_string()
.contains("failed to lookup address information"))
}
}
}
}

View File

@@ -0,0 +1,152 @@
use std::error::Error;
use std::fmt;
use std::time;
use crate::dns::NoPublicIPv4Error;
use hook_common::{pgqueue, webhook::WebhookJobError};
use thiserror::Error;
/// Enumeration of error classes handled by `WebhookWorker`.
#[derive(Error, Debug)]
pub enum WebhookError {
#[error(transparent)]
Parse(#[from] WebhookParseError),
#[error(transparent)]
Request(#[from] WebhookRequestError),
}
/// Enumeration of parsing errors that can occur as `WebhookWorker` sets up a webhook.
#[derive(Error, Debug)]
pub enum WebhookParseError {
#[error("{0} is not a valid HttpMethod")]
ParseHttpMethodError(String),
#[error("error parsing webhook headers")]
ParseHeadersError(http::Error),
#[error("error parsing webhook url")]
ParseUrlError(url::ParseError),
}
/// Enumeration of request errors that can occur as `WebhookWorker` sends a request.
#[derive(Error, Debug)]
pub enum WebhookRequestError {
RetryableRequestError {
error: reqwest::Error,
response: Option<String>,
retry_after: Option<time::Duration>,
},
NonRetryableRetryableRequestError {
error: reqwest::Error,
response: Option<String>,
},
}
/// Enumeration of errors that can occur while handling a `reqwest::Response`.
/// Currently, not consumed anywhere. Grouped here to support a common error type for
/// `utils::first_n_bytes_of_response`.
#[derive(Error, Debug)]
pub enum WebhookResponseError {
#[error("failed to parse a response as UTF8")]
ParseUTF8StringError(#[from] std::str::Utf8Error),
#[error("error while iterating over response body chunks")]
StreamIterationError(#[from] reqwest::Error),
#[error("attempted to slice a chunk of length {0} with an out of bounds index of {1}")]
ChunkOutOfBoundsError(usize, usize),
}
/// Implement display of `WebhookRequestError` by appending to the underlying `reqwest::Error`
/// any response message if available.
impl fmt::Display for WebhookRequestError {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
WebhookRequestError::RetryableRequestError {
error, response, ..
}
| WebhookRequestError::NonRetryableRetryableRequestError { error, response } => {
let response_message = match response {
Some(m) => m.to_string(),
None => "No response from the server".to_string(),
};
if is_error_source::<NoPublicIPv4Error>(error) {
writeln!(f, "{}: {}", error, NoPublicIPv4Error)?;
} else {
writeln!(f, "{}", error)?;
}
write!(f, "{}", response_message)?;
Ok(())
}
}
}
}
/// Implementation of `WebhookRequestError` designed to further describe the error.
/// In particular, we pass some calls to underyling `reqwest::Error` to provide more details.
impl WebhookRequestError {
pub fn is_timeout(&self) -> bool {
match self {
WebhookRequestError::RetryableRequestError { error, .. }
| WebhookRequestError::NonRetryableRetryableRequestError { error, .. } => {
error.is_timeout()
}
}
}
pub fn is_status(&self) -> bool {
match self {
WebhookRequestError::RetryableRequestError { error, .. }
| WebhookRequestError::NonRetryableRetryableRequestError { error, .. } => {
error.is_status()
}
}
}
pub fn status(&self) -> Option<http::StatusCode> {
match self {
WebhookRequestError::RetryableRequestError { error, .. }
| WebhookRequestError::NonRetryableRetryableRequestError { error, .. } => {
error.status()
}
}
}
}
impl From<&WebhookRequestError> for WebhookJobError {
fn from(error: &WebhookRequestError) -> Self {
if error.is_timeout() {
WebhookJobError::new_timeout(&error.to_string())
} else if error.is_status() {
WebhookJobError::new_http_status(
error.status().expect("status code is defined").into(),
&error.to_string(),
)
} else {
// Catch all other errors as `app_metrics::ErrorType::Connection` errors.
// Not all of `reqwest::Error` may strictly be connection errors, so our supported error types may need an extension
// depending on how strict error reporting has to be.
WebhookJobError::new_connection(&error.to_string())
}
}
}
/// Enumeration of errors related to initialization and consumption of webhook jobs.
#[derive(Error, Debug)]
pub enum WorkerError {
#[error("a database error occurred when executing a job")]
DatabaseError(#[from] pgqueue::DatabaseError),
#[error("a parsing error occurred in the underlying queue")]
QueueParseError(#[from] pgqueue::ParseError),
#[error("timed out while waiting for jobs to be available")]
TimeoutError,
}
/// Check the error and it's sources (recursively) to return true if an error of the given type is found.
/// TODO: use Error::sources() when stable
pub fn is_error_source<T: Error + 'static>(err: &(dyn std::error::Error + 'static)) -> bool {
if err.is::<NoPublicIPv4Error>() {
return true;
}
match err.source() {
None => false,
Some(source) => is_error_source::<T>(source),
}
}

View File

@@ -0,0 +1,5 @@
pub mod config;
pub mod dns;
pub mod error;
pub mod util;
pub mod worker;

View File

@@ -0,0 +1,78 @@
//! Consume `PgQueue` jobs to run webhook calls.
use axum::routing::get;
use axum::Router;
use envconfig::Envconfig;
use std::future::ready;
use health::HealthRegistry;
use hook_common::{
metrics::serve, metrics::setup_metrics_routes, pgqueue::PgQueue, retry::RetryPolicy,
};
use hook_worker::config::Config;
use hook_worker::error::WorkerError;
use hook_worker::worker::WebhookWorker;
#[tokio::main]
async fn main() -> Result<(), WorkerError> {
tracing_subscriber::fmt::init();
let config = Config::init_from_env().expect("Invalid configuration:");
let liveness = HealthRegistry::new("liveness");
let worker_liveness = liveness
.register("worker".to_string(), time::Duration::seconds(60)) // TODO: compute the value from worker params
.await;
let mut retry_policy_builder = RetryPolicy::build(
config.retry_policy.backoff_coefficient,
config.retry_policy.initial_interval.0,
)
.maximum_interval(config.retry_policy.maximum_interval.0);
retry_policy_builder = if let Some(retry_queue_name) = &config.retry_policy.retry_queue_name {
retry_policy_builder.queue(retry_queue_name.as_str())
} else {
retry_policy_builder
};
let queue = PgQueue::new(
config.queue_name.as_str(),
&config.database_url,
config.max_pg_connections,
"hook-worker",
)
.await
.expect("failed to initialize queue");
let worker = WebhookWorker::new(
&config.worker_name,
&queue,
config.dequeue_batch_size,
config.poll_interval.0,
config.request_timeout.0,
config.max_concurrent_jobs,
retry_policy_builder.provide(),
config.allow_internal_ips,
worker_liveness,
);
let router = Router::new()
.route("/", get(index))
.route("/_readiness", get(index))
.route("/_liveness", get(move || ready(liveness.get_status())));
let router = setup_metrics_routes(router);
let bind = config.bind();
tokio::task::spawn(async move {
serve(router, &bind)
.await
.expect("failed to start serving metrics");
});
worker.run().await;
Ok(())
}
pub async fn index() -> &'static str {
"rusty-hook worker"
}

View File

@@ -0,0 +1,35 @@
use crate::error::WebhookResponseError;
use futures::StreamExt;
use reqwest::Response;
pub async fn first_n_bytes_of_response(
response: Response,
n: usize,
) -> Result<String, WebhookResponseError> {
let mut body = response.bytes_stream();
let mut buffer = String::with_capacity(n);
while let Some(chunk) = body.next().await {
if buffer.len() >= n {
break;
}
let chunk = chunk?;
let chunk_str = std::str::from_utf8(&chunk)?;
let upper_bound = std::cmp::min(n - buffer.len(), chunk_str.len());
if let Some(partial_chunk_str) = chunk_str.get(0..upper_bound) {
buffer.push_str(partial_chunk_str);
} else {
// For whatever reason we are out of bounds. We should never land here
// given the `std::cmp::min` usage, but I am being extra careful by not
// using a slice index that would panic instead.
return Err(WebhookResponseError::ChunkOutOfBoundsError(
chunk_str.len(),
upper_bound,
));
}
}
Ok(buffer)
}

View File

@@ -0,0 +1,717 @@
use std::collections;
use std::sync::Arc;
use std::time;
use chrono::Utc;
use futures::future::join_all;
use health::HealthHandle;
use hook_common::pgqueue::PgTransactionBatch;
use hook_common::{
pgqueue::{Job, PgQueue, PgQueueJob, PgTransactionJob, RetryError, RetryInvalidError},
retry::RetryPolicy,
webhook::{HttpMethod, WebhookJobError, WebhookJobMetadata, WebhookJobParameters},
};
use http::StatusCode;
use reqwest::{header, Client};
use tokio::sync;
use tracing::error;
use crate::dns::{NoPublicIPv4Error, PublicIPv4Resolver};
use crate::error::{
is_error_source, WebhookError, WebhookParseError, WebhookRequestError, WorkerError,
};
use crate::util::first_n_bytes_of_response;
/// A WebhookJob is any `PgQueueJob` with `WebhookJobParameters` and `WebhookJobMetadata`.
trait WebhookJob: PgQueueJob + std::marker::Send {
fn parameters(&self) -> &WebhookJobParameters;
fn metadata(&self) -> &WebhookJobMetadata;
fn job(&self) -> &Job<WebhookJobParameters, WebhookJobMetadata>;
fn attempt(&self) -> i32 {
self.job().attempt
}
fn queue(&self) -> String {
self.job().queue.to_owned()
}
fn target(&self) -> String {
self.job().target.to_owned()
}
}
impl WebhookJob for PgTransactionJob<'_, WebhookJobParameters, WebhookJobMetadata> {
fn parameters(&self) -> &WebhookJobParameters {
&self.job.parameters
}
fn metadata(&self) -> &WebhookJobMetadata {
&self.job.metadata
}
fn job(&self) -> &Job<WebhookJobParameters, WebhookJobMetadata> {
&self.job
}
}
/// A worker to poll `PgQueue` and spawn tasks to process webhooks when a job becomes available.
pub struct WebhookWorker<'p> {
/// An identifier for this worker. Used to mark jobs we have consumed.
name: String,
/// The queue we will be dequeuing jobs from.
queue: &'p PgQueue,
/// The maximum number of jobs to dequeue in one query.
dequeue_batch_size: u32,
/// The interval for polling the queue.
poll_interval: time::Duration,
/// The client used for HTTP requests.
client: reqwest::Client,
/// Maximum number of concurrent jobs being processed.
max_concurrent_jobs: usize,
/// The retry policy used to calculate retry intervals when a job fails with a retryable error.
retry_policy: RetryPolicy,
/// The liveness check handle, to call on a schedule to report healthy
liveness: HealthHandle,
}
pub fn build_http_client(
request_timeout: time::Duration,
allow_internal_ips: bool,
) -> reqwest::Result<Client> {
let mut headers = header::HeaderMap::new();
headers.insert(
header::CONTENT_TYPE,
header::HeaderValue::from_static("application/json"),
);
let mut client_builder = reqwest::Client::builder()
.default_headers(headers)
.user_agent("PostHog Webhook Worker")
.timeout(request_timeout);
if !allow_internal_ips {
client_builder = client_builder.dns_resolver(Arc::new(PublicIPv4Resolver {}))
}
client_builder.build()
}
impl<'p> WebhookWorker<'p> {
#[allow(clippy::too_many_arguments)]
pub fn new(
name: &str,
queue: &'p PgQueue,
dequeue_batch_size: u32,
poll_interval: time::Duration,
request_timeout: time::Duration,
max_concurrent_jobs: usize,
retry_policy: RetryPolicy,
allow_internal_ips: bool,
liveness: HealthHandle,
) -> Self {
let client = build_http_client(request_timeout, allow_internal_ips)
.expect("failed to construct reqwest client for webhook worker");
Self {
name: name.to_owned(),
queue,
dequeue_batch_size,
poll_interval,
client,
max_concurrent_jobs,
retry_policy,
liveness,
}
}
/// Wait until at least one job becomes available in our queue in transactional mode.
async fn wait_for_jobs_tx<'a>(
&self,
) -> PgTransactionBatch<'a, WebhookJobParameters, WebhookJobMetadata> {
let mut interval = tokio::time::interval(self.poll_interval);
loop {
interval.tick().await;
self.liveness.report_healthy().await;
match self
.queue
.dequeue_tx(&self.name, self.dequeue_batch_size)
.await
{
Ok(Some(batch)) => return batch,
Ok(None) => continue,
Err(error) => {
error!("error while trying to dequeue_tx job: {}", error);
continue;
}
}
}
}
/// Run this worker to continuously process any jobs that become available.
pub async fn run(&self) {
let semaphore = Arc::new(sync::Semaphore::new(self.max_concurrent_jobs));
let report_semaphore_utilization = || {
metrics::gauge!("webhook_worker_saturation_percent")
.set(1f64 - semaphore.available_permits() as f64 / self.max_concurrent_jobs as f64);
};
let dequeue_batch_size_histogram = metrics::histogram!("webhook_dequeue_batch_size");
loop {
report_semaphore_utilization();
// TODO: We could grab semaphore permits here using something like:
// `min(semaphore.available_permits(), dequeue_batch_size)`
// And then dequeue only up to that many jobs. We'd then need to hand back the
// difference in permits based on how many jobs were dequeued.
let mut batch = self.wait_for_jobs_tx().await;
dequeue_batch_size_histogram.record(batch.jobs.len() as f64);
// Get enough permits for the jobs before spawning a task.
let permits = semaphore
.clone()
.acquire_many_owned(batch.jobs.len() as u32)
.await
.expect("semaphore has been closed");
let client = self.client.clone();
let retry_policy = self.retry_policy.clone();
tokio::spawn(async move {
let mut futures = Vec::new();
// We have to `take` the Vec of jobs from the batch to avoid a borrow checker
// error below when we commit.
for job in std::mem::take(&mut batch.jobs) {
let client = client.clone();
let retry_policy = retry_policy.clone();
let future =
async move { process_webhook_job(client, job, &retry_policy).await };
futures.push(future);
}
let results = join_all(futures).await;
for result in results {
if let Err(e) = result {
error!("error processing webhook job: {}", e);
}
}
let _ = batch.commit().await.map_err(|e| {
error!("error committing transactional batch: {}", e);
});
drop(permits);
});
}
}
}
/// Process a webhook job by transitioning it to its appropriate state after its request is sent.
/// After we finish, the webhook job will be set as completed (if the request was successful), retryable (if the request
/// was unsuccessful but we can still attempt a retry), or failed (if the request was unsuccessful and no more retries
/// may be attempted).
///
/// A webhook job is considered retryable after a failing request if:
/// 1. The job has attempts remaining (i.e. hasn't reached `max_attempts`), and...
/// 2. The status code indicates retrying at a later point could resolve the issue. This means: 429 and any 5XX.
///
/// # Arguments
///
/// * `client`: An HTTP client to execute the webhook job request.
/// * `webhook_job`: The webhook job to process as dequeued from `hook_common::pgqueue::PgQueue`.
/// * `retry_policy`: The retry policy used to set retry parameters if a job fails and has remaining attempts.
async fn process_webhook_job<W: WebhookJob>(
client: reqwest::Client,
webhook_job: W,
retry_policy: &RetryPolicy,
) -> Result<(), WorkerError> {
let parameters = webhook_job.parameters();
let labels = [("queue", webhook_job.queue())];
metrics::counter!("webhook_jobs_total", &labels).increment(1);
let now = tokio::time::Instant::now();
let send_result = send_webhook(
client,
&parameters.method,
&parameters.url,
&parameters.headers,
parameters.body.clone(),
)
.await;
let elapsed = now.elapsed().as_secs_f64();
match send_result {
Ok(_) => {
let created_at = webhook_job.job().created_at;
let retries = webhook_job.job().attempt - 1;
let labels_with_retries = [
("queue", webhook_job.queue()),
("retries", retries.to_string()),
];
webhook_job.complete().await.map_err(|error| {
metrics::counter!("webhook_jobs_database_error", &labels).increment(1);
error
})?;
let insert_to_complete_duration = Utc::now() - created_at;
metrics::histogram!(
"webhook_jobs_insert_to_complete_duration_seconds",
&labels_with_retries
)
.record((insert_to_complete_duration.num_milliseconds() as f64) / 1_000_f64);
metrics::counter!("webhook_jobs_completed", &labels).increment(1);
metrics::histogram!("webhook_jobs_processing_duration_seconds", &labels)
.record(elapsed);
Ok(())
}
Err(WebhookError::Parse(WebhookParseError::ParseHeadersError(e))) => {
webhook_job
.fail(WebhookJobError::new_parse(&e.to_string()))
.await
.map_err(|job_error| {
metrics::counter!("webhook_jobs_database_error", &labels).increment(1);
job_error
})?;
metrics::counter!("webhook_jobs_failed", &labels).increment(1);
Ok(())
}
Err(WebhookError::Parse(WebhookParseError::ParseHttpMethodError(e))) => {
webhook_job
.fail(WebhookJobError::new_parse(&e))
.await
.map_err(|job_error| {
metrics::counter!("webhook_jobs_database_error", &labels).increment(1);
job_error
})?;
metrics::counter!("webhook_jobs_failed", &labels).increment(1);
Ok(())
}
Err(WebhookError::Parse(WebhookParseError::ParseUrlError(e))) => {
webhook_job
.fail(WebhookJobError::new_parse(&e.to_string()))
.await
.map_err(|job_error| {
metrics::counter!("webhook_jobs_database_error", &labels).increment(1);
job_error
})?;
metrics::counter!("webhook_jobs_failed", &labels).increment(1);
Ok(())
}
Err(WebhookError::Request(request_error)) => {
let webhook_job_error = WebhookJobError::from(&request_error);
match request_error {
WebhookRequestError::RetryableRequestError {
error, retry_after, ..
} => {
let retry_interval =
retry_policy.retry_interval(webhook_job.attempt() as u32, retry_after);
let current_queue = webhook_job.queue();
let retry_queue = retry_policy.retry_queue(&current_queue);
match webhook_job
.retry(webhook_job_error, retry_interval, retry_queue)
.await
{
Ok(_) => {
metrics::counter!("webhook_jobs_retried", &labels).increment(1);
Ok(())
}
Err(RetryError::RetryInvalidError(RetryInvalidError {
job: webhook_job,
..
})) => {
webhook_job
.fail(WebhookJobError::from(&error))
.await
.map_err(|job_error| {
metrics::counter!("webhook_jobs_database_error", &labels)
.increment(1);
job_error
})?;
metrics::counter!("webhook_jobs_failed", &labels).increment(1);
Ok(())
}
Err(RetryError::DatabaseError(job_error)) => {
metrics::counter!("webhook_jobs_database_error", &labels).increment(1);
Err(WorkerError::from(job_error))
}
}
}
WebhookRequestError::NonRetryableRetryableRequestError { .. } => {
webhook_job
.fail(webhook_job_error)
.await
.map_err(|job_error| {
metrics::counter!("webhook_jobs_database_error", &labels).increment(1);
job_error
})?;
metrics::counter!("webhook_jobs_failed", &labels).increment(1);
Ok(())
}
}
}
}
}
/// Make an HTTP request to a webhook endpoint.
///
/// # Arguments
///
/// * `client`: An HTTP client to execute the HTTP request.
/// * `method`: The HTTP method to use in the HTTP request.
/// * `url`: The URL we are targetting with our request. Parsing this URL fail.
/// * `headers`: Key, value pairs of HTTP headers in a `std::collections::HashMap`. Can fail if headers are not valid.
/// * `body`: The body of the request. Ownership is required.
async fn send_webhook(
client: reqwest::Client,
method: &HttpMethod,
url: &str,
headers: &collections::HashMap<String, String>,
body: String,
) -> Result<reqwest::Response, WebhookError> {
let method: http::Method = method.into();
let url: reqwest::Url = (url).parse().map_err(WebhookParseError::ParseUrlError)?;
let headers: reqwest::header::HeaderMap = (headers)
.try_into()
.map_err(WebhookParseError::ParseHeadersError)?;
let body = reqwest::Body::from(body);
let response = client
.request(method, url)
.headers(headers)
.body(body)
.send()
.await
.map_err(|e| {
if is_error_source::<NoPublicIPv4Error>(&e) {
WebhookRequestError::NonRetryableRetryableRequestError {
error: e,
response: None,
}
} else {
WebhookRequestError::RetryableRequestError {
error: e,
response: None,
retry_after: None,
}
}
})?;
let retry_after = parse_retry_after_header(response.headers());
match response.error_for_status_ref() {
Ok(_) => Ok(response),
Err(err) => {
if is_retryable_status(
err.status()
.expect("status code is set as error is generated from a response"),
) {
Err(WebhookError::Request(
WebhookRequestError::RetryableRequestError {
error: err,
// TODO: Make amount of bytes configurable.
response: first_n_bytes_of_response(response, 10 * 1024).await.ok(),
retry_after,
},
))
} else {
Err(WebhookError::Request(
WebhookRequestError::NonRetryableRetryableRequestError {
error: err,
response: first_n_bytes_of_response(response, 10 * 1024).await.ok(),
},
))
}
}
}
}
fn is_retryable_status(status: StatusCode) -> bool {
status == StatusCode::TOO_MANY_REQUESTS || status.is_server_error()
}
/// Attempt to parse a chrono::Duration from a Retry-After header, returning None if not possible.
/// Retry-After header can specify a date in RFC2822 or a number of seconds; we try to parse both.
/// If a Retry-After header is not present in the provided `header_map`, `None` is returned.
///
/// # Arguments
///
/// * `header_map`: A `&reqwest::HeaderMap` of response headers that could contain Retry-After.
fn parse_retry_after_header(header_map: &reqwest::header::HeaderMap) -> Option<time::Duration> {
let retry_after_header = header_map.get(reqwest::header::RETRY_AFTER);
let retry_after = match retry_after_header {
Some(header_value) => match header_value.to_str() {
Ok(s) => s,
Err(_) => {
return None;
}
},
None => {
return None;
}
};
if let Ok(u) = retry_after.parse::<u64>() {
let duration = time::Duration::from_secs(u);
return Some(duration);
}
if let Ok(dt) = chrono::DateTime::parse_from_rfc2822(retry_after) {
let duration =
chrono::DateTime::<chrono::offset::Utc>::from(dt) - chrono::offset::Utc::now();
// This can only fail when negative, in which case we return None.
return duration.to_std().ok();
}
None
}
#[cfg(test)]
mod tests {
use super::*;
use std::time::Duration;
// Note we are ignoring some warnings in this module.
// This is due to a long-standing cargo bug that reports imports and helper functions as unused.
// See: https://github.com/rust-lang/rust/issues/46379.
use health::HealthRegistry;
use hook_common::pgqueue::{DatabaseError, NewJob};
use sqlx::PgPool;
/// Use process id as a worker id for tests.
fn worker_id() -> String {
std::process::id().to_string()
}
/// Get a request client or panic
fn localhost_client() -> Client {
build_http_client(Duration::from_secs(1), true).expect("failed to create client")
}
async fn enqueue_job(
queue: &PgQueue,
max_attempts: i32,
job_parameters: WebhookJobParameters,
job_metadata: WebhookJobMetadata,
) -> Result<(), DatabaseError> {
let job_target = job_parameters.url.to_owned();
let new_job = NewJob::new(max_attempts, job_metadata, job_parameters, &job_target);
queue.enqueue(new_job).await?;
Ok(())
}
#[test]
fn test_is_retryable_status() {
assert!(!is_retryable_status(http::StatusCode::FORBIDDEN));
assert!(!is_retryable_status(http::StatusCode::BAD_REQUEST));
assert!(is_retryable_status(http::StatusCode::TOO_MANY_REQUESTS));
assert!(is_retryable_status(http::StatusCode::INTERNAL_SERVER_ERROR));
}
#[test]
fn test_parse_retry_after_header() {
let mut headers = reqwest::header::HeaderMap::new();
headers.insert(reqwest::header::RETRY_AFTER, "120".parse().unwrap());
let duration = parse_retry_after_header(&headers).unwrap();
assert_eq!(duration, time::Duration::from_secs(120));
headers.remove(reqwest::header::RETRY_AFTER);
let duration = parse_retry_after_header(&headers);
assert_eq!(duration, None);
headers.insert(
reqwest::header::RETRY_AFTER,
"Wed, 21 Oct 2015 07:28:00 GMT".parse().unwrap(),
);
let duration = parse_retry_after_header(&headers);
assert_eq!(duration, None);
}
#[sqlx::test(migrations = "../migrations")]
async fn test_wait_for_job(db: PgPool) {
let worker_id = worker_id();
let queue_name = "test_wait_for_job".to_string();
let queue = PgQueue::new_from_pool(&queue_name, db).await;
let webhook_job_parameters = WebhookJobParameters {
body: "a webhook job body. much wow.".to_owned(),
headers: collections::HashMap::new(),
method: HttpMethod::POST,
url: "localhost".to_owned(),
};
let webhook_job_metadata = WebhookJobMetadata {
team_id: 1,
plugin_id: 2,
plugin_config_id: 3,
};
let registry = HealthRegistry::new("liveness");
let liveness = registry
.register("worker".to_string(), ::time::Duration::seconds(30))
.await;
// enqueue takes ownership of the job enqueued to avoid bugs that can cause duplicate jobs.
// Normally, a separate application would be enqueueing jobs for us to consume, so no ownership
// conflicts would arise. However, in this test we need to do the enqueueing ourselves.
// So, we clone the job to keep it around and assert the values returned by wait_for_job.
enqueue_job(
&queue,
1,
webhook_job_parameters.clone(),
webhook_job_metadata,
)
.await
.expect("failed to enqueue job");
let worker = WebhookWorker::new(
&worker_id,
&queue,
1,
time::Duration::from_millis(100),
time::Duration::from_millis(5000),
10,
RetryPolicy::default(),
false,
liveness,
);
let mut batch = worker.wait_for_jobs_tx().await;
let consumed_job = batch.jobs.pop().unwrap();
assert_eq!(consumed_job.job.attempt, 1);
assert!(consumed_job.job.attempted_by.contains(&worker_id));
assert_eq!(consumed_job.job.attempted_by.len(), 1);
assert_eq!(consumed_job.job.max_attempts, 1);
assert_eq!(
*consumed_job.job.parameters.as_ref(),
webhook_job_parameters
);
assert_eq!(consumed_job.job.target, webhook_job_parameters.url);
consumed_job
.complete()
.await
.expect("job not successfully completed");
batch.commit().await.expect("failed to commit batch");
assert!(registry.get_status().healthy)
}
#[tokio::test]
async fn test_send_webhook() {
let method = HttpMethod::POST;
let url = "http://localhost:18081/echo";
let headers = collections::HashMap::new();
let body = "a very relevant request body";
let response = send_webhook(localhost_client(), &method, url, &headers, body.to_owned())
.await
.expect("send_webhook failed");
assert_eq!(response.status(), StatusCode::OK);
assert_eq!(
response.text().await.expect("failed to read response body"),
body.to_owned(),
);
}
#[tokio::test]
async fn test_error_message_contains_response_body() {
let method = HttpMethod::POST;
let url = "http://localhost:18081/fail";
let headers = collections::HashMap::new();
let body = "this is an error message";
let err = send_webhook(localhost_client(), &method, url, &headers, body.to_owned())
.await
.err()
.expect("request didn't fail when it should have failed");
assert!(matches!(err, WebhookError::Request(..)));
if let WebhookError::Request(request_error) = err {
assert_eq!(request_error.status(), Some(StatusCode::BAD_REQUEST));
assert!(request_error.to_string().contains(body));
// This is the display implementation of reqwest. Just checking it is still there.
// See: https://github.com/seanmonstar/reqwest/blob/master/src/error.rs
assert!(request_error.to_string().contains(
"HTTP status client error (400 Bad Request) for url (http://localhost:18081/fail)"
));
}
}
#[tokio::test]
async fn test_error_message_contains_up_to_n_bytes_of_response_body() {
let method = HttpMethod::POST;
let url = "http://localhost:18081/fail";
let headers = collections::HashMap::new();
// This is double the current hardcoded amount of bytes.
// TODO: Make this configurable and change it here too.
let body = (0..20 * 1024).map(|_| "a").collect::<Vec<_>>().concat();
let err = send_webhook(localhost_client(), &method, url, &headers, body.to_owned())
.await
.err()
.expect("request didn't fail when it should have failed");
assert!(matches!(err, WebhookError::Request(..)));
if let WebhookError::Request(request_error) = err {
assert_eq!(request_error.status(), Some(StatusCode::BAD_REQUEST));
assert!(request_error.to_string().contains(&body[0..10 * 1024]));
// The 81 bytes account for the reqwest erorr message as described below.
assert_eq!(request_error.to_string().len(), 10 * 1024 + 81);
// This is the display implementation of reqwest. Just checking it is still there.
// See: https://github.com/seanmonstar/reqwest/blob/master/src/error.rs
assert!(request_error.to_string().contains(
"HTTP status client error (400 Bad Request) for url (http://localhost:18081/fail)"
));
}
}
#[tokio::test]
async fn test_private_ips_denied() {
let method = HttpMethod::POST;
let url = "http://localhost:18081/echo";
let headers = collections::HashMap::new();
let body = "a very relevant request body";
let filtering_client =
build_http_client(Duration::from_secs(1), false).expect("failed to create client");
let err = send_webhook(filtering_client, &method, url, &headers, body.to_owned())
.await
.err()
.expect("request didn't fail when it should have failed");
assert!(matches!(err, WebhookError::Request(..)));
if let WebhookError::Request(request_error) = err {
assert_eq!(request_error.status(), None);
assert!(request_error
.to_string()
.contains("No public IPv4 found for specified host"));
if let WebhookRequestError::RetryableRequestError { .. } = request_error {
panic!("error should not be retryable")
}
} else {
panic!("unexpected error type {:?}", err)
}
}
}

View File

@@ -0,0 +1,29 @@
CREATE TYPE job_status AS ENUM(
'available',
'completed',
'failed',
'running'
);
CREATE TABLE job_queue(
id BIGSERIAL PRIMARY KEY,
attempt INT NOT NULL DEFAULT 0,
attempted_at TIMESTAMPTZ DEFAULT NULL,
attempted_by TEXT [] DEFAULT ARRAY [] :: TEXT [],
created_at TIMESTAMPTZ NOT NULL DEFAULT NOW(),
errors JSONB [],
max_attempts INT NOT NULL DEFAULT 1,
metadata JSONB,
last_attempt_finished_at TIMESTAMPTZ DEFAULT NULL,
parameters JSONB,
queue TEXT NOT NULL DEFAULT 'default' :: text,
scheduled_at TIMESTAMPTZ NOT NULL DEFAULT NOW(),
status job_status NOT NULL DEFAULT 'available' :: job_status,
target TEXT NOT NULL
);
-- Needed for `dequeue` queries
CREATE INDEX idx_queue_scheduled_at ON job_queue(queue, status, scheduled_at, attempt);
-- Needed for UPDATE-ing incomplete jobs with a specific target (i.e. slow destinations)
CREATE INDEX idx_queue_target ON job_queue(queue, status, target);

View File

@@ -0,0 +1,10 @@
-- Dequeue is not hitting this index, so dropping is safe this time.
DROP INDEX idx_queue_scheduled_at;
/*
Partial index used for dequeuing from job_queue.
Dequeue only looks at available jobs so a partial index serves us well.
Moreover, dequeue sorts jobs by attempt and scheduled_at, which matches this index.
*/
CREATE INDEX idx_queue_dequeue_partial ON job_queue(queue, attempt, scheduled_at) WHERE status = 'available' :: job_status;