feat: Improvements to Chroma loader (#673)

* feat: Added where and where_document filters.

* feat: Several improvements

- Added good amount of tests
- Fixed an issue where no results were returned
- Fixed an issue where embeddings are not returned by default and need `include` to return them
- Added Chroma client dependency injection
- Fixed iteration to work with multiple results
- Fixed some typing issues

* fix: Adding chromadb in test_requirements.txt

* fix: Added test_requirements.txt to test workflow.

* fix: Running the pip install from within poetry run
This commit is contained in:
Trayan Azarov
2023-11-28 20:58:39 +02:00
committed by GitHub
parent 2a556ad4cb
commit dabcdaae28
6 changed files with 220 additions and 20 deletions
+1
View File
@@ -28,5 +28,6 @@ jobs:
- name: Install deps
run: |
poetry install
poetry run pip install -r test_requirements.txt
- name: Run testing
run: poetry run pytest tests
+6 -1
View File
@@ -21,7 +21,12 @@ reader = ChromaReader(
query_vector=[n1, n2, n3, ...]
documents = reader.load_data(collection_name="demo", query_vector=query_vector, limit=5)
where={"metadata_field": "metadata_value"}
where_document={"$contains":"word"}
documents = reader.load_data(collection_name="demo", query_vector=query_vector, limit=5, where=where, where_document=where_document)
```
This loader is designed to be used as a way to load data into [LlamaIndex](https://github.com/run-llama/llama_index/tree/main/llama_index) and/or subsequently used as a Tool in a [LangChain](https://github.com/hwchase17/langchain) Agent. See [here](https://github.com/emptycrown/llama-hub/tree/main) for examples.
> **Note**: For more information on metadata and document filters `where` and `where_document` check official ChromaDB documentation [here](https://docs.trychroma.com/reference/Collection#query) and examples [here](https://github.com/chroma-core/chroma/blob/main/examples/basic_functionality/where_filtering.ipynb)
+40 -19
View File
@@ -1,6 +1,6 @@
"""Chroma Reader."""
from typing import Any
from typing import Any, Dict, List, Optional, Union
from llama_index.readers.base import BaseReader
from llama_index.readers.schema.base import Document
@@ -20,43 +20,64 @@ class ChromaReader(BaseReader):
def __init__(
self,
collection_name: str,
persist_directory: str,
persist_directory: Optional[str] = None,
client: Optional[Any] = None,
) -> None:
"""Initialize with parameters."""
import chromadb # noqa: F401
from chromadb.config import Settings
if (collection_name is None) or (persist_directory is None):
raise ValueError("Please provide a collection name and persist directory.")
self._client = chromadb.Client(
Settings(is_persistent=True, persist_directory=persist_directory)
)
if (collection_name is None) or (persist_directory is None and client is None):
raise ValueError(
"Please provide a collection name and persist directory or Chroma client."
)
if client is not None:
self._client = client
else:
self._client = chromadb.Client(
Settings(is_persistent=True, persist_directory=persist_directory)
)
self._collection = self._client.get_collection(collection_name)
def load_data(
self,
query_vector: Any,
query_vector: List[Union[List[float], List[int]]],
limit: int = 10,
where: Optional[Dict[Any, Any]] = None,
where_document: Optional[Dict[Any, Any]] = None,
) -> Any:
"""Load data from Chroma.
Args:
query_vector (Any): Query
limit (int): Number of results to return.
where (Dict): Metadata where filter.
where_document (Dict): Document where filter.
Returns:
List[Document]: A list of documents.
"""
results = self._collection.query(query_embeddings=query_vector, n_results=limit)
documents = []
for result in zip(results["ids"], results["documents"], results["embeddings"]):
document = Document(
doc_id=result[0][0],
text=result[1][0],
embedding=result[2][0],
)
documents.append(document)
results = self._collection.query(
query_embeddings=query_vector,
include=["documents", "metadatas", "embeddings"], # noqa: E501
n_results=limit,
where=where,
where_document=where_document,
)
print(results)
documents: List[Document] = []
# early return if no results
if results is None or len(results["ids"][0]) == 0:
return documents
for i in range(len(results["ids"])):
for result in zip(
results["ids"][i], results["documents"][i], results["embeddings"][i]
):
document = Document(
doc_id=result[0],
text=result[1],
embedding=result[2],
)
documents.append(document)
return documents
+1
View File
@@ -13,6 +13,7 @@ llama-index>=0.6.9
atlassian-python-api
html2text
olefile
chromadb
# hotfix
psutil
View File
+172
View File
@@ -0,0 +1,172 @@
import shutil
from typing import Any, Generator
import pytest
import tempfile
from llama_hub.chroma import ChromaReader
@pytest.fixture
def chroma_persist_dir() -> Generator[str, None, None]:
temp_dir = tempfile.mkdtemp()
yield temp_dir
shutil.rmtree(temp_dir, ignore_errors=True)
@pytest.fixture
def chroma_client(chroma_persist_dir: str) -> Generator[Any, None, None]:
import chromadb
from chromadb.config import Settings
# The client settings must align with ChromaReader's settings otherwise
# an exception will be raised.
client = chromadb.Client(
Settings(
is_persistent=True,
persist_directory=chroma_persist_dir,
)
)
yield client
def test_chroma_with_client(chroma_client: Any) -> None:
test_collection = chroma_client.get_or_create_collection("test_collection")
test_collection.add(ids=["1"], documents=["test"], embeddings=[[1, 2, 3]])
chroma = ChromaReader(
collection_name="test_collection",
client=chroma_client,
)
assert chroma is not None
docs = chroma.load_data(query_vector=[[1, 2, 3]], limit=5)
assert len(docs) == 1
def test_chroma_with_persist_dir(chroma_client: Any, chroma_persist_dir: str) -> None:
test_collection = chroma_client.get_or_create_collection("test_collection")
test_collection.add(ids=["1"], documents=["test"], embeddings=[[1, 2, 3]])
chroma = ChromaReader(
collection_name="test_collection", persist_directory=chroma_persist_dir
)
assert chroma is not None
docs = chroma.load_data(query_vector=[[1, 2, 3]], limit=5)
assert len(docs) == 1
def test_chroma_with_where_filter(chroma_client: Any) -> None:
test_collection = chroma_client.get_or_create_collection("test_collection")
test_collection.add(
ids=["1"],
documents=["test"],
embeddings=[[1, 2, 3]],
metadatas=[{"test": "test"}],
)
chroma = ChromaReader(
collection_name="test_collection",
client=chroma_client,
)
assert chroma is not None
docs = chroma.load_data(query_vector=[[1, 2, 3]], limit=5, where={"test": "test"})
assert len(docs) == 1
def test_chroma_with_where_filter_no_match(chroma_client: Any) -> None:
test_collection = chroma_client.get_or_create_collection("test_collection")
test_collection.add(
ids=["1"],
documents=["test"],
embeddings=[[1, 2, 3]],
metadatas=[{"test": "test"}],
)
chroma = ChromaReader(
collection_name="test_collection",
client=chroma_client,
)
assert chroma is not None
docs = chroma.load_data(query_vector=[[1, 2, 3]], where={"test": "test1"})
assert len(docs) == 0
def test_chroma_with_where_document_filter(chroma_client: Any) -> None:
test_collection = chroma_client.get_or_create_collection("test_collection")
test_collection.add(
ids=["1"],
documents=["this is my test document"],
embeddings=[[1, 2, 3]],
metadatas=[{"test": "test"}],
)
chroma = ChromaReader(
collection_name="test_collection",
client=chroma_client,
)
assert chroma is not None
docs = chroma.load_data(
query_vector=[[1, 2, 3]], limit=5, where_document={"$contains": "test"}
)
assert len(docs) == 1
def test_chroma_with_where_document_filter_no_match(chroma_client: Any) -> None:
test_collection = chroma_client.get_or_create_collection("test_collection")
test_collection.add(
ids=["1"],
documents=["this is my test document"],
embeddings=[[1, 2, 3]],
metadatas=[{"test": "test"}],
)
chroma = ChromaReader(
collection_name="test_collection",
client=chroma_client,
)
assert chroma is not None
docs = chroma.load_data(
query_vector=[[1, 2, 3]], limit=5, where_document={"$contains": "test1"}
)
assert len(docs) == 0
def test_chroma_with_multiple_docs(chroma_client: Any) -> None:
test_collection = chroma_client.get_or_create_collection("test_collection")
test_collection.add(
ids=["1", "2"],
documents=["test", "another test doc"],
embeddings=[[1, 2, 3], [1, 2, 3]],
)
chroma = ChromaReader(
collection_name="test_collection",
client=chroma_client,
)
assert chroma is not None
docs = chroma.load_data(query_vector=[[1, 2, 3]], limit=5)
assert len(docs) == 2
def test_chroma_with_multiple_docs_multiple_queries(chroma_client: Any) -> None:
test_collection = chroma_client.get_or_create_collection("test_collection")
test_collection.add(
ids=["1", "2"],
documents=["test", "another test doc"],
embeddings=[[1, 2, 3], [3, 2, 1]],
)
chroma = ChromaReader(
collection_name="test_collection",
client=chroma_client,
)
assert chroma is not None
docs = chroma.load_data(query_vector=[[1, 2, 3], [3, 2, 1]], limit=5)
assert len(docs) == 4 # there are duplicates in this result
def test_chroma_with_multiple_docs_with_limit(chroma_client: Any) -> None:
test_collection = chroma_client.get_or_create_collection("test_collection")
test_collection.add(
ids=["1", "2"],
documents=["test", "another test doc"],
embeddings=[[1, 2, 3], [3, 2, 1]],
)
chroma = ChromaReader(
collection_name="test_collection",
client=chroma_client,
)
assert chroma is not None
docs = chroma.load_data(query_vector=[[1, 2, 3]], limit=1)
assert len(docs) == 1