mirror of
https://github.com/BillyOutlast/posthog.git
synced 2026-02-04 03:01:23 +01:00
refactor(temporal-worker): Handle signal setting in parent function (#31313)
This commit is contained in:
committed by
GitHub
parent
11b347000b
commit
7918f2f5fc
@@ -3,36 +3,46 @@
|
||||
set -e
|
||||
|
||||
cleanup() {
|
||||
echo "Stopping worker..."
|
||||
if kill -0 "$worker_pid" >/dev/null 2>&1; then
|
||||
echo "Signaling worker to initialize graceful shutdown"
|
||||
kill -SIGTERM "$worker_pid"
|
||||
echo "Worker signaled for graceful shutdown"
|
||||
else
|
||||
echo "Worker process is not running."
|
||||
echo "Could not initiate graceful shutdown as worker process is not running"
|
||||
fi
|
||||
}
|
||||
|
||||
trap cleanup SIGINT SIGTERM EXIT
|
||||
|
||||
python3 manage.py start_temporal_worker "$@" &
|
||||
echo "Initializing Temporal Worker"
|
||||
|
||||
python3 manage.py start_temporal_worker "$@" &
|
||||
worker_pid=$!
|
||||
|
||||
# Run wait in a loop in case we trap SIGINT or SIGTERM.
|
||||
# In both cases, wait will terminate early, potentially not waiting for graceful shutdown.
|
||||
while wait $worker_pid
|
||||
do
|
||||
status=$?
|
||||
# If we exit with SIGTERM, status will be 128 + 15.
|
||||
# If we exit with SIGINT, status will be 128 + 2.
|
||||
if [ $status -eq 143 ] || [ $status -eq 130 ]; then
|
||||
echo "Received signal $(($status - 128)), waiting for worker to finish"
|
||||
elif [ $status -eq 0 ]; then
|
||||
echo "Worker exited normally, terminating wait"
|
||||
while true; do
|
||||
if wait $worker_pid; then
|
||||
echo "Worker exited normally, exiting"
|
||||
break
|
||||
else
|
||||
echo "Worker exited with unexpected exit status $status, terminating wait"
|
||||
break
|
||||
status=$?
|
||||
|
||||
# Worker process could finish before the trap handler yields control back to us.
|
||||
# So, we first re-check if we are still running.
|
||||
if ! kill -0 "$worker_pid" >/dev/null 2>&1; then
|
||||
echo "Worker process no longer exists, exiting"
|
||||
break
|
||||
fi
|
||||
|
||||
# If we exit with SIGTERM, status will be 128 + 15.
|
||||
# If we exit with SIGINT, status will be 128 + 2.
|
||||
if [ $status -eq 143 ] || [ $status -eq 130 ]; then
|
||||
echo "Received signal $(($status - 128)), waiting for worker to finish"
|
||||
elif [ $status -eq 0 ]; then
|
||||
echo "Worker exited normally, exiting"
|
||||
break
|
||||
else
|
||||
echo "Worker exited with unexpected exit status $status, exiting"
|
||||
break
|
||||
fi
|
||||
fi
|
||||
done
|
||||
|
||||
cleanup
|
||||
|
||||
@@ -1,9 +1,11 @@
|
||||
import asyncio
|
||||
import datetime as dt
|
||||
import logging
|
||||
import functools
|
||||
import signal
|
||||
|
||||
import structlog
|
||||
from temporalio import workflow
|
||||
from temporalio.worker import Worker
|
||||
|
||||
with workflow.unsafe.imports_passed_through():
|
||||
from django.conf import settings
|
||||
@@ -22,11 +24,7 @@ from posthog.temporal.batch_exports import (
|
||||
ACTIVITIES as BATCH_EXPORTS_ACTIVITIES,
|
||||
WORKFLOWS as BATCH_EXPORTS_WORKFLOWS,
|
||||
)
|
||||
from posthog.temporal.common.worker import start_worker
|
||||
from posthog.temporal.session_recordings import (
|
||||
ACTIVITIES as SESSION_RECORDINGS_ACTIVITIES,
|
||||
WORKFLOWS as SESSION_RECORDINGS_WORKFLOWS,
|
||||
)
|
||||
from posthog.temporal.common.worker import create_worker
|
||||
from posthog.temporal.data_imports.settings import ACTIVITIES as DATA_SYNC_ACTIVITIES, WORKFLOWS as DATA_SYNC_WORKFLOWS
|
||||
from posthog.temporal.data_modeling import ACTIVITIES as DATA_MODELING_ACTIVITIES, WORKFLOWS as DATA_MODELING_WORKFLOWS
|
||||
from posthog.temporal.delete_persons import (
|
||||
@@ -34,12 +32,18 @@ from posthog.temporal.delete_persons import (
|
||||
WORKFLOWS as DELETE_PERSONS_WORKFLOWS,
|
||||
)
|
||||
from posthog.temporal.proxy_service import ACTIVITIES as PROXY_SERVICE_ACTIVITIES, WORKFLOWS as PROXY_SERVICE_WORKFLOWS
|
||||
from posthog.temporal.tests.utils.workflow import ACTIVITIES as TEST_ACTIVITIES, WORKFLOWS as TEST_WORKFLOWS
|
||||
from posthog.temporal.usage_reports import ACTIVITIES as USAGE_REPORTS_ACTIVITIES, WORKFLOWS as USAGE_REPORTS_WORKFLOWS
|
||||
from posthog.temporal.quota_limiting import (
|
||||
ACTIVITIES as QUOTA_LIMITING_ACTIVITIES,
|
||||
WORKFLOWS as QUOTA_LIMITING_WORKFLOWS,
|
||||
)
|
||||
from posthog.temporal.session_recordings import (
|
||||
ACTIVITIES as SESSION_RECORDINGS_ACTIVITIES,
|
||||
WORKFLOWS as SESSION_RECORDINGS_WORKFLOWS,
|
||||
)
|
||||
from posthog.temporal.tests.utils.workflow import ACTIVITIES as TEST_ACTIVITIES, WORKFLOWS as TEST_WORKFLOWS
|
||||
from posthog.temporal.usage_reports import ACTIVITIES as USAGE_REPORTS_ACTIVITIES, WORKFLOWS as USAGE_REPORTS_WORKFLOWS
|
||||
|
||||
logger = structlog.get_logger(__name__)
|
||||
|
||||
WORKFLOWS_DICT = {
|
||||
SYNC_BATCH_EXPORTS_TASK_QUEUE: BATCH_EXPORTS_WORKFLOWS,
|
||||
@@ -149,27 +153,62 @@ class Command(BaseCommand):
|
||||
|
||||
if options["client_key"]:
|
||||
options["client_key"] = "--SECRET--"
|
||||
logging.info(f"Starting Temporal Worker with options: {options}")
|
||||
|
||||
structlog.reset_defaults()
|
||||
|
||||
logger.info(f"Starting Temporal Worker with options: {options}")
|
||||
|
||||
metrics_port = int(options["metrics_port"])
|
||||
|
||||
asyncio.run(
|
||||
start_worker(
|
||||
temporal_host,
|
||||
temporal_port,
|
||||
metrics_port=metrics_port,
|
||||
namespace=namespace,
|
||||
task_queue=task_queue,
|
||||
server_root_ca_cert=server_root_ca_cert,
|
||||
client_cert=client_cert,
|
||||
client_key=client_key,
|
||||
workflows=workflows, # type: ignore
|
||||
activities=activities,
|
||||
graceful_shutdown_timeout=dt.timedelta(seconds=graceful_shutdown_timeout_seconds)
|
||||
if graceful_shutdown_timeout_seconds is not None
|
||||
else None,
|
||||
max_concurrent_workflow_tasks=max_concurrent_workflow_tasks,
|
||||
max_concurrent_activities=max_concurrent_activities,
|
||||
shutdown_task = None
|
||||
|
||||
def shutdown_worker_on_signal(worker: Worker, sig: signal.Signals, loop: asyncio.events.AbstractEventLoop):
|
||||
"""Shutdown Temporal worker on receiving signal."""
|
||||
nonlocal shutdown_task
|
||||
|
||||
logger.info("Signal %s received", sig)
|
||||
|
||||
if worker.is_shutdown:
|
||||
logger.info("Temporal worker already shut down")
|
||||
return
|
||||
|
||||
logger.info("Initiating Temporal worker shutdown")
|
||||
shutdown_task = loop.create_task(worker.shutdown())
|
||||
logger.info("Finished Temporal worker shutdown")
|
||||
|
||||
with asyncio.Runner() as runner:
|
||||
worker = runner.run(
|
||||
create_worker(
|
||||
temporal_host,
|
||||
temporal_port,
|
||||
metrics_port=metrics_port,
|
||||
namespace=namespace,
|
||||
task_queue=task_queue,
|
||||
server_root_ca_cert=server_root_ca_cert,
|
||||
client_cert=client_cert,
|
||||
client_key=client_key,
|
||||
workflows=workflows, # type: ignore
|
||||
activities=activities,
|
||||
graceful_shutdown_timeout=dt.timedelta(seconds=graceful_shutdown_timeout_seconds)
|
||||
if graceful_shutdown_timeout_seconds is not None
|
||||
else None,
|
||||
max_concurrent_workflow_tasks=max_concurrent_workflow_tasks,
|
||||
max_concurrent_activities=max_concurrent_activities,
|
||||
)
|
||||
)
|
||||
)
|
||||
|
||||
loop = runner.get_loop()
|
||||
for sig in (signal.SIGTERM, signal.SIGINT):
|
||||
loop.add_signal_handler(
|
||||
sig,
|
||||
functools.partial(shutdown_worker_on_signal, worker=worker, sig=sig, loop=loop),
|
||||
)
|
||||
loop.add_signal_handler(
|
||||
sig,
|
||||
functools.partial(shutdown_worker_on_signal, worker=worker, sig=sig, loop=loop),
|
||||
)
|
||||
|
||||
runner.run(worker.run())
|
||||
|
||||
if shutdown_task:
|
||||
_ = runner.run(asyncio.wait([shutdown_task]))
|
||||
|
||||
@@ -1,11 +1,8 @@
|
||||
import asyncio
|
||||
import collections.abc
|
||||
import datetime as dt
|
||||
import signal
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
|
||||
import structlog
|
||||
from django.conf import settings
|
||||
from temporalio.runtime import PrometheusConfig, Runtime, TelemetryConfig
|
||||
from temporalio.worker import UnsandboxedWorkflowRunner, Worker
|
||||
|
||||
@@ -16,14 +13,7 @@ from posthog.temporal.common.sentry import SentryInterceptor
|
||||
logger = structlog.get_logger(__name__)
|
||||
|
||||
|
||||
def _debug_pyarrows():
|
||||
if settings.PYARROW_DEBUG_LOGGING:
|
||||
import pyarrow as pa
|
||||
|
||||
pa.log_memory_allocations(enable=True)
|
||||
|
||||
|
||||
async def start_worker(
|
||||
async def create_worker(
|
||||
host: str,
|
||||
port: int,
|
||||
metrics_port: int,
|
||||
@@ -37,9 +27,28 @@ async def start_worker(
|
||||
graceful_shutdown_timeout: dt.timedelta | None = None,
|
||||
max_concurrent_workflow_tasks: int | None = None,
|
||||
max_concurrent_activities: int | None = None,
|
||||
):
|
||||
_debug_pyarrows()
|
||||
) -> Worker:
|
||||
"""Connect to Temporal server and return a Worker.
|
||||
|
||||
Arguments:
|
||||
host: The Temporal Server host.
|
||||
port: The Temporal Server port.
|
||||
metrics_port: Port used to serve Prometheus metrics.
|
||||
namespace: The Temporal namespace to connect to.
|
||||
task_queue: The task queue the worker will listen on.
|
||||
workflows: Workflows the worker is configured to run.
|
||||
activities: Activities the worker is configured to run.
|
||||
server_root_ca_cert: Root CA to validate the server certificate against.
|
||||
client_cert: Client certificate for TLS.
|
||||
client_key: Client private key for TLS.
|
||||
graceful_shutdown_timeout: Time to wait (in seconds) for graceful shutdown.
|
||||
By default we will wait 5 minutes. This should be always less than any
|
||||
timeouts used by deployment orchestrators.
|
||||
max_concurrent_workflow_tasks: Maximum number of concurrent workflow tasks
|
||||
the worker can handle. Defaults to 50.
|
||||
max_concurrent_activities: Maximum number of concurrent activity tasks the
|
||||
worker can handle. Defaults to 50.
|
||||
"""
|
||||
runtime = Runtime(telemetry=TelemetryConfig(metrics=PrometheusConfig(bind_address=f"0.0.0.0:{metrics_port:d}")))
|
||||
client = await connect(
|
||||
host,
|
||||
@@ -66,17 +75,4 @@ async def start_worker(
|
||||
# min(heartbeat_timeout * 0.8, max_heartbeat_throttle_interval).
|
||||
max_heartbeat_throttle_interval=dt.timedelta(seconds=5),
|
||||
)
|
||||
|
||||
# catch the TERM and INT signals, and stop the worker gracefully
|
||||
# https://github.com/temporalio/sdk-python#worker-shutdown
|
||||
async def shutdown_worker(s: str):
|
||||
logger.info("%s received, initiating Temporal worker shutdown", s)
|
||||
await worker.shutdown()
|
||||
logger.info("Finished Temporal worker shutdown")
|
||||
|
||||
loop = asyncio.get_event_loop()
|
||||
shutdown_tasks = set()
|
||||
loop.add_signal_handler(signal.SIGINT, lambda: shutdown_tasks.add(asyncio.create_task(shutdown_worker("SIGINT"))))
|
||||
loop.add_signal_handler(signal.SIGTERM, lambda: shutdown_tasks.add(asyncio.create_task(shutdown_worker("SIGTERM"))))
|
||||
|
||||
await worker.run()
|
||||
return worker
|
||||
|
||||
Reference in New Issue
Block a user