mirror of
https://github.com/run-llama/llama_viz.git
synced 2026-07-01 21:24:01 -04:00
reorganize code
This commit is contained in:
@@ -4,7 +4,7 @@ from typing import Dict, List, Type
|
||||
import dash_bootstrap_components as dbc
|
||||
import plotly.graph_objs as go
|
||||
from dash import dash_table, dcc, html
|
||||
from dash.dependencies import Component
|
||||
from dash.development.base_component import Component
|
||||
from pydantic import BaseModel, HttpUrl
|
||||
|
||||
from .utils import MissingType
|
||||
|
||||
+119
-1
@@ -1,7 +1,12 @@
|
||||
from typing import Any, Dict, List
|
||||
import datetime
|
||||
import json
|
||||
from typing import Any, Dict, List, Type
|
||||
|
||||
import dash_bootstrap_components as dbc
|
||||
import pandas as pd
|
||||
from llama_index.core.workflow import StopEvent, Workflow
|
||||
from pydantic import BaseModel
|
||||
from pydantic.networks import HttpUrl
|
||||
|
||||
THEMES = {
|
||||
"bootstrap": dbc.themes.BOOTSTRAP,
|
||||
@@ -63,3 +68,116 @@ def get_external_stylesheets(theme_name: str) -> List[str | Dict[str, Any]]:
|
||||
if stylesheet is None:
|
||||
raise ValueError(f"Unknown theme: {theme_name}")
|
||||
return [stylesheet]
|
||||
|
||||
|
||||
def parse_input_value(value: Any, type_hint: Type) -> Any:
|
||||
"""
|
||||
Parse the input value based on the expected type.
|
||||
|
||||
Args:
|
||||
value: The raw input value from the dash component
|
||||
type_hint: The expected type
|
||||
|
||||
Returns:
|
||||
The parsed value
|
||||
"""
|
||||
if value is None or value == "":
|
||||
if type_hint is bool:
|
||||
return False
|
||||
return None
|
||||
|
||||
if type_hint is str:
|
||||
return str(value)
|
||||
elif type_hint is int:
|
||||
try:
|
||||
return int(value)
|
||||
except (ValueError, TypeError):
|
||||
return 0
|
||||
elif type_hint is float:
|
||||
try:
|
||||
return float(value)
|
||||
except (ValueError, TypeError):
|
||||
return 0.0
|
||||
elif type_hint is bool:
|
||||
return bool(value)
|
||||
elif type_hint is datetime.date:
|
||||
if isinstance(value, str):
|
||||
try:
|
||||
return datetime.datetime.strptime(value, "%Y-%m-%d").date()
|
||||
except ValueError:
|
||||
return datetime.date.today()
|
||||
return value
|
||||
elif (
|
||||
type_hint is list
|
||||
or type_hint is List
|
||||
or hasattr(type_hint, "__origin__")
|
||||
and type_hint.__origin__ is list
|
||||
):
|
||||
try:
|
||||
return json.loads(value)
|
||||
except (json.JSONDecodeError, TypeError):
|
||||
return []
|
||||
elif (
|
||||
type_hint is dict
|
||||
or type_hint is Dict
|
||||
or hasattr(type_hint, "__origin__")
|
||||
and type_hint.__origin__ is dict
|
||||
):
|
||||
try:
|
||||
return json.loads(value)
|
||||
except (json.JSONDecodeError, TypeError):
|
||||
return {}
|
||||
elif issubclass(type_hint, BaseModel):
|
||||
try:
|
||||
return type_hint.parse_raw(value)
|
||||
except Exception:
|
||||
return None
|
||||
else:
|
||||
# For unknown types, return as is
|
||||
return value
|
||||
|
||||
|
||||
def format_output_value(value: Any, type_hint: Type) -> Any:
|
||||
"""
|
||||
Format the output value based on the component type.
|
||||
|
||||
Args:
|
||||
value: The raw output value from the workflow
|
||||
type_hint: The expected type
|
||||
|
||||
Returns:
|
||||
The formatted value appropriate for the dash component
|
||||
"""
|
||||
if value is None:
|
||||
return "" if type_hint is str else None
|
||||
|
||||
if type_hint is str or type_hint is int or type_hint is float or type_hint is bool:
|
||||
return str(value)
|
||||
elif type_hint is HttpUrl or type_hint.__name__ == "HttpUrl":
|
||||
return str(value)
|
||||
elif type_hint is pd.DataFrame:
|
||||
if isinstance(value, pd.DataFrame):
|
||||
return value.to_dict("records")
|
||||
return []
|
||||
elif (
|
||||
type_hint.__name__ == "Figure"
|
||||
or hasattr(type_hint, "__name__")
|
||||
and "Figure" in type_hint.__name__
|
||||
):
|
||||
return value
|
||||
elif isinstance(value, (dict, list)) or type_hint is dict or type_hint is list:
|
||||
try:
|
||||
return json.dumps(value, indent=2, default=str)
|
||||
except Exception:
|
||||
return str(value)
|
||||
else:
|
||||
# For complex objects, try JSON serialization
|
||||
try:
|
||||
if hasattr(value, "json"):
|
||||
return value.json(indent=2)
|
||||
elif hasattr(value, "model_dump_json"):
|
||||
return value.model_dump_json(indent=2)
|
||||
else:
|
||||
return json.dumps(value, indent=2, default=str)
|
||||
except Exception:
|
||||
return str(value)
|
||||
|
||||
+12
-127
@@ -1,22 +1,24 @@
|
||||
import asyncio
|
||||
import datetime
|
||||
import json
|
||||
from typing import Any, Dict, List, Type
|
||||
|
||||
import dash
|
||||
import dash_bootstrap_components as dbc
|
||||
import diskcache
|
||||
import pandas as pd
|
||||
from dash import Dash, DiskcacheManager, Input, Output, State, html, set_props
|
||||
from dash.dependencies import Component
|
||||
from dash.development.base_component import Component
|
||||
from dash.exceptions import PreventUpdate
|
||||
from llama_index.core import __version__ as llama_index_version
|
||||
from llama_index.core.workflow import Context, StopEvent, Workflow
|
||||
from llama_index.core.workflow.events import HumanResponseEvent, InputRequiredEvent
|
||||
from pydantic import BaseModel, HttpUrl
|
||||
|
||||
from .components import get_input_component, get_output_component
|
||||
from .utils import get_external_stylesheets, get_workflow_inputs, get_workflow_outputs
|
||||
from .utils import (
|
||||
format_output_value,
|
||||
get_external_stylesheets,
|
||||
get_workflow_inputs,
|
||||
get_workflow_outputs,
|
||||
parse_input_value,
|
||||
)
|
||||
|
||||
|
||||
class Viz:
|
||||
@@ -193,122 +195,6 @@ class Viz:
|
||||
className="p-5",
|
||||
)
|
||||
|
||||
def _parse_input_value(self, value: Any, type_hint: Type) -> Any:
|
||||
"""
|
||||
Parse the input value based on the expected type.
|
||||
|
||||
Args:
|
||||
value: The raw input value from the dash component
|
||||
type_hint: The expected type
|
||||
|
||||
Returns:
|
||||
The parsed value
|
||||
"""
|
||||
if value is None or value == "":
|
||||
if type_hint is bool:
|
||||
return False
|
||||
return None
|
||||
|
||||
if type_hint is str:
|
||||
return str(value)
|
||||
elif type_hint is int:
|
||||
try:
|
||||
return int(value)
|
||||
except (ValueError, TypeError):
|
||||
return 0
|
||||
elif type_hint is float:
|
||||
try:
|
||||
return float(value)
|
||||
except (ValueError, TypeError):
|
||||
return 0.0
|
||||
elif type_hint is bool:
|
||||
return bool(value)
|
||||
elif type_hint is datetime.date:
|
||||
if isinstance(value, str):
|
||||
try:
|
||||
return datetime.datetime.strptime(value, "%Y-%m-%d").date()
|
||||
except ValueError:
|
||||
return datetime.date.today()
|
||||
return value
|
||||
elif (
|
||||
type_hint is list
|
||||
or type_hint is List
|
||||
or hasattr(type_hint, "__origin__")
|
||||
and type_hint.__origin__ is list
|
||||
):
|
||||
try:
|
||||
return json.loads(value)
|
||||
except (json.JSONDecodeError, TypeError):
|
||||
return []
|
||||
elif (
|
||||
type_hint is dict
|
||||
or type_hint is Dict
|
||||
or hasattr(type_hint, "__origin__")
|
||||
and type_hint.__origin__ is dict
|
||||
):
|
||||
try:
|
||||
return json.loads(value)
|
||||
except (json.JSONDecodeError, TypeError):
|
||||
return {}
|
||||
elif issubclass(type_hint, BaseModel):
|
||||
try:
|
||||
return type_hint.parse_raw(value)
|
||||
except Exception:
|
||||
return None
|
||||
else:
|
||||
# For unknown types, return as is
|
||||
return value
|
||||
|
||||
def _format_output_value(self, value: Any, type_hint: Type) -> Any:
|
||||
"""
|
||||
Format the output value based on the component type.
|
||||
|
||||
Args:
|
||||
value: The raw output value from the workflow
|
||||
type_hint: The expected type
|
||||
|
||||
Returns:
|
||||
The formatted value appropriate for the dash component
|
||||
"""
|
||||
if value is None:
|
||||
return "" if type_hint is str else None
|
||||
|
||||
if (
|
||||
type_hint is str
|
||||
or type_hint is int
|
||||
or type_hint is float
|
||||
or type_hint is bool
|
||||
):
|
||||
return str(value)
|
||||
elif type_hint is HttpUrl or type_hint.__name__ == "HttpUrl":
|
||||
return str(value)
|
||||
elif type_hint is pd.DataFrame:
|
||||
if isinstance(value, pd.DataFrame):
|
||||
return value.to_dict("records")
|
||||
return []
|
||||
elif (
|
||||
type_hint.__name__ == "Figure"
|
||||
or hasattr(type_hint, "__name__")
|
||||
and "Figure" in type_hint.__name__
|
||||
):
|
||||
return value
|
||||
elif isinstance(value, (dict, list)) or type_hint is dict or type_hint is list:
|
||||
try:
|
||||
return json.dumps(value, indent=2, default=str)
|
||||
except Exception:
|
||||
return str(value)
|
||||
else:
|
||||
# For complex objects, try JSON serialization
|
||||
try:
|
||||
if hasattr(value, "json"):
|
||||
return value.json(indent=2)
|
||||
elif hasattr(value, "model_dump_json"):
|
||||
return value.model_dump_json(indent=2)
|
||||
else:
|
||||
return json.dumps(value, indent=2, default=str)
|
||||
except Exception:
|
||||
return str(value)
|
||||
|
||||
def _create_callback(self):
|
||||
"""Create the main callback for the workflow"""
|
||||
|
||||
@@ -344,7 +230,7 @@ class Viz:
|
||||
# Parse input values
|
||||
run_params = {}
|
||||
for i, (input_name, input_type) in enumerate(self._inputs.items()):
|
||||
parsed_value = self._parse_input_value(args[i], input_type)
|
||||
parsed_value = parse_input_value(args[i], input_type)
|
||||
if parsed_value is not None: # Only add non-None values
|
||||
run_params[input_name] = parsed_value
|
||||
|
||||
@@ -355,6 +241,7 @@ class Viz:
|
||||
run_params["ctx"] = self._ctx
|
||||
handler = self._workflow.run(**run_params)
|
||||
if modal_input_value:
|
||||
assert handler._ctx
|
||||
handler._ctx.send_event(
|
||||
HumanResponseEvent(response=modal_input_value)
|
||||
)
|
||||
@@ -381,7 +268,7 @@ class Viz:
|
||||
if len(self._outputs) == 1 and "result" in self._outputs:
|
||||
# Special case for simple workflows with just a "result" output
|
||||
output_values.append(
|
||||
self._format_output_value(result, self._outputs["result"])
|
||||
format_output_value(result, self._outputs["result"])
|
||||
)
|
||||
else:
|
||||
# For more complex workflows with multiple outputs
|
||||
@@ -393,9 +280,7 @@ class Viz:
|
||||
else:
|
||||
output_value = result # Use the whole result if we can't find a specific attribute
|
||||
|
||||
output_values.append(
|
||||
self._format_output_value(output_value, output_type)
|
||||
)
|
||||
output_values.append(format_output_value(output_value, output_type))
|
||||
|
||||
if result is None:
|
||||
# Workflow didn't finish, show the modal
|
||||
|
||||
Reference in New Issue
Block a user