diff --git a/src/notebookllama/Home.py b/src/notebookllama/Home.py index 2e06d69..552340d 100644 --- a/src/notebookllama/Home.py +++ b/src/notebookllama/Home.py @@ -78,7 +78,7 @@ async def run_workflow( end_time = int(time.time() * 1000000) sql_engine.to_sql_database(start_time=st_time, end_time=end_time) - document_manager.import_documents( + document_manager.put_documents( [ ManagedDocument( document_name=document_title, @@ -161,8 +161,6 @@ file_input = st.file_uploader( ) -# Add this after your existing code, before the st.title line: - # Initialize session state if "workflow_results" not in st.session_state: st.session_state.workflow_results = None diff --git a/src/notebookllama/documents.py b/src/notebookllama/documents.py index adcc98f..6179970 100644 --- a/src/notebookllama/documents.py +++ b/src/notebookllama/documents.py @@ -4,6 +4,10 @@ from typing_extensions import Self from typing import Optional, Any, List, cast +def apply_string_correction(string: str) -> str: + return string.replace("''", "'").replace('""', '"') + + class ManagedDocument(BaseModel): document_name: str content: str @@ -11,7 +15,7 @@ class ManagedDocument(BaseModel): q_and_a: str mindmap: str bullet_points: str - is_exported: bool = Field(default=False) + is_exported: bool = Field(default=False, exclude=True) @model_validator(mode="after") def validate_input_for_sql(self) -> Self: @@ -42,14 +46,17 @@ class DocumentManager: else: raise ValueError("One of engine or engine_setup_kwargs must be set") + @property + def connection(self) -> Connection: + if not self._connection: + self._connect() + return cast(Connection, self._connection) + def _connect(self) -> None: self._connection = self._engine.connect() def _create_table(self) -> None: - if not self._connection: - self._connect() - self._connection = cast(Connection, self._connection) - self._connection.execute( + self.connection.execute( text(f""" CREATE TABLE IF NOT EXISTS {self.table_name} ( id SERIAL PRIMARY KEY, @@ -62,17 +69,14 @@ class DocumentManager: ); """) ) - self._connection.commit() + self.connection.commit() self.table_exists = True - def import_documents(self, documents: List[ManagedDocument]) -> None: - if not self._connection: - self._connect() - self._connection = cast(Connection, self._connection) + def put_documents(self, documents: List[ManagedDocument]) -> None: if not self.table_exists: self._create_table() for document in documents: - self._connection.execute( + self.connection.execute( text( f""" INSERT INTO {self.table_name} (document_name, content, summary, q_and_a, mindmap, bullet_points) @@ -87,18 +91,27 @@ class DocumentManager: """ ) ) - self._connection.commit() + self.connection.commit() - def export_documents(self, limit: Optional[int] = None) -> List[ManagedDocument]: - if not limit: - limit = 15 - result = self._execute( - text( - f""" - SELECT * FROM {self.table_name} ORDER BY id LIMIT {limit}; - """ + def get_documents(self, names: Optional[List[str]] = None) -> List[ManagedDocument]: + if not self.table_exists: + self._create_table() + if not names: + result = self._execute( + text( + f""" + SELECT * FROM {self.table_name} ORDER BY id; + """ + ) + ) + else: + result = self._execute( + text( + f""" + SELECT * FROM {self.table_name} WHERE document_name = ANY(ARRAY{names}) ORDER BY id; + """ + ) ) - ) rows = result.fetchall() documents = [] for row in rows: @@ -111,33 +124,36 @@ class DocumentManager: bullet_points=row.bullet_points, is_exported=True, ) - document.mindmap = ( - document.mindmap.replace('""', '"') - .replace("''", "'") - .replace("''mynetwork''", "'mynetwork'") - ) - document.document_name = document.document_name.replace('""', '"').replace( - "''", "'" - ) - document.content = document.content.replace('""', '"').replace("''", "'") - document.summary = document.summary.replace('""', '"').replace("''", "'") - document.q_and_a = document.q_and_a.replace('""', '"').replace("''", "'") - document.bullet_points = document.bullet_points.replace('""', '"').replace( - "''", "'" - ) - documents.append(document) + doc_dict = document.model_dump() + for field in doc_dict: + doc_dict[field] = apply_string_correction(doc_dict[field]) + if field == "mindmap": + doc_dict[field] = doc_dict[field].replace( + "''mynetwork''", "'mynetwork'" + ) + documents.append(ManagedDocument.model_validate(doc_dict)) return documents + def get_names(self) -> List[str]: + if not self.table_exists: + self._create_table() + result = self._execute( + text( + f""" + SELECT * FROM {self.table_name} ORDER BY id; + """ + ) + ) + rows = result.fetchall() + return [row.document_name for row in rows] + def _execute( self, statement: Any, parameters: Optional[Any] = None, execution_options: Optional[Any] = None, ) -> Result: - if not self._connection: - self._connect() - self._connection = cast(Connection, self._connection) - return self._connection.execute( + return self.connection.execute( statement=statement, parameters=parameters, execution_options=execution_options, diff --git a/src/notebookllama/pages/1_Document_Management_UI.py b/src/notebookllama/pages/1_Document_Management_UI.py index 9879c4f..0382edf 100644 --- a/src/notebookllama/pages/1_Document_Management_UI.py +++ b/src/notebookllama/pages/1_Document_Management_UI.py @@ -2,7 +2,7 @@ import os import streamlit as st import streamlit.components.v1 as components from dotenv import load_dotenv -from typing import List +from typing import List, Optional from documents import DocumentManager, ManagedDocument @@ -14,9 +14,13 @@ engine_url = f"postgresql+psycopg2://{os.getenv('pgql_user')}:{os.getenv('pgql_p document_manager = DocumentManager(engine_url=engine_url) -def view_documents(limit: int) -> List[ManagedDocument]: +def fetch_documents(names: Optional[List[str]]) -> List[ManagedDocument]: """Retrieve documents from the database""" - return document_manager.export_documents(limit=limit) + return document_manager.get_documents(names=names) + + +def fetch_document_names() -> List[str]: + return document_manager.get_names() def display_document(document: ManagedDocument) -> None: @@ -57,15 +61,17 @@ def main(): st.markdown("## NotebookLlaMa - Document Management📚") # Slider for number of documents - limit = st.slider( - "Number of documents to display:", min_value=1, max_value=50, value=15, step=1 + names = st.multiselect( + options=fetch_document_names(), + default=None, + label="Select the Documents you want to display", ) # Button to load documents if st.button("Load Documents", type="primary"): with st.spinner("Loading documents..."): try: - documents = view_documents(limit) + documents = fetch_documents(names) if documents: st.success(f"Successfully loaded {len(documents)} document(s)") diff --git a/tests/test_document_management.py b/tests/test_document_management.py index 8226cc3..101e113 100644 --- a/tests/test_document_management.py +++ b/tests/test_document_management.py @@ -67,8 +67,10 @@ def test_document_manager(documents: List[ManagedDocument]) -> None: manager._execute(text("DROP TABLE IF EXISTS test_documents;")) manager._create_table() assert manager.table_exists - manager.import_documents(documents=documents) - docs = manager.export_documents() + manager.put_documents(documents=documents) + names = manager.get_names() + assert names == [doc.document_name for doc in documents] + docs = manager.get_documents() assert docs == documents - docs1 = manager.export_documents(limit=2) + docs1 = manager.get_documents(names=["Project Plan", "Meeting Notes"]) assert len(docs1) == 2 diff --git a/uv.lock b/uv.lock index a245ff5..6fa909d 100644 --- a/uv.lock +++ b/uv.lock @@ -1753,7 +1753,7 @@ wheels = [ [[package]] name = "notebookllama" -version = "0.3.1" +version = "0.4.0" source = { virtual = "." } dependencies = [ { name = "audioop-lts" },