From dabcdaae28cd3efda0e3292cef6db825600d0f21 Mon Sep 17 00:00:00 2001 From: Trayan Azarov Date: Tue, 28 Nov 2023 20:58:39 +0200 Subject: [PATCH] feat: Improvements to Chroma loader (#673) * feat: Added where and where_document filters. * feat: Several improvements - Added good amount of tests - Fixed an issue where no results were returned - Fixed an issue where embeddings are not returned by default and need `include` to return them - Added Chroma client dependency injection - Fixed iteration to work with multiple results - Fixed some typing issues * fix: Adding chromadb in test_requirements.txt * fix: Added test_requirements.txt to test workflow. * fix: Running the pip install from within poetry run --- .github/workflows/tests.yml | 1 + llama_hub/chroma/README.md | 7 +- llama_hub/chroma/base.py | 59 ++++++---- test_requirements.txt | 1 + tests/tests_chroma/__init__.py | 0 tests/tests_chroma/test_chroma.py | 172 ++++++++++++++++++++++++++++++ 6 files changed, 220 insertions(+), 20 deletions(-) create mode 100644 tests/tests_chroma/__init__.py create mode 100644 tests/tests_chroma/test_chroma.py diff --git a/.github/workflows/tests.yml b/.github/workflows/tests.yml index 28172eee..ed044233 100644 --- a/.github/workflows/tests.yml +++ b/.github/workflows/tests.yml @@ -28,5 +28,6 @@ jobs: - name: Install deps run: | poetry install + poetry run pip install -r test_requirements.txt - name: Run testing run: poetry run pytest tests \ No newline at end of file diff --git a/llama_hub/chroma/README.md b/llama_hub/chroma/README.md index b086cee7..02631987 100644 --- a/llama_hub/chroma/README.md +++ b/llama_hub/chroma/README.md @@ -21,7 +21,12 @@ reader = ChromaReader( query_vector=[n1, n2, n3, ...] -documents = reader.load_data(collection_name="demo", query_vector=query_vector, limit=5) +where={"metadata_field": "metadata_value"} +where_document={"$contains":"word"} + +documents = reader.load_data(collection_name="demo", query_vector=query_vector, limit=5, where=where, where_document=where_document) ``` This loader is designed to be used as a way to load data into [LlamaIndex](https://github.com/run-llama/llama_index/tree/main/llama_index) and/or subsequently used as a Tool in a [LangChain](https://github.com/hwchase17/langchain) Agent. See [here](https://github.com/emptycrown/llama-hub/tree/main) for examples. + +> **Note**: For more information on metadata and document filters `where` and `where_document` check official ChromaDB documentation [here](https://docs.trychroma.com/reference/Collection#query) and examples [here](https://github.com/chroma-core/chroma/blob/main/examples/basic_functionality/where_filtering.ipynb) diff --git a/llama_hub/chroma/base.py b/llama_hub/chroma/base.py index 10433943..9ac2ddb7 100644 --- a/llama_hub/chroma/base.py +++ b/llama_hub/chroma/base.py @@ -1,6 +1,6 @@ """Chroma Reader.""" -from typing import Any +from typing import Any, Dict, List, Optional, Union from llama_index.readers.base import BaseReader from llama_index.readers.schema.base import Document @@ -20,43 +20,64 @@ class ChromaReader(BaseReader): def __init__( self, collection_name: str, - persist_directory: str, + persist_directory: Optional[str] = None, + client: Optional[Any] = None, ) -> None: """Initialize with parameters.""" import chromadb # noqa: F401 from chromadb.config import Settings - if (collection_name is None) or (persist_directory is None): - raise ValueError("Please provide a collection name and persist directory.") - - self._client = chromadb.Client( - Settings(is_persistent=True, persist_directory=persist_directory) - ) + if (collection_name is None) or (persist_directory is None and client is None): + raise ValueError( + "Please provide a collection name and persist directory or Chroma client." + ) + if client is not None: + self._client = client + else: + self._client = chromadb.Client( + Settings(is_persistent=True, persist_directory=persist_directory) + ) self._collection = self._client.get_collection(collection_name) def load_data( self, - query_vector: Any, + query_vector: List[Union[List[float], List[int]]], limit: int = 10, + where: Optional[Dict[Any, Any]] = None, + where_document: Optional[Dict[Any, Any]] = None, ) -> Any: """Load data from Chroma. Args: query_vector (Any): Query limit (int): Number of results to return. + where (Dict): Metadata where filter. + where_document (Dict): Document where filter. Returns: List[Document]: A list of documents. """ - results = self._collection.query(query_embeddings=query_vector, n_results=limit) - - documents = [] - for result in zip(results["ids"], results["documents"], results["embeddings"]): - document = Document( - doc_id=result[0][0], - text=result[1][0], - embedding=result[2][0], - ) - documents.append(document) + results = self._collection.query( + query_embeddings=query_vector, + include=["documents", "metadatas", "embeddings"], # noqa: E501 + n_results=limit, + where=where, + where_document=where_document, + ) + print(results) + documents: List[Document] = [] + # early return if no results + if results is None or len(results["ids"][0]) == 0: + return documents + for i in range(len(results["ids"])): + for result in zip( + results["ids"][i], results["documents"][i], results["embeddings"][i] + ): + document = Document( + doc_id=result[0], + text=result[1], + embedding=result[2], + ) + documents.append(document) return documents diff --git a/test_requirements.txt b/test_requirements.txt index 215dfaa2..9903f054 100644 --- a/test_requirements.txt +++ b/test_requirements.txt @@ -13,6 +13,7 @@ llama-index>=0.6.9 atlassian-python-api html2text olefile +chromadb # hotfix psutil diff --git a/tests/tests_chroma/__init__.py b/tests/tests_chroma/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/tests/tests_chroma/test_chroma.py b/tests/tests_chroma/test_chroma.py new file mode 100644 index 00000000..4842003e --- /dev/null +++ b/tests/tests_chroma/test_chroma.py @@ -0,0 +1,172 @@ +import shutil +from typing import Any, Generator +import pytest +import tempfile +from llama_hub.chroma import ChromaReader + + +@pytest.fixture +def chroma_persist_dir() -> Generator[str, None, None]: + temp_dir = tempfile.mkdtemp() + yield temp_dir + shutil.rmtree(temp_dir, ignore_errors=True) + + +@pytest.fixture +def chroma_client(chroma_persist_dir: str) -> Generator[Any, None, None]: + import chromadb + from chromadb.config import Settings + + # The client settings must align with ChromaReader's settings otherwise + # an exception will be raised. + client = chromadb.Client( + Settings( + is_persistent=True, + persist_directory=chroma_persist_dir, + ) + ) + yield client + + +def test_chroma_with_client(chroma_client: Any) -> None: + test_collection = chroma_client.get_or_create_collection("test_collection") + test_collection.add(ids=["1"], documents=["test"], embeddings=[[1, 2, 3]]) + chroma = ChromaReader( + collection_name="test_collection", + client=chroma_client, + ) + assert chroma is not None + docs = chroma.load_data(query_vector=[[1, 2, 3]], limit=5) + assert len(docs) == 1 + + +def test_chroma_with_persist_dir(chroma_client: Any, chroma_persist_dir: str) -> None: + test_collection = chroma_client.get_or_create_collection("test_collection") + test_collection.add(ids=["1"], documents=["test"], embeddings=[[1, 2, 3]]) + chroma = ChromaReader( + collection_name="test_collection", persist_directory=chroma_persist_dir + ) + + assert chroma is not None + docs = chroma.load_data(query_vector=[[1, 2, 3]], limit=5) + assert len(docs) == 1 + + +def test_chroma_with_where_filter(chroma_client: Any) -> None: + test_collection = chroma_client.get_or_create_collection("test_collection") + test_collection.add( + ids=["1"], + documents=["test"], + embeddings=[[1, 2, 3]], + metadatas=[{"test": "test"}], + ) + chroma = ChromaReader( + collection_name="test_collection", + client=chroma_client, + ) + assert chroma is not None + docs = chroma.load_data(query_vector=[[1, 2, 3]], limit=5, where={"test": "test"}) + assert len(docs) == 1 + + +def test_chroma_with_where_filter_no_match(chroma_client: Any) -> None: + test_collection = chroma_client.get_or_create_collection("test_collection") + test_collection.add( + ids=["1"], + documents=["test"], + embeddings=[[1, 2, 3]], + metadatas=[{"test": "test"}], + ) + chroma = ChromaReader( + collection_name="test_collection", + client=chroma_client, + ) + assert chroma is not None + docs = chroma.load_data(query_vector=[[1, 2, 3]], where={"test": "test1"}) + assert len(docs) == 0 + + +def test_chroma_with_where_document_filter(chroma_client: Any) -> None: + test_collection = chroma_client.get_or_create_collection("test_collection") + test_collection.add( + ids=["1"], + documents=["this is my test document"], + embeddings=[[1, 2, 3]], + metadatas=[{"test": "test"}], + ) + chroma = ChromaReader( + collection_name="test_collection", + client=chroma_client, + ) + assert chroma is not None + docs = chroma.load_data( + query_vector=[[1, 2, 3]], limit=5, where_document={"$contains": "test"} + ) + assert len(docs) == 1 + + +def test_chroma_with_where_document_filter_no_match(chroma_client: Any) -> None: + test_collection = chroma_client.get_or_create_collection("test_collection") + test_collection.add( + ids=["1"], + documents=["this is my test document"], + embeddings=[[1, 2, 3]], + metadatas=[{"test": "test"}], + ) + chroma = ChromaReader( + collection_name="test_collection", + client=chroma_client, + ) + assert chroma is not None + docs = chroma.load_data( + query_vector=[[1, 2, 3]], limit=5, where_document={"$contains": "test1"} + ) + assert len(docs) == 0 + + +def test_chroma_with_multiple_docs(chroma_client: Any) -> None: + test_collection = chroma_client.get_or_create_collection("test_collection") + test_collection.add( + ids=["1", "2"], + documents=["test", "another test doc"], + embeddings=[[1, 2, 3], [1, 2, 3]], + ) + chroma = ChromaReader( + collection_name="test_collection", + client=chroma_client, + ) + assert chroma is not None + docs = chroma.load_data(query_vector=[[1, 2, 3]], limit=5) + assert len(docs) == 2 + + +def test_chroma_with_multiple_docs_multiple_queries(chroma_client: Any) -> None: + test_collection = chroma_client.get_or_create_collection("test_collection") + test_collection.add( + ids=["1", "2"], + documents=["test", "another test doc"], + embeddings=[[1, 2, 3], [3, 2, 1]], + ) + chroma = ChromaReader( + collection_name="test_collection", + client=chroma_client, + ) + assert chroma is not None + docs = chroma.load_data(query_vector=[[1, 2, 3], [3, 2, 1]], limit=5) + assert len(docs) == 4 # there are duplicates in this result + + +def test_chroma_with_multiple_docs_with_limit(chroma_client: Any) -> None: + test_collection = chroma_client.get_or_create_collection("test_collection") + test_collection.add( + ids=["1", "2"], + documents=["test", "another test doc"], + embeddings=[[1, 2, 3], [3, 2, 1]], + ) + chroma = ChromaReader( + collection_name="test_collection", + client=chroma_client, + ) + assert chroma is not None + docs = chroma.load_data(query_vector=[[1, 2, 3]], limit=1) + assert len(docs) == 1