mirror of
https://github.com/run-llama/notebookllama.git
synced 2026-06-30 22:17:57 -04:00
chore: implementing suggestions
This commit is contained in:
@@ -78,7 +78,7 @@ async def run_workflow(
|
||||
|
||||
end_time = int(time.time() * 1000000)
|
||||
sql_engine.to_sql_database(start_time=st_time, end_time=end_time)
|
||||
document_manager.import_documents(
|
||||
document_manager.put_documents(
|
||||
[
|
||||
ManagedDocument(
|
||||
document_name=document_title,
|
||||
@@ -161,8 +161,6 @@ file_input = st.file_uploader(
|
||||
)
|
||||
|
||||
|
||||
# Add this after your existing code, before the st.title line:
|
||||
|
||||
# Initialize session state
|
||||
if "workflow_results" not in st.session_state:
|
||||
st.session_state.workflow_results = None
|
||||
|
||||
@@ -4,6 +4,10 @@ from typing_extensions import Self
|
||||
from typing import Optional, Any, List, cast
|
||||
|
||||
|
||||
def apply_string_correction(string: str) -> str:
|
||||
return string.replace("''", "'").replace('""', '"')
|
||||
|
||||
|
||||
class ManagedDocument(BaseModel):
|
||||
document_name: str
|
||||
content: str
|
||||
@@ -11,7 +15,7 @@ class ManagedDocument(BaseModel):
|
||||
q_and_a: str
|
||||
mindmap: str
|
||||
bullet_points: str
|
||||
is_exported: bool = Field(default=False)
|
||||
is_exported: bool = Field(default=False, exclude=True)
|
||||
|
||||
@model_validator(mode="after")
|
||||
def validate_input_for_sql(self) -> Self:
|
||||
@@ -42,14 +46,17 @@ class DocumentManager:
|
||||
else:
|
||||
raise ValueError("One of engine or engine_setup_kwargs must be set")
|
||||
|
||||
@property
|
||||
def connection(self) -> Connection:
|
||||
if not self._connection:
|
||||
self._connect()
|
||||
return cast(Connection, self._connection)
|
||||
|
||||
def _connect(self) -> None:
|
||||
self._connection = self._engine.connect()
|
||||
|
||||
def _create_table(self) -> None:
|
||||
if not self._connection:
|
||||
self._connect()
|
||||
self._connection = cast(Connection, self._connection)
|
||||
self._connection.execute(
|
||||
self.connection.execute(
|
||||
text(f"""
|
||||
CREATE TABLE IF NOT EXISTS {self.table_name} (
|
||||
id SERIAL PRIMARY KEY,
|
||||
@@ -62,17 +69,14 @@ class DocumentManager:
|
||||
);
|
||||
""")
|
||||
)
|
||||
self._connection.commit()
|
||||
self.connection.commit()
|
||||
self.table_exists = True
|
||||
|
||||
def import_documents(self, documents: List[ManagedDocument]) -> None:
|
||||
if not self._connection:
|
||||
self._connect()
|
||||
self._connection = cast(Connection, self._connection)
|
||||
def put_documents(self, documents: List[ManagedDocument]) -> None:
|
||||
if not self.table_exists:
|
||||
self._create_table()
|
||||
for document in documents:
|
||||
self._connection.execute(
|
||||
self.connection.execute(
|
||||
text(
|
||||
f"""
|
||||
INSERT INTO {self.table_name} (document_name, content, summary, q_and_a, mindmap, bullet_points)
|
||||
@@ -87,18 +91,27 @@ class DocumentManager:
|
||||
"""
|
||||
)
|
||||
)
|
||||
self._connection.commit()
|
||||
self.connection.commit()
|
||||
|
||||
def export_documents(self, limit: Optional[int] = None) -> List[ManagedDocument]:
|
||||
if not limit:
|
||||
limit = 15
|
||||
result = self._execute(
|
||||
text(
|
||||
f"""
|
||||
SELECT * FROM {self.table_name} ORDER BY id LIMIT {limit};
|
||||
"""
|
||||
def get_documents(self, names: Optional[List[str]] = None) -> List[ManagedDocument]:
|
||||
if not self.table_exists:
|
||||
self._create_table()
|
||||
if not names:
|
||||
result = self._execute(
|
||||
text(
|
||||
f"""
|
||||
SELECT * FROM {self.table_name} ORDER BY id;
|
||||
"""
|
||||
)
|
||||
)
|
||||
else:
|
||||
result = self._execute(
|
||||
text(
|
||||
f"""
|
||||
SELECT * FROM {self.table_name} WHERE document_name = ANY(ARRAY{names}) ORDER BY id;
|
||||
"""
|
||||
)
|
||||
)
|
||||
)
|
||||
rows = result.fetchall()
|
||||
documents = []
|
||||
for row in rows:
|
||||
@@ -111,33 +124,36 @@ class DocumentManager:
|
||||
bullet_points=row.bullet_points,
|
||||
is_exported=True,
|
||||
)
|
||||
document.mindmap = (
|
||||
document.mindmap.replace('""', '"')
|
||||
.replace("''", "'")
|
||||
.replace("''mynetwork''", "'mynetwork'")
|
||||
)
|
||||
document.document_name = document.document_name.replace('""', '"').replace(
|
||||
"''", "'"
|
||||
)
|
||||
document.content = document.content.replace('""', '"').replace("''", "'")
|
||||
document.summary = document.summary.replace('""', '"').replace("''", "'")
|
||||
document.q_and_a = document.q_and_a.replace('""', '"').replace("''", "'")
|
||||
document.bullet_points = document.bullet_points.replace('""', '"').replace(
|
||||
"''", "'"
|
||||
)
|
||||
documents.append(document)
|
||||
doc_dict = document.model_dump()
|
||||
for field in doc_dict:
|
||||
doc_dict[field] = apply_string_correction(doc_dict[field])
|
||||
if field == "mindmap":
|
||||
doc_dict[field] = doc_dict[field].replace(
|
||||
"''mynetwork''", "'mynetwork'"
|
||||
)
|
||||
documents.append(ManagedDocument.model_validate(doc_dict))
|
||||
return documents
|
||||
|
||||
def get_names(self) -> List[str]:
|
||||
if not self.table_exists:
|
||||
self._create_table()
|
||||
result = self._execute(
|
||||
text(
|
||||
f"""
|
||||
SELECT * FROM {self.table_name} ORDER BY id;
|
||||
"""
|
||||
)
|
||||
)
|
||||
rows = result.fetchall()
|
||||
return [row.document_name for row in rows]
|
||||
|
||||
def _execute(
|
||||
self,
|
||||
statement: Any,
|
||||
parameters: Optional[Any] = None,
|
||||
execution_options: Optional[Any] = None,
|
||||
) -> Result:
|
||||
if not self._connection:
|
||||
self._connect()
|
||||
self._connection = cast(Connection, self._connection)
|
||||
return self._connection.execute(
|
||||
return self.connection.execute(
|
||||
statement=statement,
|
||||
parameters=parameters,
|
||||
execution_options=execution_options,
|
||||
|
||||
@@ -2,7 +2,7 @@ import os
|
||||
import streamlit as st
|
||||
import streamlit.components.v1 as components
|
||||
from dotenv import load_dotenv
|
||||
from typing import List
|
||||
from typing import List, Optional
|
||||
|
||||
from documents import DocumentManager, ManagedDocument
|
||||
|
||||
@@ -14,9 +14,13 @@ engine_url = f"postgresql+psycopg2://{os.getenv('pgql_user')}:{os.getenv('pgql_p
|
||||
document_manager = DocumentManager(engine_url=engine_url)
|
||||
|
||||
|
||||
def view_documents(limit: int) -> List[ManagedDocument]:
|
||||
def fetch_documents(names: Optional[List[str]]) -> List[ManagedDocument]:
|
||||
"""Retrieve documents from the database"""
|
||||
return document_manager.export_documents(limit=limit)
|
||||
return document_manager.get_documents(names=names)
|
||||
|
||||
|
||||
def fetch_document_names() -> List[str]:
|
||||
return document_manager.get_names()
|
||||
|
||||
|
||||
def display_document(document: ManagedDocument) -> None:
|
||||
@@ -57,15 +61,17 @@ def main():
|
||||
st.markdown("## NotebookLlaMa - Document Management📚")
|
||||
|
||||
# Slider for number of documents
|
||||
limit = st.slider(
|
||||
"Number of documents to display:", min_value=1, max_value=50, value=15, step=1
|
||||
names = st.multiselect(
|
||||
options=fetch_document_names(),
|
||||
default=None,
|
||||
label="Select the Documents you want to display",
|
||||
)
|
||||
|
||||
# Button to load documents
|
||||
if st.button("Load Documents", type="primary"):
|
||||
with st.spinner("Loading documents..."):
|
||||
try:
|
||||
documents = view_documents(limit)
|
||||
documents = fetch_documents(names)
|
||||
|
||||
if documents:
|
||||
st.success(f"Successfully loaded {len(documents)} document(s)")
|
||||
|
||||
@@ -67,8 +67,10 @@ def test_document_manager(documents: List[ManagedDocument]) -> None:
|
||||
manager._execute(text("DROP TABLE IF EXISTS test_documents;"))
|
||||
manager._create_table()
|
||||
assert manager.table_exists
|
||||
manager.import_documents(documents=documents)
|
||||
docs = manager.export_documents()
|
||||
manager.put_documents(documents=documents)
|
||||
names = manager.get_names()
|
||||
assert names == [doc.document_name for doc in documents]
|
||||
docs = manager.get_documents()
|
||||
assert docs == documents
|
||||
docs1 = manager.export_documents(limit=2)
|
||||
docs1 = manager.get_documents(names=["Project Plan", "Meeting Notes"])
|
||||
assert len(docs1) == 2
|
||||
|
||||
Reference in New Issue
Block a user