refactor: migrate embedding config to screen based navigation

- Replace multiple separate apps with single app using screen navigation
- (WIP): Introduce screen hierarchy (BaseScreen, ConfigurationScreen..)
- Add EmbeddingConfig dataclass
- Centralize keyboard bindings
- Simplify pipeline creation flow in create_llama_cloud_index.py to for separation of concerns

This purpose of this change is to improve the UI flow and code organization (wip) by using Textuals screen mgmt system instead of multiple separate apps
This commit is contained in:
Nick Galluzzo
2025-07-10 13:15:02 +07:00
parent 27a7fa7ce6
commit 840ca05331
2 changed files with 208 additions and 152 deletions
+175 -32
View File
@@ -1,46 +1,189 @@
import os
from textual import on
from textual.app import App, ComposeResult
from textual.binding import Binding
from textual.containers import Container
from textual.screen import Screen
from textual.widgets import Select, Input, Label, Footer
from dataclasses import dataclass
from typing import Optional
class BaseEmbeddingApp(App):
CSS_PATH = "stylesheets/base.tcss"
COMMON_BINDINGS = [
Binding("ctrl+q", "quit", "Exit", key_display="ctrl+q"),
Binding("ctrl+d", "toggle_dark", "t", key_display="ctrl+d"),
]
SUBMIT_BINDING = Binding("ctrl+s", "submit", "Submit", key_display="ctrl+s")
BINDINGS = [
Binding(
key="ctrl+q", action="quit", description="Submit", key_display="ctrl+q"
),
Binding(
key="ctrl+d",
action="toggle_dark",
description="Toggle Dark Theme",
key_display="ctrl+d",
),
]
def __init__(self):
super().__init__()
self.form_data = {}
@dataclass
class EmbeddingConfig:
provider: str
api_key: Optional[str] = None
model: Optional[str] = None
region: Optional[str] = None
key_id: Optional[str] = None
class BaseScreen(Screen):
BINDINGS = COMMON_BINDINGS
def action_toggle_dark(self) -> None:
self.theme = (
"textual-dark" if self.theme == "textual-light" else "textual-light"
self.app.theme = (
"textual-dark" if self.app.theme != "textual-light" else "textual-light"
)
def action_quit(self) -> None:
self.form_data = {
input_id: self.query_one(f"#{input_id}", Input).value
for input_id in self.get_input_ids()
}
self.exit()
self.app.exit()
def compose(self) -> ComposeResult:
yield Container(
Label(self.get_title(), classes="form-title"),
*self.get_form_elements(),
Footer(),
classes="form-container",
)
def get_title(self) -> str:
return "Base Screen"
def get_form_elements(self) -> list:
return []
class InitialScreen(BaseScreen):
def get_title(self) -> str:
return "How do you wish to proceed?"
def get_form_elements(self) -> list:
return [
Select(
options=[
("With Default Settings", "With Default Settings"),
("With Custom Settings", "With Custom Settings"),
],
prompt="Please select one of the following",
id="setup_type",
classes="form-control",
)
]
@on(Select.Changed, "#setup_type")
def handle_selection(self, event: Select.Changed) -> None:
app = self.app
if isinstance(app, EmbeddingSetupApp):
app.config.setup_type = event.value
self.handle_next()
def handle_next(self) -> None:
app = self.app
if isinstance(app, EmbeddingSetupApp):
if app.config.setup_type == "With Default Settings":
app.handle_default_setup()
else:
app.push_screen(ProviderSelectScreen(app.config))
class ProviderSelectScreen(BaseScreen):
def get_title(self) -> str:
return "Select an embedding provider"
def get_form_elements(self) -> list:
return [
Select(
options=[
("OpenAI", "OpenAI"),
("Cohere", "Cohere"),
("Bedrock", "Bedrock"),
("HuggingFace", "HuggingFace"),
("Azure", "Azure"),
("Gemini", "Gemini"),
("Other", "Other"),
],
prompt="Please select an embedding provider",
classes="form-control",
id="provider_select",
)
]
@on(Select.Changed, "#provider_select")
def handle_selection(self, event: Select.Changed) -> None:
app = self.app
if isinstance(app, EmbeddingSetupApp):
app.config.provider = event.value
self.handle_next()
def handle_next(self) -> None:
app = self.app
if isinstance(app, EmbeddingSetupApp):
provider_screens = {
"OpenAI": OpenAIEmbeddingScreen
# "Cohere": CohereEmbeddingScreen,
# "Bedrock": BedrockEmbeddingScreen,
# "HuggingFace": HuggingFaceEmbeddingScreen,
# "Azure": AzureEmbeddingScreen,
# "Gemini": GeminiEmbeddingScreen,
# "Other": OtherEmbeddingScreen,
}
screen_class = provider_screens.get(app.config.provider)
if screen_class:
app.push_screen(screen_class(app.config))
class ConfigurationScreen(BaseScreen):
BINDINGS = BaseScreen.BINDINGS + [SUBMIT_BINDING]
def action_submit(self) -> None:
"""To be implemented by specific provider screens"""
pass
class OpenAIEmbeddingScreen(ConfigurationScreen):
def get_title(self) -> str:
return "OpenAI Embedding Configuration"
def get_form_elements(self) -> list:
return [
Input(
placeholder="API key",
type="text",
password=True,
id="api_key",
classes="form-control",
),
Input(placeholder="Model", type="text", id="model", classes="form-control"),
]
def action_submit(self) -> None:
self.app.config.api_key = self.query_one("#api_key", Input).value
self.app.config.model = self.query_one("#model", Input).value
self.app.handle_completion(self.app.config)
class EmbeddingSetupApp(App):
CSS_PATH = "stylesheets/base.tcss"
def __init__(self):
super().__init__()
self.config = EmbeddingConfig(provider="")
def on_mount(self) -> None:
self.push_screen(InitialScreen())
def handle_completion(self, config: EmbeddingConfig) -> None:
self.exit(config)
def handle_default_setup(self) -> None:
self.config.provider = "OpenAI"
self.config.api_key = os.getenv("OPENAI_API_KEY")
self.config.model = "text-embedding-3-small"
def get_input_ids(self) -> list[str]:
"""Override this to define input fields for the app"""
return []
class DefaultOrCustomApp(BaseEmbeddingApp):
class DefaultOrCustomApp(BaseScreen):
def compose(self) -> ComposeResult:
yield Label("How do you wish to proceed?", classes="form-title")
yield Select(
@@ -58,7 +201,7 @@ class DefaultOrCustomApp(BaseEmbeddingApp):
self.title = str(event.value)
class SelectEmbeddingApp(BaseEmbeddingApp):
class SelectEmbeddingApp(BaseScreen):
def compose(self) -> ComposeResult:
yield Label("Embedding Model Selection", classes="form-title")
yield Select(
@@ -80,7 +223,7 @@ class SelectEmbeddingApp(BaseEmbeddingApp):
self.title = str(event.value)
class BedrockEmbeddingApp(BaseEmbeddingApp):
class BedrockEmbeddingApp(BaseScreen):
def compose(self) -> ComposeResult:
yield Label("Model and API key", classes="form-title")
yield Input(
@@ -99,7 +242,7 @@ class BedrockEmbeddingApp(BaseEmbeddingApp):
return ["region", "key_id"]
class HuggingFaceEmbeddingApp(BaseEmbeddingApp):
class HuggingFaceEmbeddingApp(BaseScreen):
def compose(self) -> ComposeResult:
yield Label("Model and API key", classes="form-title")
yield Input(
@@ -118,7 +261,7 @@ class HuggingFaceEmbeddingApp(BaseEmbeddingApp):
return ["api_key", "model"]
class OpenAIEmbeddingApp(BaseEmbeddingApp):
class OpenAIEmbeddingApp(BaseScreen):
def compose(self) -> ComposeResult:
yield Label("Model and API key", classes="form-title")
yield Input(
@@ -137,7 +280,7 @@ class OpenAIEmbeddingApp(BaseEmbeddingApp):
return ["api_key", "model"]
class CohereEmbeddingApp(BaseEmbeddingApp):
class CohereEmbeddingApp(BaseScreen):
def compose(self) -> ComposeResult:
yield Label("Model and API key", classes="form-title")
yield Input(
@@ -156,7 +299,7 @@ class CohereEmbeddingApp(BaseEmbeddingApp):
return ["api_key", "model"]
class AzureEmbeddingApp(BaseEmbeddingApp):
class AzureEmbeddingApp(BaseScreen):
def compose(self) -> ComposeResult:
yield Label("Model and API key", classes="form-title")
yield Input(
@@ -175,7 +318,7 @@ class AzureEmbeddingApp(BaseEmbeddingApp):
return ["api_key", "target_uri"]
class GeminiEmbeddingApp(BaseEmbeddingApp):
class GeminiEmbeddingApp(BaseScreen):
def compose(self) -> ComposeResult:
yield Label("Model and API key", classes="form-title")
yield Input(
@@ -192,7 +335,7 @@ class GeminiEmbeddingApp(BaseEmbeddingApp):
return ["api_key"]
class OtherEmbeddingApp(BaseEmbeddingApp):
class OtherEmbeddingApp(BaseScreen):
def compose(self) -> ComposeResult:
yield Label("Model and API key", classes="form-title")
yield Input(
+33 -120
View File
@@ -1,21 +1,9 @@
import os
from dotenv import load_dotenv
from cli.utils import EmbeddingSetupApp
from cli.utils import (
DefaultOrCustomApp,
SelectEmbeddingApp,
AzureEmbeddingApp,
GeminiEmbeddingApp,
BedrockEmbeddingApp,
OtherEmbeddingApp,
)
from llama_cloud import (
PipelineCreateEmbeddingConfig_OpenaiEmbedding,
PipelineCreateEmbeddingConfig_AzureEmbedding,
PipelineCreateEmbeddingConfig_BedrockEmbedding,
PipelineCreateEmbeddingConfig_CohereEmbedding,
PipelineCreateEmbeddingConfig_GeminiEmbedding,
PipelineCreateEmbeddingConfig_HuggingfaceApiEmbedding,
PipelineTransformConfig_Advanced,
AdvancedModeTransformConfigChunkingConfig_Sentence,
AdvancedModeTransformConfigSegmentationConfig_Page,
@@ -23,11 +11,6 @@ from llama_cloud import (
)
from llama_cloud.client import LlamaCloud
from llama_index.embeddings.openai import OpenAIEmbedding
from llama_index.embeddings.azure_inference import AzureAIEmbeddingsModel
from llama_index.embeddings.cohere import CohereEmbedding
from llama_index.embeddings.huggingface_api import HuggingFaceInferenceAPIEmbedding
from llama_index.embeddings.gemini import GeminiEmbedding
from llama_index.embeddings.bedrock import BedrockEmbedding
def default(client: LlamaCloud):
@@ -73,111 +56,41 @@ def default(client: LlamaCloud):
def main():
load_dotenv()
client = LlamaCloud(token=os.getenv("LLAMACLOUD_API_KEY"))
app1 = DefaultOrCustomApp()
app1.run()
if app1.title == "With Default Settings":
default(client=client)
app = EmbeddingSetupApp()
embedding_config = app.run()
if embedding_config:
segm_config = AdvancedModeTransformConfigSegmentationConfig_Page(mode="page")
chunk_config = AdvancedModeTransformConfigChunkingConfig_Sentence(
chunk_size=1024,
chunk_overlap=200,
separator="<whitespace>",
paragraph_separator="\n\n\n",
mode="sentence",
)
transform_config = PipelineTransformConfig_Advanced(
segmentation_config=segm_config,
chunking_config=chunk_config,
mode="advanced",
)
pipeline_request = PipelineCreate(
name="notebooklm_pipeline",
embedding_config=embedding_config,
transform_config=transform_config,
)
pipeline = client.pipelines.upsert_pipeline(request=pipeline_request)
with open(".env", "a") as f:
f.write(f'\nLLAMACLOUD_PIPELINE_ID="{pipeline.id}"')
return 0
else:
app2 = SelectEmbeddingApp()
app2.run()
if app2.title == "Azure":
app3 = AzureEmbeddingApp()
app3.run()
api_key = app3.form_data.get("api_key", "")
endpoint = app3.form_data.get("target_uri", "")
embed_model = AzureAIEmbeddingsModel(credential=api_key, endpoint=endpoint)
embedding_config = PipelineCreateEmbeddingConfig_AzureEmbedding(
type="AZURE_EMBEDDING",
component=embed_model,
)
elif app2.title == "Bedrock":
app4 = BedrockEmbeddingApp()
app4.run()
api_key = app4.form_data.get("api_key", "")
key_id = app4.form_data.get("key_id", "")
region = app4.form_data.get("region", "")
model = app4.form_data.get("model", "")
embed_model = BedrockEmbedding(
model_name=model,
aws_access_key_id=key_id,
aws_secret_access_key=api_key,
region_name=region,
)
embedding_config = PipelineCreateEmbeddingConfig_BedrockEmbedding(
type="BEDROCK_EMBEDDING",
component=embed_model,
)
elif app2.title == "Gemini":
app5 = GeminiEmbeddingApp()
app5.run()
api_key = app5.form_data.get("api_key", "")
embed_model = GeminiEmbedding(api_key=api_key)
embedding_config = PipelineCreateEmbeddingConfig_GeminiEmbedding(
type="GEMINI_EMBEDDING",
component=embed_model,
)
else:
app6 = OtherEmbeddingApp()
app6.run()
api_key = app6.form_data.get("api_key", "")
model = app6.form_data.get("model", "")
if app2.title == "Cohere":
embed_model = CohereEmbedding(
model_name=model, api_key=api_key, cohere_api_key=api_key
)
embedding_config = PipelineCreateEmbeddingConfig_CohereEmbedding(
type="COHERE_EMBEDDING",
component=embed_model,
)
elif app2.title == "OpenAI":
embed_model = OpenAIEmbedding(model=model, api_key=api_key)
embedding_config = PipelineCreateEmbeddingConfig_OpenaiEmbedding(
type="OPENAI_EMBEDDING",
component=embed_model,
)
else:
embed_model = HuggingFaceInferenceAPIEmbedding(
token=api_key, model_name=model
)
embedding_config = (
PipelineCreateEmbeddingConfig_HuggingfaceApiEmbedding(
type="HUGGINGFACE_API_EMBEDDING",
component=embed_model,
)
)
segm_config = AdvancedModeTransformConfigSegmentationConfig_Page(mode="page")
chunk_config = AdvancedModeTransformConfigChunkingConfig_Sentence(
chunk_size=1024,
chunk_overlap=200,
separator="<whitespace>",
paragraph_separator="\n\n\n",
mode="sentence",
)
transform_config = PipelineTransformConfig_Advanced(
segmentation_config=segm_config,
chunking_config=chunk_config,
mode="advanced",
)
pipeline_request = PipelineCreate(
name="notebooklm_pipeline",
embedding_config=embedding_config,
transform_config=transform_config,
)
pipeline = client.pipelines.upsert_pipeline(request=pipeline_request)
with open(".env", "a") as f:
f.write(f'\nLLAMACLOUD_PIPELINE_ID="{pipeline.id}"')
return 0
print("No embedding configuration provided")
return 1
if __name__ == "__main__":