refactor(temporal-worker): Handle signal setting in parent function (#31313)

This commit is contained in:
Tomás Farías Santana
2025-04-22 11:12:09 +02:00
committed by GitHub
parent 11b347000b
commit 7918f2f5fc
3 changed files with 117 additions and 72 deletions

View File

@@ -3,36 +3,46 @@
set -e
cleanup() {
echo "Stopping worker..."
if kill -0 "$worker_pid" >/dev/null 2>&1; then
echo "Signaling worker to initialize graceful shutdown"
kill -SIGTERM "$worker_pid"
echo "Worker signaled for graceful shutdown"
else
echo "Worker process is not running."
echo "Could not initiate graceful shutdown as worker process is not running"
fi
}
trap cleanup SIGINT SIGTERM EXIT
python3 manage.py start_temporal_worker "$@" &
echo "Initializing Temporal Worker"
python3 manage.py start_temporal_worker "$@" &
worker_pid=$!
# Run wait in a loop in case we trap SIGINT or SIGTERM.
# In both cases, wait will terminate early, potentially not waiting for graceful shutdown.
while wait $worker_pid
do
status=$?
# If we exit with SIGTERM, status will be 128 + 15.
# If we exit with SIGINT, status will be 128 + 2.
if [ $status -eq 143 ] || [ $status -eq 130 ]; then
echo "Received signal $(($status - 128)), waiting for worker to finish"
elif [ $status -eq 0 ]; then
echo "Worker exited normally, terminating wait"
while true; do
if wait $worker_pid; then
echo "Worker exited normally, exiting"
break
else
echo "Worker exited with unexpected exit status $status, terminating wait"
break
status=$?
# Worker process could finish before the trap handler yields control back to us.
# So, we first re-check if we are still running.
if ! kill -0 "$worker_pid" >/dev/null 2>&1; then
echo "Worker process no longer exists, exiting"
break
fi
# If we exit with SIGTERM, status will be 128 + 15.
# If we exit with SIGINT, status will be 128 + 2.
if [ $status -eq 143 ] || [ $status -eq 130 ]; then
echo "Received signal $(($status - 128)), waiting for worker to finish"
elif [ $status -eq 0 ]; then
echo "Worker exited normally, exiting"
break
else
echo "Worker exited with unexpected exit status $status, exiting"
break
fi
fi
done
cleanup

View File

@@ -1,9 +1,11 @@
import asyncio
import datetime as dt
import logging
import functools
import signal
import structlog
from temporalio import workflow
from temporalio.worker import Worker
with workflow.unsafe.imports_passed_through():
from django.conf import settings
@@ -22,11 +24,7 @@ from posthog.temporal.batch_exports import (
ACTIVITIES as BATCH_EXPORTS_ACTIVITIES,
WORKFLOWS as BATCH_EXPORTS_WORKFLOWS,
)
from posthog.temporal.common.worker import start_worker
from posthog.temporal.session_recordings import (
ACTIVITIES as SESSION_RECORDINGS_ACTIVITIES,
WORKFLOWS as SESSION_RECORDINGS_WORKFLOWS,
)
from posthog.temporal.common.worker import create_worker
from posthog.temporal.data_imports.settings import ACTIVITIES as DATA_SYNC_ACTIVITIES, WORKFLOWS as DATA_SYNC_WORKFLOWS
from posthog.temporal.data_modeling import ACTIVITIES as DATA_MODELING_ACTIVITIES, WORKFLOWS as DATA_MODELING_WORKFLOWS
from posthog.temporal.delete_persons import (
@@ -34,12 +32,18 @@ from posthog.temporal.delete_persons import (
WORKFLOWS as DELETE_PERSONS_WORKFLOWS,
)
from posthog.temporal.proxy_service import ACTIVITIES as PROXY_SERVICE_ACTIVITIES, WORKFLOWS as PROXY_SERVICE_WORKFLOWS
from posthog.temporal.tests.utils.workflow import ACTIVITIES as TEST_ACTIVITIES, WORKFLOWS as TEST_WORKFLOWS
from posthog.temporal.usage_reports import ACTIVITIES as USAGE_REPORTS_ACTIVITIES, WORKFLOWS as USAGE_REPORTS_WORKFLOWS
from posthog.temporal.quota_limiting import (
ACTIVITIES as QUOTA_LIMITING_ACTIVITIES,
WORKFLOWS as QUOTA_LIMITING_WORKFLOWS,
)
from posthog.temporal.session_recordings import (
ACTIVITIES as SESSION_RECORDINGS_ACTIVITIES,
WORKFLOWS as SESSION_RECORDINGS_WORKFLOWS,
)
from posthog.temporal.tests.utils.workflow import ACTIVITIES as TEST_ACTIVITIES, WORKFLOWS as TEST_WORKFLOWS
from posthog.temporal.usage_reports import ACTIVITIES as USAGE_REPORTS_ACTIVITIES, WORKFLOWS as USAGE_REPORTS_WORKFLOWS
logger = structlog.get_logger(__name__)
WORKFLOWS_DICT = {
SYNC_BATCH_EXPORTS_TASK_QUEUE: BATCH_EXPORTS_WORKFLOWS,
@@ -149,27 +153,62 @@ class Command(BaseCommand):
if options["client_key"]:
options["client_key"] = "--SECRET--"
logging.info(f"Starting Temporal Worker with options: {options}")
structlog.reset_defaults()
logger.info(f"Starting Temporal Worker with options: {options}")
metrics_port = int(options["metrics_port"])
asyncio.run(
start_worker(
temporal_host,
temporal_port,
metrics_port=metrics_port,
namespace=namespace,
task_queue=task_queue,
server_root_ca_cert=server_root_ca_cert,
client_cert=client_cert,
client_key=client_key,
workflows=workflows, # type: ignore
activities=activities,
graceful_shutdown_timeout=dt.timedelta(seconds=graceful_shutdown_timeout_seconds)
if graceful_shutdown_timeout_seconds is not None
else None,
max_concurrent_workflow_tasks=max_concurrent_workflow_tasks,
max_concurrent_activities=max_concurrent_activities,
shutdown_task = None
def shutdown_worker_on_signal(worker: Worker, sig: signal.Signals, loop: asyncio.events.AbstractEventLoop):
"""Shutdown Temporal worker on receiving signal."""
nonlocal shutdown_task
logger.info("Signal %s received", sig)
if worker.is_shutdown:
logger.info("Temporal worker already shut down")
return
logger.info("Initiating Temporal worker shutdown")
shutdown_task = loop.create_task(worker.shutdown())
logger.info("Finished Temporal worker shutdown")
with asyncio.Runner() as runner:
worker = runner.run(
create_worker(
temporal_host,
temporal_port,
metrics_port=metrics_port,
namespace=namespace,
task_queue=task_queue,
server_root_ca_cert=server_root_ca_cert,
client_cert=client_cert,
client_key=client_key,
workflows=workflows, # type: ignore
activities=activities,
graceful_shutdown_timeout=dt.timedelta(seconds=graceful_shutdown_timeout_seconds)
if graceful_shutdown_timeout_seconds is not None
else None,
max_concurrent_workflow_tasks=max_concurrent_workflow_tasks,
max_concurrent_activities=max_concurrent_activities,
)
)
)
loop = runner.get_loop()
for sig in (signal.SIGTERM, signal.SIGINT):
loop.add_signal_handler(
sig,
functools.partial(shutdown_worker_on_signal, worker=worker, sig=sig, loop=loop),
)
loop.add_signal_handler(
sig,
functools.partial(shutdown_worker_on_signal, worker=worker, sig=sig, loop=loop),
)
runner.run(worker.run())
if shutdown_task:
_ = runner.run(asyncio.wait([shutdown_task]))

View File

@@ -1,11 +1,8 @@
import asyncio
import collections.abc
import datetime as dt
import signal
from concurrent.futures import ThreadPoolExecutor
import structlog
from django.conf import settings
from temporalio.runtime import PrometheusConfig, Runtime, TelemetryConfig
from temporalio.worker import UnsandboxedWorkflowRunner, Worker
@@ -16,14 +13,7 @@ from posthog.temporal.common.sentry import SentryInterceptor
logger = structlog.get_logger(__name__)
def _debug_pyarrows():
if settings.PYARROW_DEBUG_LOGGING:
import pyarrow as pa
pa.log_memory_allocations(enable=True)
async def start_worker(
async def create_worker(
host: str,
port: int,
metrics_port: int,
@@ -37,9 +27,28 @@ async def start_worker(
graceful_shutdown_timeout: dt.timedelta | None = None,
max_concurrent_workflow_tasks: int | None = None,
max_concurrent_activities: int | None = None,
):
_debug_pyarrows()
) -> Worker:
"""Connect to Temporal server and return a Worker.
Arguments:
host: The Temporal Server host.
port: The Temporal Server port.
metrics_port: Port used to serve Prometheus metrics.
namespace: The Temporal namespace to connect to.
task_queue: The task queue the worker will listen on.
workflows: Workflows the worker is configured to run.
activities: Activities the worker is configured to run.
server_root_ca_cert: Root CA to validate the server certificate against.
client_cert: Client certificate for TLS.
client_key: Client private key for TLS.
graceful_shutdown_timeout: Time to wait (in seconds) for graceful shutdown.
By default we will wait 5 minutes. This should be always less than any
timeouts used by deployment orchestrators.
max_concurrent_workflow_tasks: Maximum number of concurrent workflow tasks
the worker can handle. Defaults to 50.
max_concurrent_activities: Maximum number of concurrent activity tasks the
worker can handle. Defaults to 50.
"""
runtime = Runtime(telemetry=TelemetryConfig(metrics=PrometheusConfig(bind_address=f"0.0.0.0:{metrics_port:d}")))
client = await connect(
host,
@@ -66,17 +75,4 @@ async def start_worker(
# min(heartbeat_timeout * 0.8, max_heartbeat_throttle_interval).
max_heartbeat_throttle_interval=dt.timedelta(seconds=5),
)
# catch the TERM and INT signals, and stop the worker gracefully
# https://github.com/temporalio/sdk-python#worker-shutdown
async def shutdown_worker(s: str):
logger.info("%s received, initiating Temporal worker shutdown", s)
await worker.shutdown()
logger.info("Finished Temporal worker shutdown")
loop = asyncio.get_event_loop()
shutdown_tasks = set()
loop.add_signal_handler(signal.SIGINT, lambda: shutdown_tasks.add(asyncio.create_task(shutdown_worker("SIGINT"))))
loop.add_signal_handler(signal.SIGTERM, lambda: shutdown_tasks.add(asyncio.create_task(shutdown_worker("SIGTERM"))))
await worker.run()
return worker