chore: implementing suggestions

This commit is contained in:
Clelia (Astra) Bertelli
2025-07-14 21:01:20 +02:00
parent d1f3944b83
commit a40a08d64e
5 changed files with 75 additions and 53 deletions
+1 -3
View File
@@ -78,7 +78,7 @@ async def run_workflow(
end_time = int(time.time() * 1000000)
sql_engine.to_sql_database(start_time=st_time, end_time=end_time)
document_manager.import_documents(
document_manager.put_documents(
[
ManagedDocument(
document_name=document_title,
@@ -161,8 +161,6 @@ file_input = st.file_uploader(
)
# Add this after your existing code, before the st.title line:
# Initialize session state
if "workflow_results" not in st.session_state:
st.session_state.workflow_results = None
+56 -40
View File
@@ -4,6 +4,10 @@ from typing_extensions import Self
from typing import Optional, Any, List, cast
def apply_string_correction(string: str) -> str:
return string.replace("''", "'").replace('""', '"')
class ManagedDocument(BaseModel):
document_name: str
content: str
@@ -11,7 +15,7 @@ class ManagedDocument(BaseModel):
q_and_a: str
mindmap: str
bullet_points: str
is_exported: bool = Field(default=False)
is_exported: bool = Field(default=False, exclude=True)
@model_validator(mode="after")
def validate_input_for_sql(self) -> Self:
@@ -42,14 +46,17 @@ class DocumentManager:
else:
raise ValueError("One of engine or engine_setup_kwargs must be set")
@property
def connection(self) -> Connection:
if not self._connection:
self._connect()
return cast(Connection, self._connection)
def _connect(self) -> None:
self._connection = self._engine.connect()
def _create_table(self) -> None:
if not self._connection:
self._connect()
self._connection = cast(Connection, self._connection)
self._connection.execute(
self.connection.execute(
text(f"""
CREATE TABLE IF NOT EXISTS {self.table_name} (
id SERIAL PRIMARY KEY,
@@ -62,17 +69,14 @@ class DocumentManager:
);
""")
)
self._connection.commit()
self.connection.commit()
self.table_exists = True
def import_documents(self, documents: List[ManagedDocument]) -> None:
if not self._connection:
self._connect()
self._connection = cast(Connection, self._connection)
def put_documents(self, documents: List[ManagedDocument]) -> None:
if not self.table_exists:
self._create_table()
for document in documents:
self._connection.execute(
self.connection.execute(
text(
f"""
INSERT INTO {self.table_name} (document_name, content, summary, q_and_a, mindmap, bullet_points)
@@ -87,18 +91,27 @@ class DocumentManager:
"""
)
)
self._connection.commit()
self.connection.commit()
def export_documents(self, limit: Optional[int] = None) -> List[ManagedDocument]:
if not limit:
limit = 15
result = self._execute(
text(
f"""
SELECT * FROM {self.table_name} ORDER BY id LIMIT {limit};
"""
def get_documents(self, names: Optional[List[str]] = None) -> List[ManagedDocument]:
if not self.table_exists:
self._create_table()
if not names:
result = self._execute(
text(
f"""
SELECT * FROM {self.table_name} ORDER BY id;
"""
)
)
else:
result = self._execute(
text(
f"""
SELECT * FROM {self.table_name} WHERE document_name = ANY(ARRAY{names}) ORDER BY id;
"""
)
)
)
rows = result.fetchall()
documents = []
for row in rows:
@@ -111,33 +124,36 @@ class DocumentManager:
bullet_points=row.bullet_points,
is_exported=True,
)
document.mindmap = (
document.mindmap.replace('""', '"')
.replace("''", "'")
.replace("''mynetwork''", "'mynetwork'")
)
document.document_name = document.document_name.replace('""', '"').replace(
"''", "'"
)
document.content = document.content.replace('""', '"').replace("''", "'")
document.summary = document.summary.replace('""', '"').replace("''", "'")
document.q_and_a = document.q_and_a.replace('""', '"').replace("''", "'")
document.bullet_points = document.bullet_points.replace('""', '"').replace(
"''", "'"
)
documents.append(document)
doc_dict = document.model_dump()
for field in doc_dict:
doc_dict[field] = apply_string_correction(doc_dict[field])
if field == "mindmap":
doc_dict[field] = doc_dict[field].replace(
"''mynetwork''", "'mynetwork'"
)
documents.append(ManagedDocument.model_validate(doc_dict))
return documents
def get_names(self) -> List[str]:
if not self.table_exists:
self._create_table()
result = self._execute(
text(
f"""
SELECT * FROM {self.table_name} ORDER BY id;
"""
)
)
rows = result.fetchall()
return [row.document_name for row in rows]
def _execute(
self,
statement: Any,
parameters: Optional[Any] = None,
execution_options: Optional[Any] = None,
) -> Result:
if not self._connection:
self._connect()
self._connection = cast(Connection, self._connection)
return self._connection.execute(
return self.connection.execute(
statement=statement,
parameters=parameters,
execution_options=execution_options,
@@ -2,7 +2,7 @@ import os
import streamlit as st
import streamlit.components.v1 as components
from dotenv import load_dotenv
from typing import List
from typing import List, Optional
from documents import DocumentManager, ManagedDocument
@@ -14,9 +14,13 @@ engine_url = f"postgresql+psycopg2://{os.getenv('pgql_user')}:{os.getenv('pgql_p
document_manager = DocumentManager(engine_url=engine_url)
def view_documents(limit: int) -> List[ManagedDocument]:
def fetch_documents(names: Optional[List[str]]) -> List[ManagedDocument]:
"""Retrieve documents from the database"""
return document_manager.export_documents(limit=limit)
return document_manager.get_documents(names=names)
def fetch_document_names() -> List[str]:
return document_manager.get_names()
def display_document(document: ManagedDocument) -> None:
@@ -57,15 +61,17 @@ def main():
st.markdown("## NotebookLlaMa - Document Management📚")
# Slider for number of documents
limit = st.slider(
"Number of documents to display:", min_value=1, max_value=50, value=15, step=1
names = st.multiselect(
options=fetch_document_names(),
default=None,
label="Select the Documents you want to display",
)
# Button to load documents
if st.button("Load Documents", type="primary"):
with st.spinner("Loading documents..."):
try:
documents = view_documents(limit)
documents = fetch_documents(names)
if documents:
st.success(f"Successfully loaded {len(documents)} document(s)")
+5 -3
View File
@@ -67,8 +67,10 @@ def test_document_manager(documents: List[ManagedDocument]) -> None:
manager._execute(text("DROP TABLE IF EXISTS test_documents;"))
manager._create_table()
assert manager.table_exists
manager.import_documents(documents=documents)
docs = manager.export_documents()
manager.put_documents(documents=documents)
names = manager.get_names()
assert names == [doc.document_name for doc in documents]
docs = manager.get_documents()
assert docs == documents
docs1 = manager.export_documents(limit=2)
docs1 = manager.get_documents(names=["Project Plan", "Meeting Notes"])
assert len(docs1) == 2
Generated
+1 -1
View File
@@ -1753,7 +1753,7 @@ wheels = [
[[package]]
name = "notebookllama"
version = "0.3.1"
version = "0.4.0"
source = { virtual = "." }
dependencies = [
{ name = "audioop-lts" },