working phoenix code

This commit is contained in:
christineastoria
2025-12-03 20:23:30 -05:00
parent 5de2df5719
commit 390e3a3d58
8 changed files with 863 additions and 1 deletions
+5
View File
@@ -13,6 +13,11 @@ ARIZE_API_KEY = os.getenv("ARIZE_API_KEY")
ARIZE_SPACE_ID = os.getenv("ARIZE_SPACE_ID")
ARIZE_PROJECT_NAMES = [p.strip() for p in os.getenv("ARIZE_PROJECT_NAMES", "").split(",") if p.strip()]
# Arize: Phoenix configuration
PHOENIX_API_KEY = os.getenv("PHOENIX_API_KEY")
PHOENIX_HOST = os.getenv("PHOENIX_HOST", "https://app.phoenix.arize.com")
PHOENIX_SPACE = os.getenv("PHOENIX_SPACE", "") # e.g., "christine"
# LangSmith configuration
LS_API_KEY = os.getenv("LANGSMITH_API_KEY")
LS_ORG_ID = os.environ["LANGSMITH_ORGANIZATION_ID"]
+14 -1
View File
@@ -1,12 +1,15 @@
from providers.langfuse.main import migrate_langfuse
from providers.arize.main import migrate_arize
from providers.phoenix.main import migrate_phoenix
from config import INCLUDE_MODEL_IN_PROMPTS, NUM_TRACES_TO_REPLAY
from utils.langfuse import lf_get_projects
from utils.arize import arize_get_projects
from utils.phoenix import phoenix_get_projects
AVAILABLE_PROVIDERS = [
"langfuse",
"arize",
"phoenix",
]
## ------------------------------------------------------------
@@ -66,7 +69,17 @@ def migrate(provider: str):
migrate = capture_user_selection("Arize", "projects")
if not migrate:
return
migrate_arize(projects)
migrate_arize(projects)
elif provider == "phoenix":
display_config("Phoenix", "project")
projects = phoenix_get_projects()
if not projects:
print("No projects found in Phoenix.")
return
migrate = capture_user_selection("Phoenix", "projects")
if not migrate:
return
migrate_phoenix(projects)
def prompt_for_provider() -> str:
+162
View File
@@ -0,0 +1,162 @@
import traceback
import json
from utils.phoenix import phoenix_get_datasets, phoenix_get_dataset_examples
from utils.langsmith import ls_create_dataset, ls_upload_examples
def _unpack_dotted_keys(record: dict) -> dict:
"""Unpack keys with dots into nested dicts.
Example: {'output.tool_calls': '...'} -> {'output': {'tool_calls': '...'}}
"""
result = {}
for key, value in record.items():
if '.' in key:
parts = key.split('.')
current = result
for part in parts[:-1]:
if part not in current:
current[part] = {}
elif not isinstance(current[part], dict):
# If it's not a dict, wrap the existing value
current[part] = {"_value": current[part]}
current = current[part]
current[parts[-1]] = value
else:
if key in result and isinstance(result[key], dict):
# Merge with existing dict
if isinstance(value, dict):
result[key].update(value)
else:
result[key]["_value"] = value
else:
result[key] = value
return result
def _try_parse_json(value):
"""Try to parse a JSON string, return original if not JSON."""
if isinstance(value, str):
try:
return json.loads(value)
except (json.JSONDecodeError, ValueError):
pass
return value
def _extract_expected_value(record: dict) -> str:
keys = [
"expected",
"expected_output",
"expectedOutput",
"reference_output",
"referenceOutput",
"output",
]
for k in keys:
if k in record and record.get(k) not in (None, ""):
return record.get(k)
for v in record.values():
if isinstance(v, dict):
for k in keys:
if k in v and v.get(k) not in (None, ""):
return v.get(k)
return ""
def phoenix_example_conversion(records: list) -> list[dict]:
"""Convert Phoenix dataset examples to LangSmith format."""
examples = []
for r in records:
# Handle both dict and object types
if hasattr(r, '__dict__'):
r = vars(r) if not hasattr(r, 'to_dict') else r.to_dict()
if not isinstance(r, dict):
r = {"input": str(r)}
# Unpack dotted keys into nested dicts
r = _unpack_dotted_keys(r)
# Parse any JSON strings in values
for key in list(r.keys()):
r[key] = _try_parse_json(r[key])
if isinstance(r[key], dict):
for k2 in list(r[key].keys()):
r[key][k2] = _try_parse_json(r[key][k2])
# Inputs - handle list of messages
inputs = r.get("input") or r.get("inputs") or {}
if isinstance(inputs, list):
# List of messages like [{"role": "user", "content": "..."}]
inputs = {"messages": inputs}
elif not isinstance(inputs, dict):
inputs = {"input": inputs}
# Metadata - merge from top-level and nested
meta = r.get("metadata") or {}
if isinstance(meta, dict):
# Parse any JSON strings in metadata values
for k in list(meta.keys()):
meta[k] = _try_parse_json(meta[k])
# Expected output - handle list of messages
expected = r.get("output") or r.get("outputs") or _extract_expected_value(r)
if isinstance(expected, list):
# List of messages - extract content from assistant message or join all
contents = []
for msg in expected:
if isinstance(msg, dict) and msg.get("content"):
contents.append(msg["content"])
elif isinstance(msg, str):
contents.append(msg)
expected = "\n".join(contents) if contents else str(expected)
elif isinstance(expected, dict):
expected = expected.get("content") or expected.get("output") or expected
ex = {
"inputs": inputs,
"outputs": {"reference_output": expected},
}
if isinstance(meta, dict) and meta:
ex["metadata"] = meta
examples.append(ex)
return examples
def migrate_datasets(workspace_id: str):
try:
datasets = phoenix_get_datasets()
except Exception as e:
print(f" x failed to fetch Phoenix datasets: {e}")
traceback.print_exc()
return
if not datasets:
print(" - no datasets found in Phoenix")
return
for ds in datasets:
ds_id = ds.get("id")
ds_name = ds.get("name") or f"phoenix-dataset-{ds_id}"
print(f" - migrating dataset: {ds_name}")
try:
ls_ds_id = ls_create_dataset(workspace_id, ds_name)
if ls_ds_id is None:
print(f" • skipped (already exists)")
continue
examples = phoenix_get_dataset_examples(dataset_id=ds_id, dataset_name=ds_name)
if examples:
converted = phoenix_example_conversion(examples)
ls_upload_examples(workspace_id, ls_ds_id, converted)
print(f" • uploaded {len(converted)} examples")
else:
print(f" • no examples found")
except Exception as e:
print(f" x dataset '{ds_name}' failed: {e}")
traceback.print_exc()
+221
View File
@@ -0,0 +1,221 @@
import re
import json
import traceback
from langchain_core.prompts import ChatPromptTemplate
from langchain_openai import ChatOpenAI
from langchain_anthropic import ChatAnthropic
from langchain_core.runnables import RunnableSequence
from langsmith.utils import LangSmithConflictError
from utils.phoenix import phoenix_get_prompts, phoenix_get_prompt
from utils.langsmith import ls_push_prompt
from config import INCLUDE_MODEL_IN_PROMPTS
def string_to_chat_template(template: str) -> ChatPromptTemplate:
"""Detect 'role: content' lines and build a ChatPromptTemplate; fall back."""
lines = [l.strip() for l in template.splitlines() if l.strip()]
pairs: list[tuple[str, str]] = []
for line in lines:
m = re.match(r"^(system|user|assistant|tool):\s*(.*)$", line, re.I)
if m:
pairs.append((m.group(1).lower(), m.group(2)))
if pairs:
return ChatPromptTemplate.from_messages(pairs)
return ChatPromptTemplate.from_template(template)
def detect_model_provider(model_name: str) -> str:
"""Detect if a model is from OpenAI, Anthropic, or other provider."""
if not model_name:
return "openai"
model_lower = model_name.lower()
if any(x in model_lower for x in ["claude", "anthropic"]):
return "anthropic"
if any(
x in model_lower
for x in ["gpt", "openai", "o1", "text-davinci", "text-curie", "text-babbage", "text-ada"]
):
return "openai"
return "openai"
def get_model_instance(model_name: str, model_params: dict = None):
"""Get the appropriate LangChain model instance based on the model name."""
provider = detect_model_provider(model_name)
params = dict(model_params or {})
# Handle Phoenix nested invocation_parameters format:
# {'type': 'openai', 'openai': {'temperature': 1.0}}
if 'type' in params:
param_type = params.pop('type', None)
# Extract actual params from nested provider key
if param_type and param_type in params:
nested_params = params.pop(param_type, {})
params.update(nested_params)
# Also try 'openai' or 'anthropic' keys
for key in ['openai', 'anthropic']:
if key in params:
nested_params = params.pop(key, {})
params.update(nested_params)
# Remove invalid keys
params.pop("model", None)
params.pop("supported_languages", None)
if provider == "anthropic":
return ChatAnthropic(model=model_name, **params)
else:
return ChatOpenAI(model=model_name, **params)
def phoenix_prompt_conversion(phoenix_prompt, prompt_info: dict = None) -> dict:
"""
Map Phoenix prompt → dict ready for push_langsmith_prompt().
Args:
phoenix_prompt: PromptVersion object or dict
prompt_info: Original prompt info dict with name/description from list
"""
prompt_info = prompt_info or {}
# Extract data from PromptVersion object's internal attributes
d = phoenix_prompt.__dict__ if hasattr(phoenix_prompt, '__dict__') else {}
# Get template/messages
template_data = d.get('_template', {})
messages_list = template_data.get('messages', []) if isinstance(template_data, dict) else []
# Extract model info
model = d.get('_model_name', '')
model_provider = d.get('_model_provider', '')
template_format = d.get('_template_format', '') # MUSTACHE, JINJA, etc.
invocation_params = d.get('_invocation_parameters', {}) or {}
# Description from PromptVersion or prompt_info
description = d.get('_description', '') or prompt_info.get("description", "")
# Convert messages to template string
# Phoenix messages have nested content: [{"type": "text", "text": "..."}]
template = ""
if messages_list:
parts = []
for msg in messages_list:
if isinstance(msg, dict):
role = msg.get('role', 'user')
content_parts = msg.get('content', [])
# Handle nested content array
if isinstance(content_parts, list):
text_parts = []
for cp in content_parts:
if isinstance(cp, dict) and cp.get('type') == 'text':
text_parts.append(cp.get('text', ''))
elif isinstance(cp, str):
text_parts.append(cp)
content = ''.join(text_parts)
elif isinstance(content_parts, str):
content = content_parts
else:
content = str(content_parts)
parts.append(f"{role}: {content}")
template = "\n\n".join(parts)
# Build prompt_dict
prompt_dict = {
"id": d.get('_id') or prompt_info.get("id"),
"name": prompt_info.get("name"),
"description": description,
}
# Normalize mustache/jinja variants to {var}
tpl = str(template)
tpl = re.sub(r"\{\{\{(\w+)\}\}\}", r"{\1}", tpl)
tpl = re.sub(r"\{\{(\w+)\}\}", r"{\1}", tpl)
tpl = re.sub(r"\{\{\s*(\w+)\s*\}\}", r"{\1}", tpl)
# Detect variables
var_pat = re.findall(r"\{(\w+)\}", tpl)
# Name normalization
raw_name = prompt_dict.get("name") or f"prompt-{prompt_dict.get('id', '')}"
name = str(raw_name).lower().replace(" ", "-")
out: dict = {
"name": name,
"description": description,
"prompt_template": tpl,
"input_variables": list(sorted(set(var_pat))),
"metadata": {
"phoenix_id": prompt_dict.get("id"),
"model": model,
"model_provider": model_provider,
"model_params": invocation_params,
"template_format": template_format,
"original_source": "phoenix",
},
}
return out
def prompt_dict_to_obj(prompt_dict: dict, include_model: bool = True) -> object:
chat_prompt = string_to_chat_template(prompt_dict["prompt_template"])
model_name = prompt_dict["metadata"].get("model")
model_params = prompt_dict["metadata"].get("model_params", {})
if model_name and include_model:
try:
model = get_model_instance(model_name, model_params)
obj = RunnableSequence(chat_prompt, model)
provider = detect_model_provider(model_name)
print(f" ... using {provider} model: {model_name}")
except Exception as e:
print(f" ! failed to create model {model_name}, using prompt only: {e}")
obj = chat_prompt
else:
if model_name and not include_model:
print(f" • prompt only (model {model_name} excluded by flag)")
obj = chat_prompt
return obj
#NOTE: this does not support versioning or preserving history of the prompts, but this can be done in LangSmith
def migrate_prompts(workspace_id: str):
try:
prompts = phoenix_get_prompts()
except Exception as e:
print(f" x failed to fetch Phoenix prompts: {e}")
traceback.print_exc()
return
if prompts:
print(f" - migrating {len(prompts)} prompt(s)…")
for prompt_info in prompts:
# prompt_info is dict with {id, name, description} from REST API list
prompt_name = prompt_info.get("name")
try:
# Get full prompt details (PromptVersion object with template)
full_prompt = None
if prompt_name:
try:
full_prompt = phoenix_get_prompt(prompt_name)
except Exception as e:
print(f" ! could not get full prompt: {e}")
# Pass both: full_prompt (has template) and prompt_info (has name/description)
ls_p_dict = phoenix_prompt_conversion(full_prompt or prompt_info, prompt_info=prompt_info)
ls_p_obj = prompt_dict_to_obj(ls_p_dict, include_model=INCLUDE_MODEL_IN_PROMPTS)
url = ls_push_prompt(ls_p_dict["name"], ls_p_dict["description"], ls_p_obj, workspace_id)
print(f"{prompt_name}{url}")
except Exception as e:
if isinstance(e, LangSmithConflictError):
print(f" • prompt '{prompt_name}' already exists, skipping...")
continue
print(f" x prompt '{prompt_name}' failed: {e}")
traceback.print_exc()
else:
print(" (no prompts)")
+293
View File
@@ -0,0 +1,293 @@
from __future__ import annotations
import traceback
import uuid
import json
import datetime as dt
from config import NUM_TRACES_TO_REPLAY
from utils.phoenix import phoenix_get_traces
from utils.langsmith import ls_replay_runs_sdk
def safe_isoformat(dt_obj):
if dt_obj is None:
return None
if isinstance(dt_obj, str):
return dt_obj
if not isinstance(dt_obj, dt.datetime):
return None
if dt_obj.tzinfo is None:
dt_obj = dt_obj.replace(tzinfo=dt.timezone.utc)
s = dt_obj.isoformat(timespec='milliseconds')
return s[:-6] + 'Z' if s.endswith('+00:00') else s
def _compact_ts(ts_val):
if ts_val is None:
return ""
if isinstance(ts_val, str):
try:
s = ts_val[:-1] + '+00:00' if ts_val.endswith('Z') else ts_val
dt_obj = dt.datetime.fromisoformat(s)
except Exception:
return ""
else:
dt_obj = ts_val
if dt_obj.tzinfo is None:
dt_obj = dt_obj.replace(tzinfo=dt.timezone.utc)
return dt_obj.strftime('%Y%m%dT%H%M%S') + f"{dt_obj.microsecond:06d}" + 'Z'
def _span_kind_to_run_type(span_kind: str) -> str:
if not span_kind:
return "chain"
kind_lower = str(span_kind).lower()
mapping = {
"llm": "llm",
"chain": "chain",
"agent": "chain",
"tool": "tool",
"retriever": "retriever",
"embedding": "llm",
"reranker": "chain",
"guardrail": "chain",
}
return mapping.get(kind_lower, "chain")
def _parse_value(value, default_key: str) -> dict:
"""Parse a value into a dict, handling JSON strings, dicts, and lists."""
if value is None:
return {}
if isinstance(value, str):
try:
parsed = json.loads(value)
if isinstance(parsed, dict):
return parsed
elif isinstance(parsed, list):
return {default_key: parsed}
else:
return {default_key: parsed}
except json.JSONDecodeError:
return {default_key: value}
elif isinstance(value, dict):
return value
elif isinstance(value, list):
return {default_key: value}
else:
return {default_key: str(value)}
def _ensure_end_times(runs: list[dict]):
for r in runs:
if isinstance(r, dict) and r.get("end_time") is None:
r["end_time"] = r.get("start_time")
def _children_map(runs: list[dict]) -> dict:
id_to_run = {r["id"]: r for r in runs if isinstance(r, dict)}
cmap: dict[str, list[str]] = {}
for r in runs:
if not isinstance(r, dict):
continue
pid = r.get("parent_run_id")
if pid:
cmap.setdefault(pid, []).append(r["id"])
for pid, kids in list(cmap.items()):
kids.sort(key=lambda k: id_to_run.get(k, {}).get('start_time') or '')
cmap[pid] = kids
return cmap
def _assign_dotted_order(runs: list[dict], root_id: str):
id_to_run = {r["id"]: r for r in runs if isinstance(r, dict)}
cmap = _children_map(runs)
def assign(run_id: str, parent_dotted: str | None):
run = id_to_run.get(run_id)
if not run:
return
ts = run.get('start_time') or run.get('end_time') or id_to_run.get(root_id, {}).get('start_time')
seg = _compact_ts(ts) + run_id
# Ensure seg has timestamp prefix - use current time if missing
if not seg or seg == run_id:
seg = dt.datetime.now(dt.timezone.utc).strftime('%Y%m%dT%H%M%S') + "000000Z" + run_id
dotted = seg if not parent_dotted else f"{parent_dotted}.{seg}"
run["dotted_order"] = dotted
for kid in cmap.get(run_id, []):
assign(kid, dotted)
assign(root_id, None)
# Ensure ALL runs have dotted_order (catch any orphans)
for run in runs:
if not run.get("dotted_order"):
ts = run.get('start_time') or run.get('end_time')
seg = _compact_ts(ts) + run["id"]
if not seg or seg == run["id"]:
seg = dt.datetime.now(dt.timezone.utc).strftime('%Y%m%dT%H%M%S') + "000000Z" + run["id"]
run["dotted_order"] = seg
def _get_attr(obj, name, default=None):
"""Get attribute from object or dict."""
if isinstance(obj, dict):
return obj.get(name, default)
return getattr(obj, name, default)
# Note: This does not preserve the full original object - it strips out the messages specifically.
# LangSmith does accept a variety of message formats that include additional metadata - this conversion is for simplicity.
def map_phoenix_traces_to_langsmith(traces_dict: dict) -> list[dict]:
"""Transform Phoenix traces (grouped by trace_id) to LangSmith runs format.
Note that there
Args:
traces_dict: dict[trace_id -> list of span dicts]
Phoenix DataFrame columns:
- context.trace_id, context.span_id, parent_id
- name, span_kind, start_time, end_time
- attributes.input.value, attributes.output.value
- attributes.llm.model_name, attributes.openinference.span.kind
"""
if not traces_dict:
return []
runs = []
for orig_trace_id, trace_spans in traces_dict.items():
# Create new IDs for LangSmith
span_id_mapping = {}
for span in trace_spans:
orig_span_id = str(span.get('context.span_id', '') or '')
if orig_span_id:
span_id_mapping[orig_span_id] = str(uuid.uuid4())
root_run_id = None
trace_runs = []
for span in trace_spans:
orig_span_id = str(span.get('context.span_id', '') or '')
run_id = span_id_mapping.get(orig_span_id, str(uuid.uuid4()))
orig_parent_id = str(span.get('parent_id', '') or '')
parent_run_id = span_id_mapping.get(orig_parent_id) if orig_parent_id else None
# First span without parent is the root
if parent_run_id is None and root_run_id is None:
root_run_id = run_id
# Get span kind from openinference attribute or span_kind column
span_kind = str(span.get('attributes.openinference.span.kind', '') or span.get('span_kind', '') or '')
run_type = _span_kind_to_run_type(span_kind)
# Parse inputs - check multiple sources
inputs = {}
tool_params = span.get('attributes.tool.parameters')
input_value = span.get('attributes.input.value')
llm_input_messages = span.get('attributes.llm.input_messages')
if tool_params:
# Tool spans - use parameters
inputs = _parse_value(tool_params, "parameters")
elif input_value:
# General input value (may contain messages for playground)
inputs = _parse_value(input_value, "input")
elif llm_input_messages:
# LLM input messages fallback
inputs = _parse_value(llm_input_messages, "messages")
# Parse outputs - check multiple sources
outputs = {}
output_value = span.get('attributes.output.value')
llm_output_messages = span.get('attributes.llm.output_messages')
if output_value:
outputs = _parse_value(output_value, "output")
elif llm_output_messages:
outputs = _parse_value(llm_output_messages, "messages")
# Build metadata from attributes
metadata = {}
model_name = span.get('attributes.llm.model_name')
if model_name and isinstance(model_name, str):
metadata["ls_model_name"] = model_name
run = {
"id": run_id,
"trace_id": root_run_id or run_id,
"name": span.get('name') or "span",
"run_type": run_type,
"parent_run_id": parent_run_id,
"inputs": inputs,
"outputs": outputs,
"start_time": safe_isoformat(span.get('start_time')),
"end_time": safe_isoformat(span.get('end_time')),
"metadata": metadata,
"tags": [],
}
trace_runs.append(run)
# Fix trace_id for all runs in this trace
if root_run_id:
for run in trace_runs:
run["trace_id"] = root_run_id
_ensure_end_times(trace_runs)
_assign_dotted_order(trace_runs, root_run_id)
runs.extend(trace_runs)
return runs
def migrate_traces(workspace_id: str, project_name: str):
print(f" - migrating traces...")
total_traces_fetched = 0
total_spans_fetched = 0
total_runs_uploaded = 0
failed_fetch = 0
failed_transform = 0
try:
limit = NUM_TRACES_TO_REPLAY if NUM_TRACES_TO_REPLAY and NUM_TRACES_TO_REPLAY > 0 else 100
traces_dict = phoenix_get_traces(project_name=project_name, limit=limit)
except Exception as e:
print(f" x failed to fetch traces from Phoenix: {e}")
traceback.print_exc()
failed_fetch = 1
print(f" • Processed traces: 0")
print(f" • Failed fetching: {failed_fetch}")
print(f" • Failed transforming: {failed_transform}")
return
if not traces_dict:
print(" • no traces found")
print(f" • Processed traces: 0")
print(f" • Failed fetching: {failed_fetch}")
print(f" • Failed transforming: {failed_transform}")
return
total_traces_fetched = len(traces_dict)
total_spans_fetched = sum(len(spans) for spans in traces_dict.values())
print(f" • fetched {total_traces_fetched} traces ({total_spans_fetched} spans) from Phoenix")
try:
runs = map_phoenix_traces_to_langsmith(traces_dict)
if not runs:
print(" • no runs to upload after transformation")
failed_transform = total_traces_fetched
else:
ls_replay_runs_sdk(workspace_id, runs, project_name=project_name)
total_runs_uploaded = len(runs)
unique_traces = len(set(r["trace_id"] for r in runs))
print(f" • uploaded {total_runs_uploaded} spans ({unique_traces} traces) to project '{project_name}'")
except Exception as e:
print(f" x failed to transform/upload traces: {e}")
traceback.print_exc()
failed_transform = total_traces_fetched
print(f" • Processed traces: {total_traces_fetched}")
print(f" • Failed fetching: {failed_fetch}")
print(f" • Failed transforming: {failed_transform}")
+29
View File
@@ -0,0 +1,29 @@
from utils.langsmith import ls_get_or_create_workspace
from config import LS_WORKSPACE_ID
from providers.phoenix.data.prompts import migrate_prompts
from providers.phoenix.data.datasets import migrate_datasets
from providers.phoenix.data.traces import migrate_traces
def migrate_phoenix(projects: list[dict]):
"""Migrate data from Phoenix to LangSmith."""
for proj in projects:
project_name = proj.get("name")
print(f"\n- Project: {project_name}")
if LS_WORKSPACE_ID:
ws_id = LS_WORKSPACE_ID
print(f" + using workspace: {ws_id}")
else:
ws = ls_get_or_create_workspace(project_name)
ws_id = ws["id"]
print(f" + workspace id: {ws_id}")
migrate_prompts(ws_id)
migrate_datasets(ws_id)
migrate_traces(ws_id, project_name)
print("\n+ Migration complete.")
+1
View File
@@ -9,4 +9,5 @@ langchain-anthropic
requests
tqdm
arize
arize-phoenix-client
pandas
+138
View File
@@ -0,0 +1,138 @@
import os
import requests
from dotenv import load_dotenv
import phoenix as px
from phoenix.client import Client
load_dotenv()
PHOENIX_API_KEY = os.getenv("PHOENIX_API_KEY")
PHOENIX_SPACE = os.getenv("PHOENIX_SPACE")
PHOENIX_BASE = "https://app.phoenix.arize.com"
# Build the base URL with space if provided
PHOENIX_BASE_URL = f"{PHOENIX_BASE}/s/{PHOENIX_SPACE}" if PHOENIX_SPACE else PHOENIX_BASE
# Set environment variables for px.Client() authentication
os.environ["PHOENIX_CLIENT_HEADERS"] = f"api_key={PHOENIX_API_KEY}"
os.environ["PHOENIX_COLLECTOR_ENDPOINT"] = PHOENIX_BASE_URL
def get_phoenix_client() -> Client:
"""Get an authenticated Phoenix client (for datasets, prompts, projects)."""
return Client(
base_url=PHOENIX_BASE_URL,
api_key=PHOENIX_API_KEY,
)
def get_px_client():
"""Get a px.Client() for spans/traces (has get_spans_dataframe)."""
return px.Client()
def phoenix_get_projects() -> list[dict]:
"""List Phoenix projects."""
client = get_phoenix_client()
projects = client.projects.list()
return [{"name": p["name"]} for p in projects]
def phoenix_get_traces(project_name: str, limit: int = 100) -> dict:
"""Fetch traces from a Phoenix project, grouped by trace_id.
Returns: dict[trace_id -> list of span dicts]
"""
client = get_phoenix_client()
# Get all spans as DataFrame
spans_df = client.spans.get_spans_dataframe(project_name=project_name)
if spans_df is None or spans_df.empty:
return {}
# Group spans into traces
traces = {}
for trace_id, group in spans_df.groupby("context.trace_id"):
sorted_group = group.sort_values("start_time")
traces[trace_id] = sorted_group.to_dict(orient='records')
# Limit to N traces
if len(traces) >= limit:
break
return traces
def phoenix_get_datasets() -> list[dict]:
"""List datasets in Phoenix."""
client = get_phoenix_client()
datasets = client.datasets.list()
return [{"id": d["id"], "name": d["name"], "description": d.get("description", "")} for d in datasets]
def phoenix_get_dataset_examples(dataset_id: str = None, dataset_name: str = None) -> list[dict]:
"""Fetch examples from a Phoenix dataset."""
client = get_phoenix_client()
if dataset_name:
dataset = client.datasets.get_dataset(dataset=dataset_name)
else:
return []
# Get examples from _examples_data.examples attribute
if hasattr(dataset, '_examples_data') and dataset._examples_data is not None:
examples_data = dataset._examples_data
if hasattr(examples_data, 'examples') and examples_data.examples is not None:
examples = examples_data.examples
elif isinstance(examples_data, dict) and 'examples' in examples_data:
examples = examples_data['examples']
else:
return []
if hasattr(examples, 'to_dict'):
return examples.to_dict(orient='records')
if isinstance(examples, list):
return examples
return list(examples)
return []
def phoenix_get_prompts() -> list[dict]:
"""List prompts in Phoenix using REST API. Phoenix Client does not support listing prompts."""
headers = {
"Authorization": f"Bearer {PHOENIX_API_KEY}",
"api_key": PHOENIX_API_KEY,
"x-api-key": PHOENIX_API_KEY,
"Content-Type": "application/json"
}
try:
resp = requests.get(f"{PHOENIX_BASE_URL}/v1/prompts", headers=headers)
if resp.status_code != 200:
print(f" x Failed to fetch prompts: {resp.status_code} {resp.text}")
return []
data = resp.json()
prompts = data.get('data', []) if isinstance(data, dict) else data
result = []
for p in prompts:
if isinstance(p, dict):
result.append({
"id": p.get("id", ""),
"name": p.get("name", ""),
"description": p.get("description", "") or ""
})
return result
except Exception as e:
print(f" x Error fetching prompts via REST: {e}")
return []
def phoenix_get_prompt(prompt_name: str):
"""Get a specific prompt by name."""
client = get_phoenix_client()
return client.prompts.get(prompt_identifier=prompt_name)