feat(array): temporal workflow for sandbox environment (#39082)

This commit is contained in:
Joshua Snyder
2025-10-07 15:57:02 +01:00
committed by GitHub
parent 6831cbdf5c
commit 06d233507e
48 changed files with 3708 additions and 210 deletions

View File

@@ -49,6 +49,7 @@ jobs:
backend_files: ${{ steps.filter.outputs.backend_files }}
migrations: ${{ steps.filter.outputs.migrations }}
migrations_files: ${{ steps.filter.outputs.migrations_files }}
tasks_temporal: ${{ steps.filter.outputs.tasks_temporal }}
steps:
# For pull requests it's not necessary to checkout the code, but we
# also want this to run on master so we need to checkout
@@ -96,6 +97,8 @@ jobs:
- 'posthog/migrations/*.py'
- 'products/*/backend/migrations/*.py'
- 'products/*/migrations/*.py' # Legacy structure
tasks_temporal:
- 'products/tasks/backend/temporal/**/*'
check-migrations:
needs: [changes]
@@ -722,9 +725,10 @@ jobs:
shell: bash
env:
AWS_S3_ALLOW_UNSAFE_RENAME: 'true'
RUNLOOP_API_KEY: ${{ needs.changes.outputs.tasks_temporal == 'true' && secrets.RUNLOOP_API_KEY || '' }}
run: |
set +e
pytest posthog/temporal products/batch_exports/backend/tests/temporal -m "not async_migrations" \
pytest posthog/temporal products/batch_exports/backend/tests/temporal products/tasks/backend/temporal -m "not async_migrations" \
--splits ${{ matrix.concurrency }} --group ${{ matrix.group }} \
--durations=100 --durations-min=1.0 --store-durations \
--splitting-algorithm=duration_based_chunks \

View File

@@ -1,7 +1,40 @@
import inspect
from collections.abc import Callable, Coroutine
from datetime import datetime
from functools import wraps
from typing import Any, ParamSpec, TypeVar
from asgiref.sync import sync_to_async
from temporalio import workflow
P = ParamSpec("P")
T = TypeVar("T")
def asyncify(fn: Callable[P, T]) -> Callable[P, Coroutine[Any, Any, T]]:
"""Decorator to convert a sync function using sync_to_async - this preserves type hints for Temporal's serialization while allowing sync Django ORM code.
This preserves type hints for Temporal's serialization while allowing
sync Django ORM code.
Usage:
@activity.defn
@asyncify
def my_activity(task_id: str) -> TaskDetails:
task = Task.objects.get(id=task_id)
return TaskDetails(...)
"""
if inspect.iscoroutinefunction(fn):
raise TypeError(
f"@asyncify should only be used on sync functions. " f"'{fn.__name__}' is already async. Remove @asyncify."
)
@wraps(fn)
async def wrapper(*args: P.args, **kwargs: P.kwargs) -> T:
return await sync_to_async(fn)(*args, **kwargs)
return wrapper
def get_scheduled_start_time():
"""Return the start time of a workflow.

View File

@@ -0,0 +1,19 @@
SETUP_REPOSITORY_PROMPT = """
Your goal is to setup the repository in the current environment.
You are operating in a sandbox environment. You must install all dependencies necessary and setup the environment such that it is ready for executing code tasks.
CONTEXT:
CWD: {cwd}
REPOSITORY: {repository}
INSTRUCTIONS:
1. Install all dependencies necessary to run the repository
2. Run any setup scripts that are available
3. Verify the setup by running tests or build if available
DO NOT make any code changes to the repository. The final state of the disk of this sandbox is what will be used for subsequent tasks, so do not leave any cruft behind, and make sure the repository is in a ready to use state.
"""

View File

@@ -0,0 +1,26 @@
# Generated by Django 4.2.22 on 2025-10-04 16:57
import django.db.models.deletion
from django.conf import settings
from django.db import migrations, models
class Migration(migrations.Migration):
dependencies = [
migrations.swappable_dependency(settings.AUTH_USER_MODEL),
("tasks", "0008_task_task_number"),
]
operations = [
migrations.AddField(
model_name="task",
name="created_by",
field=models.ForeignKey(
blank=True,
db_index=False,
null=True,
on_delete=django.db.models.deletion.SET_NULL,
to=settings.AUTH_USER_MODEL,
),
),
]

View File

@@ -1 +1 @@
0008_task_task_number
0009_task_created_by

View File

@@ -231,6 +231,7 @@ class Task(models.Model):
id = models.UUIDField(primary_key=True, default=uuid.uuid4, editable=False)
team = models.ForeignKey("posthog.Team", on_delete=models.CASCADE)
created_by = models.ForeignKey("posthog.User", on_delete=models.SET_NULL, null=True, blank=True, db_index=False)
task_number = models.IntegerField(null=True, blank=True)
title = models.CharField(max_length=255)
description = models.TextField()
@@ -331,12 +332,13 @@ class Task(models.Model):
"""
config = self.repository_config
if config.get("organization") and config.get("repository"):
full_name = f"{config.get('organization')}/{config.get('repository')}".lower()
return [
{
"org": config.get("organization"),
"repo": config.get("repository"),
"integration_id": self.github_integration_id,
"full_name": f"{config.get('organization')}/{config.get('repository')}",
"full_name": full_name,
}
]
return []

View File

@@ -85,6 +85,9 @@ class TaskSerializer(serializers.ModelSerializer):
def create(self, validated_data):
validated_data["team"] = self.context["team"]
if "request" in self.context and hasattr(self.context["request"], "user"):
validated_data["created_by"] = self.context["request"].user
# Set default GitHub integration if not provided
if not validated_data.get("github_integration"):
default_integration = Integration.objects.filter(team=self.context["team"], kind="github").first()

View File

@@ -1,8 +1,13 @@
import logging
from typing import Optional
from django.conf import settings
from pydantic import BaseModel
from .sandbox_environment import ExecutionResult, SandboxEnvironment, SandboxEnvironmentConfig
from products.tasks.backend.lib.constants import SETUP_REPOSITORY_PROMPT
from .sandbox_environment import ExecutionResult, SandboxEnvironment
logger = logging.getLogger(__name__)
@@ -11,7 +16,8 @@ REPOSITORY_TARGET_DIR = "repo"
DEFAULT_TASK_TIMEOUT_SECONDS = 20 * 60 # 20 minutes
class SandboxAgentConfig(BaseModel):
class SandboxAgentCreateConfig(BaseModel):
name: str
repository_url: str
github_token: str
task_id: str
@@ -20,83 +26,76 @@ class SandboxAgentConfig(BaseModel):
class SandboxAgent:
"""
Agent that uses sandbox environments to execute tasks.
"""
"""Agent that uses sandbox environments to execute tasks."""
config: SandboxAgentConfig
sandbox: SandboxEnvironment
def __init__(self, sandbox: SandboxEnvironment, config: SandboxAgentConfig):
def __init__(self, sandbox: SandboxEnvironment):
self.sandbox = sandbox
self.config = config
@classmethod
async def create(
cls,
sandbox: SandboxEnvironment,
config: SandboxAgentConfig,
) -> "SandboxAgent":
environment_variables = {
"REPOSITORY_URL": config.repository_url,
"POSTHOG_CLI_TOKEN": config.posthog_personal_api_key,
"POSTHOG_CLI_ENV_ID": config.posthog_project_id,
}
async def clone_repository(self, repository: str, github_token: Optional[str] = "") -> ExecutionResult:
if not self.sandbox.is_running:
raise RuntimeError(f"Sandbox not in running state. Current status: {self.sandbox.status}")
sandbox_config = SandboxEnvironmentConfig(
name=sandbox.config.name,
template=sandbox.config.template,
environment_variables=environment_variables,
entrypoint=sandbox.config.entrypoint,
org, repo = repository.lower().split("/")
repo_url = (
f"https://x-access-token:{github_token}@github.com/{org}/{repo}.git"
if github_token
else f"https://github.com/{org}/{repo}.git"
)
sandbox = await SandboxEnvironment.create(sandbox_config)
agent = cls(sandbox, config)
target_path = f"/tmp/workspace/repos/{org}/{repo}"
return agent
async def setup_repository(self) -> ExecutionResult:
if not self.sandbox.is_running:
raise RuntimeError(f"Sandbox not in running state. Current status: {self.sandbox.status}")
return await self.clone_repository(self.config.repository_url)
async def clone_repository(self, repo_url: str) -> ExecutionResult:
if not self.sandbox.is_running:
raise RuntimeError(f"Sandbox not in running state. Current status: {self.sandbox.status}")
if repo_url.startswith("https://github.com/"):
auth_url = repo_url.replace(
"https://github.com/",
f"https://x-access-token:{self.config.github_token}@github.com/",
)
else:
raise ValueError("Only GitHub is supported")
clone_command = f"git clone {auth_url} {WORKING_DIR}/{REPOSITORY_TARGET_DIR}"
logger.info(f"Cloning repository {repo_url} to {self.repository_dir} in sandbox {self.sandbox.id}")
return await self.sandbox.execute(clone_command)
async def execute_task(self) -> ExecutionResult:
"""Execute Claude Code commands in the sandbox."""
if not self.sandbox.is_running:
raise RuntimeError(f"Sandbox not in running state. Current status: {self.sandbox.status}")
full_command = f"cd {self.repository_dir} && {self.get_task_command()}"
logger.info(
f"Executing task {self.config.task_id} in directory {self.repository_dir} in sandbox {self.sandbox.id}"
# Wipe existing directory if present, then clone
clone_command = (
f"rm -rf {target_path} && "
f"mkdir -p /tmp/workspace/repos/{org} && "
f"cd /tmp/workspace/repos/{org} && "
f"git clone {repo_url} {repo}"
)
return await self.sandbox.execute(full_command, timeout_seconds=DEFAULT_TASK_TIMEOUT_SECONDS)
def get_task_command(self) -> str:
"""Get the command to execute the task."""
# TODO: Replace with actual task execution: posthog-cli task run --task-id {self.config.task_id}
return "posthog-cli --help"
logger.info(f"Cloning repository {repository} to {target_path} in sandbox {self.sandbox.id}")
return await self.sandbox.execute(clone_command, timeout_seconds=5 * 60)
async def setup_repository(self, repository: str) -> ExecutionResult:
"""Setup a repository for snapshotting using the PostHog Code Agent."""
if not self.sandbox.is_running:
raise RuntimeError(f"Sandbox not in running state. Current status: {self.sandbox.status}")
org, repo = repository.lower().split("/")
repo_path = f"/tmp/workspace/repos/{org}/{repo}"
check_result = await self.sandbox.execute(f"test -d {repo_path} && echo 'exists' || echo 'missing'")
if "missing" in check_result.stdout:
raise RuntimeError(f"Repository path {repo_path} does not exist. Clone the repository first.")
setup_command = f"cd {repo_path} && {self._get_setup_command(repo_path)}"
logger.info(f"Running code agent setup for {repository} in sandbox {self.sandbox.id}")
return await self.sandbox.execute(setup_command, timeout_seconds=15 * 60)
async def execute_task(self, task_id: str, repository: str) -> ExecutionResult:
if not self.sandbox.is_running:
raise RuntimeError(f"Sandbox not in running state. Current status: {self.sandbox.status}")
org, repo = repository.lower().split("/")
repo_path = f"/tmp/workspace/repos/{org}/{repo}"
command = f"cd {repo_path} && {self._get_task_command(task_id)}"
logger.info(f"Executing task {task_id} in {repo_path} in sandbox {self.sandbox.id}")
return await self.sandbox.execute(command, timeout_seconds=DEFAULT_TASK_TIMEOUT_SECONDS)
# TODO: Replace these once our coding agent is ready
def _get_task_command(self, task_id: str) -> str:
# return f"npx @posthog/code-agent@latest --yes --task-id {task_id}"
return f"export ANTHROPIC_API_KEY={settings.ANTHROPIC_API_KEY} && claude --dangerously-skip-permissions -p 'replace the readme with an ice cream cone'"
def _get_setup_command(self, repo_path: str) -> str:
# return f"npx @posthog/code-agent@latest --yes --prompt '{SETUP_REPOSITORY_PROMPT.format(cwd=repo_path, repository=repo_path)}'"
return f"export ANTHROPIC_API_KEY={settings.ANTHROPIC_API_KEY} && claude --dangerously-skip-permissions -p '{SETUP_REPOSITORY_PROMPT.format(cwd=repo_path, repository=repo_path)}'"
async def destroy(self) -> None:
"""Destroy the underlying sandbox."""
await self.sandbox.destroy()
async def __aenter__(self):
@@ -105,14 +104,6 @@ class SandboxAgent:
async def __aexit__(self, exc_type, exc_val, exc_tb):
await self.destroy()
@property
def working_dir(self) -> str:
return WORKING_DIR
@property
def repository_dir(self) -> str:
return f"{WORKING_DIR}/{REPOSITORY_TARGET_DIR}"
@property
def is_running(self) -> bool:
return self.sandbox.is_running

View File

@@ -3,8 +3,22 @@ import logging
from enum import Enum
from typing import Optional
from asgiref.sync import sync_to_async
from pydantic import BaseModel
from runloop_api_client import AsyncRunloop
from runloop_api_client import (
AsyncRunloop,
BadRequestError as RunloopBadRequestError,
NotFoundError as RunloopNotFoundError,
)
from products.tasks.backend.models import SandboxSnapshot
from products.tasks.backend.temporal.exceptions import (
SandboxCleanupError,
SandboxExecutionError,
SandboxNotFoundError,
SandboxProvisionError,
SnapshotCreationError,
)
logger = logging.getLogger(__name__)
@@ -20,8 +34,13 @@ class SandboxEnvironmentStatus(str, Enum):
SHUTDOWN = "shutdown"
class SandboxEnvironmentSnapshotStatus(str, Enum):
IN_PROGRESS = "in_progress"
COMPLETE = "complete"
ERROR = "error"
class SandboxEnvironmentTemplate(str, Enum):
UBUNTU_LATEST_X86_64 = "ubuntu_latest_x86_64"
DEFAULT_BASE = "default_base"
@@ -38,6 +57,9 @@ class SandboxEnvironmentConfig(BaseModel):
default_execution_timeout_seconds: int = 10 * 60 # 10 minutes
environment_variables: Optional[dict[str, str]] = None
entrypoint: Optional[str] = None
snapshot_id: Optional[str] = None
ttl_seconds: int = 60 * 30 # 30 minutes
metadata: Optional[dict[str, str]] = None
def get_runloop_client() -> AsyncRunloop:
@@ -49,7 +71,6 @@ def get_runloop_client() -> AsyncRunloop:
TEMPLATE_TO_BLUEPRINT_NAME = {
SandboxEnvironmentTemplate.DEFAULT_BASE: "sandbox-base-1",
SandboxEnvironmentTemplate.DEFAULT_BASE: "bpt_318UuXYGZbyYyl12hArAL",
}
BLUEPRINT_NAME_TO_TEMPLATE = {v: k for k, v in TEMPLATE_TO_BLUEPRINT_NAME.items()}
@@ -79,23 +100,44 @@ class SandboxEnvironment:
blueprint_name = TEMPLATE_TO_BLUEPRINT_NAME.get(config.template)
if not blueprint_name:
raise RuntimeError(f"Unknown template for sandbox {config.name}")
raise SandboxProvisionError(
f"Unknown template for sandbox {config.name}", {"template": str(config.template), "config": config}
)
snapshot_external_id = None
if config.snapshot_id:
snapshot = await sync_to_async(SandboxSnapshot.objects.get)(id=config.snapshot_id)
if snapshot.status == SandboxSnapshot.Status.COMPLETE:
snapshot_external_id = snapshot.external_id
try:
# Wait for devbox to be running before returning
devbox = await client.devboxes.create_and_await_running(
name=config.name,
blueprint_name=blueprint_name,
environment_variables=config.environment_variables or {},
entrypoint=config.entrypoint,
)
create_kwargs = {
"name": config.name,
"environment_variables": config.environment_variables or {},
"entrypoint": config.entrypoint,
"metadata": config.metadata or {},
"launch_parameters": {
"keep_alive_time_seconds": config.ttl_seconds,
},
}
if snapshot_external_id:
create_kwargs["snapshot_id"] = snapshot_external_id
else:
create_kwargs["blueprint_name"] = blueprint_name
devbox = await client.devboxes.create_and_await_running(**create_kwargs) # type: ignore[arg-type]
except Exception as e:
logger.exception(f"Failed to create sandbox: {e}")
raise RuntimeError(f"Failed to create sandbox: {e}")
raise SandboxProvisionError(f"Failed to create sandbox", {"config": config, "error": str(e)})
sandbox = SandboxEnvironment(id=devbox.id, status=SandboxEnvironmentStatus(devbox.status), config=config)
assert sandbox.is_running
logger.info(f"Created sandbox {sandbox.id} with status: {devbox.status}")
return sandbox
@@ -107,15 +149,11 @@ class SandboxEnvironment:
try:
devbox = await client.devboxes.retrieve(sandbox_id)
if not devbox.blueprint_id:
raise RuntimeError(f"Unknown template for sandbox {sandbox_id}")
template = SandboxEnvironmentTemplate.DEFAULT_BASE
blueprint = await client.blueprints.retrieve(devbox.blueprint_id)
template = BLUEPRINT_NAME_TO_TEMPLATE[blueprint.name]
if not template:
raise RuntimeError(f"Unknown template for sandbox {sandbox_id}")
if devbox.blueprint_id:
blueprint = await client.blueprints.retrieve(devbox.blueprint_id)
template = BLUEPRINT_NAME_TO_TEMPLATE.get(blueprint.name, SandboxEnvironmentTemplate.DEFAULT_BASE)
config = SandboxEnvironmentConfig(name=devbox.name or f"sandbox-{sandbox_id}", template=template)
@@ -126,8 +164,14 @@ class SandboxEnvironment:
return sandbox
except Exception as e:
logger.exception(f"Failed to retrieve sandbox {sandbox_id}: {e}")
raise RuntimeError(f"Failed to retrieve sandbox {sandbox_id}: {e}")
if isinstance(e, RunloopNotFoundError | RunloopBadRequestError):
if "non-existent-sandbox-id" in str(e) or isinstance(e, RunloopNotFoundError):
raise SandboxNotFoundError(
f"Sandbox {sandbox_id} not found", {"sandbox_id": sandbox_id, "error": str(e)}
)
raise SandboxProvisionError(
f"Failed to retrieve sandbox {sandbox_id}", {"sandbox_id": sandbox_id, "error": str(e)}
)
async def execute(
self,
@@ -135,7 +179,10 @@ class SandboxEnvironment:
timeout_seconds: Optional[int] = None,
) -> ExecutionResult:
if not self.is_running:
raise RuntimeError(f"Sandbox not in running state. Current status: {self.status}")
raise SandboxExecutionError(
f"Sandbox not in running state. Current status: {self.status}",
{"sandbox_id": self.id, "status": str(self.status)},
)
if timeout_seconds is None:
timeout_seconds = self.config.default_execution_timeout_seconds
@@ -157,15 +204,59 @@ class SandboxEnvironment:
result = ExecutionResult(
stdout=final_execution.stdout,
stderr=final_execution.stderr,
exit_code=final_execution.exit_status,
exit_code=final_execution.exit_status or 0,
error=getattr(final_execution, "error", None),
)
if result.error:
logger.warning(f"Command execution had error: {result.error}")
return result
async def initiate_snapshot(self, metadata: Optional[dict[str, str]] = None) -> str:
if not self.is_running:
raise SandboxExecutionError(
f"Sandbox not in running state. Current status: {self.status}",
{"sandbox_id": self.id, "status": str(self.status)},
)
try:
devbox = await self._client.devboxes.retrieve(self.id)
snapshot = await self._client.devboxes.snapshot_disk_async(devbox.id, metadata=metadata)
snapshot_id = snapshot.id
logger.info(f"Initiated snapshot for sandbox {self.id}, snapshot ID: {snapshot_id}")
return snapshot_id
except Exception as e:
logger.exception(f"Failed to initiate snapshot: {e}")
raise SnapshotCreationError(f"Failed to initiate snapshot: {e}", {"sandbox_id": self.id, "error": str(e)})
@staticmethod
async def delete_snapshot(external_id: str) -> None:
client = get_runloop_client()
logger.info(f"Deleting snapshot {external_id}")
await client.devboxes.disk_snapshots.delete(external_id)
logger.info(f"Deleted snapshot {external_id}")
@staticmethod
async def get_snapshot_status(external_id: str) -> SandboxEnvironmentSnapshotStatus:
try:
client = get_runloop_client()
logger.info(f"Getting snapshot status for {external_id}")
snapshot = await client.devboxes.disk_snapshots.query_status(external_id)
logger.info(f"Retrieved snapshot status for {external_id}: {snapshot.status}")
return SandboxEnvironmentSnapshotStatus(snapshot.status)
except Exception as e:
logger.exception(f"Failed to get snapshot status: {e}")
raise SnapshotCreationError(
f"Failed to get snapshot status: {e}", {"external_id": external_id, "error": str(e)}
)
async def destroy(self) -> None:
try:
await self._client.devboxes.shutdown(self.id)
@@ -176,7 +267,7 @@ class SandboxEnvironment:
except Exception as e:
logger.exception(f"Failed to destroy sandbox: {e}")
raise RuntimeError(f"Failed to destroy sandbox: {e}")
raise SandboxCleanupError(f"Failed to destroy sandbox: {e}", {"sandbox_id": self.id, "error": str(e)})
async def __aenter__(self):
return self
@@ -187,3 +278,7 @@ class SandboxEnvironment:
@property
def is_running(self) -> bool:
return self.status == SandboxEnvironmentStatus.RUNNING
@property
def name(self) -> str:
return self.config.name

View File

@@ -1,80 +0,0 @@
import os
import pytest
from products.tasks.backend.services.sandbox_agent import SandboxAgent, SandboxAgentConfig
from products.tasks.backend.services.sandbox_environment import (
SandboxEnvironment,
SandboxEnvironmentConfig,
SandboxEnvironmentTemplate,
)
@pytest.mark.asyncio
class TestSandboxAgentIntegration:
# We only run these tests when we have a Runloop API key set, we don't want to run them in CI since they create real sandbox environments and are slow.
@pytest.fixture(scope="class", autouse=True)
def check_api_key(self):
if not os.environ.get("RUNLOOP_API_KEY"):
pytest.skip("RUNLOOP_API_KEY not set, skipping integration tests")
@pytest.fixture
def mock_github_token(self):
"""Provide a mock GitHub token for testing."""
return "ghp_mock_token_for_testing_12345678901234567890"
@pytest.fixture
def mock_posthog_credentials(self):
"""Provide mock PostHog credentials for testing."""
return {"personal_api_key": "phx_mock_personal_api_key_123456789", "project_id": "test-project-id-123"}
@pytest.fixture
def public_repo_url(self):
"""Use a small public repository for testing."""
return "https://github.com/octocat/Hello-World"
async def test_complete_sandbox_agent_workflow(self, mock_github_token, public_repo_url, mock_posthog_credentials):
"""Comprehensive test covering agent lifecycle, repo cloning, and PostHog CLI execution."""
sandbox_config = SandboxEnvironmentConfig(
name="posthog-agent-test-complete", template=SandboxEnvironmentTemplate.DEFAULT_BASE
)
sandbox = await SandboxEnvironment.create(sandbox_config)
agent_config = SandboxAgentConfig(
repository_url=public_repo_url,
github_token=mock_github_token,
task_id="test",
posthog_personal_api_key=mock_posthog_credentials["personal_api_key"],
posthog_project_id=mock_posthog_credentials["project_id"],
)
async with await SandboxAgent.create(sandbox, agent_config) as agent:
assert agent.id is not None
assert agent.is_running
assert agent.working_dir == "/tmp/workspace"
assert agent.repository_dir == "/tmp/workspace/repo"
setup_result = await agent.setup_repository()
assert setup_result.exit_code == 0
assert setup_result.error is None
check_result = await agent.sandbox.execute("ls -la /tmp/workspace/repo")
assert check_result.exit_code == 0
assert ".git" in check_result.stdout
env_check = await agent.sandbox.execute("printenv")
assert "REPOSITORY_URL" in env_check.stdout
assert "POSTHOG_CLI_TOKEN" in env_check.stdout
assert "POSTHOG_CLI_ENV_ID" in env_check.stdout
cli_result = await agent.execute_task()
assert cli_result.exit_code == 0
assert "posthog-cli" in cli_result.stdout.lower() or "usage" in cli_result.stdout.lower()
context_result = await agent.sandbox.execute(f"cd {agent.repository_dir} && pwd")
assert context_result.exit_code == 0
assert agent.repository_dir in context_result.stdout
assert not agent.is_running

View File

@@ -7,6 +7,22 @@ from .github_activities import (
create_pr_activity,
create_pr_and_update_task_activity,
)
from .process_task.activities import (
check_snapshot_exists_for_repository,
cleanup_personal_api_key,
cleanup_sandbox,
clone_repository,
create_sandbox_from_snapshot,
create_snapshot,
execute_task_in_sandbox,
get_sandbox_for_setup,
get_task_details,
inject_github_token,
inject_personal_api_key,
setup_repository,
track_workflow_event,
)
from .process_task.workflow import ProcessTaskWorkflow
from .workflow_activities import (
check_temporal_workflow_permissions_activity,
execute_agent_for_transition_activity,
@@ -20,6 +36,7 @@ from .workflows import WorkflowAgnosticTaskProcessingWorkflow
WORKFLOWS = [
WorkflowAgnosticTaskProcessingWorkflow,
ProcessTaskWorkflow,
]
ACTIVITIES = [
@@ -39,4 +56,18 @@ ACTIVITIES = [
move_task_to_stage_activity,
trigger_task_processing_activity,
should_trigger_agent_workflow_activity,
# process_task activities
get_task_details,
check_snapshot_exists_for_repository,
get_sandbox_for_setup,
clone_repository,
inject_github_token,
inject_personal_api_key,
setup_repository,
create_snapshot,
create_sandbox_from_snapshot,
execute_task_in_sandbox,
cleanup_personal_api_key,
cleanup_sandbox,
track_workflow_event,
]

View File

@@ -1,36 +1,44 @@
import time
import uuid
import asyncio
import logging
from typing import Optional
import posthoganalytics
from temporalio.common import RetryPolicy, WorkflowIDReusePolicy
from posthog.constants import TASKS_TASK_QUEUE
from posthog.models.team.team import Team
from posthog.models.user import User
from posthog.temporal.common.client import async_connect
from .inputs import TaskProcessingInputs
logger = logging.getLogger(__name__)
async def _execute_task_processing_workflow(task_id: str, team_id: int, user_id: Optional[int] = None) -> str:
"""Execute the task processing workflow asynchronously."""
inputs = TaskProcessingInputs(task_id=task_id, team_id=team_id, user_id=user_id)
# Create unique workflow ID based on task and timestamp
import time
import uuid
import logging
# Use high-resolution timestamp + random suffix to avoid collisions when re-triggering within the same second
async def _execute_task_processing_workflow(
task_id: str, team_id: int, user_id: Optional[int] = None, use_sandbox: bool = False
) -> str:
workflow_id = f"task-processing-{task_id}-{int(time.time()*1000)}-{uuid.uuid4().hex[:8]}"
logging.getLogger(__name__).info(f"Starting workflow {workflow_id} for task {task_id}")
workflow_input: str | TaskProcessingInputs
if use_sandbox:
workflow_name = "process-task"
workflow_input = task_id
else:
workflow_name = "process-task-workflow-agnostic"
workflow_input = TaskProcessingInputs(task_id=task_id, team_id=team_id, user_id=user_id)
logger.info(f"Starting workflow {workflow_name} ({workflow_id}) for task {task_id}")
client = await async_connect()
retry_policy = RetryPolicy(maximum_attempts=3)
result = await client.execute_workflow(
"process-task-workflow-agnostic",
inputs,
workflow_name,
workflow_input,
id=workflow_id,
id_reuse_policy=WorkflowIDReusePolicy.ALLOW_DUPLICATE_FAILED_ONLY,
task_queue=TASKS_TASK_QUEUE,
@@ -47,21 +55,14 @@ def execute_task_processing_workflow(task_id: str, team_id: int, user_id: Option
but doesn't wait for completion.
"""
try:
import logging
import threading
logger = logging.getLogger(__name__)
# Always offload to a dedicated thread with its own event loop.
# This is safer when called from within a Temporal activity (already running an event loop)
# and from sync Django views. It avoids create_task() being cancelled when the caller loop ends.
def run_workflow() -> None:
try:
# Check feature flag in the thread where we can make sync Django calls
import posthoganalytics
from posthog.models.team.team import Team
from posthog.models.user import User
try:
if not user_id:
@@ -93,8 +94,20 @@ def execute_task_processing_workflow(task_id: str, team_id: int, user_id: Option
logger.exception(f"Error checking feature flag for task workflow: {e}")
return
logger.info(f"Triggering workflow for task {task_id}")
asyncio.run(_execute_task_processing_workflow(task_id, team_id, user_id))
# Check feature flag for sandbox-based workflow
use_sandbox = posthoganalytics.feature_enabled(
"tasks-sandbox",
user.distinct_id,
groups={"organization": str(team.organization.id)},
group_properties={"organization": {"id": str(team.organization.id)}},
only_evaluate_locally=False,
send_feature_flag_events=False,
)
logger.info(
f"Triggering workflow for task {task_id} (sandbox: {use_sandbox}, workflow: {'process-task' if use_sandbox else 'process-task-workflow-agnostic'})"
)
asyncio.run(_execute_task_processing_workflow(task_id, team_id, user_id, use_sandbox=use_sandbox))
logger.info(f"Workflow completed for task {task_id}")
except Exception as e:
logger.exception(f"Workflow execution failed for task {task_id}: {e}")
@@ -105,8 +118,5 @@ def execute_task_processing_workflow(task_id: str, team_id: int, user_id: Option
except Exception as e:
# Don't let workflow execution failures break the main operation
import logging
logger = logging.getLogger(__name__)
logger.exception(f"Failed to execute task processing workflow: {e}")
# Don't re-raise to avoid breaking the API call

View File

@@ -0,0 +1,161 @@
import random
import pytest
from asgiref.sync import sync_to_async
from temporalio.testing import ActivityEnvironment
from posthog.models import Organization, OrganizationMembership, Team
from posthog.models.integration import Integration
from posthog.models.user import User
from posthog.temporal.common.logger import configure_logger
from products.tasks.backend.models import Task, TaskWorkflow, WorkflowStage
@pytest.fixture
def activity_environment():
"""Return a testing temporal ActivityEnvironment."""
return ActivityEnvironment()
@pytest.fixture
def organization():
"""A test organization."""
name = f"TasksTestOrg-{random.randint(1, 99999)}"
org = Organization.objects.create(name=name, is_ai_data_processing_approved=True)
org.save()
yield org
org.delete()
@pytest.fixture
def team(organization):
"""A test team."""
name = f"TasksTestTeam-{random.randint(1, 99999)}"
team = Team.objects.create(organization=organization, name=name)
team.save()
yield team
team.delete()
@pytest.fixture
async def aorganization():
"""Async test organization."""
name = f"TasksTestOrg-{random.randint(1, 99999)}"
org = await sync_to_async(Organization.objects.create)(name=name, is_ai_data_processing_approved=True)
yield org
await sync_to_async(org.delete)()
@pytest.fixture
async def ateam(aorganization):
"""Async test team."""
name = f"TasksTestTeam-{random.randint(1, 99999)}"
team = await sync_to_async(Team.objects.create)(organization=aorganization, name=name)
yield team
await sync_to_async(team.delete)()
@pytest.fixture
async def task_workflow(ateam):
"""Create a test workflow with stages."""
workflow = await sync_to_async(TaskWorkflow.objects.create)(
team=ateam,
name="Test Workflow",
description="Test workflow for temporal activities",
is_default=True,
is_active=True,
)
stages = []
for i, (name, key, color) in enumerate(
[
("Backlog", "backlog", "#6b7280"),
("Ready", "ready", "#3b82f6"),
("In Progress", "in_progress", "#10b981"),
("Done", "done", "#22c55e"),
]
):
stage = await sync_to_async(WorkflowStage.objects.create)(
workflow=workflow,
name=name,
key=key,
position=i,
color=color,
is_manual_only=(i != 2), # Only "In Progress" is not manual
agent_name="claude_code_agent" if i == 2 else None,
)
stages.append(stage)
yield workflow, stages
await sync_to_async(workflow.delete)()
@pytest.fixture
async def github_integration(ateam):
"""Create a test GitHub integration."""
integration = await sync_to_async(Integration.objects.create)(
team=ateam,
kind="github",
sensitive_config={"access_token": "fake_token"},
)
yield integration
await sync_to_async(integration.delete)()
@pytest.fixture
async def auser(ateam):
user = await sync_to_async(User.objects.create)(
email=f"test-{random.randint(1, 99999)}@example.com",
password="testpassword123",
)
await sync_to_async(OrganizationMembership.objects.create)(
user=user,
organization_id=ateam.organization_id,
)
yield user
await sync_to_async(user.delete)()
@pytest.fixture
async def test_task(ateam, auser, task_workflow, github_integration):
"""Create a test task."""
workflow, stages = task_workflow
backlog_stage = stages[0]
task = await sync_to_async(Task.objects.create)(
team=ateam,
created_by=auser,
title="Test Task for Temporal Activities",
description="This is a test task for testing temporal activities",
origin_product=Task.OriginProduct.USER_CREATED,
workflow=workflow,
current_stage=backlog_stage,
position=0,
github_integration=github_integration,
repository_config={"organization": "PostHog", "repository": "posthog-js"},
)
yield task
await sync_to_async(task.delete)()
@pytest.fixture(autouse=True)
def configure_logger_auto() -> None:
"""Configure logger when running in a Temporal activity environment."""
configure_logger(cache_logger_on_first_use=False)

View File

@@ -0,0 +1,127 @@
from typing import Optional
from temporalio.exceptions import ApplicationError
class ProcessTaskError(ApplicationError):
def __init__(self, message: str, context: Optional[dict] = None, **kwargs):
self.context = context or {}
super().__init__(message, self.context, **kwargs)
class ProcessTaskFatalError(ProcessTaskError):
"""Fatal errors that should not be retried."""
def __init__(self, message: str, context: Optional[dict] = None):
super().__init__(message, context, non_retryable=True)
class ProcessTaskTransientError(ProcessTaskError):
"""Transient errors that may succeed on retry."""
def __init__(self, message: str, context: Optional[dict] = None):
super().__init__(message, context, non_retryable=False)
class TaskNotFoundError(ProcessTaskFatalError):
pass
class TaskInvalidStateError(ProcessTaskFatalError):
pass
class SandboxProvisionError(ProcessTaskTransientError):
"""Failed to provision sandbox environment."""
pass
class SandboxNotFoundError(ProcessTaskFatalError):
"""Sandbox does not exist."""
pass
class SandboxExecutionError(ProcessTaskTransientError):
"""Error during sandbox command execution."""
pass
class SandboxTimeoutError(ProcessTaskTransientError):
"""Sandbox operation timed out."""
pass
class SandboxCleanupError(ProcessTaskTransientError):
"""Error during sandbox cleanup/destruction."""
pass
class SnapshotNotFoundError(ProcessTaskTransientError):
"""Snapshot does not exist."""
pass
class SnapshotNotReadyError(ProcessTaskTransientError):
"""Snapshot exists but is not ready for use."""
pass
class SnapshotCreationError(ProcessTaskTransientError):
"""Failed to create snapshot."""
pass
class RepositoryCloneError(ProcessTaskTransientError):
"""Failed to clone repository."""
pass
class RepositorySetupError(ProcessTaskTransientError):
"""Failed to setup repository (install dependencies, etc)."""
pass
class GitHubIntegrationError(ProcessTaskFatalError):
"""GitHub integration not found or invalid."""
pass
class GitHubAuthenticationError(ProcessTaskFatalError):
"""Failed to authenticate with GitHub."""
pass
class PersonalAPIKeyError(ProcessTaskTransientError):
"""Failed to create or inject personal API key."""
pass
class TaskExecutionFailedError(ProcessTaskError):
"""Task execution completed but with non-zero exit code."""
def __init__(
self,
message: str,
exit_code: int,
stdout: str = "",
stderr: str = "",
context: Optional[dict] = None,
non_retryable: bool = False,
):
self.exit_code = exit_code
self.stdout = stdout
self.stderr = stderr
super().__init__(message, context, non_retryable=non_retryable)

View File

@@ -0,0 +1,157 @@
from collections.abc import AsyncIterator
from contextlib import asynccontextmanager
from typing import Any, Optional
import posthoganalytics
from temporalio import activity, workflow
from posthog.temporal.common.logger import get_logger
logger = get_logger(__name__)
def get_bound_logger(**context: Any):
return logger.bind(**context)
def log_with_activity_context(message: str, **extra_context: Any) -> None:
bound_logger = logger.bind(**extra_context)
if activity.in_activity():
info = activity.info()
bound_logger = bound_logger.bind(
activity_id=info.activity_id,
activity_type=info.activity_type,
attempt=info.attempt,
)
bound_logger.info(message)
def log_with_workflow_context(message: str, **extra_context: Any) -> None:
bound_logger = logger.bind(**extra_context)
if workflow.in_workflow():
info = workflow.info()
bound_logger = bound_logger.bind(
workflow_id=info.workflow_id,
workflow_run_id=info.run_id,
workflow_type=info.workflow_type,
)
bound_logger.info(message)
@asynccontextmanager
async def log_activity_execution(
activity_name: str,
distinct_id: Optional[str] = None,
**context: Any,
) -> AsyncIterator[None]:
"""Context manager for activity execution with automatic logging and analytics.
Automatically tracks:
- process_task_activity_started
- process_task_activity_completed
- process_task_activity_failed
Usage:
async with log_activity_execution(
"clone_repository",
distinct_id=f"user_{user_id}",
task_id=task_id,
repository=repo
):
result = await do_work()
return result
"""
bound_logger = logger.bind(**context)
if activity.in_activity():
info = activity.info()
bound_logger = bound_logger.bind(
activity_id=info.activity_id,
activity_type=info.activity_type,
attempt=info.attempt,
)
bound_logger.info(f"{activity_name} started")
if distinct_id:
track_event(
"process_task_activity_started",
distinct_id=distinct_id,
properties={"activity_name": activity_name, **context},
)
try:
yield
bound_logger.info(f"{activity_name} completed successfully")
if distinct_id:
track_event(
"process_task_activity_completed",
distinct_id=distinct_id,
properties={"activity_name": activity_name, **context},
)
except Exception as e:
bound_logger.exception(
f"{activity_name} failed",
error_type=type(e).__name__,
error_message=str(e),
)
if distinct_id:
track_event(
"process_task_activity_failed",
distinct_id=distinct_id,
properties={
"activity_name": activity_name,
"error_type": type(e).__name__,
"error_message": str(e)[:500],
**context,
},
)
raise
def track_event(
event_name: str,
distinct_id: str,
properties: Optional[dict[str, Any]] = None,
) -> None:
try:
enriched_properties = {**(properties or {})}
if activity.in_activity():
activity_info = activity.info()
enriched_properties.update(
{
"temporal_activity_id": activity_info.activity_id,
"temporal_activity_type": activity_info.activity_type,
"temporal_workflow_id": activity_info.workflow_id,
"temporal_workflow_run_id": activity_info.workflow_run_id,
"temporal_attempt": activity_info.attempt,
}
)
elif workflow.in_workflow() and not workflow.unsafe.is_replaying():
workflow_info = workflow.info()
enriched_properties.update(
{
"temporal_workflow_id": workflow_info.workflow_id,
"temporal_workflow_run_id": workflow_info.run_id,
"temporal_workflow_type": workflow_info.workflow_type,
}
)
posthoganalytics.capture(
distinct_id=distinct_id,
event=event_name,
properties=enriched_properties,
)
logger.debug(f"Tracked event: {event_name}", **enriched_properties)
except Exception as e:
logger.warning(f"Failed to track event {event_name}", exc_info=e)

View File

@@ -0,0 +1 @@
# Agent workflow for executing tasks in sandboxes

View File

@@ -0,0 +1,29 @@
from .check_snapshot_exists_for_repository import check_snapshot_exists_for_repository
from .cleanup_personal_api_key import cleanup_personal_api_key
from .cleanup_sandbox import cleanup_sandbox
from .clone_repository import clone_repository
from .create_sandbox_from_snapshot import create_sandbox_from_snapshot
from .create_snapshot import create_snapshot
from .execute_task_in_sandbox import execute_task_in_sandbox
from .get_sandbox_for_setup import get_sandbox_for_setup
from .get_task_details import get_task_details
from .inject_github_token import inject_github_token
from .inject_personal_api_key import inject_personal_api_key
from .setup_repository import setup_repository
from .track_workflow_event import track_workflow_event
__all__ = [
"check_snapshot_exists_for_repository",
"cleanup_personal_api_key",
"cleanup_sandbox",
"clone_repository",
"create_sandbox_from_snapshot",
"create_snapshot",
"execute_task_in_sandbox",
"get_sandbox_for_setup",
"get_task_details",
"inject_github_token",
"inject_personal_api_key",
"setup_repository",
"track_workflow_event",
]

View File

@@ -0,0 +1,40 @@
from dataclasses import dataclass
from temporalio import activity
from posthog.temporal.common.utils import asyncify
from products.tasks.backend.models import SandboxSnapshot
from products.tasks.backend.temporal.observability import log_with_activity_context
@dataclass
class CheckSnapshotExistsForRepositoryInput:
github_integration_id: int
repository: str
@dataclass
class CheckSnapshotExistsForRepositoryOutput:
exists: bool
snapshot_id: str | None
@activity.defn
@asyncify
def check_snapshot_exists_for_repository(
input: CheckSnapshotExistsForRepositoryInput,
) -> CheckSnapshotExistsForRepositoryOutput:
"""Check if a repository exists in the latest complete snapshot."""
log_with_activity_context(
"Checking if snapshot exists for repository",
github_integration_id=input.github_integration_id,
repository=input.repository,
)
snapshot = SandboxSnapshot.get_latest_snapshot_with_repos(input.github_integration_id, [input.repository])
if snapshot:
return CheckSnapshotExistsForRepositoryOutput(exists=True, snapshot_id=str(snapshot.id))
return CheckSnapshotExistsForRepositoryOutput(exists=False, snapshot_id=None)

View File

@@ -0,0 +1,10 @@
from temporalio import activity
from posthog.models import PersonalAPIKey
from posthog.temporal.common.utils import asyncify
@activity.defn
@asyncify
def cleanup_personal_api_key(personal_api_key_id: str) -> None:
PersonalAPIKey.objects.filter(id=personal_api_key_id).delete()

View File

@@ -0,0 +1,28 @@
import logging
from dataclasses import dataclass
from temporalio import activity
from products.tasks.backend.services.sandbox_environment import SandboxEnvironment
from products.tasks.backend.temporal.exceptions import SandboxNotFoundError
from products.tasks.backend.temporal.observability import log_activity_execution
logger = logging.getLogger(__name__)
@dataclass
class CleanupSandboxInput:
sandbox_id: str
@activity.defn
async def cleanup_sandbox(input: CleanupSandboxInput) -> None:
async with log_activity_execution(
"cleanup_sandbox",
sandbox_id=input.sandbox_id,
):
try:
sandbox = await SandboxEnvironment.get_by_id(input.sandbox_id)
await sandbox.destroy()
except SandboxNotFoundError:
pass

View File

@@ -0,0 +1,63 @@
from dataclasses import dataclass
from temporalio import activity
from products.tasks.backend.services.sandbox_agent import SandboxAgent
from products.tasks.backend.services.sandbox_environment import SandboxEnvironment
from products.tasks.backend.temporal.exceptions import GitHubAuthenticationError, RepositoryCloneError
from products.tasks.backend.temporal.observability import log_activity_execution
from ..utils import get_github_token
@dataclass
class CloneRepositoryInput:
sandbox_id: str
repository: str
github_integration_id: int
task_id: str
distinct_id: str
@activity.defn
async def clone_repository(input: CloneRepositoryInput) -> str:
"""Clone repository into sandbox. Idempotent: wipes existing directory. Returns clone logs."""
async with log_activity_execution(
"clone_repository",
distinct_id=input.distinct_id,
task_id=input.task_id,
sandbox_id=input.sandbox_id,
repository=input.repository,
):
try:
github_token = await get_github_token(input.github_integration_id)
except Exception as e:
raise GitHubAuthenticationError(
f"Failed to get GitHub token for integration {input.github_integration_id}",
{"github_integration_id": input.github_integration_id, "error": str(e)},
)
sandbox = await SandboxEnvironment.get_by_id(input.sandbox_id)
agent = SandboxAgent(sandbox)
try:
result = await agent.clone_repository(input.repository, github_token)
except Exception as e:
raise RepositoryCloneError(
f"Failed to clone repository {input.repository}",
{"repository": input.repository, "sandbox_id": input.sandbox_id, "error": str(e)},
)
if result.exit_code != 0:
raise RepositoryCloneError(
f"Git clone failed with exit code {result.exit_code}",
{
"repository": input.repository,
"exit_code": result.exit_code,
"stderr": result.stderr[:500],
},
)
# NOTE: git clone returns it's output in stderr
return result.stderr

View File

@@ -0,0 +1,56 @@
from dataclasses import dataclass
from django.core.exceptions import ObjectDoesNotExist
from asgiref.sync import sync_to_async
from temporalio import activity
from products.tasks.backend.models import SandboxSnapshot
from products.tasks.backend.services.sandbox_environment import (
SandboxEnvironment,
SandboxEnvironmentConfig,
SandboxEnvironmentTemplate,
)
from products.tasks.backend.temporal.exceptions import SnapshotNotFoundError, SnapshotNotReadyError
from products.tasks.backend.temporal.observability import log_activity_execution
from products.tasks.backend.temporal.process_task.utils import get_sandbox_name_for_task
@dataclass
class CreateSandboxFromSnapshotInput:
snapshot_id: str
task_id: str
distinct_id: str
@activity.defn
async def create_sandbox_from_snapshot(input: CreateSandboxFromSnapshotInput) -> str:
"""Create a sandbox from a snapshot for task execution. Returns sandbox_id when running."""
async with log_activity_execution(
"create_sandbox_from_snapshot",
distinct_id=input.distinct_id,
task_id=input.task_id,
snapshot_id=input.snapshot_id,
):
try:
snapshot = await sync_to_async(SandboxSnapshot.objects.get)(id=input.snapshot_id)
except ObjectDoesNotExist:
raise SnapshotNotFoundError(f"Snapshot {input.snapshot_id} not found", {"snapshot_id": input.snapshot_id})
if snapshot.status != SandboxSnapshot.Status.COMPLETE:
raise SnapshotNotReadyError(
f"Snapshot {input.snapshot_id} is not ready (status: {snapshot.status})",
{"snapshot_id": input.snapshot_id, "status": snapshot.status},
)
config = SandboxEnvironmentConfig(
name=get_sandbox_name_for_task(input.task_id),
template=SandboxEnvironmentTemplate.DEFAULT_BASE,
environment_variables={},
snapshot_id=str(snapshot.id),
metadata={"task_id": input.task_id},
)
sandbox = await SandboxEnvironment.create(config)
return sandbox.id

View File

@@ -0,0 +1,81 @@
import json
import asyncio
from dataclasses import dataclass
from asgiref.sync import sync_to_async
from temporalio import activity
from products.tasks.backend.models import SandboxSnapshot
from products.tasks.backend.services.sandbox_environment import SandboxEnvironment
from products.tasks.backend.temporal.exceptions import SandboxTimeoutError, SnapshotCreationError
from products.tasks.backend.temporal.observability import log_activity_execution
@dataclass
class CreateSnapshotInput:
sandbox_id: str
github_integration_id: int
team_id: int
repository: str
task_id: str
distinct_id: str
@activity.defn
async def create_snapshot(input: CreateSnapshotInput) -> str:
"""
Create and finalize snapshot. Initiates snapshot, polls until complete,
and saves the snapshot record. Returns snapshot_id.
"""
async with log_activity_execution(
"create_snapshot",
distinct_id=input.distinct_id,
task_id=input.task_id,
sandbox_id=input.sandbox_id,
repository=input.repository,
):
base_snapshot = await sync_to_async(SandboxSnapshot.get_latest_snapshot_for_integration)(
input.github_integration_id
)
base_repos = base_snapshot.repos if base_snapshot else []
new_repos: list[str] = list({*base_repos, input.repository})
sandbox = await SandboxEnvironment.get_by_id(input.sandbox_id)
snapshot_external_id = await sandbox.initiate_snapshot(
{
"integration_id": str(input.github_integration_id),
"team_id": str(input.team_id),
"repositories": json.dumps(new_repos),
"base_snapshot_id": str(base_snapshot.id) if base_snapshot else "",
}
)
max_polls = 80
for _ in range(max_polls):
status = await SandboxEnvironment.get_snapshot_status(snapshot_external_id)
if status.value == "complete":
break
elif status.value == "error":
raise SnapshotCreationError(
"Snapshot creation failed",
{"snapshot_external_id": snapshot_external_id, "repository": input.repository},
)
await asyncio.sleep(15)
else:
raise SandboxTimeoutError(
"Snapshot creation timed out after 20 minutes",
{"snapshot_external_id": snapshot_external_id, "repository": input.repository},
)
snapshot = await sync_to_async(SandboxSnapshot.objects.create)(
integration_id=input.github_integration_id,
repos=new_repos,
external_id=snapshot_external_id,
status=SandboxSnapshot.Status.COMPLETE,
)
return str(snapshot.id)

View File

@@ -0,0 +1,64 @@
from dataclasses import dataclass
from typing import Optional
from temporalio import activity
from products.tasks.backend.services.sandbox_agent import SandboxAgent
from products.tasks.backend.services.sandbox_environment import SandboxEnvironment
from products.tasks.backend.temporal.exceptions import SandboxExecutionError, TaskExecutionFailedError
from products.tasks.backend.temporal.observability import log_activity_execution
@dataclass
class ExecuteTaskInput:
sandbox_id: str
task_id: str
repository: str
distinct_id: str
@dataclass
class ExecuteTaskOutput:
stdout: str
stderr: str
exit_code: int
error: Optional[str] = None
@activity.defn
async def execute_task_in_sandbox(input: ExecuteTaskInput) -> ExecuteTaskOutput:
"""Execute the code agent task in the sandbox."""
async with log_activity_execution(
"execute_task_in_sandbox",
distinct_id=input.distinct_id,
task_id=input.task_id,
sandbox_id=input.sandbox_id,
repository=input.repository,
):
sandbox = await SandboxEnvironment.get_by_id(input.sandbox_id)
agent = SandboxAgent(sandbox)
try:
result = await agent.execute_task(input.task_id, input.repository)
except Exception as e:
raise SandboxExecutionError(
f"Failed to execute task in sandbox",
{"task_id": input.task_id, "sandbox_id": input.sandbox_id, "error": str(e)},
)
if result.exit_code != 0:
raise TaskExecutionFailedError(
f"Task execution failed with exit code {result.exit_code}",
exit_code=result.exit_code,
stdout=result.stdout,
stderr=result.stderr,
context={"task_id": input.task_id, "sandbox_id": input.sandbox_id},
)
return ExecuteTaskOutput(
stdout=result.stdout,
stderr=result.stderr,
exit_code=result.exit_code,
error=result.error,
)

View File

@@ -0,0 +1,48 @@
from dataclasses import dataclass
from asgiref.sync import sync_to_async
from temporalio import activity
from products.tasks.backend.models import SandboxSnapshot
from products.tasks.backend.services.sandbox_environment import (
SandboxEnvironment,
SandboxEnvironmentConfig,
SandboxEnvironmentTemplate,
)
from products.tasks.backend.temporal.observability import log_activity_execution
from products.tasks.backend.temporal.process_task.utils import get_sandbox_name_for_task
@dataclass
class GetSandboxForSetupInput:
github_integration_id: int
team_id: int
task_id: str
distinct_id: str
@activity.defn
async def get_sandbox_for_setup(input: GetSandboxForSetupInput) -> str:
"""
Get sandbox for setup. Searches for existing snapshot to use as base,
otherwise uses default template. Returns sandbox_id when sandbox is running.
"""
async with log_activity_execution(
"get_sandbox_for_setup",
distinct_id=input.distinct_id,
task_id=input.task_id,
github_integration_id=input.github_integration_id,
):
snapshot = await sync_to_async(SandboxSnapshot.get_latest_snapshot_for_integration)(input.github_integration_id)
config = SandboxEnvironmentConfig(
name=get_sandbox_name_for_task(input.task_id),
template=SandboxEnvironmentTemplate.DEFAULT_BASE,
environment_variables={},
snapshot_id=str(snapshot.id) if snapshot else None,
metadata={"task_id": input.task_id},
)
sandbox = await SandboxEnvironment.create(config)
return sandbox.id

View File

@@ -0,0 +1,76 @@
from dataclasses import dataclass
from django.core.exceptions import ObjectDoesNotExist
from temporalio import activity
from posthog.temporal.common.utils import asyncify
from products.tasks.backend.models import Task
from products.tasks.backend.temporal.exceptions import TaskInvalidStateError, TaskNotFoundError
from products.tasks.backend.temporal.observability import log_with_activity_context
@dataclass
class TaskDetails:
task_id: str
team_id: int
github_integration_id: int
repository: str
distinct_id: str
@activity.defn
@asyncify
def get_task_details(task_id: str) -> TaskDetails:
log_with_activity_context("Fetching task details", task_id=task_id)
try:
task = Task.objects.select_related("created_by").get(id=task_id)
except ObjectDoesNotExist:
raise TaskNotFoundError(f"Task {task_id} not found", {"task_id": task_id})
if not task.github_integration_id:
raise TaskInvalidStateError(
f"Task {task_id} has no GitHub integration",
{"task_id": task_id},
)
if not task.primary_repository:
raise TaskInvalidStateError(
f"Task {task_id} has no primary repository configured",
{"task_id": task_id},
)
repository_full_name = task.primary_repository.get("full_name")
if not repository_full_name:
raise TaskInvalidStateError(
f"Task {task_id} primary repository missing full_name",
{"task_id": task_id},
)
if not task.created_by:
raise TaskInvalidStateError(
f"Task {task_id} has no created_by user",
{"task_id": task_id},
)
assert task.created_by is not None
distinct_id = task.created_by.distinct_id or "process_task_workflow"
log_with_activity_context(
"Task details retrieved successfully",
task_id=task_id,
team_id=task.team_id,
repository=repository_full_name,
distinct_id=distinct_id,
)
return TaskDetails(
task_id=str(task.id),
team_id=task.team_id,
github_integration_id=task.github_integration_id,
repository=repository_full_name,
distinct_id=distinct_id,
)

View File

@@ -0,0 +1,49 @@
import shlex
from dataclasses import dataclass
from temporalio import activity
from products.tasks.backend.services.sandbox_environment import SandboxEnvironment
from products.tasks.backend.temporal.exceptions import GitHubAuthenticationError, SandboxExecutionError
from products.tasks.backend.temporal.observability import log_activity_execution
from ..utils import get_github_token
@dataclass
class InjectGitHubTokenInput:
sandbox_id: str
github_integration_id: int
task_id: str
distinct_id: str
@activity.defn
async def inject_github_token(input: InjectGitHubTokenInput) -> None:
async with log_activity_execution(
"inject_github_token",
distinct_id=input.distinct_id,
task_id=input.task_id,
sandbox_id=input.sandbox_id,
github_integration_id=input.github_integration_id,
):
try:
github_token = await get_github_token(input.github_integration_id) or ""
except Exception as e:
raise GitHubAuthenticationError(
f"Failed to get GitHub token for integration {input.github_integration_id}",
{"github_integration_id": input.github_integration_id, "error": str(e)},
)
sandbox = await SandboxEnvironment.get_by_id(input.sandbox_id)
escaped_github_token = shlex.quote(github_token)
result = await sandbox.execute(
f"echo 'export GITHUB_TOKEN={escaped_github_token}' >> ~/.bash_profile && echo 'export GITHUB_TOKEN=\"{escaped_github_token}\"' >> ~/.bashrc"
)
if result.exit_code != 0:
raise SandboxExecutionError(
f"Failed to inject GitHub token into sandbox",
{"sandbox_id": input.sandbox_id, "exit_code": result.exit_code, "stderr": result.stderr[:500]},
)

View File

@@ -0,0 +1,119 @@
import shlex
from dataclasses import dataclass
from django.core.exceptions import ObjectDoesNotExist
from temporalio import activity
from posthog.models import PersonalAPIKey
from posthog.models.personal_api_key import hash_key_value
from posthog.models.utils import generate_random_token_personal, mask_key_value
from posthog.scopes import API_SCOPE_OBJECTS
from posthog.temporal.common.utils import asyncify
from products.tasks.backend.models import Task
from products.tasks.backend.services.sandbox_environment import SandboxEnvironment
from products.tasks.backend.temporal.exceptions import (
PersonalAPIKeyError,
SandboxExecutionError,
TaskInvalidStateError,
TaskNotFoundError,
)
from products.tasks.backend.temporal.observability import log_activity_execution
@dataclass
class InjectPersonalAPIKeyInput:
sandbox_id: str
task_id: str
distinct_id: str
@dataclass
class InjectPersonalAPIKeyOutput:
personal_api_key_id: str
@asyncify
def _get_task(task_id: str) -> Task:
return Task.objects.select_related("created_by").get(id=task_id)
@asyncify
def _create_personal_api_key(task: Task) -> tuple[str, PersonalAPIKey]:
scopes = _get_default_scopes()
value = generate_random_token_personal()
mask_value = mask_key_value(value)
secure_value = hash_key_value(value)
if not task.created_by:
raise TaskInvalidStateError(f"Task {task.id} has no created_by user", {"task_id": task.id})
assert task.created_by is not None
personal_api_key = PersonalAPIKey.objects.create(
user=task.created_by,
label=f"Task Agent - {task.title[:20]}",
secure_value=secure_value,
mask_value=mask_value,
scopes=scopes,
scoped_teams=[task.team_id],
)
return value, personal_api_key
@activity.defn
async def inject_personal_api_key(input: InjectPersonalAPIKeyInput) -> InjectPersonalAPIKeyOutput:
async with log_activity_execution(
"inject_personal_api_key",
distinct_id=input.distinct_id,
task_id=input.task_id,
sandbox_id=input.sandbox_id,
):
try:
task = await _get_task(input.task_id)
except ObjectDoesNotExist:
raise TaskNotFoundError(f"Task {input.task_id} not found", {"task_id": input.task_id})
if not task.created_by:
raise TaskInvalidStateError(f"Task {input.task_id} has no created_by user", {"task_id": input.task_id})
try:
api_key_tuple: tuple[str, PersonalAPIKey] = await _create_personal_api_key(task)
value, personal_api_key = api_key_tuple
except Exception as e:
raise PersonalAPIKeyError(
f"Failed to create personal API key for task {input.task_id}",
{"task_id": input.task_id, "error": str(e)},
)
sandbox = await SandboxEnvironment.get_by_id(input.sandbox_id)
escaped_value = shlex.quote(value)
result = await sandbox.execute(
f"echo 'export POSTHOG_PERSONAL_API_KEY={escaped_value}' >> ~/.bash_profile && echo 'export POSTHOG_PERSONAL_API_KEY={escaped_value}' >> ~/.bashrc"
)
if result.exit_code != 0:
raise SandboxExecutionError(
f"Failed to inject personal API key into sandbox",
{"sandbox_id": input.sandbox_id, "exit_code": result.exit_code, "stderr": result.stderr[:500]},
)
return InjectPersonalAPIKeyOutput(personal_api_key_id=personal_api_key.id)
def _get_default_scopes() -> list[str]:
"""
Get default scopes for task agent API keys.
TODO: Make scopes configurable per task in the future.
For now, we provide read access to most resources.
"""
read_scopes = [f"{obj}:read" for obj in API_SCOPE_OBJECTS if obj not in ["INTERNAL"]]
return read_scopes

View File

@@ -0,0 +1,47 @@
from dataclasses import dataclass
from temporalio import activity
from products.tasks.backend.services.sandbox_agent import SandboxAgent
from products.tasks.backend.services.sandbox_environment import SandboxEnvironment
from products.tasks.backend.temporal.exceptions import RepositorySetupError
from products.tasks.backend.temporal.observability import log_activity_execution
@dataclass
class SetupRepositoryInput:
sandbox_id: str
repository: str
task_id: str
distinct_id: str
@activity.defn
async def setup_repository(input: SetupRepositoryInput) -> str:
"""Run code agent setup on repository. Returns setup logs."""
async with log_activity_execution(
"setup_repository",
distinct_id=input.distinct_id,
task_id=input.task_id,
sandbox_id=input.sandbox_id,
repository=input.repository,
):
sandbox = await SandboxEnvironment.get_by_id(input.sandbox_id)
agent = SandboxAgent(sandbox)
try:
result = await agent.setup_repository(input.repository)
except Exception as e:
raise RepositorySetupError(
f"Failed to setup repository {input.repository}",
{"repository": input.repository, "sandbox_id": input.sandbox_id, "error": str(e)},
)
if result.exit_code != 0:
raise RepositorySetupError(
f"Repository setup failed with exit code {result.exit_code}",
{"repository": input.repository, "exit_code": result.exit_code, "stderr": result.stderr[:500]},
)
return result.stdout

View File

@@ -0,0 +1,23 @@
"""
Snapshot constants for activity tests.
These snapshots are created in Runloop and can be used for consistent testing.
"""
from typing import TypedDict
class TestSnapshot(TypedDict):
external_id: str
repos: list[str]
SNAPSHOTS = [
TestSnapshot(external_id="snp_31DY4EmLlBZFy1aHV2IN2", repos=[]),
TestSnapshot(external_id="snp_31DY5L7W4ismYpImz22wN", repos=["posthog/posthog-js"]),
TestSnapshot(external_id="snp_31DY9PDHgbhD3NDgA6DGe", repos=["posthog/posthog-js", "posthog/posthog"]),
]
BASE_SNAPSHOT = SNAPSHOTS[0]
POSTHOG_JS_SNAPSHOT = SNAPSHOTS[1]
MULTI_REPO_SNAPSHOT = SNAPSHOTS[2]

View File

@@ -0,0 +1,198 @@
import time
import pytest
from asgiref.sync import sync_to_async
from posthog.models.integration import Integration
from products.tasks.backend.models import SandboxSnapshot
from products.tasks.backend.temporal.process_task.activities.check_snapshot_exists_for_repository import (
CheckSnapshotExistsForRepositoryInput,
CheckSnapshotExistsForRepositoryOutput,
check_snapshot_exists_for_repository,
)
class TestCheckSnapshotExistsForRepositoryActivity:
async def _create_snapshot(
self, github_integration, repos, status=SandboxSnapshot.Status.COMPLETE, external_id="test-snap-123"
):
return await sync_to_async(SandboxSnapshot.objects.create)(
integration=github_integration,
repos=repos,
status=status,
external_id=external_id,
)
async def _cleanup_snapshot(self, snapshot):
await sync_to_async(snapshot.delete)()
@pytest.mark.asyncio
@pytest.mark.django_db
async def test_check_snapshot_exists_for_repository_found(self, activity_environment, github_integration):
snapshot = await self._create_snapshot(
github_integration, repos=["test-owner/test-repo", "other-owner/other-repo"]
)
try:
input_data = CheckSnapshotExistsForRepositoryInput(
github_integration_id=github_integration.id, repository="test-owner/test-repo"
)
result = await activity_environment.run(check_snapshot_exists_for_repository, input_data)
assert isinstance(result, CheckSnapshotExistsForRepositoryOutput)
assert result.exists is True
assert result.snapshot_id == str(snapshot.id)
finally:
await self._cleanup_snapshot(snapshot)
@pytest.mark.asyncio
@pytest.mark.django_db
async def test_check_snapshot_exists_for_repository_not_found(self, activity_environment, github_integration):
input_data = CheckSnapshotExistsForRepositoryInput(
github_integration_id=github_integration.id, repository="nonexistent/repo"
)
result = await activity_environment.run(check_snapshot_exists_for_repository, input_data)
assert isinstance(result, CheckSnapshotExistsForRepositoryOutput)
assert result.exists is False
assert result.snapshot_id is None
@pytest.mark.asyncio
@pytest.mark.django_db
async def test_check_snapshot_exists_for_repository_repo_not_in_snapshot(
self, activity_environment, github_integration
):
snapshot = await self._create_snapshot(
github_integration, repos=["other-owner/other-repo", "another-owner/another-repo"]
)
try:
input_data = CheckSnapshotExistsForRepositoryInput(
github_integration_id=github_integration.id, repository="test-owner/test-repo"
)
result = await activity_environment.run(check_snapshot_exists_for_repository, input_data)
assert isinstance(result, CheckSnapshotExistsForRepositoryOutput)
assert result.exists is False
assert result.snapshot_id is None
finally:
await self._cleanup_snapshot(snapshot)
@pytest.mark.asyncio
@pytest.mark.django_db
async def test_check_snapshot_exists_for_repository_ignores_incomplete_snapshots(
self, activity_environment, github_integration
):
# Create snapshots with different statuses
in_progress_snapshot = await self._create_snapshot(
github_integration,
repos=["test-owner/test-repo"],
status=SandboxSnapshot.Status.IN_PROGRESS,
external_id="in-progress-snap",
)
error_snapshot = await self._create_snapshot(
github_integration,
repos=["test-owner/test-repo"],
status=SandboxSnapshot.Status.ERROR,
external_id="error-snap",
)
try:
input_data = CheckSnapshotExistsForRepositoryInput(
github_integration_id=github_integration.id, repository="test-owner/test-repo"
)
result = await activity_environment.run(check_snapshot_exists_for_repository, input_data)
# Should not find incomplete snapshots
assert isinstance(result, CheckSnapshotExistsForRepositoryOutput)
assert result.exists is False
assert result.snapshot_id is None
finally:
await self._cleanup_snapshot(in_progress_snapshot)
await self._cleanup_snapshot(error_snapshot)
@pytest.mark.asyncio
@pytest.mark.django_db
async def test_check_snapshot_exists_for_repository_returns_latest_complete(
self, activity_environment, github_integration
):
# Create multiple snapshots, with the latest being complete
older_snapshot = await self._create_snapshot(
github_integration,
repos=["test-owner/test-repo"],
status=SandboxSnapshot.Status.COMPLETE,
external_id="older-snap",
)
# Add delay to ensure different created_at times
time.sleep(0.01)
newer_snapshot = await self._create_snapshot(
github_integration,
repos=["test-owner/test-repo", "other-owner/other-repo"],
status=SandboxSnapshot.Status.COMPLETE,
external_id="newer-snap",
)
try:
input_data = CheckSnapshotExistsForRepositoryInput(
github_integration_id=github_integration.id, repository="test-owner/test-repo"
)
result = await activity_environment.run(check_snapshot_exists_for_repository, input_data)
# Should return the newer snapshot
assert isinstance(result, CheckSnapshotExistsForRepositoryOutput)
assert result.exists is True
assert result.snapshot_id == str(newer_snapshot.id)
finally:
await self._cleanup_snapshot(older_snapshot)
await self._cleanup_snapshot(newer_snapshot)
@pytest.mark.asyncio
@pytest.mark.django_db
async def test_check_snapshot_exists_for_repository_case_insensitive(
self, activity_environment, github_integration
):
# Create snapshot with mixed case repository name
snapshot = await self._create_snapshot(github_integration, repos=["TestOwner/TestRepo"])
try:
input_data = CheckSnapshotExistsForRepositoryInput(
github_integration_id=github_integration.id, repository="testowner/testrepo"
)
result = await activity_environment.run(check_snapshot_exists_for_repository, input_data)
assert isinstance(result, CheckSnapshotExistsForRepositoryOutput)
assert result.exists is True
assert result.snapshot_id == str(snapshot.id)
finally:
await self._cleanup_snapshot(snapshot)
@pytest.mark.asyncio
@pytest.mark.django_db
async def test_check_snapshot_exists_for_repository_different_integration(
self, activity_environment, github_integration, ateam
):
other_integration = await sync_to_async(Integration.objects.create)(
team=ateam,
kind="github",
config={"access_token": "other_fake_token"},
)
snapshot = await self._create_snapshot(other_integration, repos=["test-owner/test-repo"])
try:
input_data = CheckSnapshotExistsForRepositoryInput(
github_integration_id=github_integration.id, repository="test-owner/test-repo"
)
result = await activity_environment.run(check_snapshot_exists_for_repository, input_data)
# Should not find snapshot from different integration
assert isinstance(result, CheckSnapshotExistsForRepositoryOutput)
assert result.exists is False
assert result.snapshot_id is None
finally:
await self._cleanup_snapshot(snapshot)
await sync_to_async(other_integration.delete)()

View File

@@ -0,0 +1,100 @@
import os
import asyncio
import pytest
from products.tasks.backend.services.sandbox_environment import (
SandboxEnvironment,
SandboxEnvironmentConfig,
SandboxEnvironmentTemplate,
)
from products.tasks.backend.temporal.process_task.activities.cleanup_sandbox import CleanupSandboxInput, cleanup_sandbox
@pytest.mark.skipif(not os.environ.get("RUNLOOP_API_KEY"), reason="RUNLOOP_API_KEY environment variable not set")
class TestCleanupSandboxActivity:
@pytest.mark.asyncio
@pytest.mark.django_db
async def test_cleanup_sandbox_success(self, activity_environment):
config = SandboxEnvironmentConfig(
name="test-cleanup-sandbox",
template=SandboxEnvironmentTemplate.DEFAULT_BASE,
)
sandbox = await SandboxEnvironment.create(config)
sandbox_id = sandbox.id
existing_sandbox = await SandboxEnvironment.get_by_id(sandbox_id)
assert existing_sandbox.id == sandbox_id
input_data = CleanupSandboxInput(sandbox_id=sandbox_id)
await activity_environment.run(cleanup_sandbox, input_data)
cleaned_sandbox = await SandboxEnvironment.get_by_id(sandbox_id)
assert cleaned_sandbox.status.value == "shutdown"
@pytest.mark.asyncio
@pytest.mark.django_db
async def test_cleanup_sandbox_not_found_does_not_raise(self, activity_environment):
input_data = CleanupSandboxInput(sandbox_id="non-existent-sandbox-id")
# cleanup_sandbox is idempotent and doesn't raise if sandbox doesn't exist
await activity_environment.run(cleanup_sandbox, input_data)
@pytest.mark.asyncio
@pytest.mark.django_db
async def test_cleanup_sandbox_idempotency(self, activity_environment):
config = SandboxEnvironmentConfig(
name="test-cleanup-idempotent",
template=SandboxEnvironmentTemplate.DEFAULT_BASE,
)
sandbox = await SandboxEnvironment.create(config)
sandbox_id = sandbox.id
input_data = CleanupSandboxInput(sandbox_id=sandbox_id)
# First cleanup - should succeed
await activity_environment.run(cleanup_sandbox, input_data)
cleaned_sandbox = await SandboxEnvironment.get_by_id(sandbox_id)
assert cleaned_sandbox.status.value == "shutdown"
# Second cleanup - should still work on shutdown sandbox
await activity_environment.run(cleanup_sandbox, input_data)
@pytest.mark.asyncio
@pytest.mark.django_db
async def test_cleanup_sandbox_during_execution(self, activity_environment):
config = SandboxEnvironmentConfig(
name="test-cleanup-during-execution",
template=SandboxEnvironmentTemplate.DEFAULT_BASE,
)
sandbox = await SandboxEnvironment.create(config)
sandbox_id = sandbox.id
async def run_long_command():
try:
await sandbox.execute("sleep 30", timeout_seconds=60)
except Exception:
pass
long_task = asyncio.create_task(run_long_command())
# Give it a moment to start
await asyncio.sleep(5)
input_data = CleanupSandboxInput(sandbox_id=sandbox_id)
await activity_environment.run(cleanup_sandbox, input_data)
long_task.cancel()
try:
await long_task
except asyncio.CancelledError:
pass
remaining_sandbox = await SandboxEnvironment.get_by_id(sandbox_id)
assert remaining_sandbox.status.value in ["shutdown", "failure"]

View File

@@ -0,0 +1,230 @@
import os
import pytest
from unittest.mock import patch
from products.tasks.backend.services.sandbox_environment import (
SandboxEnvironment,
SandboxEnvironmentConfig,
SandboxEnvironmentTemplate,
)
from products.tasks.backend.temporal.exceptions import RepositoryCloneError, SandboxNotFoundError
from products.tasks.backend.temporal.process_task.activities.clone_repository import (
CloneRepositoryInput,
clone_repository,
)
@pytest.mark.skipif(not os.environ.get("RUNLOOP_API_KEY"), reason="RUNLOOP_API_KEY environment variable not set")
class TestCloneRepositoryActivity:
@pytest.mark.asyncio
@pytest.mark.django_db
async def test_clone_repository_success_and_directory_structure(self, activity_environment, github_integration):
config = SandboxEnvironmentConfig(
name="test-clone-success-and-structure",
template=SandboxEnvironmentTemplate.DEFAULT_BASE,
)
sandbox = None
try:
sandbox = await SandboxEnvironment.create(config)
input_data = CloneRepositoryInput(
sandbox_id=sandbox.id,
repository="PostHog/posthog-js",
github_integration_id=github_integration.id,
task_id="test-task-123",
distinct_id="test-user-id",
)
with patch(
"products.tasks.backend.temporal.process_task.activities.clone_repository.get_github_token"
) as mock_get_token:
mock_get_token.return_value = ""
result = await activity_environment.run(clone_repository, input_data)
# Verify we got output (git clone outputs to stderr)
assert result is not None
assert "posthog-js" in result
# Verify the repository actually exists in the sandbox
check_result = await sandbox.execute("ls -la /tmp/workspace/repos/posthog/")
assert "posthog-js" in check_result.stdout
# Verify it's a git repository
git_check = await sandbox.execute("cd /tmp/workspace/repos/posthog/posthog-js && git status")
assert git_check.exit_code == 0
assert "On branch" in git_check.stdout or "HEAD" in git_check.stdout
# Verify directory structure is correct
structure_check = await sandbox.execute("find /tmp/workspace/repos -type d | head -10")
assert "/tmp/workspace/repos/posthog" in structure_check.stdout
assert "/tmp/workspace/repos/posthog/posthog-js" in structure_check.stdout
# Verify we can navigate the structure
nav_check = await sandbox.execute("cd /tmp/workspace/repos/posthog/posthog-js && pwd")
assert "/tmp/workspace/repos/posthog/posthog-js" in nav_check.stdout
finally:
if sandbox:
await sandbox.destroy()
@pytest.mark.asyncio
@pytest.mark.django_db
async def test_clone_repository_idempotency(self, activity_environment, github_integration):
config = SandboxEnvironmentConfig(
name="test-clone-repository-idempotent",
template=SandboxEnvironmentTemplate.DEFAULT_BASE,
)
sandbox = None
try:
sandbox = await SandboxEnvironment.create(config)
input_data = CloneRepositoryInput(
sandbox_id=sandbox.id,
repository="PostHog/posthog-js",
github_integration_id=github_integration.id,
task_id="test-task-idempotent",
distinct_id="test-user-id",
)
with patch(
"products.tasks.backend.temporal.process_task.activities.clone_repository.get_github_token"
) as mock_get_token:
mock_get_token.return_value = ""
# First clone
result1 = await activity_environment.run(clone_repository, input_data)
assert result1 is not None
# Create a file to verify it gets wiped
await sandbox.execute("echo 'test' > /tmp/workspace/repos/posthog/posthog-js/test_file.txt")
# Verify file exists
check_file = await sandbox.execute("cat /tmp/workspace/repos/posthog/posthog-js/test_file.txt")
assert "test" in check_file.stdout
# Second clone (should wipe and re-clone)
result2 = await activity_environment.run(clone_repository, input_data)
assert "Cloning into 'posthog-js'" in result2 or "posthog-js" in result2
# Verify test file was removed (proving idempotency)
check_file_after = await sandbox.execute(
"ls /tmp/workspace/repos/posthog/posthog-js/test_file.txt 2>&1"
)
assert (
"No such file" in check_file_after.stdout
or "No such file" in check_file_after.stderr
or check_file_after.exit_code != 0
)
finally:
if sandbox:
await sandbox.destroy()
@pytest.mark.asyncio
@pytest.mark.django_db
async def test_clone_repository_private_repo_no_token(self, activity_environment, github_integration):
config = SandboxEnvironmentConfig(
name="test-clone-repository-auth-fail",
template=SandboxEnvironmentTemplate.DEFAULT_BASE,
)
sandbox = None
try:
sandbox = await SandboxEnvironment.create(config)
input_data = CloneRepositoryInput(
sandbox_id=sandbox.id,
repository="PostHog/private-test-repo-that-does-not-exist",
github_integration_id=github_integration.id,
task_id="test-task-auth-fail",
distinct_id="test-user-id",
)
with patch(
"products.tasks.backend.temporal.process_task.activities.clone_repository.get_github_token"
) as mock_get_token:
mock_get_token.return_value = "invalid-token"
with pytest.raises(RepositoryCloneError) as exc_info:
await activity_environment.run(clone_repository, input_data)
assert "Git clone failed" in str(exc_info.value)
# Verify repository doesn't exist
check_result = await sandbox.execute("ls /tmp/workspace/repos/posthog/ 2>&1")
assert "private-test-repo" not in check_result.stdout
finally:
if sandbox:
await sandbox.destroy()
@pytest.mark.asyncio
@pytest.mark.django_db
async def test_clone_repository_multiple_repos(self, activity_environment, github_integration):
config = SandboxEnvironmentConfig(
name="test-clone-multiple-repos",
template=SandboxEnvironmentTemplate.DEFAULT_BASE,
)
sandbox = None
try:
sandbox = await SandboxEnvironment.create(config)
repos = ["PostHog/posthog-js", "PostHog/posthog.com"]
with patch(
"products.tasks.backend.temporal.process_task.activities.clone_repository.get_github_token"
) as mock_get_token:
mock_get_token.return_value = "" # Public repos don't need auth
for repo in repos:
input_data = CloneRepositoryInput(
sandbox_id=sandbox.id,
repository=repo,
github_integration_id=github_integration.id,
task_id=f"test-task-{repo.split('/')[1]}",
distinct_id="test-user-id",
)
result = await activity_environment.run(clone_repository, input_data)
repo_name = repo.split("/")[1]
assert repo_name in result
# Verify both repos exist
check_result = await sandbox.execute("ls /tmp/workspace/repos/posthog/")
assert "posthog-js" in check_result.stdout
assert "posthog.com" in check_result.stdout
# Verify they're both git repositories
for repo in repos:
repo_name = repo.split("/")[1]
git_check = await sandbox.execute(f"cd /tmp/workspace/repos/posthog/{repo_name} && git remote -v")
assert git_check.exit_code == 0
assert "github.com" in git_check.stdout
finally:
if sandbox:
await sandbox.destroy()
@pytest.mark.asyncio
@pytest.mark.django_db
async def test_clone_repository_sandbox_not_found(self, activity_environment, github_integration):
input_data = CloneRepositoryInput(
sandbox_id="non-existent-sandbox-id",
repository="posthog/posthog-js",
github_integration_id=github_integration.id,
task_id="test-task-not-found",
distinct_id="test-user-id",
)
with patch(
"products.tasks.backend.temporal.process_task.activities.clone_repository.get_github_token"
) as mock_get_token:
mock_get_token.return_value = ""
with pytest.raises(SandboxNotFoundError):
await activity_environment.run(clone_repository, input_data)

View File

@@ -0,0 +1,94 @@
import os
import uuid
import pytest
from asgiref.sync import sync_to_async
from products.tasks.backend.models import SandboxSnapshot
from products.tasks.backend.services.sandbox_environment import SandboxEnvironment
from products.tasks.backend.temporal.exceptions import SandboxProvisionError, SnapshotNotFoundError
from products.tasks.backend.temporal.process_task.activities.create_sandbox_from_snapshot import (
CreateSandboxFromSnapshotInput,
create_sandbox_from_snapshot,
)
from .constants import BASE_SNAPSHOT
@pytest.mark.skipif(not os.environ.get("RUNLOOP_API_KEY"), reason="RUNLOOP_API_KEY environment variable not set")
class TestCreateSandboxFromSnapshotActivity:
async def _create_snapshot(self, github_integration, external_id=None, status=SandboxSnapshot.Status.COMPLETE):
if external_id is None:
external_id = str(uuid.uuid4())
return await sync_to_async(SandboxSnapshot.objects.create)(
integration=github_integration,
external_id=external_id,
status=status,
)
async def _cleanup_snapshot(self, snapshot):
await sync_to_async(snapshot.delete)()
async def _cleanup_sandbox(self, sandbox_id):
sandbox = await SandboxEnvironment.get_by_id(sandbox_id)
await sandbox.destroy()
@pytest.mark.asyncio
@pytest.mark.django_db
async def test_create_sandbox_from_snapshot_success(self, activity_environment, github_integration):
snapshot = await self._create_snapshot(github_integration, external_id=BASE_SNAPSHOT["external_id"])
task_id = "test-task-123"
sandbox_id = None
try:
input_data = CreateSandboxFromSnapshotInput(
snapshot_id=str(snapshot.id), task_id=task_id, distinct_id="test-user-id"
)
sandbox_id = await activity_environment.run(create_sandbox_from_snapshot, input_data)
assert isinstance(sandbox_id, str)
assert len(sandbox_id) > 0
sandbox = await SandboxEnvironment.get_by_id(sandbox_id)
assert sandbox.id == sandbox_id
assert sandbox.status in ["pending", "initializing", "running"]
finally:
await self._cleanup_snapshot(snapshot)
if sandbox_id:
await self._cleanup_sandbox(sandbox_id)
@pytest.mark.asyncio
@pytest.mark.django_db
async def test_create_sandbox_from_snapshot_not_found(self, activity_environment):
input_data = CreateSandboxFromSnapshotInput(
snapshot_id=str(uuid.uuid4()),
task_id="test-task-456",
distinct_id="test-user-id",
)
with pytest.raises(SnapshotNotFoundError):
await activity_environment.run(create_sandbox_from_snapshot, input_data)
@pytest.mark.asyncio
@pytest.mark.django_db
async def test_create_sandbox_from_snapshot_with_invalid_external_id(
self, activity_environment, github_integration
):
snapshot = await self._create_snapshot(github_integration, external_id="invalid-snapshot-id")
task_id = "test-task-789"
sandbox_id = None
try:
input_data = CreateSandboxFromSnapshotInput(
snapshot_id=str(snapshot.id), task_id=task_id, distinct_id="test-user-id"
)
with pytest.raises(SandboxProvisionError):
sandbox_id = await activity_environment.run(create_sandbox_from_snapshot, input_data)
finally:
await self._cleanup_snapshot(snapshot)
if sandbox_id:
await self._cleanup_sandbox(sandbox_id)

View File

@@ -0,0 +1,142 @@
import os
import uuid
import pytest
from asgiref.sync import sync_to_async
from products.tasks.backend.models import SandboxSnapshot
from products.tasks.backend.services.sandbox_environment import (
SandboxEnvironment,
SandboxEnvironmentConfig,
SandboxEnvironmentTemplate,
)
from products.tasks.backend.temporal.exceptions import SandboxNotFoundError
from products.tasks.backend.temporal.process_task.activities.create_snapshot import CreateSnapshotInput, create_snapshot
@pytest.mark.skipif(not os.environ.get("RUNLOOP_API_KEY"), reason="RUNLOOP_API_KEY environment variable not set")
class TestCreateSnapshotActivity:
@pytest.mark.asyncio
@pytest.mark.django_db
async def test_create_snapshot_real(self, activity_environment, github_integration, ateam):
"""Test real snapshot creation with actual sandbox."""
config = SandboxEnvironmentConfig(
name="test-create-snapshot",
template=SandboxEnvironmentTemplate.DEFAULT_BASE,
)
sandbox = None
created_snapshot = None
created_snapshot_external_id = None
try:
# Create a real sandbox
sandbox = await SandboxEnvironment.create(config)
input_data = CreateSnapshotInput(
sandbox_id=sandbox.id,
github_integration_id=github_integration.id,
team_id=ateam.id,
repository="test-owner/test-repo",
task_id="test-task-123",
distinct_id="test-user-id",
)
# This will create a real snapshot and wait for it to complete
result = await activity_environment.run(create_snapshot, input_data)
# Verify a UUID was returned
assert result is not None
uuid.UUID(result) # Should not raise
# Verify snapshot was created in the database
created_snapshot = await sync_to_async(SandboxSnapshot.objects.get)(id=result)
created_snapshot_external_id = created_snapshot.external_id
assert created_snapshot.external_id is not None
assert created_snapshot.integration_id == github_integration.id
assert "test-owner/test-repo" in created_snapshot.repos
assert created_snapshot.status == SandboxSnapshot.Status.COMPLETE
# Verify the snapshot exists in provider
snapshot_status = await SandboxEnvironment.get_snapshot_status(created_snapshot.external_id)
assert snapshot_status.value == "complete"
finally:
if sandbox:
await sandbox.destroy()
if created_snapshot:
await sync_to_async(created_snapshot.delete)()
if created_snapshot_external_id:
await SandboxEnvironment.delete_snapshot(created_snapshot_external_id)
@pytest.mark.asyncio
@pytest.mark.django_db
async def test_create_snapshot_with_existing_base_snapshot(self, activity_environment, github_integration, ateam):
"""Test snapshot creation with existing base snapshot repos."""
# Create a base snapshot in the database (using a fake external ID since we're not creating it in Runloop)
base_snapshot = await sync_to_async(SandboxSnapshot.objects.create)(
integration=github_integration,
external_id=f"fake_base_{uuid.uuid4().hex[:8]}",
repos=["existing-owner/existing-repo"],
status=SandboxSnapshot.Status.COMPLETE,
)
config = SandboxEnvironmentConfig(
name="test-create-snapshot-with-base",
template=SandboxEnvironmentTemplate.DEFAULT_BASE,
)
sandbox = None
created_snapshot = None
created_snapshot_external_id = None
try:
sandbox = await SandboxEnvironment.create(config)
input_data = CreateSnapshotInput(
sandbox_id=sandbox.id,
github_integration_id=github_integration.id,
team_id=ateam.id,
repository="new-owner/new-repo",
task_id="test-task-with-base",
distinct_id="test-user-id",
)
result = await activity_environment.run(create_snapshot, input_data)
# Verify new snapshot includes both repos
created_snapshot = await sync_to_async(SandboxSnapshot.objects.get)(id=result)
created_snapshot_external_id = created_snapshot.external_id
assert created_snapshot.external_id is not None
assert "existing-owner/existing-repo" in created_snapshot.repos
assert "new-owner/new-repo" in created_snapshot.repos
assert len(created_snapshot.repos) == 2
# Verify the snapshot actually exists in Runloop
snapshot_status = await SandboxEnvironment.get_snapshot_status(created_snapshot.external_id)
assert snapshot_status.value == "complete"
finally:
await sync_to_async(base_snapshot.delete)()
if sandbox:
await sandbox.destroy()
if created_snapshot:
await sync_to_async(created_snapshot.delete)()
if created_snapshot_external_id:
await SandboxEnvironment.delete_snapshot(created_snapshot_external_id)
@pytest.mark.asyncio
@pytest.mark.django_db
async def test_create_snapshot_sandbox_not_found(self, activity_environment, github_integration, ateam):
input_data = CreateSnapshotInput(
sandbox_id="non-existent-sandbox-id",
github_integration_id=github_integration.id,
team_id=ateam.id,
repository="test-owner/test-repo",
task_id="test-task-not-found",
distinct_id="test-user-id",
)
with pytest.raises(SandboxNotFoundError):
await activity_environment.run(create_snapshot, input_data)

View File

@@ -0,0 +1,213 @@
import os
import pytest
from unittest.mock import patch
from products.tasks.backend.services.sandbox_environment import (
SandboxEnvironment,
SandboxEnvironmentConfig,
SandboxEnvironmentTemplate,
)
from products.tasks.backend.temporal.exceptions import SandboxNotFoundError, TaskExecutionFailedError
from products.tasks.backend.temporal.process_task.activities.clone_repository import (
CloneRepositoryInput,
clone_repository,
)
from products.tasks.backend.temporal.process_task.activities.execute_task_in_sandbox import (
ExecuteTaskInput,
execute_task_in_sandbox,
)
@pytest.mark.skipif(not os.environ.get("RUNLOOP_API_KEY"), reason="RUNLOOP_API_KEY environment variable not set")
class TestExecuteTaskInSandboxActivity:
@pytest.mark.asyncio
@pytest.mark.django_db
async def test_execute_task_success(self, activity_environment, github_integration):
"""Test successful task execution in sandbox."""
config = SandboxEnvironmentConfig(
name="test-execute-task",
template=SandboxEnvironmentTemplate.DEFAULT_BASE,
)
sandbox = None
try:
sandbox = await SandboxEnvironment.create(config)
clone_input = CloneRepositoryInput(
sandbox_id=sandbox.id,
repository="PostHog/posthog-js",
github_integration_id=github_integration.id,
task_id="test-task-123",
distinct_id="test-user-id",
)
with patch(
"products.tasks.backend.temporal.process_task.activities.clone_repository.get_github_token"
) as mock_get_token:
mock_get_token.return_value = ""
await activity_environment.run(clone_repository, clone_input)
# We mock the _get_task_command to run a simple command instead of the code agent
with patch(
"products.tasks.backend.temporal.process_task.activities.execute_task_in_sandbox.SandboxAgent._get_task_command"
) as mock_task_cmd:
mock_task_cmd.return_value = "echo 'Task executed successfully'"
input_data = ExecuteTaskInput(
sandbox_id=sandbox.id,
task_id="test-task-123",
repository="PostHog/posthog-js",
distinct_id="test-user-id",
)
await activity_environment.run(execute_task_in_sandbox, input_data)
mock_task_cmd.assert_called_once_with("test-task-123")
finally:
if sandbox:
await sandbox.destroy()
@pytest.mark.asyncio
@pytest.mark.django_db
async def test_execute_task_failure(self, activity_environment, github_integration):
"""Test task execution failure handling."""
config = SandboxEnvironmentConfig(
name="test-execute-task-fail",
template=SandboxEnvironmentTemplate.DEFAULT_BASE,
)
sandbox = None
try:
sandbox = await SandboxEnvironment.create(config)
clone_input = CloneRepositoryInput(
sandbox_id=sandbox.id,
repository="PostHog/posthog-js",
github_integration_id=github_integration.id,
task_id="test-task-fail",
distinct_id="test-user-id",
)
with patch(
"products.tasks.backend.temporal.process_task.activities.clone_repository.get_github_token"
) as mock_get_token:
mock_get_token.return_value = ""
await activity_environment.run(clone_repository, clone_input)
# We mock the _get_task_command to run a failing command
with patch(
"products.tasks.backend.temporal.process_task.activities.execute_task_in_sandbox.SandboxAgent._get_task_command"
) as mock_task_cmd:
mock_task_cmd.return_value = "exit 1" # Command that fails
input_data = ExecuteTaskInput(
sandbox_id=sandbox.id,
task_id="test-task-fail",
repository="PostHog/posthog-js",
distinct_id="test-user-id",
)
with pytest.raises(TaskExecutionFailedError):
await activity_environment.run(execute_task_in_sandbox, input_data)
finally:
if sandbox:
await sandbox.destroy()
@pytest.mark.asyncio
@pytest.mark.django_db
async def test_execute_task_repository_not_found(self, activity_environment):
"""Test task execution when repository doesn't exist in sandbox."""
config = SandboxEnvironmentConfig(
name="test-execute-task-no-repo",
template=SandboxEnvironmentTemplate.DEFAULT_BASE,
)
sandbox = None
try:
sandbox = await SandboxEnvironment.create(config)
# We don't clone any repository, just try to execute task
with patch(
"products.tasks.backend.temporal.process_task.activities.execute_task_in_sandbox.SandboxAgent._get_task_command"
) as mock_task_cmd:
mock_task_cmd.return_value = "ls -la"
input_data = ExecuteTaskInput(
sandbox_id=sandbox.id,
task_id="test-task-no-repo",
repository="PostHog/posthog-js",
distinct_id="test-user-id",
)
with pytest.raises(TaskExecutionFailedError):
await activity_environment.run(execute_task_in_sandbox, input_data)
finally:
if sandbox:
await sandbox.destroy()
@pytest.mark.asyncio
@pytest.mark.django_db
async def test_execute_task_sandbox_not_found(self, activity_environment):
input_data = ExecuteTaskInput(
sandbox_id="non-existent-sandbox-id",
task_id="test-task",
repository="PostHog/posthog-js",
distinct_id="test-user-id",
)
with pytest.raises(SandboxNotFoundError):
await activity_environment.run(execute_task_in_sandbox, input_data)
@pytest.mark.asyncio
@pytest.mark.django_db
async def test_execute_task_with_different_repositories(self, activity_environment, github_integration):
config = SandboxEnvironmentConfig(
name="test-execute-different-repos",
template=SandboxEnvironmentTemplate.DEFAULT_BASE,
)
sandbox = None
try:
sandbox = await SandboxEnvironment.create(config)
repos_to_test = ["PostHog/posthog-js", "PostHog/posthog.com"]
with patch(
"products.tasks.backend.temporal.process_task.activities.clone_repository.get_github_token"
) as mock_get_token:
mock_get_token.return_value = ""
for repo in repos_to_test:
clone_input = CloneRepositoryInput(
sandbox_id=sandbox.id,
repository=repo,
github_integration_id=github_integration.id,
task_id=f"test-task-{repo.split('/')[1]}",
distinct_id="test-user-id",
)
await activity_environment.run(clone_repository, clone_input)
# Execute task in each repository
with patch(
"products.tasks.backend.temporal.process_task.activities.execute_task_in_sandbox.SandboxAgent._get_task_command"
) as mock_task_cmd:
mock_task_cmd.return_value = f"echo 'Working in {repo}'"
input_data = ExecuteTaskInput(
sandbox_id=sandbox.id,
task_id=f"test-task-{repo.split('/')[1]}",
repository=repo,
distinct_id="test-user-id",
)
await activity_environment.run(execute_task_in_sandbox, input_data)
mock_task_cmd.assert_called_once_with(f"test-task-{repo.split('/')[1]}")
finally:
if sandbox:
await sandbox.destroy()

View File

@@ -0,0 +1,165 @@
import os
import uuid
import pytest
from asgiref.sync import sync_to_async
from products.tasks.backend.models import SandboxSnapshot
from products.tasks.backend.services.sandbox_environment import SandboxEnvironment
from products.tasks.backend.temporal.process_task.activities.get_sandbox_for_setup import (
GetSandboxForSetupInput,
get_sandbox_for_setup,
)
from products.tasks.backend.temporal.process_task.utils import get_sandbox_name_for_task
from .constants import BASE_SNAPSHOT
@pytest.mark.skipif(not os.environ.get("RUNLOOP_API_KEY"), reason="RUNLOOP_API_KEY environment variable not set")
class TestGetSandboxForSetupActivity:
"""Test suite for the get_sandbox_for_setup activity."""
async def _create_snapshot(self, github_integration, external_id=None, status=SandboxSnapshot.Status.COMPLETE):
"""Helper method to create a snapshot."""
if external_id is None:
external_id = str(uuid.uuid4())
return await sync_to_async(SandboxSnapshot.objects.create)(
integration=github_integration,
external_id=external_id,
status=status,
)
async def _cleanup_snapshot(self, snapshot):
"""Helper method to clean up a snapshot."""
await sync_to_async(snapshot.delete)()
async def _cleanup_sandbox(self, sandbox_id):
"""Helper method to clean up a sandbox."""
sandbox = await SandboxEnvironment.get_by_id(sandbox_id)
await sandbox.destroy()
@pytest.mark.asyncio
@pytest.mark.django_db
async def test_get_sandbox_for_setup_with_existing_snapshot(self, activity_environment, github_integration, ateam):
snapshot = await self._create_snapshot(github_integration, external_id=BASE_SNAPSHOT["external_id"])
task_id = "test-task-123"
sandbox_id = None
try:
input_data = GetSandboxForSetupInput(
github_integration_id=github_integration.id,
team_id=ateam.id,
task_id=task_id,
distinct_id="test-user-id",
)
sandbox_id = await activity_environment.run(get_sandbox_for_setup, input_data)
assert isinstance(sandbox_id, str)
assert len(sandbox_id) > 0
# Verify sandbox was created
sandbox = await SandboxEnvironment.get_by_id(sandbox_id)
assert sandbox.id == sandbox_id
assert sandbox.status in ["pending", "initializing", "running"]
finally:
await self._cleanup_snapshot(snapshot)
if sandbox_id:
await self._cleanup_sandbox(sandbox_id)
@pytest.mark.asyncio
@pytest.mark.django_db
async def test_get_sandbox_for_setup_without_existing_snapshot(
self, activity_environment, github_integration, ateam
):
task_id = "test-task-456"
sandbox_id = None
try:
input_data = GetSandboxForSetupInput(
github_integration_id=github_integration.id,
team_id=ateam.id,
task_id=task_id,
distinct_id="test-user-id",
)
sandbox_id = await activity_environment.run(get_sandbox_for_setup, input_data)
assert isinstance(sandbox_id, str)
assert len(sandbox_id) > 0
# Verify sandbox was created
sandbox = await SandboxEnvironment.get_by_id(sandbox_id)
assert sandbox.id == sandbox_id
assert sandbox.status in ["pending", "initializing", "running"]
finally:
if sandbox_id:
await self._cleanup_sandbox(sandbox_id)
@pytest.mark.asyncio
@pytest.mark.django_db
async def test_get_sandbox_for_setup_ignores_incomplete_snapshots(
self, activity_environment, github_integration, ateam
):
# Create snapshots with incomplete status
in_progress_snapshot = await self._create_snapshot(
github_integration, status=SandboxSnapshot.Status.IN_PROGRESS
)
error_snapshot = await self._create_snapshot(github_integration, status=SandboxSnapshot.Status.ERROR)
task_id = "test-task-789"
sandbox_id = None
try:
input_data = GetSandboxForSetupInput(
github_integration_id=github_integration.id,
team_id=ateam.id,
task_id=task_id,
distinct_id="test-user-id",
)
sandbox_id = await activity_environment.run(get_sandbox_for_setup, input_data)
assert isinstance(sandbox_id, str)
assert len(sandbox_id) > 0
# Verify sandbox was created (should not use incomplete snapshots as base)
sandbox = await SandboxEnvironment.get_by_id(sandbox_id)
assert sandbox.id == sandbox_id
finally:
await self._cleanup_snapshot(in_progress_snapshot)
await self._cleanup_snapshot(error_snapshot)
if sandbox_id:
await self._cleanup_sandbox(sandbox_id)
@pytest.mark.asyncio
@pytest.mark.django_db
async def test_get_sandbox_for_setup_sandbox_name_generation(self, activity_environment, github_integration, ateam):
task_id = "special-task-id-with-uuid-abc123"
sandbox_id = None
try:
input_data = GetSandboxForSetupInput(
github_integration_id=github_integration.id,
team_id=ateam.id,
task_id=task_id,
distinct_id="test-user-id",
)
sandbox_id = await activity_environment.run(get_sandbox_for_setup, input_data)
assert isinstance(sandbox_id, str)
assert len(sandbox_id) > 0
# Verify sandbox exists
sandbox = await SandboxEnvironment.get_by_id(sandbox_id)
assert sandbox.id == sandbox_id
assert sandbox.name == get_sandbox_name_for_task(task_id)
finally:
if sandbox_id:
await self._cleanup_sandbox(sandbox_id)

View File

@@ -0,0 +1,96 @@
import pytest
from django.core.exceptions import ValidationError
from asgiref.sync import sync_to_async
from products.tasks.backend.models import Task
from products.tasks.backend.temporal.exceptions import TaskInvalidStateError, TaskNotFoundError
from products.tasks.backend.temporal.process_task.activities.get_task_details import TaskDetails, get_task_details
class TestGetTaskDetailsActivity:
async def _create_task_with_repo(self, ateam, auser, task_workflow, github_integration, repo_config):
workflow, stages = task_workflow
backlog_stage = stages[0]
return await sync_to_async(Task.objects.create)(
team=ateam,
title="Test Task",
description="Test task description",
origin_product=Task.OriginProduct.USER_CREATED,
workflow=workflow,
current_stage=backlog_stage,
position=0,
github_integration=github_integration,
repository_config=repo_config,
created_by=auser,
)
async def _cleanup_task(self, task):
await sync_to_async(task.delete)()
@pytest.mark.asyncio
@pytest.mark.django_db
async def test_get_task_details_success(self, activity_environment, test_task):
result = await activity_environment.run(get_task_details, str(test_task.id))
assert isinstance(result, TaskDetails)
assert result.task_id == str(test_task.id)
assert result.team_id == test_task.team_id
assert result.github_integration_id == test_task.github_integration_id
assert result.repository == "posthog/posthog-js"
@pytest.mark.asyncio
@pytest.mark.django_db
async def test_get_task_details_task_not_found(self, activity_environment):
non_existent_task_id = "550e8400-e29b-41d4-a716-446655440000"
with pytest.raises(TaskNotFoundError):
await activity_environment.run(get_task_details, non_existent_task_id)
@pytest.mark.asyncio
@pytest.mark.django_db
async def test_get_task_details_invalid_uuid(self, activity_environment):
invalid_task_id = "not-a-uuid"
with pytest.raises(ValidationError):
await activity_environment.run(get_task_details, invalid_task_id)
@pytest.mark.asyncio
@pytest.mark.django_db
async def test_get_task_details_with_different_repository(
self, activity_environment, ateam, auser, task_workflow, github_integration
):
task = await self._create_task_with_repo(
ateam, auser, task_workflow, github_integration, {"organization": "posthog", "repository": "posthog-js"}
)
try:
result = await activity_environment.run(get_task_details, str(task.id))
assert result.task_id == str(task.id)
assert result.team_id == task.team_id
assert result.github_integration_id == github_integration.id
assert result.repository == "posthog/posthog-js"
finally:
await self._cleanup_task(task)
@pytest.mark.asyncio
@pytest.mark.django_db
async def test_get_task_details_with_missing_repository(
self, activity_environment, ateam, auser, task_workflow, github_integration
):
task = await self._create_task_with_repo(
ateam,
auser,
task_workflow,
github_integration,
{"organization": "test-org"}, # Missing "repository" key
)
try:
with pytest.raises(TaskInvalidStateError):
await activity_environment.run(get_task_details, str(task.id))
finally:
await self._cleanup_task(task)

View File

@@ -0,0 +1,72 @@
import os
import pytest
from unittest.mock import patch
from products.tasks.backend.services.sandbox_environment import (
SandboxEnvironment,
SandboxEnvironmentConfig,
SandboxEnvironmentTemplate,
)
from products.tasks.backend.temporal.exceptions import SandboxNotFoundError
from products.tasks.backend.temporal.process_task.activities.inject_github_token import (
InjectGitHubTokenInput,
inject_github_token,
)
@pytest.mark.skipif(not os.environ.get("RUNLOOP_API_KEY"), reason="RUNLOOP_API_KEY environment variable not set")
class TestInjectGitHubTokenActivity:
@pytest.mark.asyncio
@pytest.mark.django_db
async def test_inject_github_token_success(self, activity_environment, github_integration):
config = SandboxEnvironmentConfig(
name="test-inject-token-success",
template=SandboxEnvironmentTemplate.DEFAULT_BASE,
)
sandbox = None
try:
sandbox = await SandboxEnvironment.create(config)
input_data = InjectGitHubTokenInput(
sandbox_id=sandbox.id,
github_integration_id=github_integration.id,
task_id="test-task-123",
distinct_id="test-user-id",
)
test_token = "ghp_test_token_12345"
with patch(
"products.tasks.backend.temporal.process_task.activities.inject_github_token.get_github_token"
) as mock_get_token:
mock_get_token.return_value = test_token
await activity_environment.run(inject_github_token, input_data)
check_result = await sandbox.execute("bash -c 'source ~/.bashrc && echo $GITHUB_TOKEN'")
assert check_result.exit_code == 0
assert test_token in check_result.stdout
finally:
if sandbox:
await sandbox.destroy()
@pytest.mark.asyncio
@pytest.mark.django_db
async def test_inject_github_token_sandbox_not_found(self, activity_environment, github_integration):
input_data = InjectGitHubTokenInput(
sandbox_id="non-existent-sandbox-id",
github_integration_id=github_integration.id,
task_id="test-task-not-found",
distinct_id="test-user-id",
)
with patch(
"products.tasks.backend.temporal.process_task.activities.inject_github_token.get_github_token"
) as mock_get_token:
mock_get_token.return_value = "test_token"
with pytest.raises(SandboxNotFoundError):
await activity_environment.run(inject_github_token, input_data)

View File

@@ -0,0 +1,102 @@
import os
import pytest
from asgiref.sync import sync_to_async
from posthog.models import PersonalAPIKey
from products.tasks.backend.services.sandbox_environment import (
SandboxEnvironment,
SandboxEnvironmentConfig,
SandboxEnvironmentTemplate,
)
from products.tasks.backend.temporal.exceptions import SandboxNotFoundError, TaskInvalidStateError
from products.tasks.backend.temporal.process_task.activities.inject_personal_api_key import (
InjectPersonalAPIKeyInput,
InjectPersonalAPIKeyOutput,
inject_personal_api_key,
)
@pytest.mark.skipif(not os.environ.get("RUNLOOP_API_KEY"), reason="RUNLOOP_API_KEY environment variable not set")
class TestInjectPersonalAPIKeyActivity:
@pytest.mark.asyncio
@pytest.mark.django_db
async def test_inject_personal_api_key_success(self, activity_environment, test_task):
config = SandboxEnvironmentConfig(
name="test-inject-api-key-success",
template=SandboxEnvironmentTemplate.DEFAULT_BASE,
)
sandbox = None
try:
sandbox = await SandboxEnvironment.create(config)
input_data = InjectPersonalAPIKeyInput(
sandbox_id=sandbox.id,
task_id=str(test_task.id),
distinct_id="test-user-id",
)
result: InjectPersonalAPIKeyOutput = await activity_environment.run(inject_personal_api_key, input_data)
assert result.personal_api_key_id is not None
api_key = await sync_to_async(PersonalAPIKey.objects.get)(id=result.personal_api_key_id)
assert api_key.user_id == test_task.created_by_id
assert api_key.scopes is not None
assert len(api_key.scopes) > 0
assert api_key.scoped_teams == [test_task.team_id]
assert f"Task Agent - {test_task.title[:20]}" == api_key.label
check_result = await sandbox.execute("bash -c 'source ~/.bashrc && echo $POSTHOG_PERSONAL_API_KEY'")
assert check_result.exit_code == 0
api_key_value = check_result.stdout.strip()
assert api_key_value.startswith("phx_")
await sync_to_async(api_key.delete)()
finally:
if sandbox:
await sandbox.destroy()
@pytest.mark.asyncio
@pytest.mark.django_db
async def test_inject_personal_api_key_no_user(self, activity_environment, test_task):
config = SandboxEnvironmentConfig(
name="test-inject-api-key-no-user",
template=SandboxEnvironmentTemplate.DEFAULT_BASE,
)
sandbox = None
try:
sandbox = await SandboxEnvironment.create(config)
test_task.created_by = None
await sync_to_async(test_task.save)()
input_data = InjectPersonalAPIKeyInput(
sandbox_id=sandbox.id,
task_id=str(test_task.id),
distinct_id="test-user-id",
)
with pytest.raises(TaskInvalidStateError):
await activity_environment.run(inject_personal_api_key, input_data)
finally:
if sandbox:
await sandbox.destroy()
@pytest.mark.asyncio
@pytest.mark.django_db
async def test_inject_personal_api_key_sandbox_not_found(self, activity_environment, test_task):
input_data = InjectPersonalAPIKeyInput(
sandbox_id="non-existent-sandbox-id",
task_id=str(test_task.id),
distinct_id="test-user-id",
)
with pytest.raises(SandboxNotFoundError):
await activity_environment.run(inject_personal_api_key, input_data)

View File

@@ -0,0 +1,118 @@
import os
import pytest
from unittest.mock import patch
from products.tasks.backend.services.sandbox_environment import (
SandboxEnvironment,
SandboxEnvironmentConfig,
SandboxEnvironmentTemplate,
)
from products.tasks.backend.temporal.exceptions import RepositorySetupError, SandboxNotFoundError
from products.tasks.backend.temporal.process_task.activities.clone_repository import (
CloneRepositoryInput,
clone_repository,
)
from products.tasks.backend.temporal.process_task.activities.setup_repository import (
SetupRepositoryInput,
setup_repository,
)
@pytest.mark.skipif(not os.environ.get("RUNLOOP_API_KEY"), reason="RUNLOOP_API_KEY environment variable not set")
class TestSetupRepositoryActivity:
@pytest.mark.asyncio
@pytest.mark.django_db
async def test_setup_repository_success(self, activity_environment, github_integration):
config = SandboxEnvironmentConfig(
name="test-setup-repository",
template=SandboxEnvironmentTemplate.DEFAULT_BASE,
)
sandbox = None
try:
sandbox = await SandboxEnvironment.create(config)
clone_input = CloneRepositoryInput(
sandbox_id=sandbox.id,
repository="posthog/posthog-js",
github_integration_id=github_integration.id,
task_id="test-task-123",
distinct_id="test-user-id",
)
with patch(
"products.tasks.backend.temporal.process_task.activities.clone_repository.get_github_token"
) as mock_get_token:
mock_get_token.return_value = ""
await activity_environment.run(clone_repository, clone_input)
check_before = await sandbox.execute(
"ls -la /tmp/workspace/repos/posthog/posthog-js/ | grep node_modules || echo 'no node_modules'"
)
assert "no node_modules" in check_before.stdout
# We mock the _get_setup_command inside the setup_repository activity to just run pnpm install for the test, instead of using the coding agent
with patch(
"products.tasks.backend.temporal.process_task.activities.setup_repository.SandboxAgent._get_setup_command"
) as mock_setup_cmd:
mock_setup_cmd.return_value = "pnpm install"
setup_input = SetupRepositoryInput(
sandbox_id=sandbox.id,
repository="posthog/posthog-js",
task_id="test-task-123",
distinct_id="test-user-id",
)
result = await activity_environment.run(setup_repository, setup_input)
assert result is not None
check_after = await sandbox.execute(
"ls -la /tmp/workspace/repos/posthog/posthog-js/ | grep node_modules || echo 'no node_modules'"
)
assert "node_modules" in check_after.stdout
assert "no node_modules" not in check_after.stdout
finally:
if sandbox:
await sandbox.destroy()
@pytest.mark.asyncio
@pytest.mark.django_db
async def test_setup_repository_without_clone(self, activity_environment):
config = SandboxEnvironmentConfig(
name="test-setup-no-clone",
template=SandboxEnvironmentTemplate.DEFAULT_BASE,
)
sandbox = None
try:
sandbox = await SandboxEnvironment.create(config)
setup_input = SetupRepositoryInput(
sandbox_id=sandbox.id,
repository="posthog/posthog-js",
task_id="test-task-no-clone",
distinct_id="test-user-id",
)
with pytest.raises(RepositorySetupError):
await activity_environment.run(setup_repository, setup_input)
finally:
if sandbox:
await sandbox.destroy()
@pytest.mark.asyncio
@pytest.mark.django_db
async def test_setup_repository_sandbox_not_found(self, activity_environment):
setup_input = SetupRepositoryInput(
sandbox_id="non-existent-sandbox-id",
repository="posthog/posthog-js",
task_id="test-task-not-found",
distinct_id="test-user-id",
)
with pytest.raises(SandboxNotFoundError):
await activity_environment.run(setup_repository, setup_input)

View File

@@ -0,0 +1,34 @@
from dataclasses import dataclass
from typing import Any
import posthoganalytics
from temporalio import activity
from posthog.temporal.common.logger import get_logger
logger = get_logger(__name__)
@dataclass
class TrackWorkflowEventInput:
event_name: str
distinct_id: str
properties: dict[str, Any]
@activity.defn
def track_workflow_event(input: TrackWorkflowEventInput) -> None:
"""Track workflow-level events to PostHog."""
try:
posthoganalytics.capture(
distinct_id=input.distinct_id,
event=input.event_name,
properties=input.properties,
)
except Exception:
logger.exception(
"Failed to track workflow event",
event_name=input.event_name,
distinct_id=input.distinct_id,
properties=input.properties,
)

View File

@@ -0,0 +1,284 @@
import os
import uuid
from concurrent.futures import ThreadPoolExecutor
from datetime import timedelta
from typing import cast
import pytest
from unittest.mock import patch
from asgiref.sync import sync_to_async
from temporalio.common import RetryPolicy
from temporalio.testing import WorkflowEnvironment
from temporalio.worker import UnsandboxedWorkflowRunner, Worker
from posthog import constants
from products.tasks.backend.models import SandboxSnapshot
from products.tasks.backend.services.sandbox_environment import SandboxEnvironment, SandboxEnvironmentStatus
from products.tasks.backend.temporal.process_task.activities.check_snapshot_exists_for_repository import (
check_snapshot_exists_for_repository,
)
from products.tasks.backend.temporal.process_task.activities.cleanup_personal_api_key import cleanup_personal_api_key
from products.tasks.backend.temporal.process_task.activities.cleanup_sandbox import cleanup_sandbox
from products.tasks.backend.temporal.process_task.activities.clone_repository import clone_repository
from products.tasks.backend.temporal.process_task.activities.create_sandbox_from_snapshot import (
create_sandbox_from_snapshot,
)
from products.tasks.backend.temporal.process_task.activities.create_snapshot import create_snapshot
from products.tasks.backend.temporal.process_task.activities.execute_task_in_sandbox import execute_task_in_sandbox
from products.tasks.backend.temporal.process_task.activities.get_sandbox_for_setup import get_sandbox_for_setup
from products.tasks.backend.temporal.process_task.activities.get_task_details import get_task_details
from products.tasks.backend.temporal.process_task.activities.inject_github_token import inject_github_token
from products.tasks.backend.temporal.process_task.activities.inject_personal_api_key import inject_personal_api_key
from products.tasks.backend.temporal.process_task.activities.setup_repository import setup_repository
from products.tasks.backend.temporal.process_task.activities.tests.constants import POSTHOG_JS_SNAPSHOT
from products.tasks.backend.temporal.process_task.activities.track_workflow_event import track_workflow_event
from products.tasks.backend.temporal.process_task.workflow import ProcessTaskOutput, ProcessTaskWorkflow
pytestmark = [pytest.mark.asyncio, pytest.mark.django_db]
@pytest.mark.skipif(not os.environ.get("RUNLOOP_API_KEY"), reason="RUNLOOP_API_KEY environment variable not set")
class TestProcessTaskWorkflow:
"""
End-to-end workflow tests using real Runloop sandboxes.
These tests create actual sandboxes and snapshots, only mocking the task execution command
to avoid running the full AI agent. This allows us to verify:
- Snapshot creation and reuse
- Sandbox lifecycle management
- Proper cleanup on success and failure
"""
async def _run_workflow(self, task_id: str, mock_task_command: str = "echo 'task complete'") -> ProcessTaskOutput:
workflow_id = str(uuid.uuid4())
with (
patch(
"products.tasks.backend.temporal.process_task.activities.setup_repository.SandboxAgent._get_setup_command"
) as mock_setup,
patch(
"products.tasks.backend.temporal.process_task.activities.execute_task_in_sandbox.SandboxAgent._get_task_command"
) as mock_task,
):
mock_setup.return_value = "pnpm install"
mock_task.return_value = mock_task_command
async with (
await WorkflowEnvironment.start_time_skipping() as env,
Worker(
env.client,
task_queue=constants.TASKS_TASK_QUEUE,
workflows=[ProcessTaskWorkflow],
activities=[
get_task_details,
check_snapshot_exists_for_repository,
get_sandbox_for_setup,
clone_repository,
setup_repository,
create_snapshot,
create_sandbox_from_snapshot,
inject_github_token,
inject_personal_api_key,
execute_task_in_sandbox,
cleanup_sandbox,
cleanup_personal_api_key,
track_workflow_event,
],
workflow_runner=UnsandboxedWorkflowRunner(),
activity_executor=ThreadPoolExecutor(max_workers=10),
),
):
result = await env.client.execute_workflow(
ProcessTaskWorkflow.run,
task_id,
id=workflow_id,
task_queue=constants.TASKS_TASK_QUEUE,
retry_policy=RetryPolicy(maximum_attempts=1),
execution_timeout=timedelta(minutes=60),
)
return result
async def _verify_file_in_sandbox(self, sandbox_id: str, filepath: str) -> bool:
"""Verify a file exists in a sandbox."""
sandbox = await SandboxEnvironment.get_by_id(sandbox_id)
result = await sandbox.execute(f"test -f {filepath} && echo 'exists' || echo 'missing'")
return "exists" in result.stdout
async def test_workflow_with_existing_snapshot_reuses_snapshot(self, test_task, github_integration):
snapshot, _ = await sync_to_async(SandboxSnapshot.objects.get_or_create)(
external_id=POSTHOG_JS_SNAPSHOT["external_id"],
defaults={
"integration": github_integration,
"repos": POSTHOG_JS_SNAPSHOT["repos"],
"status": SandboxSnapshot.Status.COMPLETE,
},
)
try:
result = await self._run_workflow(test_task.id)
assert result.success is True
assert result.task_result is not None
assert result.task_result.exit_code == 0
assert "task complete" in result.task_result.stdout
snapshots_query = SandboxSnapshot.objects.filter(integration=github_integration).order_by("-created_at")
snapshots = cast(list[SandboxSnapshot], await sync_to_async(list)(snapshots_query)) # type: ignore[call-arg]
assert len(snapshots) == 1
assert snapshots[0].id == snapshot.id
assert "posthog/posthog-js" in snapshots[0].repos
finally:
await sync_to_async(snapshot.delete)()
async def test_workflow_creates_snapshot_for_new_repository(self, test_task, github_integration):
created_snapshots = []
try:
result = await self._run_workflow(test_task.id)
assert result.success is True
assert result.task_result is not None
assert result.task_result.exit_code == 0
snapshots_query = SandboxSnapshot.objects.filter(
integration=github_integration, status=SandboxSnapshot.Status.COMPLETE
).order_by("-created_at")
snapshots = cast(list[SandboxSnapshot], await sync_to_async(list)(snapshots_query)) # type: ignore[call-arg]
assert len(snapshots) >= 1
latest_snapshot = snapshots[0]
assert "posthog/posthog-js" in latest_snapshot.repos
assert latest_snapshot.status == SandboxSnapshot.Status.COMPLETE
assert latest_snapshot.external_id is not None
created_snapshots.append(latest_snapshot)
finally:
for snapshot in created_snapshots:
try:
if snapshot.external_id:
await SandboxEnvironment.delete_snapshot(snapshot.external_id)
await sync_to_async(snapshot.delete)()
except Exception:
pass
async def test_workflow_executes_task_in_sandbox(self, test_task, github_integration):
snapshot = await sync_to_async(SandboxSnapshot.objects.create)(
integration=github_integration,
external_id=POSTHOG_JS_SNAPSHOT["external_id"],
repos=POSTHOG_JS_SNAPSHOT["repos"],
status=SandboxSnapshot.Status.COMPLETE,
)
custom_message = f"workflow_test_{uuid.uuid4().hex[:8]}"
try:
result = await self._run_workflow(test_task.id, mock_task_command=f"echo '{custom_message}'")
assert result.success is True
assert result.task_result is not None
assert result.task_result.exit_code == 0
assert custom_message in result.task_result.stdout
finally:
await sync_to_async(snapshot.delete)()
async def test_workflow_cleans_up_sandbox_on_success(self, test_task, github_integration):
snapshot = await sync_to_async(SandboxSnapshot.objects.create)(
integration=github_integration,
external_id=POSTHOG_JS_SNAPSHOT["external_id"],
repos=POSTHOG_JS_SNAPSHOT["repos"],
status=SandboxSnapshot.Status.COMPLETE,
)
try:
result = await self._run_workflow(test_task.id)
assert result.success is True
assert result.task_result is not None
assert result.sandbox_id is not None
sandbox = await SandboxEnvironment.get_by_id(result.sandbox_id)
assert sandbox.status == SandboxEnvironmentStatus.SHUTDOWN.value
finally:
await sync_to_async(snapshot.delete)()
async def test_workflow_cleans_up_sandbox_on_failure(self, test_task, github_integration):
snapshot = await sync_to_async(SandboxSnapshot.objects.create)(
integration=github_integration,
external_id=POSTHOG_JS_SNAPSHOT["external_id"],
repos=POSTHOG_JS_SNAPSHOT["repos"],
status=SandboxSnapshot.Status.COMPLETE,
)
try:
result = await self._run_workflow(test_task.id, mock_task_command="exit 1")
assert result.success is False
assert result.error is not None
assert result.task_result is None
assert result.sandbox_id is not None
sandbox_id = result.sandbox_id
sandbox = await SandboxEnvironment.get_by_id(sandbox_id)
assert sandbox.status == SandboxEnvironmentStatus.SHUTDOWN.value
finally:
await sync_to_async(snapshot.delete)()
async def test_workflow_handles_missing_task(self):
fake_task_id = str(uuid.uuid4())
result = await self._run_workflow(fake_task_id)
assert result.success is False
assert result.error is not None
assert "activity task failed" in result.error.lower() or "failed" in result.error.lower()
async def test_workflow_full_cycle_no_snapshot(self, test_task, github_integration):
created_snapshots = []
try:
result = await self._run_workflow(test_task.id)
assert result.success is True
assert result.task_result is not None
assert result.task_result.exit_code == 0
snapshots_query = SandboxSnapshot.objects.filter(
integration=github_integration, status=SandboxSnapshot.Status.COMPLETE
).order_by("-created_at")
snapshots = cast(list[SandboxSnapshot], await sync_to_async(list)(snapshots_query)) # type: ignore[call-arg]
assert len(snapshots) >= 1
latest_snapshot = snapshots[0]
assert "posthog/posthog-js" in latest_snapshot.repos
assert latest_snapshot.status == SandboxSnapshot.Status.COMPLETE
created_snapshots.append(latest_snapshot)
result2 = await self._run_workflow(test_task.id)
assert result2.success is True
assert result2.task_result is not None
snapshots_after_query = SandboxSnapshot.objects.filter(
integration=github_integration, status=SandboxSnapshot.Status.COMPLETE
).order_by("-created_at")
snapshots_after = cast(list[SandboxSnapshot], await sync_to_async(list)(snapshots_after_query)) # type: ignore[call-arg]
assert len(snapshots_after) == len(snapshots)
finally:
for snapshot in created_snapshots:
try:
if snapshot.external_id:
await SandboxEnvironment.delete_snapshot(snapshot.external_id)
await sync_to_async(snapshot.delete)()
except Exception:
pass

View File

@@ -0,0 +1,19 @@
from typing import Optional
from posthog.models.integration import GitHubIntegration, Integration
from posthog.temporal.common.utils import asyncify
@asyncify
def get_github_token(github_integration_id: int) -> Optional[str]:
integration = Integration.objects.get(id=github_integration_id)
github_integration = GitHubIntegration(integration)
if github_integration.access_token_expired():
github_integration.refresh_access_token()
return github_integration.integration.access_token or None
def get_sandbox_name_for_task(task_id: str) -> str:
return f"task-sandbox-{task_id}"

View File

@@ -0,0 +1,318 @@
import json
from dataclasses import dataclass
from datetime import timedelta
from typing import Optional
import temporalio
from temporalio import workflow
from temporalio.common import RetryPolicy
from posthog.temporal.common.base import PostHogWorkflow
from posthog.temporal.common.logger import get_logger
from .activities.check_snapshot_exists_for_repository import (
CheckSnapshotExistsForRepositoryInput,
check_snapshot_exists_for_repository,
)
from .activities.cleanup_personal_api_key import cleanup_personal_api_key
from .activities.cleanup_sandbox import CleanupSandboxInput, cleanup_sandbox
from .activities.clone_repository import CloneRepositoryInput, clone_repository
from .activities.create_sandbox_from_snapshot import CreateSandboxFromSnapshotInput, create_sandbox_from_snapshot
from .activities.create_snapshot import CreateSnapshotInput, create_snapshot
from .activities.execute_task_in_sandbox import ExecuteTaskInput, ExecuteTaskOutput, execute_task_in_sandbox
from .activities.get_sandbox_for_setup import GetSandboxForSetupInput, get_sandbox_for_setup
from .activities.get_task_details import TaskDetails, get_task_details
from .activities.inject_github_token import InjectGitHubTokenInput, inject_github_token
from .activities.inject_personal_api_key import (
InjectPersonalAPIKeyInput,
InjectPersonalAPIKeyOutput,
inject_personal_api_key,
)
from .activities.setup_repository import SetupRepositoryInput, setup_repository
from .activities.track_workflow_event import TrackWorkflowEventInput, track_workflow_event
logger = get_logger(__name__)
@dataclass
class ProcessTaskOutput:
success: bool
task_result: Optional[ExecuteTaskOutput] = None
error: Optional[str] = None
sandbox_id: Optional[str] = None
@temporalio.workflow.defn(name="process-task")
class ProcessTaskWorkflow(PostHogWorkflow):
def __init__(self) -> None:
self._task_details: Optional[TaskDetails] = None
@property
def task_details(self) -> TaskDetails:
if self._task_details is None:
raise RuntimeError("task_details accessed before being set")
return self._task_details
@staticmethod
def parse_inputs(inputs: list[str]) -> str:
loaded = json.loads(inputs[0])
return loaded["task_id"]
@temporalio.workflow.run
async def run(self, task_id: str) -> ProcessTaskOutput:
sandbox_id = None
personal_api_key_id = None
try:
self._task_details = await self._get_task_details(task_id)
await self._track_workflow_event(
"process_task_workflow_started",
{
"task_id": self.task_details.task_id,
"repository": self.task_details.repository,
"team_id": self.task_details.team_id,
},
)
snapshot_id = await self._get_snapshot_for_repository()
sandbox_id = await self._create_sandbox_from_snapshot(snapshot_id)
await self._inject_github_token(sandbox_id)
api_key_output = await self._inject_personal_api_key(sandbox_id)
personal_api_key_id = api_key_output.personal_api_key_id
result = await self._execute_task_in_sandbox(sandbox_id)
await self._track_workflow_event(
"process_task_workflow_completed",
{
"task_id": self.task_details.task_id,
"sandbox_id": sandbox_id,
"exit_code": result.exit_code,
},
)
return ProcessTaskOutput(
success=True,
task_result=result,
error=None,
sandbox_id=sandbox_id,
)
except Exception as e:
if self._task_details:
await self._track_workflow_event(
"process_task_workflow_failed",
{
"task_id": self.task_details.task_id,
"error_type": type(e).__name__,
"error_message": str(e)[:500],
"sandbox_id": sandbox_id,
},
)
return ProcessTaskOutput(
success=False,
task_result=None,
error=str(e),
sandbox_id=sandbox_id,
)
finally:
if personal_api_key_id:
await self._cleanup_personal_api_key(personal_api_key_id)
if sandbox_id:
await self._cleanup_sandbox(sandbox_id)
async def _get_task_details(self, task_id: str) -> TaskDetails:
return await workflow.execute_activity(
get_task_details,
task_id,
start_to_close_timeout=timedelta(seconds=30),
retry_policy=RetryPolicy(maximum_attempts=3),
)
async def _get_snapshot_for_repository(self) -> str:
check_input = CheckSnapshotExistsForRepositoryInput(
github_integration_id=self.task_details.github_integration_id,
repository=self.task_details.repository,
)
check_result = await workflow.execute_activity(
check_snapshot_exists_for_repository,
check_input,
start_to_close_timeout=timedelta(seconds=30),
retry_policy=RetryPolicy(maximum_attempts=3),
)
if check_result.snapshot_id:
return check_result.snapshot_id
return await self._setup_snapshot_with_repository()
async def _get_sandbox_for_setup(self) -> str:
get_sandbox_input = GetSandboxForSetupInput(
github_integration_id=self.task_details.github_integration_id,
team_id=self.task_details.team_id,
task_id=self.task_details.task_id,
distinct_id=self.task_details.distinct_id,
)
return await workflow.execute_activity(
get_sandbox_for_setup,
get_sandbox_input,
start_to_close_timeout=timedelta(minutes=10),
retry_policy=RetryPolicy(maximum_attempts=2),
)
async def _clone_repository_in_sandbox(self, sandbox_id: str) -> None:
clone_input = CloneRepositoryInput(
sandbox_id=sandbox_id,
repository=self.task_details.repository,
github_integration_id=self.task_details.github_integration_id,
task_id=self.task_details.task_id,
distinct_id=self.task_details.distinct_id,
)
await workflow.execute_activity(
clone_repository,
clone_input,
start_to_close_timeout=timedelta(minutes=10),
retry_policy=RetryPolicy(maximum_attempts=2),
)
async def _setup_repository_in_sandbox(self, sandbox_id: str) -> None:
setup_repo_input = SetupRepositoryInput(
sandbox_id=sandbox_id,
repository=self.task_details.repository,
task_id=self.task_details.task_id,
distinct_id=self.task_details.distinct_id,
)
await workflow.execute_activity(
setup_repository,
setup_repo_input,
start_to_close_timeout=timedelta(minutes=15),
retry_policy=RetryPolicy(maximum_attempts=1),
)
async def _snapshot_sandbox(self, sandbox_id: str) -> str:
snapshot_input = CreateSnapshotInput(
sandbox_id=sandbox_id,
github_integration_id=self.task_details.github_integration_id,
team_id=self.task_details.team_id,
repository=self.task_details.repository,
task_id=self.task_details.task_id,
distinct_id=self.task_details.distinct_id,
)
return await workflow.execute_activity(
create_snapshot,
snapshot_input,
start_to_close_timeout=timedelta(minutes=25),
retry_policy=RetryPolicy(maximum_attempts=3),
)
async def _cleanup_sandbox(self, sandbox_id: str) -> None:
cleanup_input = CleanupSandboxInput(sandbox_id=sandbox_id)
await workflow.execute_activity(
cleanup_sandbox,
cleanup_input,
start_to_close_timeout=timedelta(minutes=5),
retry_policy=RetryPolicy(maximum_attempts=3),
)
async def _setup_snapshot_with_repository(self) -> str:
setup_sandbox_id = None
try:
setup_sandbox_id = await self._get_sandbox_for_setup()
await self._clone_repository_in_sandbox(setup_sandbox_id)
await self._setup_repository_in_sandbox(setup_sandbox_id)
snapshot_id = await self._snapshot_sandbox(setup_sandbox_id)
return snapshot_id
finally:
if setup_sandbox_id:
await self._cleanup_sandbox(setup_sandbox_id)
async def _create_sandbox_from_snapshot(self, snapshot_id: str) -> str:
create_sandbox_input = CreateSandboxFromSnapshotInput(
snapshot_id=snapshot_id,
task_id=self.task_details.task_id,
distinct_id=self.task_details.distinct_id,
)
return await workflow.execute_activity(
create_sandbox_from_snapshot,
create_sandbox_input,
start_to_close_timeout=timedelta(minutes=5),
retry_policy=RetryPolicy(maximum_attempts=2),
)
async def _inject_github_token(self, sandbox_id: str) -> None:
inject_token_input = InjectGitHubTokenInput(
sandbox_id=sandbox_id,
github_integration_id=self.task_details.github_integration_id,
task_id=self.task_details.task_id,
distinct_id=self.task_details.distinct_id,
)
await workflow.execute_activity(
inject_github_token,
inject_token_input,
start_to_close_timeout=timedelta(minutes=5),
retry_policy=RetryPolicy(maximum_attempts=3),
)
async def _inject_personal_api_key(self, sandbox_id: str) -> InjectPersonalAPIKeyOutput:
inject_key_input = InjectPersonalAPIKeyInput(
sandbox_id=sandbox_id,
task_id=self.task_details.task_id,
distinct_id=self.task_details.distinct_id,
)
return await workflow.execute_activity(
inject_personal_api_key,
inject_key_input,
start_to_close_timeout=timedelta(minutes=5),
retry_policy=RetryPolicy(maximum_attempts=3),
)
async def _cleanup_personal_api_key(self, personal_api_key_id: str) -> None:
try:
await workflow.execute_activity(
cleanup_personal_api_key,
personal_api_key_id,
start_to_close_timeout=timedelta(minutes=5),
retry_policy=RetryPolicy(maximum_attempts=3),
)
except Exception as e:
logger.warning(f"Failed to cleanup personal API key {personal_api_key_id}: {e}")
async def _execute_task_in_sandbox(self, sandbox_id: str) -> ExecuteTaskOutput:
execute_input = ExecuteTaskInput(
sandbox_id=sandbox_id,
task_id=self.task_details.task_id,
repository=self.task_details.repository,
distinct_id=self.task_details.distinct_id,
)
return await workflow.execute_activity(
execute_task_in_sandbox,
execute_input,
start_to_close_timeout=timedelta(minutes=30),
retry_policy=RetryPolicy(maximum_attempts=1),
)
async def _track_workflow_event(self, event_name: str, properties: dict) -> None:
track_input = TrackWorkflowEventInput(
event_name=event_name,
distinct_id=self.task_details.distinct_id,
properties=properties,
)
await workflow.execute_activity(
track_workflow_event,
track_input,
start_to_close_timeout=timedelta(seconds=10),
retry_policy=RetryPolicy(maximum_attempts=1),
)

View File

@@ -513,7 +513,7 @@ class TestTask(TestCase):
self.assertEqual(repo_list[0]["org"], "PostHog")
self.assertEqual(repo_list[0]["repo"], "posthog")
self.assertEqual(repo_list[0]["integration_id"], integration.id)
self.assertEqual(repo_list[0]["full_name"], "PostHog/posthog")
self.assertEqual(repo_list[0]["full_name"], "posthog/posthog")
def test_repository_list_empty(self):
task = Task.objects.create(