mirror of
https://github.com/BillyOutlast/posthog.git
synced 2026-02-04 03:01:23 +01:00
feat(array): temporal workflow for sandbox environment (#39082)
This commit is contained in:
6
.github/workflows/ci-backend.yml
vendored
6
.github/workflows/ci-backend.yml
vendored
@@ -49,6 +49,7 @@ jobs:
|
||||
backend_files: ${{ steps.filter.outputs.backend_files }}
|
||||
migrations: ${{ steps.filter.outputs.migrations }}
|
||||
migrations_files: ${{ steps.filter.outputs.migrations_files }}
|
||||
tasks_temporal: ${{ steps.filter.outputs.tasks_temporal }}
|
||||
steps:
|
||||
# For pull requests it's not necessary to checkout the code, but we
|
||||
# also want this to run on master so we need to checkout
|
||||
@@ -96,6 +97,8 @@ jobs:
|
||||
- 'posthog/migrations/*.py'
|
||||
- 'products/*/backend/migrations/*.py'
|
||||
- 'products/*/migrations/*.py' # Legacy structure
|
||||
tasks_temporal:
|
||||
- 'products/tasks/backend/temporal/**/*'
|
||||
|
||||
check-migrations:
|
||||
needs: [changes]
|
||||
@@ -722,9 +725,10 @@ jobs:
|
||||
shell: bash
|
||||
env:
|
||||
AWS_S3_ALLOW_UNSAFE_RENAME: 'true'
|
||||
RUNLOOP_API_KEY: ${{ needs.changes.outputs.tasks_temporal == 'true' && secrets.RUNLOOP_API_KEY || '' }}
|
||||
run: |
|
||||
set +e
|
||||
pytest posthog/temporal products/batch_exports/backend/tests/temporal -m "not async_migrations" \
|
||||
pytest posthog/temporal products/batch_exports/backend/tests/temporal products/tasks/backend/temporal -m "not async_migrations" \
|
||||
--splits ${{ matrix.concurrency }} --group ${{ matrix.group }} \
|
||||
--durations=100 --durations-min=1.0 --store-durations \
|
||||
--splitting-algorithm=duration_based_chunks \
|
||||
|
||||
@@ -1,7 +1,40 @@
|
||||
import inspect
|
||||
from collections.abc import Callable, Coroutine
|
||||
from datetime import datetime
|
||||
from functools import wraps
|
||||
from typing import Any, ParamSpec, TypeVar
|
||||
|
||||
from asgiref.sync import sync_to_async
|
||||
from temporalio import workflow
|
||||
|
||||
P = ParamSpec("P")
|
||||
T = TypeVar("T")
|
||||
|
||||
|
||||
def asyncify(fn: Callable[P, T]) -> Callable[P, Coroutine[Any, Any, T]]:
|
||||
"""Decorator to convert a sync function using sync_to_async - this preserves type hints for Temporal's serialization while allowing sync Django ORM code.
|
||||
|
||||
This preserves type hints for Temporal's serialization while allowing
|
||||
sync Django ORM code.
|
||||
|
||||
Usage:
|
||||
@activity.defn
|
||||
@asyncify
|
||||
def my_activity(task_id: str) -> TaskDetails:
|
||||
task = Task.objects.get(id=task_id)
|
||||
return TaskDetails(...)
|
||||
"""
|
||||
if inspect.iscoroutinefunction(fn):
|
||||
raise TypeError(
|
||||
f"@asyncify should only be used on sync functions. " f"'{fn.__name__}' is already async. Remove @asyncify."
|
||||
)
|
||||
|
||||
@wraps(fn)
|
||||
async def wrapper(*args: P.args, **kwargs: P.kwargs) -> T:
|
||||
return await sync_to_async(fn)(*args, **kwargs)
|
||||
|
||||
return wrapper
|
||||
|
||||
|
||||
def get_scheduled_start_time():
|
||||
"""Return the start time of a workflow.
|
||||
|
||||
19
products/tasks/backend/lib/constants.py
Normal file
19
products/tasks/backend/lib/constants.py
Normal file
@@ -0,0 +1,19 @@
|
||||
SETUP_REPOSITORY_PROMPT = """
|
||||
Your goal is to setup the repository in the current environment.
|
||||
|
||||
You are operating in a sandbox environment. You must install all dependencies necessary and setup the environment such that it is ready for executing code tasks.
|
||||
|
||||
CONTEXT:
|
||||
|
||||
CWD: {cwd}
|
||||
|
||||
REPOSITORY: {repository}
|
||||
|
||||
INSTRUCTIONS:
|
||||
|
||||
1. Install all dependencies necessary to run the repository
|
||||
2. Run any setup scripts that are available
|
||||
3. Verify the setup by running tests or build if available
|
||||
|
||||
DO NOT make any code changes to the repository. The final state of the disk of this sandbox is what will be used for subsequent tasks, so do not leave any cruft behind, and make sure the repository is in a ready to use state.
|
||||
"""
|
||||
26
products/tasks/backend/migrations/0009_task_created_by.py
Normal file
26
products/tasks/backend/migrations/0009_task_created_by.py
Normal file
@@ -0,0 +1,26 @@
|
||||
# Generated by Django 4.2.22 on 2025-10-04 16:57
|
||||
|
||||
import django.db.models.deletion
|
||||
from django.conf import settings
|
||||
from django.db import migrations, models
|
||||
|
||||
|
||||
class Migration(migrations.Migration):
|
||||
dependencies = [
|
||||
migrations.swappable_dependency(settings.AUTH_USER_MODEL),
|
||||
("tasks", "0008_task_task_number"),
|
||||
]
|
||||
|
||||
operations = [
|
||||
migrations.AddField(
|
||||
model_name="task",
|
||||
name="created_by",
|
||||
field=models.ForeignKey(
|
||||
blank=True,
|
||||
db_index=False,
|
||||
null=True,
|
||||
on_delete=django.db.models.deletion.SET_NULL,
|
||||
to=settings.AUTH_USER_MODEL,
|
||||
),
|
||||
),
|
||||
]
|
||||
@@ -1 +1 @@
|
||||
0008_task_task_number
|
||||
0009_task_created_by
|
||||
|
||||
@@ -231,6 +231,7 @@ class Task(models.Model):
|
||||
|
||||
id = models.UUIDField(primary_key=True, default=uuid.uuid4, editable=False)
|
||||
team = models.ForeignKey("posthog.Team", on_delete=models.CASCADE)
|
||||
created_by = models.ForeignKey("posthog.User", on_delete=models.SET_NULL, null=True, blank=True, db_index=False)
|
||||
task_number = models.IntegerField(null=True, blank=True)
|
||||
title = models.CharField(max_length=255)
|
||||
description = models.TextField()
|
||||
@@ -331,12 +332,13 @@ class Task(models.Model):
|
||||
"""
|
||||
config = self.repository_config
|
||||
if config.get("organization") and config.get("repository"):
|
||||
full_name = f"{config.get('organization')}/{config.get('repository')}".lower()
|
||||
return [
|
||||
{
|
||||
"org": config.get("organization"),
|
||||
"repo": config.get("repository"),
|
||||
"integration_id": self.github_integration_id,
|
||||
"full_name": f"{config.get('organization')}/{config.get('repository')}",
|
||||
"full_name": full_name,
|
||||
}
|
||||
]
|
||||
return []
|
||||
|
||||
@@ -85,6 +85,9 @@ class TaskSerializer(serializers.ModelSerializer):
|
||||
def create(self, validated_data):
|
||||
validated_data["team"] = self.context["team"]
|
||||
|
||||
if "request" in self.context and hasattr(self.context["request"], "user"):
|
||||
validated_data["created_by"] = self.context["request"].user
|
||||
|
||||
# Set default GitHub integration if not provided
|
||||
if not validated_data.get("github_integration"):
|
||||
default_integration = Integration.objects.filter(team=self.context["team"], kind="github").first()
|
||||
|
||||
@@ -1,8 +1,13 @@
|
||||
import logging
|
||||
from typing import Optional
|
||||
|
||||
from django.conf import settings
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
from .sandbox_environment import ExecutionResult, SandboxEnvironment, SandboxEnvironmentConfig
|
||||
from products.tasks.backend.lib.constants import SETUP_REPOSITORY_PROMPT
|
||||
|
||||
from .sandbox_environment import ExecutionResult, SandboxEnvironment
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -11,7 +16,8 @@ REPOSITORY_TARGET_DIR = "repo"
|
||||
DEFAULT_TASK_TIMEOUT_SECONDS = 20 * 60 # 20 minutes
|
||||
|
||||
|
||||
class SandboxAgentConfig(BaseModel):
|
||||
class SandboxAgentCreateConfig(BaseModel):
|
||||
name: str
|
||||
repository_url: str
|
||||
github_token: str
|
||||
task_id: str
|
||||
@@ -20,83 +26,76 @@ class SandboxAgentConfig(BaseModel):
|
||||
|
||||
|
||||
class SandboxAgent:
|
||||
"""
|
||||
Agent that uses sandbox environments to execute tasks.
|
||||
"""
|
||||
"""Agent that uses sandbox environments to execute tasks."""
|
||||
|
||||
config: SandboxAgentConfig
|
||||
sandbox: SandboxEnvironment
|
||||
|
||||
def __init__(self, sandbox: SandboxEnvironment, config: SandboxAgentConfig):
|
||||
def __init__(self, sandbox: SandboxEnvironment):
|
||||
self.sandbox = sandbox
|
||||
self.config = config
|
||||
|
||||
@classmethod
|
||||
async def create(
|
||||
cls,
|
||||
sandbox: SandboxEnvironment,
|
||||
config: SandboxAgentConfig,
|
||||
) -> "SandboxAgent":
|
||||
environment_variables = {
|
||||
"REPOSITORY_URL": config.repository_url,
|
||||
"POSTHOG_CLI_TOKEN": config.posthog_personal_api_key,
|
||||
"POSTHOG_CLI_ENV_ID": config.posthog_project_id,
|
||||
}
|
||||
async def clone_repository(self, repository: str, github_token: Optional[str] = "") -> ExecutionResult:
|
||||
if not self.sandbox.is_running:
|
||||
raise RuntimeError(f"Sandbox not in running state. Current status: {self.sandbox.status}")
|
||||
|
||||
sandbox_config = SandboxEnvironmentConfig(
|
||||
name=sandbox.config.name,
|
||||
template=sandbox.config.template,
|
||||
environment_variables=environment_variables,
|
||||
entrypoint=sandbox.config.entrypoint,
|
||||
org, repo = repository.lower().split("/")
|
||||
repo_url = (
|
||||
f"https://x-access-token:{github_token}@github.com/{org}/{repo}.git"
|
||||
if github_token
|
||||
else f"https://github.com/{org}/{repo}.git"
|
||||
)
|
||||
|
||||
sandbox = await SandboxEnvironment.create(sandbox_config)
|
||||
agent = cls(sandbox, config)
|
||||
target_path = f"/tmp/workspace/repos/{org}/{repo}"
|
||||
|
||||
return agent
|
||||
|
||||
async def setup_repository(self) -> ExecutionResult:
|
||||
if not self.sandbox.is_running:
|
||||
raise RuntimeError(f"Sandbox not in running state. Current status: {self.sandbox.status}")
|
||||
|
||||
return await self.clone_repository(self.config.repository_url)
|
||||
|
||||
async def clone_repository(self, repo_url: str) -> ExecutionResult:
|
||||
if not self.sandbox.is_running:
|
||||
raise RuntimeError(f"Sandbox not in running state. Current status: {self.sandbox.status}")
|
||||
|
||||
if repo_url.startswith("https://github.com/"):
|
||||
auth_url = repo_url.replace(
|
||||
"https://github.com/",
|
||||
f"https://x-access-token:{self.config.github_token}@github.com/",
|
||||
)
|
||||
else:
|
||||
raise ValueError("Only GitHub is supported")
|
||||
|
||||
clone_command = f"git clone {auth_url} {WORKING_DIR}/{REPOSITORY_TARGET_DIR}"
|
||||
|
||||
logger.info(f"Cloning repository {repo_url} to {self.repository_dir} in sandbox {self.sandbox.id}")
|
||||
return await self.sandbox.execute(clone_command)
|
||||
|
||||
async def execute_task(self) -> ExecutionResult:
|
||||
"""Execute Claude Code commands in the sandbox."""
|
||||
if not self.sandbox.is_running:
|
||||
raise RuntimeError(f"Sandbox not in running state. Current status: {self.sandbox.status}")
|
||||
|
||||
full_command = f"cd {self.repository_dir} && {self.get_task_command()}"
|
||||
|
||||
logger.info(
|
||||
f"Executing task {self.config.task_id} in directory {self.repository_dir} in sandbox {self.sandbox.id}"
|
||||
# Wipe existing directory if present, then clone
|
||||
clone_command = (
|
||||
f"rm -rf {target_path} && "
|
||||
f"mkdir -p /tmp/workspace/repos/{org} && "
|
||||
f"cd /tmp/workspace/repos/{org} && "
|
||||
f"git clone {repo_url} {repo}"
|
||||
)
|
||||
return await self.sandbox.execute(full_command, timeout_seconds=DEFAULT_TASK_TIMEOUT_SECONDS)
|
||||
|
||||
def get_task_command(self) -> str:
|
||||
"""Get the command to execute the task."""
|
||||
# TODO: Replace with actual task execution: posthog-cli task run --task-id {self.config.task_id}
|
||||
return "posthog-cli --help"
|
||||
logger.info(f"Cloning repository {repository} to {target_path} in sandbox {self.sandbox.id}")
|
||||
return await self.sandbox.execute(clone_command, timeout_seconds=5 * 60)
|
||||
|
||||
async def setup_repository(self, repository: str) -> ExecutionResult:
|
||||
"""Setup a repository for snapshotting using the PostHog Code Agent."""
|
||||
if not self.sandbox.is_running:
|
||||
raise RuntimeError(f"Sandbox not in running state. Current status: {self.sandbox.status}")
|
||||
|
||||
org, repo = repository.lower().split("/")
|
||||
repo_path = f"/tmp/workspace/repos/{org}/{repo}"
|
||||
|
||||
check_result = await self.sandbox.execute(f"test -d {repo_path} && echo 'exists' || echo 'missing'")
|
||||
if "missing" in check_result.stdout:
|
||||
raise RuntimeError(f"Repository path {repo_path} does not exist. Clone the repository first.")
|
||||
|
||||
setup_command = f"cd {repo_path} && {self._get_setup_command(repo_path)}"
|
||||
|
||||
logger.info(f"Running code agent setup for {repository} in sandbox {self.sandbox.id}")
|
||||
return await self.sandbox.execute(setup_command, timeout_seconds=15 * 60)
|
||||
|
||||
async def execute_task(self, task_id: str, repository: str) -> ExecutionResult:
|
||||
if not self.sandbox.is_running:
|
||||
raise RuntimeError(f"Sandbox not in running state. Current status: {self.sandbox.status}")
|
||||
|
||||
org, repo = repository.lower().split("/")
|
||||
repo_path = f"/tmp/workspace/repos/{org}/{repo}"
|
||||
|
||||
command = f"cd {repo_path} && {self._get_task_command(task_id)}"
|
||||
|
||||
logger.info(f"Executing task {task_id} in {repo_path} in sandbox {self.sandbox.id}")
|
||||
return await self.sandbox.execute(command, timeout_seconds=DEFAULT_TASK_TIMEOUT_SECONDS)
|
||||
|
||||
# TODO: Replace these once our coding agent is ready
|
||||
def _get_task_command(self, task_id: str) -> str:
|
||||
# return f"npx @posthog/code-agent@latest --yes --task-id {task_id}"
|
||||
return f"export ANTHROPIC_API_KEY={settings.ANTHROPIC_API_KEY} && claude --dangerously-skip-permissions -p 'replace the readme with an ice cream cone'"
|
||||
|
||||
def _get_setup_command(self, repo_path: str) -> str:
|
||||
# return f"npx @posthog/code-agent@latest --yes --prompt '{SETUP_REPOSITORY_PROMPT.format(cwd=repo_path, repository=repo_path)}'"
|
||||
return f"export ANTHROPIC_API_KEY={settings.ANTHROPIC_API_KEY} && claude --dangerously-skip-permissions -p '{SETUP_REPOSITORY_PROMPT.format(cwd=repo_path, repository=repo_path)}'"
|
||||
|
||||
async def destroy(self) -> None:
|
||||
"""Destroy the underlying sandbox."""
|
||||
await self.sandbox.destroy()
|
||||
|
||||
async def __aenter__(self):
|
||||
@@ -105,14 +104,6 @@ class SandboxAgent:
|
||||
async def __aexit__(self, exc_type, exc_val, exc_tb):
|
||||
await self.destroy()
|
||||
|
||||
@property
|
||||
def working_dir(self) -> str:
|
||||
return WORKING_DIR
|
||||
|
||||
@property
|
||||
def repository_dir(self) -> str:
|
||||
return f"{WORKING_DIR}/{REPOSITORY_TARGET_DIR}"
|
||||
|
||||
@property
|
||||
def is_running(self) -> bool:
|
||||
return self.sandbox.is_running
|
||||
|
||||
@@ -3,8 +3,22 @@ import logging
|
||||
from enum import Enum
|
||||
from typing import Optional
|
||||
|
||||
from asgiref.sync import sync_to_async
|
||||
from pydantic import BaseModel
|
||||
from runloop_api_client import AsyncRunloop
|
||||
from runloop_api_client import (
|
||||
AsyncRunloop,
|
||||
BadRequestError as RunloopBadRequestError,
|
||||
NotFoundError as RunloopNotFoundError,
|
||||
)
|
||||
|
||||
from products.tasks.backend.models import SandboxSnapshot
|
||||
from products.tasks.backend.temporal.exceptions import (
|
||||
SandboxCleanupError,
|
||||
SandboxExecutionError,
|
||||
SandboxNotFoundError,
|
||||
SandboxProvisionError,
|
||||
SnapshotCreationError,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -20,8 +34,13 @@ class SandboxEnvironmentStatus(str, Enum):
|
||||
SHUTDOWN = "shutdown"
|
||||
|
||||
|
||||
class SandboxEnvironmentSnapshotStatus(str, Enum):
|
||||
IN_PROGRESS = "in_progress"
|
||||
COMPLETE = "complete"
|
||||
ERROR = "error"
|
||||
|
||||
|
||||
class SandboxEnvironmentTemplate(str, Enum):
|
||||
UBUNTU_LATEST_X86_64 = "ubuntu_latest_x86_64"
|
||||
DEFAULT_BASE = "default_base"
|
||||
|
||||
|
||||
@@ -38,6 +57,9 @@ class SandboxEnvironmentConfig(BaseModel):
|
||||
default_execution_timeout_seconds: int = 10 * 60 # 10 minutes
|
||||
environment_variables: Optional[dict[str, str]] = None
|
||||
entrypoint: Optional[str] = None
|
||||
snapshot_id: Optional[str] = None
|
||||
ttl_seconds: int = 60 * 30 # 30 minutes
|
||||
metadata: Optional[dict[str, str]] = None
|
||||
|
||||
|
||||
def get_runloop_client() -> AsyncRunloop:
|
||||
@@ -49,7 +71,6 @@ def get_runloop_client() -> AsyncRunloop:
|
||||
|
||||
TEMPLATE_TO_BLUEPRINT_NAME = {
|
||||
SandboxEnvironmentTemplate.DEFAULT_BASE: "sandbox-base-1",
|
||||
SandboxEnvironmentTemplate.DEFAULT_BASE: "bpt_318UuXYGZbyYyl12hArAL",
|
||||
}
|
||||
|
||||
BLUEPRINT_NAME_TO_TEMPLATE = {v: k for k, v in TEMPLATE_TO_BLUEPRINT_NAME.items()}
|
||||
@@ -79,23 +100,44 @@ class SandboxEnvironment:
|
||||
blueprint_name = TEMPLATE_TO_BLUEPRINT_NAME.get(config.template)
|
||||
|
||||
if not blueprint_name:
|
||||
raise RuntimeError(f"Unknown template for sandbox {config.name}")
|
||||
raise SandboxProvisionError(
|
||||
f"Unknown template for sandbox {config.name}", {"template": str(config.template), "config": config}
|
||||
)
|
||||
|
||||
snapshot_external_id = None
|
||||
|
||||
if config.snapshot_id:
|
||||
snapshot = await sync_to_async(SandboxSnapshot.objects.get)(id=config.snapshot_id)
|
||||
|
||||
if snapshot.status == SandboxSnapshot.Status.COMPLETE:
|
||||
snapshot_external_id = snapshot.external_id
|
||||
|
||||
try:
|
||||
# Wait for devbox to be running before returning
|
||||
devbox = await client.devboxes.create_and_await_running(
|
||||
name=config.name,
|
||||
blueprint_name=blueprint_name,
|
||||
environment_variables=config.environment_variables or {},
|
||||
entrypoint=config.entrypoint,
|
||||
)
|
||||
create_kwargs = {
|
||||
"name": config.name,
|
||||
"environment_variables": config.environment_variables or {},
|
||||
"entrypoint": config.entrypoint,
|
||||
"metadata": config.metadata or {},
|
||||
"launch_parameters": {
|
||||
"keep_alive_time_seconds": config.ttl_seconds,
|
||||
},
|
||||
}
|
||||
if snapshot_external_id:
|
||||
create_kwargs["snapshot_id"] = snapshot_external_id
|
||||
else:
|
||||
create_kwargs["blueprint_name"] = blueprint_name
|
||||
|
||||
devbox = await client.devboxes.create_and_await_running(**create_kwargs) # type: ignore[arg-type]
|
||||
|
||||
except Exception as e:
|
||||
logger.exception(f"Failed to create sandbox: {e}")
|
||||
raise RuntimeError(f"Failed to create sandbox: {e}")
|
||||
raise SandboxProvisionError(f"Failed to create sandbox", {"config": config, "error": str(e)})
|
||||
|
||||
sandbox = SandboxEnvironment(id=devbox.id, status=SandboxEnvironmentStatus(devbox.status), config=config)
|
||||
|
||||
assert sandbox.is_running
|
||||
|
||||
logger.info(f"Created sandbox {sandbox.id} with status: {devbox.status}")
|
||||
|
||||
return sandbox
|
||||
@@ -107,15 +149,11 @@ class SandboxEnvironment:
|
||||
try:
|
||||
devbox = await client.devboxes.retrieve(sandbox_id)
|
||||
|
||||
if not devbox.blueprint_id:
|
||||
raise RuntimeError(f"Unknown template for sandbox {sandbox_id}")
|
||||
template = SandboxEnvironmentTemplate.DEFAULT_BASE
|
||||
|
||||
blueprint = await client.blueprints.retrieve(devbox.blueprint_id)
|
||||
|
||||
template = BLUEPRINT_NAME_TO_TEMPLATE[blueprint.name]
|
||||
|
||||
if not template:
|
||||
raise RuntimeError(f"Unknown template for sandbox {sandbox_id}")
|
||||
if devbox.blueprint_id:
|
||||
blueprint = await client.blueprints.retrieve(devbox.blueprint_id)
|
||||
template = BLUEPRINT_NAME_TO_TEMPLATE.get(blueprint.name, SandboxEnvironmentTemplate.DEFAULT_BASE)
|
||||
|
||||
config = SandboxEnvironmentConfig(name=devbox.name or f"sandbox-{sandbox_id}", template=template)
|
||||
|
||||
@@ -126,8 +164,14 @@ class SandboxEnvironment:
|
||||
return sandbox
|
||||
|
||||
except Exception as e:
|
||||
logger.exception(f"Failed to retrieve sandbox {sandbox_id}: {e}")
|
||||
raise RuntimeError(f"Failed to retrieve sandbox {sandbox_id}: {e}")
|
||||
if isinstance(e, RunloopNotFoundError | RunloopBadRequestError):
|
||||
if "non-existent-sandbox-id" in str(e) or isinstance(e, RunloopNotFoundError):
|
||||
raise SandboxNotFoundError(
|
||||
f"Sandbox {sandbox_id} not found", {"sandbox_id": sandbox_id, "error": str(e)}
|
||||
)
|
||||
raise SandboxProvisionError(
|
||||
f"Failed to retrieve sandbox {sandbox_id}", {"sandbox_id": sandbox_id, "error": str(e)}
|
||||
)
|
||||
|
||||
async def execute(
|
||||
self,
|
||||
@@ -135,7 +179,10 @@ class SandboxEnvironment:
|
||||
timeout_seconds: Optional[int] = None,
|
||||
) -> ExecutionResult:
|
||||
if not self.is_running:
|
||||
raise RuntimeError(f"Sandbox not in running state. Current status: {self.status}")
|
||||
raise SandboxExecutionError(
|
||||
f"Sandbox not in running state. Current status: {self.status}",
|
||||
{"sandbox_id": self.id, "status": str(self.status)},
|
||||
)
|
||||
|
||||
if timeout_seconds is None:
|
||||
timeout_seconds = self.config.default_execution_timeout_seconds
|
||||
@@ -157,15 +204,59 @@ class SandboxEnvironment:
|
||||
result = ExecutionResult(
|
||||
stdout=final_execution.stdout,
|
||||
stderr=final_execution.stderr,
|
||||
exit_code=final_execution.exit_status,
|
||||
exit_code=final_execution.exit_status or 0,
|
||||
error=getattr(final_execution, "error", None),
|
||||
)
|
||||
|
||||
if result.error:
|
||||
logger.warning(f"Command execution had error: {result.error}")
|
||||
|
||||
return result
|
||||
|
||||
async def initiate_snapshot(self, metadata: Optional[dict[str, str]] = None) -> str:
|
||||
if not self.is_running:
|
||||
raise SandboxExecutionError(
|
||||
f"Sandbox not in running state. Current status: {self.status}",
|
||||
{"sandbox_id": self.id, "status": str(self.status)},
|
||||
)
|
||||
|
||||
try:
|
||||
devbox = await self._client.devboxes.retrieve(self.id)
|
||||
|
||||
snapshot = await self._client.devboxes.snapshot_disk_async(devbox.id, metadata=metadata)
|
||||
|
||||
snapshot_id = snapshot.id
|
||||
|
||||
logger.info(f"Initiated snapshot for sandbox {self.id}, snapshot ID: {snapshot_id}")
|
||||
|
||||
return snapshot_id
|
||||
|
||||
except Exception as e:
|
||||
logger.exception(f"Failed to initiate snapshot: {e}")
|
||||
raise SnapshotCreationError(f"Failed to initiate snapshot: {e}", {"sandbox_id": self.id, "error": str(e)})
|
||||
|
||||
@staticmethod
|
||||
async def delete_snapshot(external_id: str) -> None:
|
||||
client = get_runloop_client()
|
||||
logger.info(f"Deleting snapshot {external_id}")
|
||||
await client.devboxes.disk_snapshots.delete(external_id)
|
||||
logger.info(f"Deleted snapshot {external_id}")
|
||||
|
||||
@staticmethod
|
||||
async def get_snapshot_status(external_id: str) -> SandboxEnvironmentSnapshotStatus:
|
||||
try:
|
||||
client = get_runloop_client()
|
||||
|
||||
logger.info(f"Getting snapshot status for {external_id}")
|
||||
|
||||
snapshot = await client.devboxes.disk_snapshots.query_status(external_id)
|
||||
|
||||
logger.info(f"Retrieved snapshot status for {external_id}: {snapshot.status}")
|
||||
|
||||
return SandboxEnvironmentSnapshotStatus(snapshot.status)
|
||||
except Exception as e:
|
||||
logger.exception(f"Failed to get snapshot status: {e}")
|
||||
raise SnapshotCreationError(
|
||||
f"Failed to get snapshot status: {e}", {"external_id": external_id, "error": str(e)}
|
||||
)
|
||||
|
||||
async def destroy(self) -> None:
|
||||
try:
|
||||
await self._client.devboxes.shutdown(self.id)
|
||||
@@ -176,7 +267,7 @@ class SandboxEnvironment:
|
||||
|
||||
except Exception as e:
|
||||
logger.exception(f"Failed to destroy sandbox: {e}")
|
||||
raise RuntimeError(f"Failed to destroy sandbox: {e}")
|
||||
raise SandboxCleanupError(f"Failed to destroy sandbox: {e}", {"sandbox_id": self.id, "error": str(e)})
|
||||
|
||||
async def __aenter__(self):
|
||||
return self
|
||||
@@ -187,3 +278,7 @@ class SandboxEnvironment:
|
||||
@property
|
||||
def is_running(self) -> bool:
|
||||
return self.status == SandboxEnvironmentStatus.RUNNING
|
||||
|
||||
@property
|
||||
def name(self) -> str:
|
||||
return self.config.name
|
||||
|
||||
@@ -1,80 +0,0 @@
|
||||
import os
|
||||
|
||||
import pytest
|
||||
|
||||
from products.tasks.backend.services.sandbox_agent import SandboxAgent, SandboxAgentConfig
|
||||
from products.tasks.backend.services.sandbox_environment import (
|
||||
SandboxEnvironment,
|
||||
SandboxEnvironmentConfig,
|
||||
SandboxEnvironmentTemplate,
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
class TestSandboxAgentIntegration:
|
||||
# We only run these tests when we have a Runloop API key set, we don't want to run them in CI since they create real sandbox environments and are slow.
|
||||
@pytest.fixture(scope="class", autouse=True)
|
||||
def check_api_key(self):
|
||||
if not os.environ.get("RUNLOOP_API_KEY"):
|
||||
pytest.skip("RUNLOOP_API_KEY not set, skipping integration tests")
|
||||
|
||||
@pytest.fixture
|
||||
def mock_github_token(self):
|
||||
"""Provide a mock GitHub token for testing."""
|
||||
return "ghp_mock_token_for_testing_12345678901234567890"
|
||||
|
||||
@pytest.fixture
|
||||
def mock_posthog_credentials(self):
|
||||
"""Provide mock PostHog credentials for testing."""
|
||||
return {"personal_api_key": "phx_mock_personal_api_key_123456789", "project_id": "test-project-id-123"}
|
||||
|
||||
@pytest.fixture
|
||||
def public_repo_url(self):
|
||||
"""Use a small public repository for testing."""
|
||||
return "https://github.com/octocat/Hello-World"
|
||||
|
||||
async def test_complete_sandbox_agent_workflow(self, mock_github_token, public_repo_url, mock_posthog_credentials):
|
||||
"""Comprehensive test covering agent lifecycle, repo cloning, and PostHog CLI execution."""
|
||||
sandbox_config = SandboxEnvironmentConfig(
|
||||
name="posthog-agent-test-complete", template=SandboxEnvironmentTemplate.DEFAULT_BASE
|
||||
)
|
||||
sandbox = await SandboxEnvironment.create(sandbox_config)
|
||||
|
||||
agent_config = SandboxAgentConfig(
|
||||
repository_url=public_repo_url,
|
||||
github_token=mock_github_token,
|
||||
task_id="test",
|
||||
posthog_personal_api_key=mock_posthog_credentials["personal_api_key"],
|
||||
posthog_project_id=mock_posthog_credentials["project_id"],
|
||||
)
|
||||
|
||||
async with await SandboxAgent.create(sandbox, agent_config) as agent:
|
||||
assert agent.id is not None
|
||||
assert agent.is_running
|
||||
assert agent.working_dir == "/tmp/workspace"
|
||||
assert agent.repository_dir == "/tmp/workspace/repo"
|
||||
|
||||
setup_result = await agent.setup_repository()
|
||||
assert setup_result.exit_code == 0
|
||||
assert setup_result.error is None
|
||||
|
||||
check_result = await agent.sandbox.execute("ls -la /tmp/workspace/repo")
|
||||
assert check_result.exit_code == 0
|
||||
assert ".git" in check_result.stdout
|
||||
|
||||
env_check = await agent.sandbox.execute("printenv")
|
||||
|
||||
assert "REPOSITORY_URL" in env_check.stdout
|
||||
assert "POSTHOG_CLI_TOKEN" in env_check.stdout
|
||||
assert "POSTHOG_CLI_ENV_ID" in env_check.stdout
|
||||
|
||||
cli_result = await agent.execute_task()
|
||||
assert cli_result.exit_code == 0
|
||||
|
||||
assert "posthog-cli" in cli_result.stdout.lower() or "usage" in cli_result.stdout.lower()
|
||||
|
||||
context_result = await agent.sandbox.execute(f"cd {agent.repository_dir} && pwd")
|
||||
assert context_result.exit_code == 0
|
||||
assert agent.repository_dir in context_result.stdout
|
||||
|
||||
assert not agent.is_running
|
||||
@@ -7,6 +7,22 @@ from .github_activities import (
|
||||
create_pr_activity,
|
||||
create_pr_and_update_task_activity,
|
||||
)
|
||||
from .process_task.activities import (
|
||||
check_snapshot_exists_for_repository,
|
||||
cleanup_personal_api_key,
|
||||
cleanup_sandbox,
|
||||
clone_repository,
|
||||
create_sandbox_from_snapshot,
|
||||
create_snapshot,
|
||||
execute_task_in_sandbox,
|
||||
get_sandbox_for_setup,
|
||||
get_task_details,
|
||||
inject_github_token,
|
||||
inject_personal_api_key,
|
||||
setup_repository,
|
||||
track_workflow_event,
|
||||
)
|
||||
from .process_task.workflow import ProcessTaskWorkflow
|
||||
from .workflow_activities import (
|
||||
check_temporal_workflow_permissions_activity,
|
||||
execute_agent_for_transition_activity,
|
||||
@@ -20,6 +36,7 @@ from .workflows import WorkflowAgnosticTaskProcessingWorkflow
|
||||
|
||||
WORKFLOWS = [
|
||||
WorkflowAgnosticTaskProcessingWorkflow,
|
||||
ProcessTaskWorkflow,
|
||||
]
|
||||
|
||||
ACTIVITIES = [
|
||||
@@ -39,4 +56,18 @@ ACTIVITIES = [
|
||||
move_task_to_stage_activity,
|
||||
trigger_task_processing_activity,
|
||||
should_trigger_agent_workflow_activity,
|
||||
# process_task activities
|
||||
get_task_details,
|
||||
check_snapshot_exists_for_repository,
|
||||
get_sandbox_for_setup,
|
||||
clone_repository,
|
||||
inject_github_token,
|
||||
inject_personal_api_key,
|
||||
setup_repository,
|
||||
create_snapshot,
|
||||
create_sandbox_from_snapshot,
|
||||
execute_task_in_sandbox,
|
||||
cleanup_personal_api_key,
|
||||
cleanup_sandbox,
|
||||
track_workflow_event,
|
||||
]
|
||||
|
||||
@@ -1,36 +1,44 @@
|
||||
import time
|
||||
import uuid
|
||||
import asyncio
|
||||
import logging
|
||||
from typing import Optional
|
||||
|
||||
import posthoganalytics
|
||||
from temporalio.common import RetryPolicy, WorkflowIDReusePolicy
|
||||
|
||||
from posthog.constants import TASKS_TASK_QUEUE
|
||||
from posthog.models.team.team import Team
|
||||
from posthog.models.user import User
|
||||
from posthog.temporal.common.client import async_connect
|
||||
|
||||
from .inputs import TaskProcessingInputs
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
async def _execute_task_processing_workflow(task_id: str, team_id: int, user_id: Optional[int] = None) -> str:
|
||||
"""Execute the task processing workflow asynchronously."""
|
||||
|
||||
inputs = TaskProcessingInputs(task_id=task_id, team_id=team_id, user_id=user_id)
|
||||
|
||||
# Create unique workflow ID based on task and timestamp
|
||||
import time
|
||||
import uuid
|
||||
import logging
|
||||
|
||||
# Use high-resolution timestamp + random suffix to avoid collisions when re-triggering within the same second
|
||||
async def _execute_task_processing_workflow(
|
||||
task_id: str, team_id: int, user_id: Optional[int] = None, use_sandbox: bool = False
|
||||
) -> str:
|
||||
workflow_id = f"task-processing-{task_id}-{int(time.time()*1000)}-{uuid.uuid4().hex[:8]}"
|
||||
|
||||
logging.getLogger(__name__).info(f"Starting workflow {workflow_id} for task {task_id}")
|
||||
workflow_input: str | TaskProcessingInputs
|
||||
if use_sandbox:
|
||||
workflow_name = "process-task"
|
||||
workflow_input = task_id
|
||||
else:
|
||||
workflow_name = "process-task-workflow-agnostic"
|
||||
workflow_input = TaskProcessingInputs(task_id=task_id, team_id=team_id, user_id=user_id)
|
||||
|
||||
logger.info(f"Starting workflow {workflow_name} ({workflow_id}) for task {task_id}")
|
||||
|
||||
client = await async_connect()
|
||||
|
||||
retry_policy = RetryPolicy(maximum_attempts=3)
|
||||
|
||||
result = await client.execute_workflow(
|
||||
"process-task-workflow-agnostic",
|
||||
inputs,
|
||||
workflow_name,
|
||||
workflow_input,
|
||||
id=workflow_id,
|
||||
id_reuse_policy=WorkflowIDReusePolicy.ALLOW_DUPLICATE_FAILED_ONLY,
|
||||
task_queue=TASKS_TASK_QUEUE,
|
||||
@@ -47,21 +55,14 @@ def execute_task_processing_workflow(task_id: str, team_id: int, user_id: Option
|
||||
but doesn't wait for completion.
|
||||
"""
|
||||
try:
|
||||
import logging
|
||||
import threading
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Always offload to a dedicated thread with its own event loop.
|
||||
# This is safer when called from within a Temporal activity (already running an event loop)
|
||||
# and from sync Django views. It avoids create_task() being cancelled when the caller loop ends.
|
||||
def run_workflow() -> None:
|
||||
try:
|
||||
# Check feature flag in the thread where we can make sync Django calls
|
||||
import posthoganalytics
|
||||
|
||||
from posthog.models.team.team import Team
|
||||
from posthog.models.user import User
|
||||
|
||||
try:
|
||||
if not user_id:
|
||||
@@ -93,8 +94,20 @@ def execute_task_processing_workflow(task_id: str, team_id: int, user_id: Option
|
||||
logger.exception(f"Error checking feature flag for task workflow: {e}")
|
||||
return
|
||||
|
||||
logger.info(f"Triggering workflow for task {task_id}")
|
||||
asyncio.run(_execute_task_processing_workflow(task_id, team_id, user_id))
|
||||
# Check feature flag for sandbox-based workflow
|
||||
use_sandbox = posthoganalytics.feature_enabled(
|
||||
"tasks-sandbox",
|
||||
user.distinct_id,
|
||||
groups={"organization": str(team.organization.id)},
|
||||
group_properties={"organization": {"id": str(team.organization.id)}},
|
||||
only_evaluate_locally=False,
|
||||
send_feature_flag_events=False,
|
||||
)
|
||||
|
||||
logger.info(
|
||||
f"Triggering workflow for task {task_id} (sandbox: {use_sandbox}, workflow: {'process-task' if use_sandbox else 'process-task-workflow-agnostic'})"
|
||||
)
|
||||
asyncio.run(_execute_task_processing_workflow(task_id, team_id, user_id, use_sandbox=use_sandbox))
|
||||
logger.info(f"Workflow completed for task {task_id}")
|
||||
except Exception as e:
|
||||
logger.exception(f"Workflow execution failed for task {task_id}: {e}")
|
||||
@@ -105,8 +118,5 @@ def execute_task_processing_workflow(task_id: str, team_id: int, user_id: Option
|
||||
|
||||
except Exception as e:
|
||||
# Don't let workflow execution failures break the main operation
|
||||
import logging
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
logger.exception(f"Failed to execute task processing workflow: {e}")
|
||||
# Don't re-raise to avoid breaking the API call
|
||||
|
||||
161
products/tasks/backend/temporal/conftest.py
Normal file
161
products/tasks/backend/temporal/conftest.py
Normal file
@@ -0,0 +1,161 @@
|
||||
import random
|
||||
|
||||
import pytest
|
||||
|
||||
from asgiref.sync import sync_to_async
|
||||
from temporalio.testing import ActivityEnvironment
|
||||
|
||||
from posthog.models import Organization, OrganizationMembership, Team
|
||||
from posthog.models.integration import Integration
|
||||
from posthog.models.user import User
|
||||
from posthog.temporal.common.logger import configure_logger
|
||||
|
||||
from products.tasks.backend.models import Task, TaskWorkflow, WorkflowStage
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def activity_environment():
|
||||
"""Return a testing temporal ActivityEnvironment."""
|
||||
return ActivityEnvironment()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def organization():
|
||||
"""A test organization."""
|
||||
name = f"TasksTestOrg-{random.randint(1, 99999)}"
|
||||
org = Organization.objects.create(name=name, is_ai_data_processing_approved=True)
|
||||
org.save()
|
||||
|
||||
yield org
|
||||
|
||||
org.delete()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def team(organization):
|
||||
"""A test team."""
|
||||
name = f"TasksTestTeam-{random.randint(1, 99999)}"
|
||||
team = Team.objects.create(organization=organization, name=name)
|
||||
team.save()
|
||||
|
||||
yield team
|
||||
|
||||
team.delete()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
async def aorganization():
|
||||
"""Async test organization."""
|
||||
name = f"TasksTestOrg-{random.randint(1, 99999)}"
|
||||
org = await sync_to_async(Organization.objects.create)(name=name, is_ai_data_processing_approved=True)
|
||||
|
||||
yield org
|
||||
|
||||
await sync_to_async(org.delete)()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
async def ateam(aorganization):
|
||||
"""Async test team."""
|
||||
name = f"TasksTestTeam-{random.randint(1, 99999)}"
|
||||
team = await sync_to_async(Team.objects.create)(organization=aorganization, name=name)
|
||||
|
||||
yield team
|
||||
|
||||
await sync_to_async(team.delete)()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
async def task_workflow(ateam):
|
||||
"""Create a test workflow with stages."""
|
||||
workflow = await sync_to_async(TaskWorkflow.objects.create)(
|
||||
team=ateam,
|
||||
name="Test Workflow",
|
||||
description="Test workflow for temporal activities",
|
||||
is_default=True,
|
||||
is_active=True,
|
||||
)
|
||||
|
||||
stages = []
|
||||
for i, (name, key, color) in enumerate(
|
||||
[
|
||||
("Backlog", "backlog", "#6b7280"),
|
||||
("Ready", "ready", "#3b82f6"),
|
||||
("In Progress", "in_progress", "#10b981"),
|
||||
("Done", "done", "#22c55e"),
|
||||
]
|
||||
):
|
||||
stage = await sync_to_async(WorkflowStage.objects.create)(
|
||||
workflow=workflow,
|
||||
name=name,
|
||||
key=key,
|
||||
position=i,
|
||||
color=color,
|
||||
is_manual_only=(i != 2), # Only "In Progress" is not manual
|
||||
agent_name="claude_code_agent" if i == 2 else None,
|
||||
)
|
||||
stages.append(stage)
|
||||
|
||||
yield workflow, stages
|
||||
|
||||
await sync_to_async(workflow.delete)()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
async def github_integration(ateam):
|
||||
"""Create a test GitHub integration."""
|
||||
integration = await sync_to_async(Integration.objects.create)(
|
||||
team=ateam,
|
||||
kind="github",
|
||||
sensitive_config={"access_token": "fake_token"},
|
||||
)
|
||||
|
||||
yield integration
|
||||
|
||||
await sync_to_async(integration.delete)()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
async def auser(ateam):
|
||||
user = await sync_to_async(User.objects.create)(
|
||||
email=f"test-{random.randint(1, 99999)}@example.com",
|
||||
password="testpassword123",
|
||||
)
|
||||
|
||||
await sync_to_async(OrganizationMembership.objects.create)(
|
||||
user=user,
|
||||
organization_id=ateam.organization_id,
|
||||
)
|
||||
|
||||
yield user
|
||||
await sync_to_async(user.delete)()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
async def test_task(ateam, auser, task_workflow, github_integration):
|
||||
"""Create a test task."""
|
||||
workflow, stages = task_workflow
|
||||
backlog_stage = stages[0]
|
||||
|
||||
task = await sync_to_async(Task.objects.create)(
|
||||
team=ateam,
|
||||
created_by=auser,
|
||||
title="Test Task for Temporal Activities",
|
||||
description="This is a test task for testing temporal activities",
|
||||
origin_product=Task.OriginProduct.USER_CREATED,
|
||||
workflow=workflow,
|
||||
current_stage=backlog_stage,
|
||||
position=0,
|
||||
github_integration=github_integration,
|
||||
repository_config={"organization": "PostHog", "repository": "posthog-js"},
|
||||
)
|
||||
|
||||
yield task
|
||||
|
||||
await sync_to_async(task.delete)()
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def configure_logger_auto() -> None:
|
||||
"""Configure logger when running in a Temporal activity environment."""
|
||||
configure_logger(cache_logger_on_first_use=False)
|
||||
127
products/tasks/backend/temporal/exceptions.py
Normal file
127
products/tasks/backend/temporal/exceptions.py
Normal file
@@ -0,0 +1,127 @@
|
||||
from typing import Optional
|
||||
|
||||
from temporalio.exceptions import ApplicationError
|
||||
|
||||
|
||||
class ProcessTaskError(ApplicationError):
|
||||
def __init__(self, message: str, context: Optional[dict] = None, **kwargs):
|
||||
self.context = context or {}
|
||||
super().__init__(message, self.context, **kwargs)
|
||||
|
||||
|
||||
class ProcessTaskFatalError(ProcessTaskError):
|
||||
"""Fatal errors that should not be retried."""
|
||||
|
||||
def __init__(self, message: str, context: Optional[dict] = None):
|
||||
super().__init__(message, context, non_retryable=True)
|
||||
|
||||
|
||||
class ProcessTaskTransientError(ProcessTaskError):
|
||||
"""Transient errors that may succeed on retry."""
|
||||
|
||||
def __init__(self, message: str, context: Optional[dict] = None):
|
||||
super().__init__(message, context, non_retryable=False)
|
||||
|
||||
|
||||
class TaskNotFoundError(ProcessTaskFatalError):
|
||||
pass
|
||||
|
||||
|
||||
class TaskInvalidStateError(ProcessTaskFatalError):
|
||||
pass
|
||||
|
||||
|
||||
class SandboxProvisionError(ProcessTaskTransientError):
|
||||
"""Failed to provision sandbox environment."""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
class SandboxNotFoundError(ProcessTaskFatalError):
|
||||
"""Sandbox does not exist."""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
class SandboxExecutionError(ProcessTaskTransientError):
|
||||
"""Error during sandbox command execution."""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
class SandboxTimeoutError(ProcessTaskTransientError):
|
||||
"""Sandbox operation timed out."""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
class SandboxCleanupError(ProcessTaskTransientError):
|
||||
"""Error during sandbox cleanup/destruction."""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
class SnapshotNotFoundError(ProcessTaskTransientError):
|
||||
"""Snapshot does not exist."""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
class SnapshotNotReadyError(ProcessTaskTransientError):
|
||||
"""Snapshot exists but is not ready for use."""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
class SnapshotCreationError(ProcessTaskTransientError):
|
||||
"""Failed to create snapshot."""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
class RepositoryCloneError(ProcessTaskTransientError):
|
||||
"""Failed to clone repository."""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
class RepositorySetupError(ProcessTaskTransientError):
|
||||
"""Failed to setup repository (install dependencies, etc)."""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
class GitHubIntegrationError(ProcessTaskFatalError):
|
||||
"""GitHub integration not found or invalid."""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
class GitHubAuthenticationError(ProcessTaskFatalError):
|
||||
"""Failed to authenticate with GitHub."""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
class PersonalAPIKeyError(ProcessTaskTransientError):
|
||||
"""Failed to create or inject personal API key."""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
class TaskExecutionFailedError(ProcessTaskError):
|
||||
"""Task execution completed but with non-zero exit code."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
message: str,
|
||||
exit_code: int,
|
||||
stdout: str = "",
|
||||
stderr: str = "",
|
||||
context: Optional[dict] = None,
|
||||
non_retryable: bool = False,
|
||||
):
|
||||
self.exit_code = exit_code
|
||||
self.stdout = stdout
|
||||
self.stderr = stderr
|
||||
super().__init__(message, context, non_retryable=non_retryable)
|
||||
157
products/tasks/backend/temporal/observability.py
Normal file
157
products/tasks/backend/temporal/observability.py
Normal file
@@ -0,0 +1,157 @@
|
||||
from collections.abc import AsyncIterator
|
||||
from contextlib import asynccontextmanager
|
||||
from typing import Any, Optional
|
||||
|
||||
import posthoganalytics
|
||||
from temporalio import activity, workflow
|
||||
|
||||
from posthog.temporal.common.logger import get_logger
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
def get_bound_logger(**context: Any):
|
||||
return logger.bind(**context)
|
||||
|
||||
|
||||
def log_with_activity_context(message: str, **extra_context: Any) -> None:
|
||||
bound_logger = logger.bind(**extra_context)
|
||||
|
||||
if activity.in_activity():
|
||||
info = activity.info()
|
||||
bound_logger = bound_logger.bind(
|
||||
activity_id=info.activity_id,
|
||||
activity_type=info.activity_type,
|
||||
attempt=info.attempt,
|
||||
)
|
||||
|
||||
bound_logger.info(message)
|
||||
|
||||
|
||||
def log_with_workflow_context(message: str, **extra_context: Any) -> None:
|
||||
bound_logger = logger.bind(**extra_context)
|
||||
|
||||
if workflow.in_workflow():
|
||||
info = workflow.info()
|
||||
bound_logger = bound_logger.bind(
|
||||
workflow_id=info.workflow_id,
|
||||
workflow_run_id=info.run_id,
|
||||
workflow_type=info.workflow_type,
|
||||
)
|
||||
|
||||
bound_logger.info(message)
|
||||
|
||||
|
||||
@asynccontextmanager
|
||||
async def log_activity_execution(
|
||||
activity_name: str,
|
||||
distinct_id: Optional[str] = None,
|
||||
**context: Any,
|
||||
) -> AsyncIterator[None]:
|
||||
"""Context manager for activity execution with automatic logging and analytics.
|
||||
|
||||
Automatically tracks:
|
||||
- process_task_activity_started
|
||||
- process_task_activity_completed
|
||||
- process_task_activity_failed
|
||||
|
||||
Usage:
|
||||
async with log_activity_execution(
|
||||
"clone_repository",
|
||||
distinct_id=f"user_{user_id}",
|
||||
task_id=task_id,
|
||||
repository=repo
|
||||
):
|
||||
result = await do_work()
|
||||
return result
|
||||
"""
|
||||
bound_logger = logger.bind(**context)
|
||||
|
||||
if activity.in_activity():
|
||||
info = activity.info()
|
||||
bound_logger = bound_logger.bind(
|
||||
activity_id=info.activity_id,
|
||||
activity_type=info.activity_type,
|
||||
attempt=info.attempt,
|
||||
)
|
||||
|
||||
bound_logger.info(f"{activity_name} started")
|
||||
|
||||
if distinct_id:
|
||||
track_event(
|
||||
"process_task_activity_started",
|
||||
distinct_id=distinct_id,
|
||||
properties={"activity_name": activity_name, **context},
|
||||
)
|
||||
|
||||
try:
|
||||
yield
|
||||
bound_logger.info(f"{activity_name} completed successfully")
|
||||
|
||||
if distinct_id:
|
||||
track_event(
|
||||
"process_task_activity_completed",
|
||||
distinct_id=distinct_id,
|
||||
properties={"activity_name": activity_name, **context},
|
||||
)
|
||||
except Exception as e:
|
||||
bound_logger.exception(
|
||||
f"{activity_name} failed",
|
||||
error_type=type(e).__name__,
|
||||
error_message=str(e),
|
||||
)
|
||||
|
||||
if distinct_id:
|
||||
track_event(
|
||||
"process_task_activity_failed",
|
||||
distinct_id=distinct_id,
|
||||
properties={
|
||||
"activity_name": activity_name,
|
||||
"error_type": type(e).__name__,
|
||||
"error_message": str(e)[:500],
|
||||
**context,
|
||||
},
|
||||
)
|
||||
|
||||
raise
|
||||
|
||||
|
||||
def track_event(
|
||||
event_name: str,
|
||||
distinct_id: str,
|
||||
properties: Optional[dict[str, Any]] = None,
|
||||
) -> None:
|
||||
try:
|
||||
enriched_properties = {**(properties or {})}
|
||||
|
||||
if activity.in_activity():
|
||||
activity_info = activity.info()
|
||||
enriched_properties.update(
|
||||
{
|
||||
"temporal_activity_id": activity_info.activity_id,
|
||||
"temporal_activity_type": activity_info.activity_type,
|
||||
"temporal_workflow_id": activity_info.workflow_id,
|
||||
"temporal_workflow_run_id": activity_info.workflow_run_id,
|
||||
"temporal_attempt": activity_info.attempt,
|
||||
}
|
||||
)
|
||||
elif workflow.in_workflow() and not workflow.unsafe.is_replaying():
|
||||
workflow_info = workflow.info()
|
||||
enriched_properties.update(
|
||||
{
|
||||
"temporal_workflow_id": workflow_info.workflow_id,
|
||||
"temporal_workflow_run_id": workflow_info.run_id,
|
||||
"temporal_workflow_type": workflow_info.workflow_type,
|
||||
}
|
||||
)
|
||||
|
||||
posthoganalytics.capture(
|
||||
distinct_id=distinct_id,
|
||||
event=event_name,
|
||||
properties=enriched_properties,
|
||||
)
|
||||
|
||||
logger.debug(f"Tracked event: {event_name}", **enriched_properties)
|
||||
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to track event {event_name}", exc_info=e)
|
||||
1
products/tasks/backend/temporal/process_task/__init__.py
Normal file
1
products/tasks/backend/temporal/process_task/__init__.py
Normal file
@@ -0,0 +1 @@
|
||||
# Agent workflow for executing tasks in sandboxes
|
||||
@@ -0,0 +1,29 @@
|
||||
from .check_snapshot_exists_for_repository import check_snapshot_exists_for_repository
|
||||
from .cleanup_personal_api_key import cleanup_personal_api_key
|
||||
from .cleanup_sandbox import cleanup_sandbox
|
||||
from .clone_repository import clone_repository
|
||||
from .create_sandbox_from_snapshot import create_sandbox_from_snapshot
|
||||
from .create_snapshot import create_snapshot
|
||||
from .execute_task_in_sandbox import execute_task_in_sandbox
|
||||
from .get_sandbox_for_setup import get_sandbox_for_setup
|
||||
from .get_task_details import get_task_details
|
||||
from .inject_github_token import inject_github_token
|
||||
from .inject_personal_api_key import inject_personal_api_key
|
||||
from .setup_repository import setup_repository
|
||||
from .track_workflow_event import track_workflow_event
|
||||
|
||||
__all__ = [
|
||||
"check_snapshot_exists_for_repository",
|
||||
"cleanup_personal_api_key",
|
||||
"cleanup_sandbox",
|
||||
"clone_repository",
|
||||
"create_sandbox_from_snapshot",
|
||||
"create_snapshot",
|
||||
"execute_task_in_sandbox",
|
||||
"get_sandbox_for_setup",
|
||||
"get_task_details",
|
||||
"inject_github_token",
|
||||
"inject_personal_api_key",
|
||||
"setup_repository",
|
||||
"track_workflow_event",
|
||||
]
|
||||
@@ -0,0 +1,40 @@
|
||||
from dataclasses import dataclass
|
||||
|
||||
from temporalio import activity
|
||||
|
||||
from posthog.temporal.common.utils import asyncify
|
||||
|
||||
from products.tasks.backend.models import SandboxSnapshot
|
||||
from products.tasks.backend.temporal.observability import log_with_activity_context
|
||||
|
||||
|
||||
@dataclass
|
||||
class CheckSnapshotExistsForRepositoryInput:
|
||||
github_integration_id: int
|
||||
repository: str
|
||||
|
||||
|
||||
@dataclass
|
||||
class CheckSnapshotExistsForRepositoryOutput:
|
||||
exists: bool
|
||||
snapshot_id: str | None
|
||||
|
||||
|
||||
@activity.defn
|
||||
@asyncify
|
||||
def check_snapshot_exists_for_repository(
|
||||
input: CheckSnapshotExistsForRepositoryInput,
|
||||
) -> CheckSnapshotExistsForRepositoryOutput:
|
||||
"""Check if a repository exists in the latest complete snapshot."""
|
||||
log_with_activity_context(
|
||||
"Checking if snapshot exists for repository",
|
||||
github_integration_id=input.github_integration_id,
|
||||
repository=input.repository,
|
||||
)
|
||||
|
||||
snapshot = SandboxSnapshot.get_latest_snapshot_with_repos(input.github_integration_id, [input.repository])
|
||||
|
||||
if snapshot:
|
||||
return CheckSnapshotExistsForRepositoryOutput(exists=True, snapshot_id=str(snapshot.id))
|
||||
|
||||
return CheckSnapshotExistsForRepositoryOutput(exists=False, snapshot_id=None)
|
||||
@@ -0,0 +1,10 @@
|
||||
from temporalio import activity
|
||||
|
||||
from posthog.models import PersonalAPIKey
|
||||
from posthog.temporal.common.utils import asyncify
|
||||
|
||||
|
||||
@activity.defn
|
||||
@asyncify
|
||||
def cleanup_personal_api_key(personal_api_key_id: str) -> None:
|
||||
PersonalAPIKey.objects.filter(id=personal_api_key_id).delete()
|
||||
@@ -0,0 +1,28 @@
|
||||
import logging
|
||||
from dataclasses import dataclass
|
||||
|
||||
from temporalio import activity
|
||||
|
||||
from products.tasks.backend.services.sandbox_environment import SandboxEnvironment
|
||||
from products.tasks.backend.temporal.exceptions import SandboxNotFoundError
|
||||
from products.tasks.backend.temporal.observability import log_activity_execution
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@dataclass
|
||||
class CleanupSandboxInput:
|
||||
sandbox_id: str
|
||||
|
||||
|
||||
@activity.defn
|
||||
async def cleanup_sandbox(input: CleanupSandboxInput) -> None:
|
||||
async with log_activity_execution(
|
||||
"cleanup_sandbox",
|
||||
sandbox_id=input.sandbox_id,
|
||||
):
|
||||
try:
|
||||
sandbox = await SandboxEnvironment.get_by_id(input.sandbox_id)
|
||||
await sandbox.destroy()
|
||||
except SandboxNotFoundError:
|
||||
pass
|
||||
@@ -0,0 +1,63 @@
|
||||
from dataclasses import dataclass
|
||||
|
||||
from temporalio import activity
|
||||
|
||||
from products.tasks.backend.services.sandbox_agent import SandboxAgent
|
||||
from products.tasks.backend.services.sandbox_environment import SandboxEnvironment
|
||||
from products.tasks.backend.temporal.exceptions import GitHubAuthenticationError, RepositoryCloneError
|
||||
from products.tasks.backend.temporal.observability import log_activity_execution
|
||||
|
||||
from ..utils import get_github_token
|
||||
|
||||
|
||||
@dataclass
|
||||
class CloneRepositoryInput:
|
||||
sandbox_id: str
|
||||
repository: str
|
||||
github_integration_id: int
|
||||
task_id: str
|
||||
distinct_id: str
|
||||
|
||||
|
||||
@activity.defn
|
||||
async def clone_repository(input: CloneRepositoryInput) -> str:
|
||||
"""Clone repository into sandbox. Idempotent: wipes existing directory. Returns clone logs."""
|
||||
async with log_activity_execution(
|
||||
"clone_repository",
|
||||
distinct_id=input.distinct_id,
|
||||
task_id=input.task_id,
|
||||
sandbox_id=input.sandbox_id,
|
||||
repository=input.repository,
|
||||
):
|
||||
try:
|
||||
github_token = await get_github_token(input.github_integration_id)
|
||||
except Exception as e:
|
||||
raise GitHubAuthenticationError(
|
||||
f"Failed to get GitHub token for integration {input.github_integration_id}",
|
||||
{"github_integration_id": input.github_integration_id, "error": str(e)},
|
||||
)
|
||||
|
||||
sandbox = await SandboxEnvironment.get_by_id(input.sandbox_id)
|
||||
|
||||
agent = SandboxAgent(sandbox)
|
||||
|
||||
try:
|
||||
result = await agent.clone_repository(input.repository, github_token)
|
||||
except Exception as e:
|
||||
raise RepositoryCloneError(
|
||||
f"Failed to clone repository {input.repository}",
|
||||
{"repository": input.repository, "sandbox_id": input.sandbox_id, "error": str(e)},
|
||||
)
|
||||
|
||||
if result.exit_code != 0:
|
||||
raise RepositoryCloneError(
|
||||
f"Git clone failed with exit code {result.exit_code}",
|
||||
{
|
||||
"repository": input.repository,
|
||||
"exit_code": result.exit_code,
|
||||
"stderr": result.stderr[:500],
|
||||
},
|
||||
)
|
||||
|
||||
# NOTE: git clone returns it's output in stderr
|
||||
return result.stderr
|
||||
@@ -0,0 +1,56 @@
|
||||
from dataclasses import dataclass
|
||||
|
||||
from django.core.exceptions import ObjectDoesNotExist
|
||||
|
||||
from asgiref.sync import sync_to_async
|
||||
from temporalio import activity
|
||||
|
||||
from products.tasks.backend.models import SandboxSnapshot
|
||||
from products.tasks.backend.services.sandbox_environment import (
|
||||
SandboxEnvironment,
|
||||
SandboxEnvironmentConfig,
|
||||
SandboxEnvironmentTemplate,
|
||||
)
|
||||
from products.tasks.backend.temporal.exceptions import SnapshotNotFoundError, SnapshotNotReadyError
|
||||
from products.tasks.backend.temporal.observability import log_activity_execution
|
||||
from products.tasks.backend.temporal.process_task.utils import get_sandbox_name_for_task
|
||||
|
||||
|
||||
@dataclass
|
||||
class CreateSandboxFromSnapshotInput:
|
||||
snapshot_id: str
|
||||
task_id: str
|
||||
distinct_id: str
|
||||
|
||||
|
||||
@activity.defn
|
||||
async def create_sandbox_from_snapshot(input: CreateSandboxFromSnapshotInput) -> str:
|
||||
"""Create a sandbox from a snapshot for task execution. Returns sandbox_id when running."""
|
||||
async with log_activity_execution(
|
||||
"create_sandbox_from_snapshot",
|
||||
distinct_id=input.distinct_id,
|
||||
task_id=input.task_id,
|
||||
snapshot_id=input.snapshot_id,
|
||||
):
|
||||
try:
|
||||
snapshot = await sync_to_async(SandboxSnapshot.objects.get)(id=input.snapshot_id)
|
||||
except ObjectDoesNotExist:
|
||||
raise SnapshotNotFoundError(f"Snapshot {input.snapshot_id} not found", {"snapshot_id": input.snapshot_id})
|
||||
|
||||
if snapshot.status != SandboxSnapshot.Status.COMPLETE:
|
||||
raise SnapshotNotReadyError(
|
||||
f"Snapshot {input.snapshot_id} is not ready (status: {snapshot.status})",
|
||||
{"snapshot_id": input.snapshot_id, "status": snapshot.status},
|
||||
)
|
||||
|
||||
config = SandboxEnvironmentConfig(
|
||||
name=get_sandbox_name_for_task(input.task_id),
|
||||
template=SandboxEnvironmentTemplate.DEFAULT_BASE,
|
||||
environment_variables={},
|
||||
snapshot_id=str(snapshot.id),
|
||||
metadata={"task_id": input.task_id},
|
||||
)
|
||||
|
||||
sandbox = await SandboxEnvironment.create(config)
|
||||
|
||||
return sandbox.id
|
||||
@@ -0,0 +1,81 @@
|
||||
import json
|
||||
import asyncio
|
||||
from dataclasses import dataclass
|
||||
|
||||
from asgiref.sync import sync_to_async
|
||||
from temporalio import activity
|
||||
|
||||
from products.tasks.backend.models import SandboxSnapshot
|
||||
from products.tasks.backend.services.sandbox_environment import SandboxEnvironment
|
||||
from products.tasks.backend.temporal.exceptions import SandboxTimeoutError, SnapshotCreationError
|
||||
from products.tasks.backend.temporal.observability import log_activity_execution
|
||||
|
||||
|
||||
@dataclass
|
||||
class CreateSnapshotInput:
|
||||
sandbox_id: str
|
||||
github_integration_id: int
|
||||
team_id: int
|
||||
repository: str
|
||||
task_id: str
|
||||
distinct_id: str
|
||||
|
||||
|
||||
@activity.defn
|
||||
async def create_snapshot(input: CreateSnapshotInput) -> str:
|
||||
"""
|
||||
Create and finalize snapshot. Initiates snapshot, polls until complete,
|
||||
and saves the snapshot record. Returns snapshot_id.
|
||||
"""
|
||||
async with log_activity_execution(
|
||||
"create_snapshot",
|
||||
distinct_id=input.distinct_id,
|
||||
task_id=input.task_id,
|
||||
sandbox_id=input.sandbox_id,
|
||||
repository=input.repository,
|
||||
):
|
||||
base_snapshot = await sync_to_async(SandboxSnapshot.get_latest_snapshot_for_integration)(
|
||||
input.github_integration_id
|
||||
)
|
||||
|
||||
base_repos = base_snapshot.repos if base_snapshot else []
|
||||
new_repos: list[str] = list({*base_repos, input.repository})
|
||||
|
||||
sandbox = await SandboxEnvironment.get_by_id(input.sandbox_id)
|
||||
|
||||
snapshot_external_id = await sandbox.initiate_snapshot(
|
||||
{
|
||||
"integration_id": str(input.github_integration_id),
|
||||
"team_id": str(input.team_id),
|
||||
"repositories": json.dumps(new_repos),
|
||||
"base_snapshot_id": str(base_snapshot.id) if base_snapshot else "",
|
||||
}
|
||||
)
|
||||
|
||||
max_polls = 80
|
||||
for _ in range(max_polls):
|
||||
status = await SandboxEnvironment.get_snapshot_status(snapshot_external_id)
|
||||
|
||||
if status.value == "complete":
|
||||
break
|
||||
elif status.value == "error":
|
||||
raise SnapshotCreationError(
|
||||
"Snapshot creation failed",
|
||||
{"snapshot_external_id": snapshot_external_id, "repository": input.repository},
|
||||
)
|
||||
|
||||
await asyncio.sleep(15)
|
||||
else:
|
||||
raise SandboxTimeoutError(
|
||||
"Snapshot creation timed out after 20 minutes",
|
||||
{"snapshot_external_id": snapshot_external_id, "repository": input.repository},
|
||||
)
|
||||
|
||||
snapshot = await sync_to_async(SandboxSnapshot.objects.create)(
|
||||
integration_id=input.github_integration_id,
|
||||
repos=new_repos,
|
||||
external_id=snapshot_external_id,
|
||||
status=SandboxSnapshot.Status.COMPLETE,
|
||||
)
|
||||
|
||||
return str(snapshot.id)
|
||||
@@ -0,0 +1,64 @@
|
||||
from dataclasses import dataclass
|
||||
from typing import Optional
|
||||
|
||||
from temporalio import activity
|
||||
|
||||
from products.tasks.backend.services.sandbox_agent import SandboxAgent
|
||||
from products.tasks.backend.services.sandbox_environment import SandboxEnvironment
|
||||
from products.tasks.backend.temporal.exceptions import SandboxExecutionError, TaskExecutionFailedError
|
||||
from products.tasks.backend.temporal.observability import log_activity_execution
|
||||
|
||||
|
||||
@dataclass
|
||||
class ExecuteTaskInput:
|
||||
sandbox_id: str
|
||||
task_id: str
|
||||
repository: str
|
||||
distinct_id: str
|
||||
|
||||
|
||||
@dataclass
|
||||
class ExecuteTaskOutput:
|
||||
stdout: str
|
||||
stderr: str
|
||||
exit_code: int
|
||||
error: Optional[str] = None
|
||||
|
||||
|
||||
@activity.defn
|
||||
async def execute_task_in_sandbox(input: ExecuteTaskInput) -> ExecuteTaskOutput:
|
||||
"""Execute the code agent task in the sandbox."""
|
||||
async with log_activity_execution(
|
||||
"execute_task_in_sandbox",
|
||||
distinct_id=input.distinct_id,
|
||||
task_id=input.task_id,
|
||||
sandbox_id=input.sandbox_id,
|
||||
repository=input.repository,
|
||||
):
|
||||
sandbox = await SandboxEnvironment.get_by_id(input.sandbox_id)
|
||||
|
||||
agent = SandboxAgent(sandbox)
|
||||
|
||||
try:
|
||||
result = await agent.execute_task(input.task_id, input.repository)
|
||||
except Exception as e:
|
||||
raise SandboxExecutionError(
|
||||
f"Failed to execute task in sandbox",
|
||||
{"task_id": input.task_id, "sandbox_id": input.sandbox_id, "error": str(e)},
|
||||
)
|
||||
|
||||
if result.exit_code != 0:
|
||||
raise TaskExecutionFailedError(
|
||||
f"Task execution failed with exit code {result.exit_code}",
|
||||
exit_code=result.exit_code,
|
||||
stdout=result.stdout,
|
||||
stderr=result.stderr,
|
||||
context={"task_id": input.task_id, "sandbox_id": input.sandbox_id},
|
||||
)
|
||||
|
||||
return ExecuteTaskOutput(
|
||||
stdout=result.stdout,
|
||||
stderr=result.stderr,
|
||||
exit_code=result.exit_code,
|
||||
error=result.error,
|
||||
)
|
||||
@@ -0,0 +1,48 @@
|
||||
from dataclasses import dataclass
|
||||
|
||||
from asgiref.sync import sync_to_async
|
||||
from temporalio import activity
|
||||
|
||||
from products.tasks.backend.models import SandboxSnapshot
|
||||
from products.tasks.backend.services.sandbox_environment import (
|
||||
SandboxEnvironment,
|
||||
SandboxEnvironmentConfig,
|
||||
SandboxEnvironmentTemplate,
|
||||
)
|
||||
from products.tasks.backend.temporal.observability import log_activity_execution
|
||||
from products.tasks.backend.temporal.process_task.utils import get_sandbox_name_for_task
|
||||
|
||||
|
||||
@dataclass
|
||||
class GetSandboxForSetupInput:
|
||||
github_integration_id: int
|
||||
team_id: int
|
||||
task_id: str
|
||||
distinct_id: str
|
||||
|
||||
|
||||
@activity.defn
|
||||
async def get_sandbox_for_setup(input: GetSandboxForSetupInput) -> str:
|
||||
"""
|
||||
Get sandbox for setup. Searches for existing snapshot to use as base,
|
||||
otherwise uses default template. Returns sandbox_id when sandbox is running.
|
||||
"""
|
||||
async with log_activity_execution(
|
||||
"get_sandbox_for_setup",
|
||||
distinct_id=input.distinct_id,
|
||||
task_id=input.task_id,
|
||||
github_integration_id=input.github_integration_id,
|
||||
):
|
||||
snapshot = await sync_to_async(SandboxSnapshot.get_latest_snapshot_for_integration)(input.github_integration_id)
|
||||
|
||||
config = SandboxEnvironmentConfig(
|
||||
name=get_sandbox_name_for_task(input.task_id),
|
||||
template=SandboxEnvironmentTemplate.DEFAULT_BASE,
|
||||
environment_variables={},
|
||||
snapshot_id=str(snapshot.id) if snapshot else None,
|
||||
metadata={"task_id": input.task_id},
|
||||
)
|
||||
|
||||
sandbox = await SandboxEnvironment.create(config)
|
||||
|
||||
return sandbox.id
|
||||
@@ -0,0 +1,76 @@
|
||||
from dataclasses import dataclass
|
||||
|
||||
from django.core.exceptions import ObjectDoesNotExist
|
||||
|
||||
from temporalio import activity
|
||||
|
||||
from posthog.temporal.common.utils import asyncify
|
||||
|
||||
from products.tasks.backend.models import Task
|
||||
from products.tasks.backend.temporal.exceptions import TaskInvalidStateError, TaskNotFoundError
|
||||
from products.tasks.backend.temporal.observability import log_with_activity_context
|
||||
|
||||
|
||||
@dataclass
|
||||
class TaskDetails:
|
||||
task_id: str
|
||||
team_id: int
|
||||
github_integration_id: int
|
||||
repository: str
|
||||
distinct_id: str
|
||||
|
||||
|
||||
@activity.defn
|
||||
@asyncify
|
||||
def get_task_details(task_id: str) -> TaskDetails:
|
||||
log_with_activity_context("Fetching task details", task_id=task_id)
|
||||
|
||||
try:
|
||||
task = Task.objects.select_related("created_by").get(id=task_id)
|
||||
except ObjectDoesNotExist:
|
||||
raise TaskNotFoundError(f"Task {task_id} not found", {"task_id": task_id})
|
||||
|
||||
if not task.github_integration_id:
|
||||
raise TaskInvalidStateError(
|
||||
f"Task {task_id} has no GitHub integration",
|
||||
{"task_id": task_id},
|
||||
)
|
||||
|
||||
if not task.primary_repository:
|
||||
raise TaskInvalidStateError(
|
||||
f"Task {task_id} has no primary repository configured",
|
||||
{"task_id": task_id},
|
||||
)
|
||||
|
||||
repository_full_name = task.primary_repository.get("full_name")
|
||||
if not repository_full_name:
|
||||
raise TaskInvalidStateError(
|
||||
f"Task {task_id} primary repository missing full_name",
|
||||
{"task_id": task_id},
|
||||
)
|
||||
|
||||
if not task.created_by:
|
||||
raise TaskInvalidStateError(
|
||||
f"Task {task_id} has no created_by user",
|
||||
{"task_id": task_id},
|
||||
)
|
||||
|
||||
assert task.created_by is not None
|
||||
|
||||
distinct_id = task.created_by.distinct_id or "process_task_workflow"
|
||||
|
||||
log_with_activity_context(
|
||||
"Task details retrieved successfully",
|
||||
task_id=task_id,
|
||||
team_id=task.team_id,
|
||||
repository=repository_full_name,
|
||||
distinct_id=distinct_id,
|
||||
)
|
||||
|
||||
return TaskDetails(
|
||||
task_id=str(task.id),
|
||||
team_id=task.team_id,
|
||||
github_integration_id=task.github_integration_id,
|
||||
repository=repository_full_name,
|
||||
distinct_id=distinct_id,
|
||||
)
|
||||
@@ -0,0 +1,49 @@
|
||||
import shlex
|
||||
from dataclasses import dataclass
|
||||
|
||||
from temporalio import activity
|
||||
|
||||
from products.tasks.backend.services.sandbox_environment import SandboxEnvironment
|
||||
from products.tasks.backend.temporal.exceptions import GitHubAuthenticationError, SandboxExecutionError
|
||||
from products.tasks.backend.temporal.observability import log_activity_execution
|
||||
|
||||
from ..utils import get_github_token
|
||||
|
||||
|
||||
@dataclass
|
||||
class InjectGitHubTokenInput:
|
||||
sandbox_id: str
|
||||
github_integration_id: int
|
||||
task_id: str
|
||||
distinct_id: str
|
||||
|
||||
|
||||
@activity.defn
|
||||
async def inject_github_token(input: InjectGitHubTokenInput) -> None:
|
||||
async with log_activity_execution(
|
||||
"inject_github_token",
|
||||
distinct_id=input.distinct_id,
|
||||
task_id=input.task_id,
|
||||
sandbox_id=input.sandbox_id,
|
||||
github_integration_id=input.github_integration_id,
|
||||
):
|
||||
try:
|
||||
github_token = await get_github_token(input.github_integration_id) or ""
|
||||
except Exception as e:
|
||||
raise GitHubAuthenticationError(
|
||||
f"Failed to get GitHub token for integration {input.github_integration_id}",
|
||||
{"github_integration_id": input.github_integration_id, "error": str(e)},
|
||||
)
|
||||
|
||||
sandbox = await SandboxEnvironment.get_by_id(input.sandbox_id)
|
||||
|
||||
escaped_github_token = shlex.quote(github_token)
|
||||
result = await sandbox.execute(
|
||||
f"echo 'export GITHUB_TOKEN={escaped_github_token}' >> ~/.bash_profile && echo 'export GITHUB_TOKEN=\"{escaped_github_token}\"' >> ~/.bashrc"
|
||||
)
|
||||
|
||||
if result.exit_code != 0:
|
||||
raise SandboxExecutionError(
|
||||
f"Failed to inject GitHub token into sandbox",
|
||||
{"sandbox_id": input.sandbox_id, "exit_code": result.exit_code, "stderr": result.stderr[:500]},
|
||||
)
|
||||
@@ -0,0 +1,119 @@
|
||||
import shlex
|
||||
from dataclasses import dataclass
|
||||
|
||||
from django.core.exceptions import ObjectDoesNotExist
|
||||
|
||||
from temporalio import activity
|
||||
|
||||
from posthog.models import PersonalAPIKey
|
||||
from posthog.models.personal_api_key import hash_key_value
|
||||
from posthog.models.utils import generate_random_token_personal, mask_key_value
|
||||
from posthog.scopes import API_SCOPE_OBJECTS
|
||||
from posthog.temporal.common.utils import asyncify
|
||||
|
||||
from products.tasks.backend.models import Task
|
||||
from products.tasks.backend.services.sandbox_environment import SandboxEnvironment
|
||||
from products.tasks.backend.temporal.exceptions import (
|
||||
PersonalAPIKeyError,
|
||||
SandboxExecutionError,
|
||||
TaskInvalidStateError,
|
||||
TaskNotFoundError,
|
||||
)
|
||||
from products.tasks.backend.temporal.observability import log_activity_execution
|
||||
|
||||
|
||||
@dataclass
|
||||
class InjectPersonalAPIKeyInput:
|
||||
sandbox_id: str
|
||||
task_id: str
|
||||
distinct_id: str
|
||||
|
||||
|
||||
@dataclass
|
||||
class InjectPersonalAPIKeyOutput:
|
||||
personal_api_key_id: str
|
||||
|
||||
|
||||
@asyncify
|
||||
def _get_task(task_id: str) -> Task:
|
||||
return Task.objects.select_related("created_by").get(id=task_id)
|
||||
|
||||
|
||||
@asyncify
|
||||
def _create_personal_api_key(task: Task) -> tuple[str, PersonalAPIKey]:
|
||||
scopes = _get_default_scopes()
|
||||
|
||||
value = generate_random_token_personal()
|
||||
|
||||
mask_value = mask_key_value(value)
|
||||
secure_value = hash_key_value(value)
|
||||
|
||||
if not task.created_by:
|
||||
raise TaskInvalidStateError(f"Task {task.id} has no created_by user", {"task_id": task.id})
|
||||
|
||||
assert task.created_by is not None
|
||||
|
||||
personal_api_key = PersonalAPIKey.objects.create(
|
||||
user=task.created_by,
|
||||
label=f"Task Agent - {task.title[:20]}",
|
||||
secure_value=secure_value,
|
||||
mask_value=mask_value,
|
||||
scopes=scopes,
|
||||
scoped_teams=[task.team_id],
|
||||
)
|
||||
|
||||
return value, personal_api_key
|
||||
|
||||
|
||||
@activity.defn
|
||||
async def inject_personal_api_key(input: InjectPersonalAPIKeyInput) -> InjectPersonalAPIKeyOutput:
|
||||
async with log_activity_execution(
|
||||
"inject_personal_api_key",
|
||||
distinct_id=input.distinct_id,
|
||||
task_id=input.task_id,
|
||||
sandbox_id=input.sandbox_id,
|
||||
):
|
||||
try:
|
||||
task = await _get_task(input.task_id)
|
||||
except ObjectDoesNotExist:
|
||||
raise TaskNotFoundError(f"Task {input.task_id} not found", {"task_id": input.task_id})
|
||||
|
||||
if not task.created_by:
|
||||
raise TaskInvalidStateError(f"Task {input.task_id} has no created_by user", {"task_id": input.task_id})
|
||||
|
||||
try:
|
||||
api_key_tuple: tuple[str, PersonalAPIKey] = await _create_personal_api_key(task)
|
||||
value, personal_api_key = api_key_tuple
|
||||
except Exception as e:
|
||||
raise PersonalAPIKeyError(
|
||||
f"Failed to create personal API key for task {input.task_id}",
|
||||
{"task_id": input.task_id, "error": str(e)},
|
||||
)
|
||||
|
||||
sandbox = await SandboxEnvironment.get_by_id(input.sandbox_id)
|
||||
|
||||
escaped_value = shlex.quote(value)
|
||||
|
||||
result = await sandbox.execute(
|
||||
f"echo 'export POSTHOG_PERSONAL_API_KEY={escaped_value}' >> ~/.bash_profile && echo 'export POSTHOG_PERSONAL_API_KEY={escaped_value}' >> ~/.bashrc"
|
||||
)
|
||||
|
||||
if result.exit_code != 0:
|
||||
raise SandboxExecutionError(
|
||||
f"Failed to inject personal API key into sandbox",
|
||||
{"sandbox_id": input.sandbox_id, "exit_code": result.exit_code, "stderr": result.stderr[:500]},
|
||||
)
|
||||
|
||||
return InjectPersonalAPIKeyOutput(personal_api_key_id=personal_api_key.id)
|
||||
|
||||
|
||||
def _get_default_scopes() -> list[str]:
|
||||
"""
|
||||
Get default scopes for task agent API keys.
|
||||
|
||||
TODO: Make scopes configurable per task in the future.
|
||||
For now, we provide read access to most resources.
|
||||
"""
|
||||
read_scopes = [f"{obj}:read" for obj in API_SCOPE_OBJECTS if obj not in ["INTERNAL"]]
|
||||
|
||||
return read_scopes
|
||||
@@ -0,0 +1,47 @@
|
||||
from dataclasses import dataclass
|
||||
|
||||
from temporalio import activity
|
||||
|
||||
from products.tasks.backend.services.sandbox_agent import SandboxAgent
|
||||
from products.tasks.backend.services.sandbox_environment import SandboxEnvironment
|
||||
from products.tasks.backend.temporal.exceptions import RepositorySetupError
|
||||
from products.tasks.backend.temporal.observability import log_activity_execution
|
||||
|
||||
|
||||
@dataclass
|
||||
class SetupRepositoryInput:
|
||||
sandbox_id: str
|
||||
repository: str
|
||||
task_id: str
|
||||
distinct_id: str
|
||||
|
||||
|
||||
@activity.defn
|
||||
async def setup_repository(input: SetupRepositoryInput) -> str:
|
||||
"""Run code agent setup on repository. Returns setup logs."""
|
||||
async with log_activity_execution(
|
||||
"setup_repository",
|
||||
distinct_id=input.distinct_id,
|
||||
task_id=input.task_id,
|
||||
sandbox_id=input.sandbox_id,
|
||||
repository=input.repository,
|
||||
):
|
||||
sandbox = await SandboxEnvironment.get_by_id(input.sandbox_id)
|
||||
|
||||
agent = SandboxAgent(sandbox)
|
||||
|
||||
try:
|
||||
result = await agent.setup_repository(input.repository)
|
||||
except Exception as e:
|
||||
raise RepositorySetupError(
|
||||
f"Failed to setup repository {input.repository}",
|
||||
{"repository": input.repository, "sandbox_id": input.sandbox_id, "error": str(e)},
|
||||
)
|
||||
|
||||
if result.exit_code != 0:
|
||||
raise RepositorySetupError(
|
||||
f"Repository setup failed with exit code {result.exit_code}",
|
||||
{"repository": input.repository, "exit_code": result.exit_code, "stderr": result.stderr[:500]},
|
||||
)
|
||||
|
||||
return result.stdout
|
||||
@@ -0,0 +1,23 @@
|
||||
"""
|
||||
Snapshot constants for activity tests.
|
||||
These snapshots are created in Runloop and can be used for consistent testing.
|
||||
"""
|
||||
|
||||
from typing import TypedDict
|
||||
|
||||
|
||||
class TestSnapshot(TypedDict):
|
||||
external_id: str
|
||||
repos: list[str]
|
||||
|
||||
|
||||
SNAPSHOTS = [
|
||||
TestSnapshot(external_id="snp_31DY4EmLlBZFy1aHV2IN2", repos=[]),
|
||||
TestSnapshot(external_id="snp_31DY5L7W4ismYpImz22wN", repos=["posthog/posthog-js"]),
|
||||
TestSnapshot(external_id="snp_31DY9PDHgbhD3NDgA6DGe", repos=["posthog/posthog-js", "posthog/posthog"]),
|
||||
]
|
||||
|
||||
|
||||
BASE_SNAPSHOT = SNAPSHOTS[0]
|
||||
POSTHOG_JS_SNAPSHOT = SNAPSHOTS[1]
|
||||
MULTI_REPO_SNAPSHOT = SNAPSHOTS[2]
|
||||
@@ -0,0 +1,198 @@
|
||||
import time
|
||||
|
||||
import pytest
|
||||
|
||||
from asgiref.sync import sync_to_async
|
||||
|
||||
from posthog.models.integration import Integration
|
||||
|
||||
from products.tasks.backend.models import SandboxSnapshot
|
||||
from products.tasks.backend.temporal.process_task.activities.check_snapshot_exists_for_repository import (
|
||||
CheckSnapshotExistsForRepositoryInput,
|
||||
CheckSnapshotExistsForRepositoryOutput,
|
||||
check_snapshot_exists_for_repository,
|
||||
)
|
||||
|
||||
|
||||
class TestCheckSnapshotExistsForRepositoryActivity:
|
||||
async def _create_snapshot(
|
||||
self, github_integration, repos, status=SandboxSnapshot.Status.COMPLETE, external_id="test-snap-123"
|
||||
):
|
||||
return await sync_to_async(SandboxSnapshot.objects.create)(
|
||||
integration=github_integration,
|
||||
repos=repos,
|
||||
status=status,
|
||||
external_id=external_id,
|
||||
)
|
||||
|
||||
async def _cleanup_snapshot(self, snapshot):
|
||||
await sync_to_async(snapshot.delete)()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.django_db
|
||||
async def test_check_snapshot_exists_for_repository_found(self, activity_environment, github_integration):
|
||||
snapshot = await self._create_snapshot(
|
||||
github_integration, repos=["test-owner/test-repo", "other-owner/other-repo"]
|
||||
)
|
||||
|
||||
try:
|
||||
input_data = CheckSnapshotExistsForRepositoryInput(
|
||||
github_integration_id=github_integration.id, repository="test-owner/test-repo"
|
||||
)
|
||||
result = await activity_environment.run(check_snapshot_exists_for_repository, input_data)
|
||||
|
||||
assert isinstance(result, CheckSnapshotExistsForRepositoryOutput)
|
||||
assert result.exists is True
|
||||
assert result.snapshot_id == str(snapshot.id)
|
||||
finally:
|
||||
await self._cleanup_snapshot(snapshot)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.django_db
|
||||
async def test_check_snapshot_exists_for_repository_not_found(self, activity_environment, github_integration):
|
||||
input_data = CheckSnapshotExistsForRepositoryInput(
|
||||
github_integration_id=github_integration.id, repository="nonexistent/repo"
|
||||
)
|
||||
result = await activity_environment.run(check_snapshot_exists_for_repository, input_data)
|
||||
|
||||
assert isinstance(result, CheckSnapshotExistsForRepositoryOutput)
|
||||
assert result.exists is False
|
||||
assert result.snapshot_id is None
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.django_db
|
||||
async def test_check_snapshot_exists_for_repository_repo_not_in_snapshot(
|
||||
self, activity_environment, github_integration
|
||||
):
|
||||
snapshot = await self._create_snapshot(
|
||||
github_integration, repos=["other-owner/other-repo", "another-owner/another-repo"]
|
||||
)
|
||||
|
||||
try:
|
||||
input_data = CheckSnapshotExistsForRepositoryInput(
|
||||
github_integration_id=github_integration.id, repository="test-owner/test-repo"
|
||||
)
|
||||
result = await activity_environment.run(check_snapshot_exists_for_repository, input_data)
|
||||
|
||||
assert isinstance(result, CheckSnapshotExistsForRepositoryOutput)
|
||||
assert result.exists is False
|
||||
assert result.snapshot_id is None
|
||||
finally:
|
||||
await self._cleanup_snapshot(snapshot)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.django_db
|
||||
async def test_check_snapshot_exists_for_repository_ignores_incomplete_snapshots(
|
||||
self, activity_environment, github_integration
|
||||
):
|
||||
# Create snapshots with different statuses
|
||||
in_progress_snapshot = await self._create_snapshot(
|
||||
github_integration,
|
||||
repos=["test-owner/test-repo"],
|
||||
status=SandboxSnapshot.Status.IN_PROGRESS,
|
||||
external_id="in-progress-snap",
|
||||
)
|
||||
error_snapshot = await self._create_snapshot(
|
||||
github_integration,
|
||||
repos=["test-owner/test-repo"],
|
||||
status=SandboxSnapshot.Status.ERROR,
|
||||
external_id="error-snap",
|
||||
)
|
||||
|
||||
try:
|
||||
input_data = CheckSnapshotExistsForRepositoryInput(
|
||||
github_integration_id=github_integration.id, repository="test-owner/test-repo"
|
||||
)
|
||||
result = await activity_environment.run(check_snapshot_exists_for_repository, input_data)
|
||||
|
||||
# Should not find incomplete snapshots
|
||||
assert isinstance(result, CheckSnapshotExistsForRepositoryOutput)
|
||||
assert result.exists is False
|
||||
assert result.snapshot_id is None
|
||||
finally:
|
||||
await self._cleanup_snapshot(in_progress_snapshot)
|
||||
await self._cleanup_snapshot(error_snapshot)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.django_db
|
||||
async def test_check_snapshot_exists_for_repository_returns_latest_complete(
|
||||
self, activity_environment, github_integration
|
||||
):
|
||||
# Create multiple snapshots, with the latest being complete
|
||||
older_snapshot = await self._create_snapshot(
|
||||
github_integration,
|
||||
repos=["test-owner/test-repo"],
|
||||
status=SandboxSnapshot.Status.COMPLETE,
|
||||
external_id="older-snap",
|
||||
)
|
||||
|
||||
# Add delay to ensure different created_at times
|
||||
time.sleep(0.01)
|
||||
|
||||
newer_snapshot = await self._create_snapshot(
|
||||
github_integration,
|
||||
repos=["test-owner/test-repo", "other-owner/other-repo"],
|
||||
status=SandboxSnapshot.Status.COMPLETE,
|
||||
external_id="newer-snap",
|
||||
)
|
||||
|
||||
try:
|
||||
input_data = CheckSnapshotExistsForRepositoryInput(
|
||||
github_integration_id=github_integration.id, repository="test-owner/test-repo"
|
||||
)
|
||||
result = await activity_environment.run(check_snapshot_exists_for_repository, input_data)
|
||||
|
||||
# Should return the newer snapshot
|
||||
assert isinstance(result, CheckSnapshotExistsForRepositoryOutput)
|
||||
assert result.exists is True
|
||||
assert result.snapshot_id == str(newer_snapshot.id)
|
||||
finally:
|
||||
await self._cleanup_snapshot(older_snapshot)
|
||||
await self._cleanup_snapshot(newer_snapshot)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.django_db
|
||||
async def test_check_snapshot_exists_for_repository_case_insensitive(
|
||||
self, activity_environment, github_integration
|
||||
):
|
||||
# Create snapshot with mixed case repository name
|
||||
snapshot = await self._create_snapshot(github_integration, repos=["TestOwner/TestRepo"])
|
||||
|
||||
try:
|
||||
input_data = CheckSnapshotExistsForRepositoryInput(
|
||||
github_integration_id=github_integration.id, repository="testowner/testrepo"
|
||||
)
|
||||
result = await activity_environment.run(check_snapshot_exists_for_repository, input_data)
|
||||
|
||||
assert isinstance(result, CheckSnapshotExistsForRepositoryOutput)
|
||||
assert result.exists is True
|
||||
assert result.snapshot_id == str(snapshot.id)
|
||||
finally:
|
||||
await self._cleanup_snapshot(snapshot)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.django_db
|
||||
async def test_check_snapshot_exists_for_repository_different_integration(
|
||||
self, activity_environment, github_integration, ateam
|
||||
):
|
||||
other_integration = await sync_to_async(Integration.objects.create)(
|
||||
team=ateam,
|
||||
kind="github",
|
||||
config={"access_token": "other_fake_token"},
|
||||
)
|
||||
|
||||
snapshot = await self._create_snapshot(other_integration, repos=["test-owner/test-repo"])
|
||||
|
||||
try:
|
||||
input_data = CheckSnapshotExistsForRepositoryInput(
|
||||
github_integration_id=github_integration.id, repository="test-owner/test-repo"
|
||||
)
|
||||
result = await activity_environment.run(check_snapshot_exists_for_repository, input_data)
|
||||
|
||||
# Should not find snapshot from different integration
|
||||
assert isinstance(result, CheckSnapshotExistsForRepositoryOutput)
|
||||
assert result.exists is False
|
||||
assert result.snapshot_id is None
|
||||
finally:
|
||||
await self._cleanup_snapshot(snapshot)
|
||||
await sync_to_async(other_integration.delete)()
|
||||
@@ -0,0 +1,100 @@
|
||||
import os
|
||||
import asyncio
|
||||
|
||||
import pytest
|
||||
|
||||
from products.tasks.backend.services.sandbox_environment import (
|
||||
SandboxEnvironment,
|
||||
SandboxEnvironmentConfig,
|
||||
SandboxEnvironmentTemplate,
|
||||
)
|
||||
from products.tasks.backend.temporal.process_task.activities.cleanup_sandbox import CleanupSandboxInput, cleanup_sandbox
|
||||
|
||||
|
||||
@pytest.mark.skipif(not os.environ.get("RUNLOOP_API_KEY"), reason="RUNLOOP_API_KEY environment variable not set")
|
||||
class TestCleanupSandboxActivity:
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.django_db
|
||||
async def test_cleanup_sandbox_success(self, activity_environment):
|
||||
config = SandboxEnvironmentConfig(
|
||||
name="test-cleanup-sandbox",
|
||||
template=SandboxEnvironmentTemplate.DEFAULT_BASE,
|
||||
)
|
||||
|
||||
sandbox = await SandboxEnvironment.create(config)
|
||||
sandbox_id = sandbox.id
|
||||
|
||||
existing_sandbox = await SandboxEnvironment.get_by_id(sandbox_id)
|
||||
assert existing_sandbox.id == sandbox_id
|
||||
|
||||
input_data = CleanupSandboxInput(sandbox_id=sandbox_id)
|
||||
|
||||
await activity_environment.run(cleanup_sandbox, input_data)
|
||||
|
||||
cleaned_sandbox = await SandboxEnvironment.get_by_id(sandbox_id)
|
||||
assert cleaned_sandbox.status.value == "shutdown"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.django_db
|
||||
async def test_cleanup_sandbox_not_found_does_not_raise(self, activity_environment):
|
||||
input_data = CleanupSandboxInput(sandbox_id="non-existent-sandbox-id")
|
||||
|
||||
# cleanup_sandbox is idempotent and doesn't raise if sandbox doesn't exist
|
||||
await activity_environment.run(cleanup_sandbox, input_data)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.django_db
|
||||
async def test_cleanup_sandbox_idempotency(self, activity_environment):
|
||||
config = SandboxEnvironmentConfig(
|
||||
name="test-cleanup-idempotent",
|
||||
template=SandboxEnvironmentTemplate.DEFAULT_BASE,
|
||||
)
|
||||
|
||||
sandbox = await SandboxEnvironment.create(config)
|
||||
sandbox_id = sandbox.id
|
||||
|
||||
input_data = CleanupSandboxInput(sandbox_id=sandbox_id)
|
||||
|
||||
# First cleanup - should succeed
|
||||
await activity_environment.run(cleanup_sandbox, input_data)
|
||||
|
||||
cleaned_sandbox = await SandboxEnvironment.get_by_id(sandbox_id)
|
||||
assert cleaned_sandbox.status.value == "shutdown"
|
||||
|
||||
# Second cleanup - should still work on shutdown sandbox
|
||||
await activity_environment.run(cleanup_sandbox, input_data)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.django_db
|
||||
async def test_cleanup_sandbox_during_execution(self, activity_environment):
|
||||
config = SandboxEnvironmentConfig(
|
||||
name="test-cleanup-during-execution",
|
||||
template=SandboxEnvironmentTemplate.DEFAULT_BASE,
|
||||
)
|
||||
|
||||
sandbox = await SandboxEnvironment.create(config)
|
||||
sandbox_id = sandbox.id
|
||||
|
||||
async def run_long_command():
|
||||
try:
|
||||
await sandbox.execute("sleep 30", timeout_seconds=60)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
long_task = asyncio.create_task(run_long_command())
|
||||
|
||||
# Give it a moment to start
|
||||
await asyncio.sleep(5)
|
||||
|
||||
input_data = CleanupSandboxInput(sandbox_id=sandbox_id)
|
||||
await activity_environment.run(cleanup_sandbox, input_data)
|
||||
|
||||
long_task.cancel()
|
||||
try:
|
||||
await long_task
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
|
||||
remaining_sandbox = await SandboxEnvironment.get_by_id(sandbox_id)
|
||||
|
||||
assert remaining_sandbox.status.value in ["shutdown", "failure"]
|
||||
@@ -0,0 +1,230 @@
|
||||
import os
|
||||
|
||||
import pytest
|
||||
from unittest.mock import patch
|
||||
|
||||
from products.tasks.backend.services.sandbox_environment import (
|
||||
SandboxEnvironment,
|
||||
SandboxEnvironmentConfig,
|
||||
SandboxEnvironmentTemplate,
|
||||
)
|
||||
from products.tasks.backend.temporal.exceptions import RepositoryCloneError, SandboxNotFoundError
|
||||
from products.tasks.backend.temporal.process_task.activities.clone_repository import (
|
||||
CloneRepositoryInput,
|
||||
clone_repository,
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.skipif(not os.environ.get("RUNLOOP_API_KEY"), reason="RUNLOOP_API_KEY environment variable not set")
|
||||
class TestCloneRepositoryActivity:
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.django_db
|
||||
async def test_clone_repository_success_and_directory_structure(self, activity_environment, github_integration):
|
||||
config = SandboxEnvironmentConfig(
|
||||
name="test-clone-success-and-structure",
|
||||
template=SandboxEnvironmentTemplate.DEFAULT_BASE,
|
||||
)
|
||||
|
||||
sandbox = None
|
||||
try:
|
||||
sandbox = await SandboxEnvironment.create(config)
|
||||
|
||||
input_data = CloneRepositoryInput(
|
||||
sandbox_id=sandbox.id,
|
||||
repository="PostHog/posthog-js",
|
||||
github_integration_id=github_integration.id,
|
||||
task_id="test-task-123",
|
||||
distinct_id="test-user-id",
|
||||
)
|
||||
|
||||
with patch(
|
||||
"products.tasks.backend.temporal.process_task.activities.clone_repository.get_github_token"
|
||||
) as mock_get_token:
|
||||
mock_get_token.return_value = ""
|
||||
|
||||
result = await activity_environment.run(clone_repository, input_data)
|
||||
|
||||
# Verify we got output (git clone outputs to stderr)
|
||||
assert result is not None
|
||||
assert "posthog-js" in result
|
||||
|
||||
# Verify the repository actually exists in the sandbox
|
||||
check_result = await sandbox.execute("ls -la /tmp/workspace/repos/posthog/")
|
||||
assert "posthog-js" in check_result.stdout
|
||||
|
||||
# Verify it's a git repository
|
||||
git_check = await sandbox.execute("cd /tmp/workspace/repos/posthog/posthog-js && git status")
|
||||
assert git_check.exit_code == 0
|
||||
assert "On branch" in git_check.stdout or "HEAD" in git_check.stdout
|
||||
|
||||
# Verify directory structure is correct
|
||||
structure_check = await sandbox.execute("find /tmp/workspace/repos -type d | head -10")
|
||||
assert "/tmp/workspace/repos/posthog" in structure_check.stdout
|
||||
assert "/tmp/workspace/repos/posthog/posthog-js" in structure_check.stdout
|
||||
|
||||
# Verify we can navigate the structure
|
||||
nav_check = await sandbox.execute("cd /tmp/workspace/repos/posthog/posthog-js && pwd")
|
||||
assert "/tmp/workspace/repos/posthog/posthog-js" in nav_check.stdout
|
||||
|
||||
finally:
|
||||
if sandbox:
|
||||
await sandbox.destroy()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.django_db
|
||||
async def test_clone_repository_idempotency(self, activity_environment, github_integration):
|
||||
config = SandboxEnvironmentConfig(
|
||||
name="test-clone-repository-idempotent",
|
||||
template=SandboxEnvironmentTemplate.DEFAULT_BASE,
|
||||
)
|
||||
|
||||
sandbox = None
|
||||
try:
|
||||
sandbox = await SandboxEnvironment.create(config)
|
||||
|
||||
input_data = CloneRepositoryInput(
|
||||
sandbox_id=sandbox.id,
|
||||
repository="PostHog/posthog-js",
|
||||
github_integration_id=github_integration.id,
|
||||
task_id="test-task-idempotent",
|
||||
distinct_id="test-user-id",
|
||||
)
|
||||
|
||||
with patch(
|
||||
"products.tasks.backend.temporal.process_task.activities.clone_repository.get_github_token"
|
||||
) as mock_get_token:
|
||||
mock_get_token.return_value = ""
|
||||
|
||||
# First clone
|
||||
result1 = await activity_environment.run(clone_repository, input_data)
|
||||
assert result1 is not None
|
||||
|
||||
# Create a file to verify it gets wiped
|
||||
await sandbox.execute("echo 'test' > /tmp/workspace/repos/posthog/posthog-js/test_file.txt")
|
||||
|
||||
# Verify file exists
|
||||
check_file = await sandbox.execute("cat /tmp/workspace/repos/posthog/posthog-js/test_file.txt")
|
||||
assert "test" in check_file.stdout
|
||||
|
||||
# Second clone (should wipe and re-clone)
|
||||
result2 = await activity_environment.run(clone_repository, input_data)
|
||||
assert "Cloning into 'posthog-js'" in result2 or "posthog-js" in result2
|
||||
|
||||
# Verify test file was removed (proving idempotency)
|
||||
check_file_after = await sandbox.execute(
|
||||
"ls /tmp/workspace/repos/posthog/posthog-js/test_file.txt 2>&1"
|
||||
)
|
||||
assert (
|
||||
"No such file" in check_file_after.stdout
|
||||
or "No such file" in check_file_after.stderr
|
||||
or check_file_after.exit_code != 0
|
||||
)
|
||||
|
||||
finally:
|
||||
if sandbox:
|
||||
await sandbox.destroy()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.django_db
|
||||
async def test_clone_repository_private_repo_no_token(self, activity_environment, github_integration):
|
||||
config = SandboxEnvironmentConfig(
|
||||
name="test-clone-repository-auth-fail",
|
||||
template=SandboxEnvironmentTemplate.DEFAULT_BASE,
|
||||
)
|
||||
|
||||
sandbox = None
|
||||
try:
|
||||
sandbox = await SandboxEnvironment.create(config)
|
||||
|
||||
input_data = CloneRepositoryInput(
|
||||
sandbox_id=sandbox.id,
|
||||
repository="PostHog/private-test-repo-that-does-not-exist",
|
||||
github_integration_id=github_integration.id,
|
||||
task_id="test-task-auth-fail",
|
||||
distinct_id="test-user-id",
|
||||
)
|
||||
|
||||
with patch(
|
||||
"products.tasks.backend.temporal.process_task.activities.clone_repository.get_github_token"
|
||||
) as mock_get_token:
|
||||
mock_get_token.return_value = "invalid-token"
|
||||
|
||||
with pytest.raises(RepositoryCloneError) as exc_info:
|
||||
await activity_environment.run(clone_repository, input_data)
|
||||
|
||||
assert "Git clone failed" in str(exc_info.value)
|
||||
|
||||
# Verify repository doesn't exist
|
||||
check_result = await sandbox.execute("ls /tmp/workspace/repos/posthog/ 2>&1")
|
||||
assert "private-test-repo" not in check_result.stdout
|
||||
|
||||
finally:
|
||||
if sandbox:
|
||||
await sandbox.destroy()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.django_db
|
||||
async def test_clone_repository_multiple_repos(self, activity_environment, github_integration):
|
||||
config = SandboxEnvironmentConfig(
|
||||
name="test-clone-multiple-repos",
|
||||
template=SandboxEnvironmentTemplate.DEFAULT_BASE,
|
||||
)
|
||||
|
||||
sandbox = None
|
||||
try:
|
||||
sandbox = await SandboxEnvironment.create(config)
|
||||
|
||||
repos = ["PostHog/posthog-js", "PostHog/posthog.com"]
|
||||
|
||||
with patch(
|
||||
"products.tasks.backend.temporal.process_task.activities.clone_repository.get_github_token"
|
||||
) as mock_get_token:
|
||||
mock_get_token.return_value = "" # Public repos don't need auth
|
||||
|
||||
for repo in repos:
|
||||
input_data = CloneRepositoryInput(
|
||||
sandbox_id=sandbox.id,
|
||||
repository=repo,
|
||||
github_integration_id=github_integration.id,
|
||||
task_id=f"test-task-{repo.split('/')[1]}",
|
||||
distinct_id="test-user-id",
|
||||
)
|
||||
|
||||
result = await activity_environment.run(clone_repository, input_data)
|
||||
repo_name = repo.split("/")[1]
|
||||
assert repo_name in result
|
||||
|
||||
# Verify both repos exist
|
||||
check_result = await sandbox.execute("ls /tmp/workspace/repos/posthog/")
|
||||
assert "posthog-js" in check_result.stdout
|
||||
assert "posthog.com" in check_result.stdout
|
||||
|
||||
# Verify they're both git repositories
|
||||
for repo in repos:
|
||||
repo_name = repo.split("/")[1]
|
||||
git_check = await sandbox.execute(f"cd /tmp/workspace/repos/posthog/{repo_name} && git remote -v")
|
||||
assert git_check.exit_code == 0
|
||||
assert "github.com" in git_check.stdout
|
||||
|
||||
finally:
|
||||
if sandbox:
|
||||
await sandbox.destroy()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.django_db
|
||||
async def test_clone_repository_sandbox_not_found(self, activity_environment, github_integration):
|
||||
input_data = CloneRepositoryInput(
|
||||
sandbox_id="non-existent-sandbox-id",
|
||||
repository="posthog/posthog-js",
|
||||
github_integration_id=github_integration.id,
|
||||
task_id="test-task-not-found",
|
||||
distinct_id="test-user-id",
|
||||
)
|
||||
|
||||
with patch(
|
||||
"products.tasks.backend.temporal.process_task.activities.clone_repository.get_github_token"
|
||||
) as mock_get_token:
|
||||
mock_get_token.return_value = ""
|
||||
|
||||
with pytest.raises(SandboxNotFoundError):
|
||||
await activity_environment.run(clone_repository, input_data)
|
||||
@@ -0,0 +1,94 @@
|
||||
import os
|
||||
import uuid
|
||||
|
||||
import pytest
|
||||
|
||||
from asgiref.sync import sync_to_async
|
||||
|
||||
from products.tasks.backend.models import SandboxSnapshot
|
||||
from products.tasks.backend.services.sandbox_environment import SandboxEnvironment
|
||||
from products.tasks.backend.temporal.exceptions import SandboxProvisionError, SnapshotNotFoundError
|
||||
from products.tasks.backend.temporal.process_task.activities.create_sandbox_from_snapshot import (
|
||||
CreateSandboxFromSnapshotInput,
|
||||
create_sandbox_from_snapshot,
|
||||
)
|
||||
|
||||
from .constants import BASE_SNAPSHOT
|
||||
|
||||
|
||||
@pytest.mark.skipif(not os.environ.get("RUNLOOP_API_KEY"), reason="RUNLOOP_API_KEY environment variable not set")
|
||||
class TestCreateSandboxFromSnapshotActivity:
|
||||
async def _create_snapshot(self, github_integration, external_id=None, status=SandboxSnapshot.Status.COMPLETE):
|
||||
if external_id is None:
|
||||
external_id = str(uuid.uuid4())
|
||||
return await sync_to_async(SandboxSnapshot.objects.create)(
|
||||
integration=github_integration,
|
||||
external_id=external_id,
|
||||
status=status,
|
||||
)
|
||||
|
||||
async def _cleanup_snapshot(self, snapshot):
|
||||
await sync_to_async(snapshot.delete)()
|
||||
|
||||
async def _cleanup_sandbox(self, sandbox_id):
|
||||
sandbox = await SandboxEnvironment.get_by_id(sandbox_id)
|
||||
await sandbox.destroy()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.django_db
|
||||
async def test_create_sandbox_from_snapshot_success(self, activity_environment, github_integration):
|
||||
snapshot = await self._create_snapshot(github_integration, external_id=BASE_SNAPSHOT["external_id"])
|
||||
task_id = "test-task-123"
|
||||
sandbox_id = None
|
||||
|
||||
try:
|
||||
input_data = CreateSandboxFromSnapshotInput(
|
||||
snapshot_id=str(snapshot.id), task_id=task_id, distinct_id="test-user-id"
|
||||
)
|
||||
sandbox_id = await activity_environment.run(create_sandbox_from_snapshot, input_data)
|
||||
|
||||
assert isinstance(sandbox_id, str)
|
||||
assert len(sandbox_id) > 0
|
||||
|
||||
sandbox = await SandboxEnvironment.get_by_id(sandbox_id)
|
||||
assert sandbox.id == sandbox_id
|
||||
assert sandbox.status in ["pending", "initializing", "running"]
|
||||
|
||||
finally:
|
||||
await self._cleanup_snapshot(snapshot)
|
||||
if sandbox_id:
|
||||
await self._cleanup_sandbox(sandbox_id)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.django_db
|
||||
async def test_create_sandbox_from_snapshot_not_found(self, activity_environment):
|
||||
input_data = CreateSandboxFromSnapshotInput(
|
||||
snapshot_id=str(uuid.uuid4()),
|
||||
task_id="test-task-456",
|
||||
distinct_id="test-user-id",
|
||||
)
|
||||
|
||||
with pytest.raises(SnapshotNotFoundError):
|
||||
await activity_environment.run(create_sandbox_from_snapshot, input_data)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.django_db
|
||||
async def test_create_sandbox_from_snapshot_with_invalid_external_id(
|
||||
self, activity_environment, github_integration
|
||||
):
|
||||
snapshot = await self._create_snapshot(github_integration, external_id="invalid-snapshot-id")
|
||||
task_id = "test-task-789"
|
||||
sandbox_id = None
|
||||
|
||||
try:
|
||||
input_data = CreateSandboxFromSnapshotInput(
|
||||
snapshot_id=str(snapshot.id), task_id=task_id, distinct_id="test-user-id"
|
||||
)
|
||||
|
||||
with pytest.raises(SandboxProvisionError):
|
||||
sandbox_id = await activity_environment.run(create_sandbox_from_snapshot, input_data)
|
||||
|
||||
finally:
|
||||
await self._cleanup_snapshot(snapshot)
|
||||
if sandbox_id:
|
||||
await self._cleanup_sandbox(sandbox_id)
|
||||
@@ -0,0 +1,142 @@
|
||||
import os
|
||||
import uuid
|
||||
|
||||
import pytest
|
||||
|
||||
from asgiref.sync import sync_to_async
|
||||
|
||||
from products.tasks.backend.models import SandboxSnapshot
|
||||
from products.tasks.backend.services.sandbox_environment import (
|
||||
SandboxEnvironment,
|
||||
SandboxEnvironmentConfig,
|
||||
SandboxEnvironmentTemplate,
|
||||
)
|
||||
from products.tasks.backend.temporal.exceptions import SandboxNotFoundError
|
||||
from products.tasks.backend.temporal.process_task.activities.create_snapshot import CreateSnapshotInput, create_snapshot
|
||||
|
||||
|
||||
@pytest.mark.skipif(not os.environ.get("RUNLOOP_API_KEY"), reason="RUNLOOP_API_KEY environment variable not set")
|
||||
class TestCreateSnapshotActivity:
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.django_db
|
||||
async def test_create_snapshot_real(self, activity_environment, github_integration, ateam):
|
||||
"""Test real snapshot creation with actual sandbox."""
|
||||
config = SandboxEnvironmentConfig(
|
||||
name="test-create-snapshot",
|
||||
template=SandboxEnvironmentTemplate.DEFAULT_BASE,
|
||||
)
|
||||
|
||||
sandbox = None
|
||||
created_snapshot = None
|
||||
created_snapshot_external_id = None
|
||||
try:
|
||||
# Create a real sandbox
|
||||
sandbox = await SandboxEnvironment.create(config)
|
||||
|
||||
input_data = CreateSnapshotInput(
|
||||
sandbox_id=sandbox.id,
|
||||
github_integration_id=github_integration.id,
|
||||
team_id=ateam.id,
|
||||
repository="test-owner/test-repo",
|
||||
task_id="test-task-123",
|
||||
distinct_id="test-user-id",
|
||||
)
|
||||
|
||||
# This will create a real snapshot and wait for it to complete
|
||||
result = await activity_environment.run(create_snapshot, input_data)
|
||||
|
||||
# Verify a UUID was returned
|
||||
assert result is not None
|
||||
uuid.UUID(result) # Should not raise
|
||||
|
||||
# Verify snapshot was created in the database
|
||||
created_snapshot = await sync_to_async(SandboxSnapshot.objects.get)(id=result)
|
||||
created_snapshot_external_id = created_snapshot.external_id
|
||||
assert created_snapshot.external_id is not None
|
||||
assert created_snapshot.integration_id == github_integration.id
|
||||
assert "test-owner/test-repo" in created_snapshot.repos
|
||||
assert created_snapshot.status == SandboxSnapshot.Status.COMPLETE
|
||||
|
||||
# Verify the snapshot exists in provider
|
||||
snapshot_status = await SandboxEnvironment.get_snapshot_status(created_snapshot.external_id)
|
||||
assert snapshot_status.value == "complete"
|
||||
|
||||
finally:
|
||||
if sandbox:
|
||||
await sandbox.destroy()
|
||||
|
||||
if created_snapshot:
|
||||
await sync_to_async(created_snapshot.delete)()
|
||||
|
||||
if created_snapshot_external_id:
|
||||
await SandboxEnvironment.delete_snapshot(created_snapshot_external_id)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.django_db
|
||||
async def test_create_snapshot_with_existing_base_snapshot(self, activity_environment, github_integration, ateam):
|
||||
"""Test snapshot creation with existing base snapshot repos."""
|
||||
# Create a base snapshot in the database (using a fake external ID since we're not creating it in Runloop)
|
||||
base_snapshot = await sync_to_async(SandboxSnapshot.objects.create)(
|
||||
integration=github_integration,
|
||||
external_id=f"fake_base_{uuid.uuid4().hex[:8]}",
|
||||
repos=["existing-owner/existing-repo"],
|
||||
status=SandboxSnapshot.Status.COMPLETE,
|
||||
)
|
||||
|
||||
config = SandboxEnvironmentConfig(
|
||||
name="test-create-snapshot-with-base",
|
||||
template=SandboxEnvironmentTemplate.DEFAULT_BASE,
|
||||
)
|
||||
|
||||
sandbox = None
|
||||
created_snapshot = None
|
||||
created_snapshot_external_id = None
|
||||
try:
|
||||
sandbox = await SandboxEnvironment.create(config)
|
||||
|
||||
input_data = CreateSnapshotInput(
|
||||
sandbox_id=sandbox.id,
|
||||
github_integration_id=github_integration.id,
|
||||
team_id=ateam.id,
|
||||
repository="new-owner/new-repo",
|
||||
task_id="test-task-with-base",
|
||||
distinct_id="test-user-id",
|
||||
)
|
||||
|
||||
result = await activity_environment.run(create_snapshot, input_data)
|
||||
|
||||
# Verify new snapshot includes both repos
|
||||
created_snapshot = await sync_to_async(SandboxSnapshot.objects.get)(id=result)
|
||||
created_snapshot_external_id = created_snapshot.external_id
|
||||
assert created_snapshot.external_id is not None
|
||||
assert "existing-owner/existing-repo" in created_snapshot.repos
|
||||
assert "new-owner/new-repo" in created_snapshot.repos
|
||||
assert len(created_snapshot.repos) == 2
|
||||
|
||||
# Verify the snapshot actually exists in Runloop
|
||||
snapshot_status = await SandboxEnvironment.get_snapshot_status(created_snapshot.external_id)
|
||||
assert snapshot_status.value == "complete"
|
||||
|
||||
finally:
|
||||
await sync_to_async(base_snapshot.delete)()
|
||||
if sandbox:
|
||||
await sandbox.destroy()
|
||||
if created_snapshot:
|
||||
await sync_to_async(created_snapshot.delete)()
|
||||
if created_snapshot_external_id:
|
||||
await SandboxEnvironment.delete_snapshot(created_snapshot_external_id)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.django_db
|
||||
async def test_create_snapshot_sandbox_not_found(self, activity_environment, github_integration, ateam):
|
||||
input_data = CreateSnapshotInput(
|
||||
sandbox_id="non-existent-sandbox-id",
|
||||
github_integration_id=github_integration.id,
|
||||
team_id=ateam.id,
|
||||
repository="test-owner/test-repo",
|
||||
task_id="test-task-not-found",
|
||||
distinct_id="test-user-id",
|
||||
)
|
||||
|
||||
with pytest.raises(SandboxNotFoundError):
|
||||
await activity_environment.run(create_snapshot, input_data)
|
||||
@@ -0,0 +1,213 @@
|
||||
import os
|
||||
|
||||
import pytest
|
||||
from unittest.mock import patch
|
||||
|
||||
from products.tasks.backend.services.sandbox_environment import (
|
||||
SandboxEnvironment,
|
||||
SandboxEnvironmentConfig,
|
||||
SandboxEnvironmentTemplate,
|
||||
)
|
||||
from products.tasks.backend.temporal.exceptions import SandboxNotFoundError, TaskExecutionFailedError
|
||||
from products.tasks.backend.temporal.process_task.activities.clone_repository import (
|
||||
CloneRepositoryInput,
|
||||
clone_repository,
|
||||
)
|
||||
from products.tasks.backend.temporal.process_task.activities.execute_task_in_sandbox import (
|
||||
ExecuteTaskInput,
|
||||
execute_task_in_sandbox,
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.skipif(not os.environ.get("RUNLOOP_API_KEY"), reason="RUNLOOP_API_KEY environment variable not set")
|
||||
class TestExecuteTaskInSandboxActivity:
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.django_db
|
||||
async def test_execute_task_success(self, activity_environment, github_integration):
|
||||
"""Test successful task execution in sandbox."""
|
||||
config = SandboxEnvironmentConfig(
|
||||
name="test-execute-task",
|
||||
template=SandboxEnvironmentTemplate.DEFAULT_BASE,
|
||||
)
|
||||
|
||||
sandbox = None
|
||||
try:
|
||||
sandbox = await SandboxEnvironment.create(config)
|
||||
|
||||
clone_input = CloneRepositoryInput(
|
||||
sandbox_id=sandbox.id,
|
||||
repository="PostHog/posthog-js",
|
||||
github_integration_id=github_integration.id,
|
||||
task_id="test-task-123",
|
||||
distinct_id="test-user-id",
|
||||
)
|
||||
|
||||
with patch(
|
||||
"products.tasks.backend.temporal.process_task.activities.clone_repository.get_github_token"
|
||||
) as mock_get_token:
|
||||
mock_get_token.return_value = ""
|
||||
await activity_environment.run(clone_repository, clone_input)
|
||||
|
||||
# We mock the _get_task_command to run a simple command instead of the code agent
|
||||
with patch(
|
||||
"products.tasks.backend.temporal.process_task.activities.execute_task_in_sandbox.SandboxAgent._get_task_command"
|
||||
) as mock_task_cmd:
|
||||
mock_task_cmd.return_value = "echo 'Task executed successfully'"
|
||||
|
||||
input_data = ExecuteTaskInput(
|
||||
sandbox_id=sandbox.id,
|
||||
task_id="test-task-123",
|
||||
repository="PostHog/posthog-js",
|
||||
distinct_id="test-user-id",
|
||||
)
|
||||
|
||||
await activity_environment.run(execute_task_in_sandbox, input_data)
|
||||
|
||||
mock_task_cmd.assert_called_once_with("test-task-123")
|
||||
|
||||
finally:
|
||||
if sandbox:
|
||||
await sandbox.destroy()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.django_db
|
||||
async def test_execute_task_failure(self, activity_environment, github_integration):
|
||||
"""Test task execution failure handling."""
|
||||
config = SandboxEnvironmentConfig(
|
||||
name="test-execute-task-fail",
|
||||
template=SandboxEnvironmentTemplate.DEFAULT_BASE,
|
||||
)
|
||||
|
||||
sandbox = None
|
||||
try:
|
||||
sandbox = await SandboxEnvironment.create(config)
|
||||
|
||||
clone_input = CloneRepositoryInput(
|
||||
sandbox_id=sandbox.id,
|
||||
repository="PostHog/posthog-js",
|
||||
github_integration_id=github_integration.id,
|
||||
task_id="test-task-fail",
|
||||
distinct_id="test-user-id",
|
||||
)
|
||||
|
||||
with patch(
|
||||
"products.tasks.backend.temporal.process_task.activities.clone_repository.get_github_token"
|
||||
) as mock_get_token:
|
||||
mock_get_token.return_value = ""
|
||||
await activity_environment.run(clone_repository, clone_input)
|
||||
|
||||
# We mock the _get_task_command to run a failing command
|
||||
with patch(
|
||||
"products.tasks.backend.temporal.process_task.activities.execute_task_in_sandbox.SandboxAgent._get_task_command"
|
||||
) as mock_task_cmd:
|
||||
mock_task_cmd.return_value = "exit 1" # Command that fails
|
||||
|
||||
input_data = ExecuteTaskInput(
|
||||
sandbox_id=sandbox.id,
|
||||
task_id="test-task-fail",
|
||||
repository="PostHog/posthog-js",
|
||||
distinct_id="test-user-id",
|
||||
)
|
||||
|
||||
with pytest.raises(TaskExecutionFailedError):
|
||||
await activity_environment.run(execute_task_in_sandbox, input_data)
|
||||
|
||||
finally:
|
||||
if sandbox:
|
||||
await sandbox.destroy()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.django_db
|
||||
async def test_execute_task_repository_not_found(self, activity_environment):
|
||||
"""Test task execution when repository doesn't exist in sandbox."""
|
||||
config = SandboxEnvironmentConfig(
|
||||
name="test-execute-task-no-repo",
|
||||
template=SandboxEnvironmentTemplate.DEFAULT_BASE,
|
||||
)
|
||||
|
||||
sandbox = None
|
||||
try:
|
||||
sandbox = await SandboxEnvironment.create(config)
|
||||
|
||||
# We don't clone any repository, just try to execute task
|
||||
with patch(
|
||||
"products.tasks.backend.temporal.process_task.activities.execute_task_in_sandbox.SandboxAgent._get_task_command"
|
||||
) as mock_task_cmd:
|
||||
mock_task_cmd.return_value = "ls -la"
|
||||
|
||||
input_data = ExecuteTaskInput(
|
||||
sandbox_id=sandbox.id,
|
||||
task_id="test-task-no-repo",
|
||||
repository="PostHog/posthog-js",
|
||||
distinct_id="test-user-id",
|
||||
)
|
||||
|
||||
with pytest.raises(TaskExecutionFailedError):
|
||||
await activity_environment.run(execute_task_in_sandbox, input_data)
|
||||
|
||||
finally:
|
||||
if sandbox:
|
||||
await sandbox.destroy()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.django_db
|
||||
async def test_execute_task_sandbox_not_found(self, activity_environment):
|
||||
input_data = ExecuteTaskInput(
|
||||
sandbox_id="non-existent-sandbox-id",
|
||||
task_id="test-task",
|
||||
repository="PostHog/posthog-js",
|
||||
distinct_id="test-user-id",
|
||||
)
|
||||
|
||||
with pytest.raises(SandboxNotFoundError):
|
||||
await activity_environment.run(execute_task_in_sandbox, input_data)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.django_db
|
||||
async def test_execute_task_with_different_repositories(self, activity_environment, github_integration):
|
||||
config = SandboxEnvironmentConfig(
|
||||
name="test-execute-different-repos",
|
||||
template=SandboxEnvironmentTemplate.DEFAULT_BASE,
|
||||
)
|
||||
|
||||
sandbox = None
|
||||
try:
|
||||
sandbox = await SandboxEnvironment.create(config)
|
||||
|
||||
repos_to_test = ["PostHog/posthog-js", "PostHog/posthog.com"]
|
||||
|
||||
with patch(
|
||||
"products.tasks.backend.temporal.process_task.activities.clone_repository.get_github_token"
|
||||
) as mock_get_token:
|
||||
mock_get_token.return_value = ""
|
||||
|
||||
for repo in repos_to_test:
|
||||
clone_input = CloneRepositoryInput(
|
||||
sandbox_id=sandbox.id,
|
||||
repository=repo,
|
||||
github_integration_id=github_integration.id,
|
||||
task_id=f"test-task-{repo.split('/')[1]}",
|
||||
distinct_id="test-user-id",
|
||||
)
|
||||
await activity_environment.run(clone_repository, clone_input)
|
||||
|
||||
# Execute task in each repository
|
||||
with patch(
|
||||
"products.tasks.backend.temporal.process_task.activities.execute_task_in_sandbox.SandboxAgent._get_task_command"
|
||||
) as mock_task_cmd:
|
||||
mock_task_cmd.return_value = f"echo 'Working in {repo}'"
|
||||
|
||||
input_data = ExecuteTaskInput(
|
||||
sandbox_id=sandbox.id,
|
||||
task_id=f"test-task-{repo.split('/')[1]}",
|
||||
repository=repo,
|
||||
distinct_id="test-user-id",
|
||||
)
|
||||
|
||||
await activity_environment.run(execute_task_in_sandbox, input_data)
|
||||
|
||||
mock_task_cmd.assert_called_once_with(f"test-task-{repo.split('/')[1]}")
|
||||
|
||||
finally:
|
||||
if sandbox:
|
||||
await sandbox.destroy()
|
||||
@@ -0,0 +1,165 @@
|
||||
import os
|
||||
import uuid
|
||||
|
||||
import pytest
|
||||
|
||||
from asgiref.sync import sync_to_async
|
||||
|
||||
from products.tasks.backend.models import SandboxSnapshot
|
||||
from products.tasks.backend.services.sandbox_environment import SandboxEnvironment
|
||||
from products.tasks.backend.temporal.process_task.activities.get_sandbox_for_setup import (
|
||||
GetSandboxForSetupInput,
|
||||
get_sandbox_for_setup,
|
||||
)
|
||||
from products.tasks.backend.temporal.process_task.utils import get_sandbox_name_for_task
|
||||
|
||||
from .constants import BASE_SNAPSHOT
|
||||
|
||||
|
||||
@pytest.mark.skipif(not os.environ.get("RUNLOOP_API_KEY"), reason="RUNLOOP_API_KEY environment variable not set")
|
||||
class TestGetSandboxForSetupActivity:
|
||||
"""Test suite for the get_sandbox_for_setup activity."""
|
||||
|
||||
async def _create_snapshot(self, github_integration, external_id=None, status=SandboxSnapshot.Status.COMPLETE):
|
||||
"""Helper method to create a snapshot."""
|
||||
if external_id is None:
|
||||
external_id = str(uuid.uuid4())
|
||||
return await sync_to_async(SandboxSnapshot.objects.create)(
|
||||
integration=github_integration,
|
||||
external_id=external_id,
|
||||
status=status,
|
||||
)
|
||||
|
||||
async def _cleanup_snapshot(self, snapshot):
|
||||
"""Helper method to clean up a snapshot."""
|
||||
await sync_to_async(snapshot.delete)()
|
||||
|
||||
async def _cleanup_sandbox(self, sandbox_id):
|
||||
"""Helper method to clean up a sandbox."""
|
||||
|
||||
sandbox = await SandboxEnvironment.get_by_id(sandbox_id)
|
||||
await sandbox.destroy()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.django_db
|
||||
async def test_get_sandbox_for_setup_with_existing_snapshot(self, activity_environment, github_integration, ateam):
|
||||
snapshot = await self._create_snapshot(github_integration, external_id=BASE_SNAPSHOT["external_id"])
|
||||
|
||||
task_id = "test-task-123"
|
||||
sandbox_id = None
|
||||
|
||||
try:
|
||||
input_data = GetSandboxForSetupInput(
|
||||
github_integration_id=github_integration.id,
|
||||
team_id=ateam.id,
|
||||
task_id=task_id,
|
||||
distinct_id="test-user-id",
|
||||
)
|
||||
sandbox_id = await activity_environment.run(get_sandbox_for_setup, input_data)
|
||||
|
||||
assert isinstance(sandbox_id, str)
|
||||
assert len(sandbox_id) > 0
|
||||
|
||||
# Verify sandbox was created
|
||||
sandbox = await SandboxEnvironment.get_by_id(sandbox_id)
|
||||
assert sandbox.id == sandbox_id
|
||||
assert sandbox.status in ["pending", "initializing", "running"]
|
||||
|
||||
finally:
|
||||
await self._cleanup_snapshot(snapshot)
|
||||
if sandbox_id:
|
||||
await self._cleanup_sandbox(sandbox_id)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.django_db
|
||||
async def test_get_sandbox_for_setup_without_existing_snapshot(
|
||||
self, activity_environment, github_integration, ateam
|
||||
):
|
||||
task_id = "test-task-456"
|
||||
sandbox_id = None
|
||||
|
||||
try:
|
||||
input_data = GetSandboxForSetupInput(
|
||||
github_integration_id=github_integration.id,
|
||||
team_id=ateam.id,
|
||||
task_id=task_id,
|
||||
distinct_id="test-user-id",
|
||||
)
|
||||
sandbox_id = await activity_environment.run(get_sandbox_for_setup, input_data)
|
||||
|
||||
assert isinstance(sandbox_id, str)
|
||||
assert len(sandbox_id) > 0
|
||||
|
||||
# Verify sandbox was created
|
||||
sandbox = await SandboxEnvironment.get_by_id(sandbox_id)
|
||||
assert sandbox.id == sandbox_id
|
||||
|
||||
assert sandbox.status in ["pending", "initializing", "running"]
|
||||
|
||||
finally:
|
||||
if sandbox_id:
|
||||
await self._cleanup_sandbox(sandbox_id)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.django_db
|
||||
async def test_get_sandbox_for_setup_ignores_incomplete_snapshots(
|
||||
self, activity_environment, github_integration, ateam
|
||||
):
|
||||
# Create snapshots with incomplete status
|
||||
in_progress_snapshot = await self._create_snapshot(
|
||||
github_integration, status=SandboxSnapshot.Status.IN_PROGRESS
|
||||
)
|
||||
error_snapshot = await self._create_snapshot(github_integration, status=SandboxSnapshot.Status.ERROR)
|
||||
|
||||
task_id = "test-task-789"
|
||||
sandbox_id = None
|
||||
|
||||
try:
|
||||
input_data = GetSandboxForSetupInput(
|
||||
github_integration_id=github_integration.id,
|
||||
team_id=ateam.id,
|
||||
task_id=task_id,
|
||||
distinct_id="test-user-id",
|
||||
)
|
||||
sandbox_id = await activity_environment.run(get_sandbox_for_setup, input_data)
|
||||
|
||||
assert isinstance(sandbox_id, str)
|
||||
assert len(sandbox_id) > 0
|
||||
|
||||
# Verify sandbox was created (should not use incomplete snapshots as base)
|
||||
sandbox = await SandboxEnvironment.get_by_id(sandbox_id)
|
||||
assert sandbox.id == sandbox_id
|
||||
|
||||
finally:
|
||||
await self._cleanup_snapshot(in_progress_snapshot)
|
||||
await self._cleanup_snapshot(error_snapshot)
|
||||
if sandbox_id:
|
||||
await self._cleanup_sandbox(sandbox_id)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.django_db
|
||||
async def test_get_sandbox_for_setup_sandbox_name_generation(self, activity_environment, github_integration, ateam):
|
||||
task_id = "special-task-id-with-uuid-abc123"
|
||||
sandbox_id = None
|
||||
|
||||
try:
|
||||
input_data = GetSandboxForSetupInput(
|
||||
github_integration_id=github_integration.id,
|
||||
team_id=ateam.id,
|
||||
task_id=task_id,
|
||||
distinct_id="test-user-id",
|
||||
)
|
||||
sandbox_id = await activity_environment.run(get_sandbox_for_setup, input_data)
|
||||
|
||||
assert isinstance(sandbox_id, str)
|
||||
assert len(sandbox_id) > 0
|
||||
|
||||
# Verify sandbox exists
|
||||
sandbox = await SandboxEnvironment.get_by_id(sandbox_id)
|
||||
|
||||
assert sandbox.id == sandbox_id
|
||||
assert sandbox.name == get_sandbox_name_for_task(task_id)
|
||||
|
||||
finally:
|
||||
if sandbox_id:
|
||||
await self._cleanup_sandbox(sandbox_id)
|
||||
@@ -0,0 +1,96 @@
|
||||
import pytest
|
||||
|
||||
from django.core.exceptions import ValidationError
|
||||
|
||||
from asgiref.sync import sync_to_async
|
||||
|
||||
from products.tasks.backend.models import Task
|
||||
from products.tasks.backend.temporal.exceptions import TaskInvalidStateError, TaskNotFoundError
|
||||
from products.tasks.backend.temporal.process_task.activities.get_task_details import TaskDetails, get_task_details
|
||||
|
||||
|
||||
class TestGetTaskDetailsActivity:
|
||||
async def _create_task_with_repo(self, ateam, auser, task_workflow, github_integration, repo_config):
|
||||
workflow, stages = task_workflow
|
||||
backlog_stage = stages[0]
|
||||
|
||||
return await sync_to_async(Task.objects.create)(
|
||||
team=ateam,
|
||||
title="Test Task",
|
||||
description="Test task description",
|
||||
origin_product=Task.OriginProduct.USER_CREATED,
|
||||
workflow=workflow,
|
||||
current_stage=backlog_stage,
|
||||
position=0,
|
||||
github_integration=github_integration,
|
||||
repository_config=repo_config,
|
||||
created_by=auser,
|
||||
)
|
||||
|
||||
async def _cleanup_task(self, task):
|
||||
await sync_to_async(task.delete)()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.django_db
|
||||
async def test_get_task_details_success(self, activity_environment, test_task):
|
||||
result = await activity_environment.run(get_task_details, str(test_task.id))
|
||||
|
||||
assert isinstance(result, TaskDetails)
|
||||
assert result.task_id == str(test_task.id)
|
||||
assert result.team_id == test_task.team_id
|
||||
assert result.github_integration_id == test_task.github_integration_id
|
||||
assert result.repository == "posthog/posthog-js"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.django_db
|
||||
async def test_get_task_details_task_not_found(self, activity_environment):
|
||||
non_existent_task_id = "550e8400-e29b-41d4-a716-446655440000"
|
||||
|
||||
with pytest.raises(TaskNotFoundError):
|
||||
await activity_environment.run(get_task_details, non_existent_task_id)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.django_db
|
||||
async def test_get_task_details_invalid_uuid(self, activity_environment):
|
||||
invalid_task_id = "not-a-uuid"
|
||||
|
||||
with pytest.raises(ValidationError):
|
||||
await activity_environment.run(get_task_details, invalid_task_id)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.django_db
|
||||
async def test_get_task_details_with_different_repository(
|
||||
self, activity_environment, ateam, auser, task_workflow, github_integration
|
||||
):
|
||||
task = await self._create_task_with_repo(
|
||||
ateam, auser, task_workflow, github_integration, {"organization": "posthog", "repository": "posthog-js"}
|
||||
)
|
||||
|
||||
try:
|
||||
result = await activity_environment.run(get_task_details, str(task.id))
|
||||
|
||||
assert result.task_id == str(task.id)
|
||||
assert result.team_id == task.team_id
|
||||
assert result.github_integration_id == github_integration.id
|
||||
assert result.repository == "posthog/posthog-js"
|
||||
finally:
|
||||
await self._cleanup_task(task)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.django_db
|
||||
async def test_get_task_details_with_missing_repository(
|
||||
self, activity_environment, ateam, auser, task_workflow, github_integration
|
||||
):
|
||||
task = await self._create_task_with_repo(
|
||||
ateam,
|
||||
auser,
|
||||
task_workflow,
|
||||
github_integration,
|
||||
{"organization": "test-org"}, # Missing "repository" key
|
||||
)
|
||||
|
||||
try:
|
||||
with pytest.raises(TaskInvalidStateError):
|
||||
await activity_environment.run(get_task_details, str(task.id))
|
||||
finally:
|
||||
await self._cleanup_task(task)
|
||||
@@ -0,0 +1,72 @@
|
||||
import os
|
||||
|
||||
import pytest
|
||||
from unittest.mock import patch
|
||||
|
||||
from products.tasks.backend.services.sandbox_environment import (
|
||||
SandboxEnvironment,
|
||||
SandboxEnvironmentConfig,
|
||||
SandboxEnvironmentTemplate,
|
||||
)
|
||||
from products.tasks.backend.temporal.exceptions import SandboxNotFoundError
|
||||
from products.tasks.backend.temporal.process_task.activities.inject_github_token import (
|
||||
InjectGitHubTokenInput,
|
||||
inject_github_token,
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.skipif(not os.environ.get("RUNLOOP_API_KEY"), reason="RUNLOOP_API_KEY environment variable not set")
|
||||
class TestInjectGitHubTokenActivity:
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.django_db
|
||||
async def test_inject_github_token_success(self, activity_environment, github_integration):
|
||||
config = SandboxEnvironmentConfig(
|
||||
name="test-inject-token-success",
|
||||
template=SandboxEnvironmentTemplate.DEFAULT_BASE,
|
||||
)
|
||||
|
||||
sandbox = None
|
||||
try:
|
||||
sandbox = await SandboxEnvironment.create(config)
|
||||
|
||||
input_data = InjectGitHubTokenInput(
|
||||
sandbox_id=sandbox.id,
|
||||
github_integration_id=github_integration.id,
|
||||
task_id="test-task-123",
|
||||
distinct_id="test-user-id",
|
||||
)
|
||||
|
||||
test_token = "ghp_test_token_12345"
|
||||
|
||||
with patch(
|
||||
"products.tasks.backend.temporal.process_task.activities.inject_github_token.get_github_token"
|
||||
) as mock_get_token:
|
||||
mock_get_token.return_value = test_token
|
||||
|
||||
await activity_environment.run(inject_github_token, input_data)
|
||||
|
||||
check_result = await sandbox.execute("bash -c 'source ~/.bashrc && echo $GITHUB_TOKEN'")
|
||||
assert check_result.exit_code == 0
|
||||
assert test_token in check_result.stdout
|
||||
|
||||
finally:
|
||||
if sandbox:
|
||||
await sandbox.destroy()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.django_db
|
||||
async def test_inject_github_token_sandbox_not_found(self, activity_environment, github_integration):
|
||||
input_data = InjectGitHubTokenInput(
|
||||
sandbox_id="non-existent-sandbox-id",
|
||||
github_integration_id=github_integration.id,
|
||||
task_id="test-task-not-found",
|
||||
distinct_id="test-user-id",
|
||||
)
|
||||
|
||||
with patch(
|
||||
"products.tasks.backend.temporal.process_task.activities.inject_github_token.get_github_token"
|
||||
) as mock_get_token:
|
||||
mock_get_token.return_value = "test_token"
|
||||
|
||||
with pytest.raises(SandboxNotFoundError):
|
||||
await activity_environment.run(inject_github_token, input_data)
|
||||
@@ -0,0 +1,102 @@
|
||||
import os
|
||||
|
||||
import pytest
|
||||
|
||||
from asgiref.sync import sync_to_async
|
||||
|
||||
from posthog.models import PersonalAPIKey
|
||||
|
||||
from products.tasks.backend.services.sandbox_environment import (
|
||||
SandboxEnvironment,
|
||||
SandboxEnvironmentConfig,
|
||||
SandboxEnvironmentTemplate,
|
||||
)
|
||||
from products.tasks.backend.temporal.exceptions import SandboxNotFoundError, TaskInvalidStateError
|
||||
from products.tasks.backend.temporal.process_task.activities.inject_personal_api_key import (
|
||||
InjectPersonalAPIKeyInput,
|
||||
InjectPersonalAPIKeyOutput,
|
||||
inject_personal_api_key,
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.skipif(not os.environ.get("RUNLOOP_API_KEY"), reason="RUNLOOP_API_KEY environment variable not set")
|
||||
class TestInjectPersonalAPIKeyActivity:
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.django_db
|
||||
async def test_inject_personal_api_key_success(self, activity_environment, test_task):
|
||||
config = SandboxEnvironmentConfig(
|
||||
name="test-inject-api-key-success",
|
||||
template=SandboxEnvironmentTemplate.DEFAULT_BASE,
|
||||
)
|
||||
|
||||
sandbox = None
|
||||
try:
|
||||
sandbox = await SandboxEnvironment.create(config)
|
||||
|
||||
input_data = InjectPersonalAPIKeyInput(
|
||||
sandbox_id=sandbox.id,
|
||||
task_id=str(test_task.id),
|
||||
distinct_id="test-user-id",
|
||||
)
|
||||
|
||||
result: InjectPersonalAPIKeyOutput = await activity_environment.run(inject_personal_api_key, input_data)
|
||||
|
||||
assert result.personal_api_key_id is not None
|
||||
|
||||
api_key = await sync_to_async(PersonalAPIKey.objects.get)(id=result.personal_api_key_id)
|
||||
assert api_key.user_id == test_task.created_by_id
|
||||
assert api_key.scopes is not None
|
||||
assert len(api_key.scopes) > 0
|
||||
assert api_key.scoped_teams == [test_task.team_id]
|
||||
assert f"Task Agent - {test_task.title[:20]}" == api_key.label
|
||||
|
||||
check_result = await sandbox.execute("bash -c 'source ~/.bashrc && echo $POSTHOG_PERSONAL_API_KEY'")
|
||||
assert check_result.exit_code == 0
|
||||
api_key_value = check_result.stdout.strip()
|
||||
assert api_key_value.startswith("phx_")
|
||||
|
||||
await sync_to_async(api_key.delete)()
|
||||
|
||||
finally:
|
||||
if sandbox:
|
||||
await sandbox.destroy()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.django_db
|
||||
async def test_inject_personal_api_key_no_user(self, activity_environment, test_task):
|
||||
config = SandboxEnvironmentConfig(
|
||||
name="test-inject-api-key-no-user",
|
||||
template=SandboxEnvironmentTemplate.DEFAULT_BASE,
|
||||
)
|
||||
|
||||
sandbox = None
|
||||
try:
|
||||
sandbox = await SandboxEnvironment.create(config)
|
||||
|
||||
test_task.created_by = None
|
||||
await sync_to_async(test_task.save)()
|
||||
|
||||
input_data = InjectPersonalAPIKeyInput(
|
||||
sandbox_id=sandbox.id,
|
||||
task_id=str(test_task.id),
|
||||
distinct_id="test-user-id",
|
||||
)
|
||||
|
||||
with pytest.raises(TaskInvalidStateError):
|
||||
await activity_environment.run(inject_personal_api_key, input_data)
|
||||
|
||||
finally:
|
||||
if sandbox:
|
||||
await sandbox.destroy()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.django_db
|
||||
async def test_inject_personal_api_key_sandbox_not_found(self, activity_environment, test_task):
|
||||
input_data = InjectPersonalAPIKeyInput(
|
||||
sandbox_id="non-existent-sandbox-id",
|
||||
task_id=str(test_task.id),
|
||||
distinct_id="test-user-id",
|
||||
)
|
||||
|
||||
with pytest.raises(SandboxNotFoundError):
|
||||
await activity_environment.run(inject_personal_api_key, input_data)
|
||||
@@ -0,0 +1,118 @@
|
||||
import os
|
||||
|
||||
import pytest
|
||||
from unittest.mock import patch
|
||||
|
||||
from products.tasks.backend.services.sandbox_environment import (
|
||||
SandboxEnvironment,
|
||||
SandboxEnvironmentConfig,
|
||||
SandboxEnvironmentTemplate,
|
||||
)
|
||||
from products.tasks.backend.temporal.exceptions import RepositorySetupError, SandboxNotFoundError
|
||||
from products.tasks.backend.temporal.process_task.activities.clone_repository import (
|
||||
CloneRepositoryInput,
|
||||
clone_repository,
|
||||
)
|
||||
from products.tasks.backend.temporal.process_task.activities.setup_repository import (
|
||||
SetupRepositoryInput,
|
||||
setup_repository,
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.skipif(not os.environ.get("RUNLOOP_API_KEY"), reason="RUNLOOP_API_KEY environment variable not set")
|
||||
class TestSetupRepositoryActivity:
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.django_db
|
||||
async def test_setup_repository_success(self, activity_environment, github_integration):
|
||||
config = SandboxEnvironmentConfig(
|
||||
name="test-setup-repository",
|
||||
template=SandboxEnvironmentTemplate.DEFAULT_BASE,
|
||||
)
|
||||
|
||||
sandbox = None
|
||||
try:
|
||||
sandbox = await SandboxEnvironment.create(config)
|
||||
|
||||
clone_input = CloneRepositoryInput(
|
||||
sandbox_id=sandbox.id,
|
||||
repository="posthog/posthog-js",
|
||||
github_integration_id=github_integration.id,
|
||||
task_id="test-task-123",
|
||||
distinct_id="test-user-id",
|
||||
)
|
||||
|
||||
with patch(
|
||||
"products.tasks.backend.temporal.process_task.activities.clone_repository.get_github_token"
|
||||
) as mock_get_token:
|
||||
mock_get_token.return_value = ""
|
||||
await activity_environment.run(clone_repository, clone_input)
|
||||
|
||||
check_before = await sandbox.execute(
|
||||
"ls -la /tmp/workspace/repos/posthog/posthog-js/ | grep node_modules || echo 'no node_modules'"
|
||||
)
|
||||
assert "no node_modules" in check_before.stdout
|
||||
|
||||
# We mock the _get_setup_command inside the setup_repository activity to just run pnpm install for the test, instead of using the coding agent
|
||||
with patch(
|
||||
"products.tasks.backend.temporal.process_task.activities.setup_repository.SandboxAgent._get_setup_command"
|
||||
) as mock_setup_cmd:
|
||||
mock_setup_cmd.return_value = "pnpm install"
|
||||
|
||||
setup_input = SetupRepositoryInput(
|
||||
sandbox_id=sandbox.id,
|
||||
repository="posthog/posthog-js",
|
||||
task_id="test-task-123",
|
||||
distinct_id="test-user-id",
|
||||
)
|
||||
|
||||
result = await activity_environment.run(setup_repository, setup_input)
|
||||
|
||||
assert result is not None
|
||||
|
||||
check_after = await sandbox.execute(
|
||||
"ls -la /tmp/workspace/repos/posthog/posthog-js/ | grep node_modules || echo 'no node_modules'"
|
||||
)
|
||||
assert "node_modules" in check_after.stdout
|
||||
assert "no node_modules" not in check_after.stdout
|
||||
|
||||
finally:
|
||||
if sandbox:
|
||||
await sandbox.destroy()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.django_db
|
||||
async def test_setup_repository_without_clone(self, activity_environment):
|
||||
config = SandboxEnvironmentConfig(
|
||||
name="test-setup-no-clone",
|
||||
template=SandboxEnvironmentTemplate.DEFAULT_BASE,
|
||||
)
|
||||
|
||||
sandbox = None
|
||||
try:
|
||||
sandbox = await SandboxEnvironment.create(config)
|
||||
|
||||
setup_input = SetupRepositoryInput(
|
||||
sandbox_id=sandbox.id,
|
||||
repository="posthog/posthog-js",
|
||||
task_id="test-task-no-clone",
|
||||
distinct_id="test-user-id",
|
||||
)
|
||||
|
||||
with pytest.raises(RepositorySetupError):
|
||||
await activity_environment.run(setup_repository, setup_input)
|
||||
finally:
|
||||
if sandbox:
|
||||
await sandbox.destroy()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.django_db
|
||||
async def test_setup_repository_sandbox_not_found(self, activity_environment):
|
||||
setup_input = SetupRepositoryInput(
|
||||
sandbox_id="non-existent-sandbox-id",
|
||||
repository="posthog/posthog-js",
|
||||
task_id="test-task-not-found",
|
||||
distinct_id="test-user-id",
|
||||
)
|
||||
|
||||
with pytest.raises(SandboxNotFoundError):
|
||||
await activity_environment.run(setup_repository, setup_input)
|
||||
@@ -0,0 +1,34 @@
|
||||
from dataclasses import dataclass
|
||||
from typing import Any
|
||||
|
||||
import posthoganalytics
|
||||
from temporalio import activity
|
||||
|
||||
from posthog.temporal.common.logger import get_logger
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
@dataclass
|
||||
class TrackWorkflowEventInput:
|
||||
event_name: str
|
||||
distinct_id: str
|
||||
properties: dict[str, Any]
|
||||
|
||||
|
||||
@activity.defn
|
||||
def track_workflow_event(input: TrackWorkflowEventInput) -> None:
|
||||
"""Track workflow-level events to PostHog."""
|
||||
try:
|
||||
posthoganalytics.capture(
|
||||
distinct_id=input.distinct_id,
|
||||
event=input.event_name,
|
||||
properties=input.properties,
|
||||
)
|
||||
except Exception:
|
||||
logger.exception(
|
||||
"Failed to track workflow event",
|
||||
event_name=input.event_name,
|
||||
distinct_id=input.distinct_id,
|
||||
properties=input.properties,
|
||||
)
|
||||
@@ -0,0 +1,284 @@
|
||||
import os
|
||||
import uuid
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
from datetime import timedelta
|
||||
from typing import cast
|
||||
|
||||
import pytest
|
||||
from unittest.mock import patch
|
||||
|
||||
from asgiref.sync import sync_to_async
|
||||
from temporalio.common import RetryPolicy
|
||||
from temporalio.testing import WorkflowEnvironment
|
||||
from temporalio.worker import UnsandboxedWorkflowRunner, Worker
|
||||
|
||||
from posthog import constants
|
||||
|
||||
from products.tasks.backend.models import SandboxSnapshot
|
||||
from products.tasks.backend.services.sandbox_environment import SandboxEnvironment, SandboxEnvironmentStatus
|
||||
from products.tasks.backend.temporal.process_task.activities.check_snapshot_exists_for_repository import (
|
||||
check_snapshot_exists_for_repository,
|
||||
)
|
||||
from products.tasks.backend.temporal.process_task.activities.cleanup_personal_api_key import cleanup_personal_api_key
|
||||
from products.tasks.backend.temporal.process_task.activities.cleanup_sandbox import cleanup_sandbox
|
||||
from products.tasks.backend.temporal.process_task.activities.clone_repository import clone_repository
|
||||
from products.tasks.backend.temporal.process_task.activities.create_sandbox_from_snapshot import (
|
||||
create_sandbox_from_snapshot,
|
||||
)
|
||||
from products.tasks.backend.temporal.process_task.activities.create_snapshot import create_snapshot
|
||||
from products.tasks.backend.temporal.process_task.activities.execute_task_in_sandbox import execute_task_in_sandbox
|
||||
from products.tasks.backend.temporal.process_task.activities.get_sandbox_for_setup import get_sandbox_for_setup
|
||||
from products.tasks.backend.temporal.process_task.activities.get_task_details import get_task_details
|
||||
from products.tasks.backend.temporal.process_task.activities.inject_github_token import inject_github_token
|
||||
from products.tasks.backend.temporal.process_task.activities.inject_personal_api_key import inject_personal_api_key
|
||||
from products.tasks.backend.temporal.process_task.activities.setup_repository import setup_repository
|
||||
from products.tasks.backend.temporal.process_task.activities.tests.constants import POSTHOG_JS_SNAPSHOT
|
||||
from products.tasks.backend.temporal.process_task.activities.track_workflow_event import track_workflow_event
|
||||
from products.tasks.backend.temporal.process_task.workflow import ProcessTaskOutput, ProcessTaskWorkflow
|
||||
|
||||
pytestmark = [pytest.mark.asyncio, pytest.mark.django_db]
|
||||
|
||||
|
||||
@pytest.mark.skipif(not os.environ.get("RUNLOOP_API_KEY"), reason="RUNLOOP_API_KEY environment variable not set")
|
||||
class TestProcessTaskWorkflow:
|
||||
"""
|
||||
End-to-end workflow tests using real Runloop sandboxes.
|
||||
|
||||
These tests create actual sandboxes and snapshots, only mocking the task execution command
|
||||
to avoid running the full AI agent. This allows us to verify:
|
||||
- Snapshot creation and reuse
|
||||
- Sandbox lifecycle management
|
||||
- Proper cleanup on success and failure
|
||||
"""
|
||||
|
||||
async def _run_workflow(self, task_id: str, mock_task_command: str = "echo 'task complete'") -> ProcessTaskOutput:
|
||||
workflow_id = str(uuid.uuid4())
|
||||
|
||||
with (
|
||||
patch(
|
||||
"products.tasks.backend.temporal.process_task.activities.setup_repository.SandboxAgent._get_setup_command"
|
||||
) as mock_setup,
|
||||
patch(
|
||||
"products.tasks.backend.temporal.process_task.activities.execute_task_in_sandbox.SandboxAgent._get_task_command"
|
||||
) as mock_task,
|
||||
):
|
||||
mock_setup.return_value = "pnpm install"
|
||||
mock_task.return_value = mock_task_command
|
||||
|
||||
async with (
|
||||
await WorkflowEnvironment.start_time_skipping() as env,
|
||||
Worker(
|
||||
env.client,
|
||||
task_queue=constants.TASKS_TASK_QUEUE,
|
||||
workflows=[ProcessTaskWorkflow],
|
||||
activities=[
|
||||
get_task_details,
|
||||
check_snapshot_exists_for_repository,
|
||||
get_sandbox_for_setup,
|
||||
clone_repository,
|
||||
setup_repository,
|
||||
create_snapshot,
|
||||
create_sandbox_from_snapshot,
|
||||
inject_github_token,
|
||||
inject_personal_api_key,
|
||||
execute_task_in_sandbox,
|
||||
cleanup_sandbox,
|
||||
cleanup_personal_api_key,
|
||||
track_workflow_event,
|
||||
],
|
||||
workflow_runner=UnsandboxedWorkflowRunner(),
|
||||
activity_executor=ThreadPoolExecutor(max_workers=10),
|
||||
),
|
||||
):
|
||||
result = await env.client.execute_workflow(
|
||||
ProcessTaskWorkflow.run,
|
||||
task_id,
|
||||
id=workflow_id,
|
||||
task_queue=constants.TASKS_TASK_QUEUE,
|
||||
retry_policy=RetryPolicy(maximum_attempts=1),
|
||||
execution_timeout=timedelta(minutes=60),
|
||||
)
|
||||
|
||||
return result
|
||||
|
||||
async def _verify_file_in_sandbox(self, sandbox_id: str, filepath: str) -> bool:
|
||||
"""Verify a file exists in a sandbox."""
|
||||
sandbox = await SandboxEnvironment.get_by_id(sandbox_id)
|
||||
result = await sandbox.execute(f"test -f {filepath} && echo 'exists' || echo 'missing'")
|
||||
return "exists" in result.stdout
|
||||
|
||||
async def test_workflow_with_existing_snapshot_reuses_snapshot(self, test_task, github_integration):
|
||||
snapshot, _ = await sync_to_async(SandboxSnapshot.objects.get_or_create)(
|
||||
external_id=POSTHOG_JS_SNAPSHOT["external_id"],
|
||||
defaults={
|
||||
"integration": github_integration,
|
||||
"repos": POSTHOG_JS_SNAPSHOT["repos"],
|
||||
"status": SandboxSnapshot.Status.COMPLETE,
|
||||
},
|
||||
)
|
||||
|
||||
try:
|
||||
result = await self._run_workflow(test_task.id)
|
||||
|
||||
assert result.success is True
|
||||
assert result.task_result is not None
|
||||
assert result.task_result.exit_code == 0
|
||||
assert "task complete" in result.task_result.stdout
|
||||
|
||||
snapshots_query = SandboxSnapshot.objects.filter(integration=github_integration).order_by("-created_at")
|
||||
snapshots = cast(list[SandboxSnapshot], await sync_to_async(list)(snapshots_query)) # type: ignore[call-arg]
|
||||
assert len(snapshots) == 1
|
||||
assert snapshots[0].id == snapshot.id
|
||||
assert "posthog/posthog-js" in snapshots[0].repos
|
||||
|
||||
finally:
|
||||
await sync_to_async(snapshot.delete)()
|
||||
|
||||
async def test_workflow_creates_snapshot_for_new_repository(self, test_task, github_integration):
|
||||
created_snapshots = []
|
||||
|
||||
try:
|
||||
result = await self._run_workflow(test_task.id)
|
||||
|
||||
assert result.success is True
|
||||
assert result.task_result is not None
|
||||
assert result.task_result.exit_code == 0
|
||||
|
||||
snapshots_query = SandboxSnapshot.objects.filter(
|
||||
integration=github_integration, status=SandboxSnapshot.Status.COMPLETE
|
||||
).order_by("-created_at")
|
||||
snapshots = cast(list[SandboxSnapshot], await sync_to_async(list)(snapshots_query)) # type: ignore[call-arg]
|
||||
|
||||
assert len(snapshots) >= 1
|
||||
latest_snapshot = snapshots[0]
|
||||
assert "posthog/posthog-js" in latest_snapshot.repos
|
||||
assert latest_snapshot.status == SandboxSnapshot.Status.COMPLETE
|
||||
assert latest_snapshot.external_id is not None
|
||||
|
||||
created_snapshots.append(latest_snapshot)
|
||||
|
||||
finally:
|
||||
for snapshot in created_snapshots:
|
||||
try:
|
||||
if snapshot.external_id:
|
||||
await SandboxEnvironment.delete_snapshot(snapshot.external_id)
|
||||
await sync_to_async(snapshot.delete)()
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
async def test_workflow_executes_task_in_sandbox(self, test_task, github_integration):
|
||||
snapshot = await sync_to_async(SandboxSnapshot.objects.create)(
|
||||
integration=github_integration,
|
||||
external_id=POSTHOG_JS_SNAPSHOT["external_id"],
|
||||
repos=POSTHOG_JS_SNAPSHOT["repos"],
|
||||
status=SandboxSnapshot.Status.COMPLETE,
|
||||
)
|
||||
|
||||
custom_message = f"workflow_test_{uuid.uuid4().hex[:8]}"
|
||||
|
||||
try:
|
||||
result = await self._run_workflow(test_task.id, mock_task_command=f"echo '{custom_message}'")
|
||||
|
||||
assert result.success is True
|
||||
assert result.task_result is not None
|
||||
assert result.task_result.exit_code == 0
|
||||
assert custom_message in result.task_result.stdout
|
||||
|
||||
finally:
|
||||
await sync_to_async(snapshot.delete)()
|
||||
|
||||
async def test_workflow_cleans_up_sandbox_on_success(self, test_task, github_integration):
|
||||
snapshot = await sync_to_async(SandboxSnapshot.objects.create)(
|
||||
integration=github_integration,
|
||||
external_id=POSTHOG_JS_SNAPSHOT["external_id"],
|
||||
repos=POSTHOG_JS_SNAPSHOT["repos"],
|
||||
status=SandboxSnapshot.Status.COMPLETE,
|
||||
)
|
||||
|
||||
try:
|
||||
result = await self._run_workflow(test_task.id)
|
||||
|
||||
assert result.success is True
|
||||
assert result.task_result is not None
|
||||
assert result.sandbox_id is not None
|
||||
|
||||
sandbox = await SandboxEnvironment.get_by_id(result.sandbox_id)
|
||||
assert sandbox.status == SandboxEnvironmentStatus.SHUTDOWN.value
|
||||
|
||||
finally:
|
||||
await sync_to_async(snapshot.delete)()
|
||||
|
||||
async def test_workflow_cleans_up_sandbox_on_failure(self, test_task, github_integration):
|
||||
snapshot = await sync_to_async(SandboxSnapshot.objects.create)(
|
||||
integration=github_integration,
|
||||
external_id=POSTHOG_JS_SNAPSHOT["external_id"],
|
||||
repos=POSTHOG_JS_SNAPSHOT["repos"],
|
||||
status=SandboxSnapshot.Status.COMPLETE,
|
||||
)
|
||||
|
||||
try:
|
||||
result = await self._run_workflow(test_task.id, mock_task_command="exit 1")
|
||||
|
||||
assert result.success is False
|
||||
assert result.error is not None
|
||||
assert result.task_result is None
|
||||
|
||||
assert result.sandbox_id is not None
|
||||
sandbox_id = result.sandbox_id
|
||||
|
||||
sandbox = await SandboxEnvironment.get_by_id(sandbox_id)
|
||||
assert sandbox.status == SandboxEnvironmentStatus.SHUTDOWN.value
|
||||
|
||||
finally:
|
||||
await sync_to_async(snapshot.delete)()
|
||||
|
||||
async def test_workflow_handles_missing_task(self):
|
||||
fake_task_id = str(uuid.uuid4())
|
||||
|
||||
result = await self._run_workflow(fake_task_id)
|
||||
|
||||
assert result.success is False
|
||||
assert result.error is not None
|
||||
assert "activity task failed" in result.error.lower() or "failed" in result.error.lower()
|
||||
|
||||
async def test_workflow_full_cycle_no_snapshot(self, test_task, github_integration):
|
||||
created_snapshots = []
|
||||
|
||||
try:
|
||||
result = await self._run_workflow(test_task.id)
|
||||
|
||||
assert result.success is True
|
||||
assert result.task_result is not None
|
||||
assert result.task_result.exit_code == 0
|
||||
|
||||
snapshots_query = SandboxSnapshot.objects.filter(
|
||||
integration=github_integration, status=SandboxSnapshot.Status.COMPLETE
|
||||
).order_by("-created_at")
|
||||
snapshots = cast(list[SandboxSnapshot], await sync_to_async(list)(snapshots_query)) # type: ignore[call-arg]
|
||||
|
||||
assert len(snapshots) >= 1
|
||||
latest_snapshot = snapshots[0]
|
||||
assert "posthog/posthog-js" in latest_snapshot.repos
|
||||
assert latest_snapshot.status == SandboxSnapshot.Status.COMPLETE
|
||||
|
||||
created_snapshots.append(latest_snapshot)
|
||||
|
||||
result2 = await self._run_workflow(test_task.id)
|
||||
|
||||
assert result2.success is True
|
||||
assert result2.task_result is not None
|
||||
|
||||
snapshots_after_query = SandboxSnapshot.objects.filter(
|
||||
integration=github_integration, status=SandboxSnapshot.Status.COMPLETE
|
||||
).order_by("-created_at")
|
||||
snapshots_after = cast(list[SandboxSnapshot], await sync_to_async(list)(snapshots_after_query)) # type: ignore[call-arg]
|
||||
assert len(snapshots_after) == len(snapshots)
|
||||
|
||||
finally:
|
||||
for snapshot in created_snapshots:
|
||||
try:
|
||||
if snapshot.external_id:
|
||||
await SandboxEnvironment.delete_snapshot(snapshot.external_id)
|
||||
await sync_to_async(snapshot.delete)()
|
||||
except Exception:
|
||||
pass
|
||||
19
products/tasks/backend/temporal/process_task/utils.py
Normal file
19
products/tasks/backend/temporal/process_task/utils.py
Normal file
@@ -0,0 +1,19 @@
|
||||
from typing import Optional
|
||||
|
||||
from posthog.models.integration import GitHubIntegration, Integration
|
||||
from posthog.temporal.common.utils import asyncify
|
||||
|
||||
|
||||
@asyncify
|
||||
def get_github_token(github_integration_id: int) -> Optional[str]:
|
||||
integration = Integration.objects.get(id=github_integration_id)
|
||||
github_integration = GitHubIntegration(integration)
|
||||
|
||||
if github_integration.access_token_expired():
|
||||
github_integration.refresh_access_token()
|
||||
|
||||
return github_integration.integration.access_token or None
|
||||
|
||||
|
||||
def get_sandbox_name_for_task(task_id: str) -> str:
|
||||
return f"task-sandbox-{task_id}"
|
||||
318
products/tasks/backend/temporal/process_task/workflow.py
Normal file
318
products/tasks/backend/temporal/process_task/workflow.py
Normal file
@@ -0,0 +1,318 @@
|
||||
import json
|
||||
from dataclasses import dataclass
|
||||
from datetime import timedelta
|
||||
from typing import Optional
|
||||
|
||||
import temporalio
|
||||
from temporalio import workflow
|
||||
from temporalio.common import RetryPolicy
|
||||
|
||||
from posthog.temporal.common.base import PostHogWorkflow
|
||||
from posthog.temporal.common.logger import get_logger
|
||||
|
||||
from .activities.check_snapshot_exists_for_repository import (
|
||||
CheckSnapshotExistsForRepositoryInput,
|
||||
check_snapshot_exists_for_repository,
|
||||
)
|
||||
from .activities.cleanup_personal_api_key import cleanup_personal_api_key
|
||||
from .activities.cleanup_sandbox import CleanupSandboxInput, cleanup_sandbox
|
||||
from .activities.clone_repository import CloneRepositoryInput, clone_repository
|
||||
from .activities.create_sandbox_from_snapshot import CreateSandboxFromSnapshotInput, create_sandbox_from_snapshot
|
||||
from .activities.create_snapshot import CreateSnapshotInput, create_snapshot
|
||||
from .activities.execute_task_in_sandbox import ExecuteTaskInput, ExecuteTaskOutput, execute_task_in_sandbox
|
||||
from .activities.get_sandbox_for_setup import GetSandboxForSetupInput, get_sandbox_for_setup
|
||||
from .activities.get_task_details import TaskDetails, get_task_details
|
||||
from .activities.inject_github_token import InjectGitHubTokenInput, inject_github_token
|
||||
from .activities.inject_personal_api_key import (
|
||||
InjectPersonalAPIKeyInput,
|
||||
InjectPersonalAPIKeyOutput,
|
||||
inject_personal_api_key,
|
||||
)
|
||||
from .activities.setup_repository import SetupRepositoryInput, setup_repository
|
||||
from .activities.track_workflow_event import TrackWorkflowEventInput, track_workflow_event
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
@dataclass
|
||||
class ProcessTaskOutput:
|
||||
success: bool
|
||||
task_result: Optional[ExecuteTaskOutput] = None
|
||||
error: Optional[str] = None
|
||||
sandbox_id: Optional[str] = None
|
||||
|
||||
|
||||
@temporalio.workflow.defn(name="process-task")
|
||||
class ProcessTaskWorkflow(PostHogWorkflow):
|
||||
def __init__(self) -> None:
|
||||
self._task_details: Optional[TaskDetails] = None
|
||||
|
||||
@property
|
||||
def task_details(self) -> TaskDetails:
|
||||
if self._task_details is None:
|
||||
raise RuntimeError("task_details accessed before being set")
|
||||
return self._task_details
|
||||
|
||||
@staticmethod
|
||||
def parse_inputs(inputs: list[str]) -> str:
|
||||
loaded = json.loads(inputs[0])
|
||||
return loaded["task_id"]
|
||||
|
||||
@temporalio.workflow.run
|
||||
async def run(self, task_id: str) -> ProcessTaskOutput:
|
||||
sandbox_id = None
|
||||
personal_api_key_id = None
|
||||
|
||||
try:
|
||||
self._task_details = await self._get_task_details(task_id)
|
||||
|
||||
await self._track_workflow_event(
|
||||
"process_task_workflow_started",
|
||||
{
|
||||
"task_id": self.task_details.task_id,
|
||||
"repository": self.task_details.repository,
|
||||
"team_id": self.task_details.team_id,
|
||||
},
|
||||
)
|
||||
|
||||
snapshot_id = await self._get_snapshot_for_repository()
|
||||
|
||||
sandbox_id = await self._create_sandbox_from_snapshot(snapshot_id)
|
||||
|
||||
await self._inject_github_token(sandbox_id)
|
||||
|
||||
api_key_output = await self._inject_personal_api_key(sandbox_id)
|
||||
personal_api_key_id = api_key_output.personal_api_key_id
|
||||
|
||||
result = await self._execute_task_in_sandbox(sandbox_id)
|
||||
|
||||
await self._track_workflow_event(
|
||||
"process_task_workflow_completed",
|
||||
{
|
||||
"task_id": self.task_details.task_id,
|
||||
"sandbox_id": sandbox_id,
|
||||
"exit_code": result.exit_code,
|
||||
},
|
||||
)
|
||||
|
||||
return ProcessTaskOutput(
|
||||
success=True,
|
||||
task_result=result,
|
||||
error=None,
|
||||
sandbox_id=sandbox_id,
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
if self._task_details:
|
||||
await self._track_workflow_event(
|
||||
"process_task_workflow_failed",
|
||||
{
|
||||
"task_id": self.task_details.task_id,
|
||||
"error_type": type(e).__name__,
|
||||
"error_message": str(e)[:500],
|
||||
"sandbox_id": sandbox_id,
|
||||
},
|
||||
)
|
||||
|
||||
return ProcessTaskOutput(
|
||||
success=False,
|
||||
task_result=None,
|
||||
error=str(e),
|
||||
sandbox_id=sandbox_id,
|
||||
)
|
||||
|
||||
finally:
|
||||
if personal_api_key_id:
|
||||
await self._cleanup_personal_api_key(personal_api_key_id)
|
||||
if sandbox_id:
|
||||
await self._cleanup_sandbox(sandbox_id)
|
||||
|
||||
async def _get_task_details(self, task_id: str) -> TaskDetails:
|
||||
return await workflow.execute_activity(
|
||||
get_task_details,
|
||||
task_id,
|
||||
start_to_close_timeout=timedelta(seconds=30),
|
||||
retry_policy=RetryPolicy(maximum_attempts=3),
|
||||
)
|
||||
|
||||
async def _get_snapshot_for_repository(self) -> str:
|
||||
check_input = CheckSnapshotExistsForRepositoryInput(
|
||||
github_integration_id=self.task_details.github_integration_id,
|
||||
repository=self.task_details.repository,
|
||||
)
|
||||
|
||||
check_result = await workflow.execute_activity(
|
||||
check_snapshot_exists_for_repository,
|
||||
check_input,
|
||||
start_to_close_timeout=timedelta(seconds=30),
|
||||
retry_policy=RetryPolicy(maximum_attempts=3),
|
||||
)
|
||||
|
||||
if check_result.snapshot_id:
|
||||
return check_result.snapshot_id
|
||||
|
||||
return await self._setup_snapshot_with_repository()
|
||||
|
||||
async def _get_sandbox_for_setup(self) -> str:
|
||||
get_sandbox_input = GetSandboxForSetupInput(
|
||||
github_integration_id=self.task_details.github_integration_id,
|
||||
team_id=self.task_details.team_id,
|
||||
task_id=self.task_details.task_id,
|
||||
distinct_id=self.task_details.distinct_id,
|
||||
)
|
||||
return await workflow.execute_activity(
|
||||
get_sandbox_for_setup,
|
||||
get_sandbox_input,
|
||||
start_to_close_timeout=timedelta(minutes=10),
|
||||
retry_policy=RetryPolicy(maximum_attempts=2),
|
||||
)
|
||||
|
||||
async def _clone_repository_in_sandbox(self, sandbox_id: str) -> None:
|
||||
clone_input = CloneRepositoryInput(
|
||||
sandbox_id=sandbox_id,
|
||||
repository=self.task_details.repository,
|
||||
github_integration_id=self.task_details.github_integration_id,
|
||||
task_id=self.task_details.task_id,
|
||||
distinct_id=self.task_details.distinct_id,
|
||||
)
|
||||
await workflow.execute_activity(
|
||||
clone_repository,
|
||||
clone_input,
|
||||
start_to_close_timeout=timedelta(minutes=10),
|
||||
retry_policy=RetryPolicy(maximum_attempts=2),
|
||||
)
|
||||
|
||||
async def _setup_repository_in_sandbox(self, sandbox_id: str) -> None:
|
||||
setup_repo_input = SetupRepositoryInput(
|
||||
sandbox_id=sandbox_id,
|
||||
repository=self.task_details.repository,
|
||||
task_id=self.task_details.task_id,
|
||||
distinct_id=self.task_details.distinct_id,
|
||||
)
|
||||
await workflow.execute_activity(
|
||||
setup_repository,
|
||||
setup_repo_input,
|
||||
start_to_close_timeout=timedelta(minutes=15),
|
||||
retry_policy=RetryPolicy(maximum_attempts=1),
|
||||
)
|
||||
|
||||
async def _snapshot_sandbox(self, sandbox_id: str) -> str:
|
||||
snapshot_input = CreateSnapshotInput(
|
||||
sandbox_id=sandbox_id,
|
||||
github_integration_id=self.task_details.github_integration_id,
|
||||
team_id=self.task_details.team_id,
|
||||
repository=self.task_details.repository,
|
||||
task_id=self.task_details.task_id,
|
||||
distinct_id=self.task_details.distinct_id,
|
||||
)
|
||||
return await workflow.execute_activity(
|
||||
create_snapshot,
|
||||
snapshot_input,
|
||||
start_to_close_timeout=timedelta(minutes=25),
|
||||
retry_policy=RetryPolicy(maximum_attempts=3),
|
||||
)
|
||||
|
||||
async def _cleanup_sandbox(self, sandbox_id: str) -> None:
|
||||
cleanup_input = CleanupSandboxInput(sandbox_id=sandbox_id)
|
||||
await workflow.execute_activity(
|
||||
cleanup_sandbox,
|
||||
cleanup_input,
|
||||
start_to_close_timeout=timedelta(minutes=5),
|
||||
retry_policy=RetryPolicy(maximum_attempts=3),
|
||||
)
|
||||
|
||||
async def _setup_snapshot_with_repository(self) -> str:
|
||||
setup_sandbox_id = None
|
||||
|
||||
try:
|
||||
setup_sandbox_id = await self._get_sandbox_for_setup()
|
||||
|
||||
await self._clone_repository_in_sandbox(setup_sandbox_id)
|
||||
|
||||
await self._setup_repository_in_sandbox(setup_sandbox_id)
|
||||
|
||||
snapshot_id = await self._snapshot_sandbox(setup_sandbox_id)
|
||||
|
||||
return snapshot_id
|
||||
|
||||
finally:
|
||||
if setup_sandbox_id:
|
||||
await self._cleanup_sandbox(setup_sandbox_id)
|
||||
|
||||
async def _create_sandbox_from_snapshot(self, snapshot_id: str) -> str:
|
||||
create_sandbox_input = CreateSandboxFromSnapshotInput(
|
||||
snapshot_id=snapshot_id,
|
||||
task_id=self.task_details.task_id,
|
||||
distinct_id=self.task_details.distinct_id,
|
||||
)
|
||||
return await workflow.execute_activity(
|
||||
create_sandbox_from_snapshot,
|
||||
create_sandbox_input,
|
||||
start_to_close_timeout=timedelta(minutes=5),
|
||||
retry_policy=RetryPolicy(maximum_attempts=2),
|
||||
)
|
||||
|
||||
async def _inject_github_token(self, sandbox_id: str) -> None:
|
||||
inject_token_input = InjectGitHubTokenInput(
|
||||
sandbox_id=sandbox_id,
|
||||
github_integration_id=self.task_details.github_integration_id,
|
||||
task_id=self.task_details.task_id,
|
||||
distinct_id=self.task_details.distinct_id,
|
||||
)
|
||||
await workflow.execute_activity(
|
||||
inject_github_token,
|
||||
inject_token_input,
|
||||
start_to_close_timeout=timedelta(minutes=5),
|
||||
retry_policy=RetryPolicy(maximum_attempts=3),
|
||||
)
|
||||
|
||||
async def _inject_personal_api_key(self, sandbox_id: str) -> InjectPersonalAPIKeyOutput:
|
||||
inject_key_input = InjectPersonalAPIKeyInput(
|
||||
sandbox_id=sandbox_id,
|
||||
task_id=self.task_details.task_id,
|
||||
distinct_id=self.task_details.distinct_id,
|
||||
)
|
||||
return await workflow.execute_activity(
|
||||
inject_personal_api_key,
|
||||
inject_key_input,
|
||||
start_to_close_timeout=timedelta(minutes=5),
|
||||
retry_policy=RetryPolicy(maximum_attempts=3),
|
||||
)
|
||||
|
||||
async def _cleanup_personal_api_key(self, personal_api_key_id: str) -> None:
|
||||
try:
|
||||
await workflow.execute_activity(
|
||||
cleanup_personal_api_key,
|
||||
personal_api_key_id,
|
||||
start_to_close_timeout=timedelta(minutes=5),
|
||||
retry_policy=RetryPolicy(maximum_attempts=3),
|
||||
)
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to cleanup personal API key {personal_api_key_id}: {e}")
|
||||
|
||||
async def _execute_task_in_sandbox(self, sandbox_id: str) -> ExecuteTaskOutput:
|
||||
execute_input = ExecuteTaskInput(
|
||||
sandbox_id=sandbox_id,
|
||||
task_id=self.task_details.task_id,
|
||||
repository=self.task_details.repository,
|
||||
distinct_id=self.task_details.distinct_id,
|
||||
)
|
||||
return await workflow.execute_activity(
|
||||
execute_task_in_sandbox,
|
||||
execute_input,
|
||||
start_to_close_timeout=timedelta(minutes=30),
|
||||
retry_policy=RetryPolicy(maximum_attempts=1),
|
||||
)
|
||||
|
||||
async def _track_workflow_event(self, event_name: str, properties: dict) -> None:
|
||||
track_input = TrackWorkflowEventInput(
|
||||
event_name=event_name,
|
||||
distinct_id=self.task_details.distinct_id,
|
||||
properties=properties,
|
||||
)
|
||||
await workflow.execute_activity(
|
||||
track_workflow_event,
|
||||
track_input,
|
||||
start_to_close_timeout=timedelta(seconds=10),
|
||||
retry_policy=RetryPolicy(maximum_attempts=1),
|
||||
)
|
||||
@@ -513,7 +513,7 @@ class TestTask(TestCase):
|
||||
self.assertEqual(repo_list[0]["org"], "PostHog")
|
||||
self.assertEqual(repo_list[0]["repo"], "posthog")
|
||||
self.assertEqual(repo_list[0]["integration_id"], integration.id)
|
||||
self.assertEqual(repo_list[0]["full_name"], "PostHog/posthog")
|
||||
self.assertEqual(repo_list[0]["full_name"], "posthog/posthog")
|
||||
|
||||
def test_repository_list_empty(self):
|
||||
task = Task.objects.create(
|
||||
|
||||
Reference in New Issue
Block a user