mirror of
https://github.com/run-llama/notebookllama.git
synced 2026-07-01 22:14:04 -04:00
refactor: migrate embedding config to screen based navigation
- Replace multiple separate apps with single app using screen navigation - (WIP): Introduce screen hierarchy (BaseScreen, ConfigurationScreen..) - Add EmbeddingConfig dataclass - Centralize keyboard bindings - Simplify pipeline creation flow in create_llama_cloud_index.py to for separation of concerns This purpose of this change is to improve the UI flow and code organization (wip) by using Textuals screen mgmt system instead of multiple separate apps
This commit is contained in:
+175
-32
@@ -1,46 +1,189 @@
|
||||
import os
|
||||
from textual import on
|
||||
from textual.app import App, ComposeResult
|
||||
from textual.binding import Binding
|
||||
from textual.containers import Container
|
||||
from textual.screen import Screen
|
||||
from textual.widgets import Select, Input, Label, Footer
|
||||
from dataclasses import dataclass
|
||||
from typing import Optional
|
||||
|
||||
|
||||
class BaseEmbeddingApp(App):
|
||||
CSS_PATH = "stylesheets/base.tcss"
|
||||
COMMON_BINDINGS = [
|
||||
Binding("ctrl+q", "quit", "Exit", key_display="ctrl+q"),
|
||||
Binding("ctrl+d", "toggle_dark", "t", key_display="ctrl+d"),
|
||||
]
|
||||
SUBMIT_BINDING = Binding("ctrl+s", "submit", "Submit", key_display="ctrl+s")
|
||||
|
||||
BINDINGS = [
|
||||
Binding(
|
||||
key="ctrl+q", action="quit", description="Submit", key_display="ctrl+q"
|
||||
),
|
||||
Binding(
|
||||
key="ctrl+d",
|
||||
action="toggle_dark",
|
||||
description="Toggle Dark Theme",
|
||||
key_display="ctrl+d",
|
||||
),
|
||||
]
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.form_data = {}
|
||||
@dataclass
|
||||
class EmbeddingConfig:
|
||||
provider: str
|
||||
api_key: Optional[str] = None
|
||||
model: Optional[str] = None
|
||||
region: Optional[str] = None
|
||||
key_id: Optional[str] = None
|
||||
|
||||
|
||||
class BaseScreen(Screen):
|
||||
BINDINGS = COMMON_BINDINGS
|
||||
|
||||
def action_toggle_dark(self) -> None:
|
||||
self.theme = (
|
||||
"textual-dark" if self.theme == "textual-light" else "textual-light"
|
||||
self.app.theme = (
|
||||
"textual-dark" if self.app.theme != "textual-light" else "textual-light"
|
||||
)
|
||||
|
||||
def action_quit(self) -> None:
|
||||
self.form_data = {
|
||||
input_id: self.query_one(f"#{input_id}", Input).value
|
||||
for input_id in self.get_input_ids()
|
||||
}
|
||||
self.exit()
|
||||
self.app.exit()
|
||||
|
||||
def compose(self) -> ComposeResult:
|
||||
yield Container(
|
||||
Label(self.get_title(), classes="form-title"),
|
||||
*self.get_form_elements(),
|
||||
Footer(),
|
||||
classes="form-container",
|
||||
)
|
||||
|
||||
def get_title(self) -> str:
|
||||
return "Base Screen"
|
||||
|
||||
def get_form_elements(self) -> list:
|
||||
return []
|
||||
|
||||
|
||||
class InitialScreen(BaseScreen):
|
||||
def get_title(self) -> str:
|
||||
return "How do you wish to proceed?"
|
||||
|
||||
def get_form_elements(self) -> list:
|
||||
return [
|
||||
Select(
|
||||
options=[
|
||||
("With Default Settings", "With Default Settings"),
|
||||
("With Custom Settings", "With Custom Settings"),
|
||||
],
|
||||
prompt="Please select one of the following",
|
||||
id="setup_type",
|
||||
classes="form-control",
|
||||
)
|
||||
]
|
||||
|
||||
@on(Select.Changed, "#setup_type")
|
||||
def handle_selection(self, event: Select.Changed) -> None:
|
||||
app = self.app
|
||||
if isinstance(app, EmbeddingSetupApp):
|
||||
app.config.setup_type = event.value
|
||||
self.handle_next()
|
||||
|
||||
def handle_next(self) -> None:
|
||||
app = self.app
|
||||
if isinstance(app, EmbeddingSetupApp):
|
||||
if app.config.setup_type == "With Default Settings":
|
||||
app.handle_default_setup()
|
||||
else:
|
||||
app.push_screen(ProviderSelectScreen(app.config))
|
||||
|
||||
|
||||
class ProviderSelectScreen(BaseScreen):
|
||||
def get_title(self) -> str:
|
||||
return "Select an embedding provider"
|
||||
|
||||
def get_form_elements(self) -> list:
|
||||
return [
|
||||
Select(
|
||||
options=[
|
||||
("OpenAI", "OpenAI"),
|
||||
("Cohere", "Cohere"),
|
||||
("Bedrock", "Bedrock"),
|
||||
("HuggingFace", "HuggingFace"),
|
||||
("Azure", "Azure"),
|
||||
("Gemini", "Gemini"),
|
||||
("Other", "Other"),
|
||||
],
|
||||
prompt="Please select an embedding provider",
|
||||
classes="form-control",
|
||||
id="provider_select",
|
||||
)
|
||||
]
|
||||
|
||||
@on(Select.Changed, "#provider_select")
|
||||
def handle_selection(self, event: Select.Changed) -> None:
|
||||
app = self.app
|
||||
if isinstance(app, EmbeddingSetupApp):
|
||||
app.config.provider = event.value
|
||||
self.handle_next()
|
||||
|
||||
def handle_next(self) -> None:
|
||||
app = self.app
|
||||
if isinstance(app, EmbeddingSetupApp):
|
||||
provider_screens = {
|
||||
"OpenAI": OpenAIEmbeddingScreen
|
||||
# "Cohere": CohereEmbeddingScreen,
|
||||
# "Bedrock": BedrockEmbeddingScreen,
|
||||
# "HuggingFace": HuggingFaceEmbeddingScreen,
|
||||
# "Azure": AzureEmbeddingScreen,
|
||||
# "Gemini": GeminiEmbeddingScreen,
|
||||
# "Other": OtherEmbeddingScreen,
|
||||
}
|
||||
screen_class = provider_screens.get(app.config.provider)
|
||||
if screen_class:
|
||||
app.push_screen(screen_class(app.config))
|
||||
|
||||
|
||||
class ConfigurationScreen(BaseScreen):
|
||||
BINDINGS = BaseScreen.BINDINGS + [SUBMIT_BINDING]
|
||||
|
||||
def action_submit(self) -> None:
|
||||
"""To be implemented by specific provider screens"""
|
||||
pass
|
||||
|
||||
|
||||
class OpenAIEmbeddingScreen(ConfigurationScreen):
|
||||
def get_title(self) -> str:
|
||||
return "OpenAI Embedding Configuration"
|
||||
|
||||
def get_form_elements(self) -> list:
|
||||
return [
|
||||
Input(
|
||||
placeholder="API key",
|
||||
type="text",
|
||||
password=True,
|
||||
id="api_key",
|
||||
classes="form-control",
|
||||
),
|
||||
Input(placeholder="Model", type="text", id="model", classes="form-control"),
|
||||
]
|
||||
|
||||
def action_submit(self) -> None:
|
||||
self.app.config.api_key = self.query_one("#api_key", Input).value
|
||||
self.app.config.model = self.query_one("#model", Input).value
|
||||
self.app.handle_completion(self.app.config)
|
||||
|
||||
|
||||
class EmbeddingSetupApp(App):
|
||||
CSS_PATH = "stylesheets/base.tcss"
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.config = EmbeddingConfig(provider="")
|
||||
|
||||
def on_mount(self) -> None:
|
||||
self.push_screen(InitialScreen())
|
||||
|
||||
def handle_completion(self, config: EmbeddingConfig) -> None:
|
||||
self.exit(config)
|
||||
|
||||
def handle_default_setup(self) -> None:
|
||||
self.config.provider = "OpenAI"
|
||||
self.config.api_key = os.getenv("OPENAI_API_KEY")
|
||||
self.config.model = "text-embedding-3-small"
|
||||
|
||||
def get_input_ids(self) -> list[str]:
|
||||
"""Override this to define input fields for the app"""
|
||||
return []
|
||||
|
||||
|
||||
class DefaultOrCustomApp(BaseEmbeddingApp):
|
||||
class DefaultOrCustomApp(BaseScreen):
|
||||
def compose(self) -> ComposeResult:
|
||||
yield Label("How do you wish to proceed?", classes="form-title")
|
||||
yield Select(
|
||||
@@ -58,7 +201,7 @@ class DefaultOrCustomApp(BaseEmbeddingApp):
|
||||
self.title = str(event.value)
|
||||
|
||||
|
||||
class SelectEmbeddingApp(BaseEmbeddingApp):
|
||||
class SelectEmbeddingApp(BaseScreen):
|
||||
def compose(self) -> ComposeResult:
|
||||
yield Label("Embedding Model Selection", classes="form-title")
|
||||
yield Select(
|
||||
@@ -80,7 +223,7 @@ class SelectEmbeddingApp(BaseEmbeddingApp):
|
||||
self.title = str(event.value)
|
||||
|
||||
|
||||
class BedrockEmbeddingApp(BaseEmbeddingApp):
|
||||
class BedrockEmbeddingApp(BaseScreen):
|
||||
def compose(self) -> ComposeResult:
|
||||
yield Label("Model and API key", classes="form-title")
|
||||
yield Input(
|
||||
@@ -99,7 +242,7 @@ class BedrockEmbeddingApp(BaseEmbeddingApp):
|
||||
return ["region", "key_id"]
|
||||
|
||||
|
||||
class HuggingFaceEmbeddingApp(BaseEmbeddingApp):
|
||||
class HuggingFaceEmbeddingApp(BaseScreen):
|
||||
def compose(self) -> ComposeResult:
|
||||
yield Label("Model and API key", classes="form-title")
|
||||
yield Input(
|
||||
@@ -118,7 +261,7 @@ class HuggingFaceEmbeddingApp(BaseEmbeddingApp):
|
||||
return ["api_key", "model"]
|
||||
|
||||
|
||||
class OpenAIEmbeddingApp(BaseEmbeddingApp):
|
||||
class OpenAIEmbeddingApp(BaseScreen):
|
||||
def compose(self) -> ComposeResult:
|
||||
yield Label("Model and API key", classes="form-title")
|
||||
yield Input(
|
||||
@@ -137,7 +280,7 @@ class OpenAIEmbeddingApp(BaseEmbeddingApp):
|
||||
return ["api_key", "model"]
|
||||
|
||||
|
||||
class CohereEmbeddingApp(BaseEmbeddingApp):
|
||||
class CohereEmbeddingApp(BaseScreen):
|
||||
def compose(self) -> ComposeResult:
|
||||
yield Label("Model and API key", classes="form-title")
|
||||
yield Input(
|
||||
@@ -156,7 +299,7 @@ class CohereEmbeddingApp(BaseEmbeddingApp):
|
||||
return ["api_key", "model"]
|
||||
|
||||
|
||||
class AzureEmbeddingApp(BaseEmbeddingApp):
|
||||
class AzureEmbeddingApp(BaseScreen):
|
||||
def compose(self) -> ComposeResult:
|
||||
yield Label("Model and API key", classes="form-title")
|
||||
yield Input(
|
||||
@@ -175,7 +318,7 @@ class AzureEmbeddingApp(BaseEmbeddingApp):
|
||||
return ["api_key", "target_uri"]
|
||||
|
||||
|
||||
class GeminiEmbeddingApp(BaseEmbeddingApp):
|
||||
class GeminiEmbeddingApp(BaseScreen):
|
||||
def compose(self) -> ComposeResult:
|
||||
yield Label("Model and API key", classes="form-title")
|
||||
yield Input(
|
||||
@@ -192,7 +335,7 @@ class GeminiEmbeddingApp(BaseEmbeddingApp):
|
||||
return ["api_key"]
|
||||
|
||||
|
||||
class OtherEmbeddingApp(BaseEmbeddingApp):
|
||||
class OtherEmbeddingApp(BaseScreen):
|
||||
def compose(self) -> ComposeResult:
|
||||
yield Label("Model and API key", classes="form-title")
|
||||
yield Input(
|
||||
|
||||
@@ -1,21 +1,9 @@
|
||||
import os
|
||||
from dotenv import load_dotenv
|
||||
from cli.utils import EmbeddingSetupApp
|
||||
|
||||
from cli.utils import (
|
||||
DefaultOrCustomApp,
|
||||
SelectEmbeddingApp,
|
||||
AzureEmbeddingApp,
|
||||
GeminiEmbeddingApp,
|
||||
BedrockEmbeddingApp,
|
||||
OtherEmbeddingApp,
|
||||
)
|
||||
from llama_cloud import (
|
||||
PipelineCreateEmbeddingConfig_OpenaiEmbedding,
|
||||
PipelineCreateEmbeddingConfig_AzureEmbedding,
|
||||
PipelineCreateEmbeddingConfig_BedrockEmbedding,
|
||||
PipelineCreateEmbeddingConfig_CohereEmbedding,
|
||||
PipelineCreateEmbeddingConfig_GeminiEmbedding,
|
||||
PipelineCreateEmbeddingConfig_HuggingfaceApiEmbedding,
|
||||
PipelineTransformConfig_Advanced,
|
||||
AdvancedModeTransformConfigChunkingConfig_Sentence,
|
||||
AdvancedModeTransformConfigSegmentationConfig_Page,
|
||||
@@ -23,11 +11,6 @@ from llama_cloud import (
|
||||
)
|
||||
from llama_cloud.client import LlamaCloud
|
||||
from llama_index.embeddings.openai import OpenAIEmbedding
|
||||
from llama_index.embeddings.azure_inference import AzureAIEmbeddingsModel
|
||||
from llama_index.embeddings.cohere import CohereEmbedding
|
||||
from llama_index.embeddings.huggingface_api import HuggingFaceInferenceAPIEmbedding
|
||||
from llama_index.embeddings.gemini import GeminiEmbedding
|
||||
from llama_index.embeddings.bedrock import BedrockEmbedding
|
||||
|
||||
|
||||
def default(client: LlamaCloud):
|
||||
@@ -73,111 +56,41 @@ def default(client: LlamaCloud):
|
||||
def main():
|
||||
load_dotenv()
|
||||
client = LlamaCloud(token=os.getenv("LLAMACLOUD_API_KEY"))
|
||||
app1 = DefaultOrCustomApp()
|
||||
app1.run()
|
||||
|
||||
if app1.title == "With Default Settings":
|
||||
default(client=client)
|
||||
app = EmbeddingSetupApp()
|
||||
embedding_config = app.run()
|
||||
|
||||
if embedding_config:
|
||||
segm_config = AdvancedModeTransformConfigSegmentationConfig_Page(mode="page")
|
||||
chunk_config = AdvancedModeTransformConfigChunkingConfig_Sentence(
|
||||
chunk_size=1024,
|
||||
chunk_overlap=200,
|
||||
separator="<whitespace>",
|
||||
paragraph_separator="\n\n\n",
|
||||
mode="sentence",
|
||||
)
|
||||
|
||||
transform_config = PipelineTransformConfig_Advanced(
|
||||
segmentation_config=segm_config,
|
||||
chunking_config=chunk_config,
|
||||
mode="advanced",
|
||||
)
|
||||
|
||||
pipeline_request = PipelineCreate(
|
||||
name="notebooklm_pipeline",
|
||||
embedding_config=embedding_config,
|
||||
transform_config=transform_config,
|
||||
)
|
||||
|
||||
pipeline = client.pipelines.upsert_pipeline(request=pipeline_request)
|
||||
|
||||
with open(".env", "a") as f:
|
||||
f.write(f'\nLLAMACLOUD_PIPELINE_ID="{pipeline.id}"')
|
||||
|
||||
return 0
|
||||
else:
|
||||
app2 = SelectEmbeddingApp()
|
||||
app2.run()
|
||||
|
||||
if app2.title == "Azure":
|
||||
app3 = AzureEmbeddingApp()
|
||||
app3.run()
|
||||
api_key = app3.form_data.get("api_key", "")
|
||||
endpoint = app3.form_data.get("target_uri", "")
|
||||
embed_model = AzureAIEmbeddingsModel(credential=api_key, endpoint=endpoint)
|
||||
embedding_config = PipelineCreateEmbeddingConfig_AzureEmbedding(
|
||||
type="AZURE_EMBEDDING",
|
||||
component=embed_model,
|
||||
)
|
||||
|
||||
elif app2.title == "Bedrock":
|
||||
app4 = BedrockEmbeddingApp()
|
||||
app4.run()
|
||||
api_key = app4.form_data.get("api_key", "")
|
||||
key_id = app4.form_data.get("key_id", "")
|
||||
region = app4.form_data.get("region", "")
|
||||
model = app4.form_data.get("model", "")
|
||||
embed_model = BedrockEmbedding(
|
||||
model_name=model,
|
||||
aws_access_key_id=key_id,
|
||||
aws_secret_access_key=api_key,
|
||||
region_name=region,
|
||||
)
|
||||
embedding_config = PipelineCreateEmbeddingConfig_BedrockEmbedding(
|
||||
type="BEDROCK_EMBEDDING",
|
||||
component=embed_model,
|
||||
)
|
||||
|
||||
elif app2.title == "Gemini":
|
||||
app5 = GeminiEmbeddingApp()
|
||||
app5.run()
|
||||
api_key = app5.form_data.get("api_key", "")
|
||||
embed_model = GeminiEmbedding(api_key=api_key)
|
||||
embedding_config = PipelineCreateEmbeddingConfig_GeminiEmbedding(
|
||||
type="GEMINI_EMBEDDING",
|
||||
component=embed_model,
|
||||
)
|
||||
|
||||
else:
|
||||
app6 = OtherEmbeddingApp()
|
||||
app6.run()
|
||||
api_key = app6.form_data.get("api_key", "")
|
||||
model = app6.form_data.get("model", "")
|
||||
if app2.title == "Cohere":
|
||||
embed_model = CohereEmbedding(
|
||||
model_name=model, api_key=api_key, cohere_api_key=api_key
|
||||
)
|
||||
embedding_config = PipelineCreateEmbeddingConfig_CohereEmbedding(
|
||||
type="COHERE_EMBEDDING",
|
||||
component=embed_model,
|
||||
)
|
||||
elif app2.title == "OpenAI":
|
||||
embed_model = OpenAIEmbedding(model=model, api_key=api_key)
|
||||
embedding_config = PipelineCreateEmbeddingConfig_OpenaiEmbedding(
|
||||
type="OPENAI_EMBEDDING",
|
||||
component=embed_model,
|
||||
)
|
||||
else:
|
||||
embed_model = HuggingFaceInferenceAPIEmbedding(
|
||||
token=api_key, model_name=model
|
||||
)
|
||||
embedding_config = (
|
||||
PipelineCreateEmbeddingConfig_HuggingfaceApiEmbedding(
|
||||
type="HUGGINGFACE_API_EMBEDDING",
|
||||
component=embed_model,
|
||||
)
|
||||
)
|
||||
segm_config = AdvancedModeTransformConfigSegmentationConfig_Page(mode="page")
|
||||
chunk_config = AdvancedModeTransformConfigChunkingConfig_Sentence(
|
||||
chunk_size=1024,
|
||||
chunk_overlap=200,
|
||||
separator="<whitespace>",
|
||||
paragraph_separator="\n\n\n",
|
||||
mode="sentence",
|
||||
)
|
||||
|
||||
transform_config = PipelineTransformConfig_Advanced(
|
||||
segmentation_config=segm_config,
|
||||
chunking_config=chunk_config,
|
||||
mode="advanced",
|
||||
)
|
||||
|
||||
pipeline_request = PipelineCreate(
|
||||
name="notebooklm_pipeline",
|
||||
embedding_config=embedding_config,
|
||||
transform_config=transform_config,
|
||||
)
|
||||
|
||||
pipeline = client.pipelines.upsert_pipeline(request=pipeline_request)
|
||||
|
||||
with open(".env", "a") as f:
|
||||
f.write(f'\nLLAMACLOUD_PIPELINE_ID="{pipeline.id}"')
|
||||
|
||||
return 0
|
||||
print("No embedding configuration provided")
|
||||
return 1
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
Reference in New Issue
Block a user