mirror of
https://github.com/run-llama/llama-hub.git
synced 2026-06-30 20:47:58 -04:00
feat: Improvements to Chroma loader (#673)
* feat: Added where and where_document filters. * feat: Several improvements - Added good amount of tests - Fixed an issue where no results were returned - Fixed an issue where embeddings are not returned by default and need `include` to return them - Added Chroma client dependency injection - Fixed iteration to work with multiple results - Fixed some typing issues * fix: Adding chromadb in test_requirements.txt * fix: Added test_requirements.txt to test workflow. * fix: Running the pip install from within poetry run
This commit is contained in:
@@ -28,5 +28,6 @@ jobs:
|
||||
- name: Install deps
|
||||
run: |
|
||||
poetry install
|
||||
poetry run pip install -r test_requirements.txt
|
||||
- name: Run testing
|
||||
run: poetry run pytest tests
|
||||
@@ -21,7 +21,12 @@ reader = ChromaReader(
|
||||
|
||||
query_vector=[n1, n2, n3, ...]
|
||||
|
||||
documents = reader.load_data(collection_name="demo", query_vector=query_vector, limit=5)
|
||||
where={"metadata_field": "metadata_value"}
|
||||
where_document={"$contains":"word"}
|
||||
|
||||
documents = reader.load_data(collection_name="demo", query_vector=query_vector, limit=5, where=where, where_document=where_document)
|
||||
```
|
||||
|
||||
This loader is designed to be used as a way to load data into [LlamaIndex](https://github.com/run-llama/llama_index/tree/main/llama_index) and/or subsequently used as a Tool in a [LangChain](https://github.com/hwchase17/langchain) Agent. See [here](https://github.com/emptycrown/llama-hub/tree/main) for examples.
|
||||
|
||||
> **Note**: For more information on metadata and document filters `where` and `where_document` check official ChromaDB documentation [here](https://docs.trychroma.com/reference/Collection#query) and examples [here](https://github.com/chroma-core/chroma/blob/main/examples/basic_functionality/where_filtering.ipynb)
|
||||
|
||||
+40
-19
@@ -1,6 +1,6 @@
|
||||
"""Chroma Reader."""
|
||||
|
||||
from typing import Any
|
||||
from typing import Any, Dict, List, Optional, Union
|
||||
|
||||
from llama_index.readers.base import BaseReader
|
||||
from llama_index.readers.schema.base import Document
|
||||
@@ -20,43 +20,64 @@ class ChromaReader(BaseReader):
|
||||
def __init__(
|
||||
self,
|
||||
collection_name: str,
|
||||
persist_directory: str,
|
||||
persist_directory: Optional[str] = None,
|
||||
client: Optional[Any] = None,
|
||||
) -> None:
|
||||
"""Initialize with parameters."""
|
||||
import chromadb # noqa: F401
|
||||
from chromadb.config import Settings
|
||||
|
||||
if (collection_name is None) or (persist_directory is None):
|
||||
raise ValueError("Please provide a collection name and persist directory.")
|
||||
|
||||
self._client = chromadb.Client(
|
||||
Settings(is_persistent=True, persist_directory=persist_directory)
|
||||
)
|
||||
if (collection_name is None) or (persist_directory is None and client is None):
|
||||
raise ValueError(
|
||||
"Please provide a collection name and persist directory or Chroma client."
|
||||
)
|
||||
if client is not None:
|
||||
self._client = client
|
||||
else:
|
||||
self._client = chromadb.Client(
|
||||
Settings(is_persistent=True, persist_directory=persist_directory)
|
||||
)
|
||||
self._collection = self._client.get_collection(collection_name)
|
||||
|
||||
def load_data(
|
||||
self,
|
||||
query_vector: Any,
|
||||
query_vector: List[Union[List[float], List[int]]],
|
||||
limit: int = 10,
|
||||
where: Optional[Dict[Any, Any]] = None,
|
||||
where_document: Optional[Dict[Any, Any]] = None,
|
||||
) -> Any:
|
||||
"""Load data from Chroma.
|
||||
|
||||
Args:
|
||||
query_vector (Any): Query
|
||||
limit (int): Number of results to return.
|
||||
where (Dict): Metadata where filter.
|
||||
where_document (Dict): Document where filter.
|
||||
|
||||
Returns:
|
||||
List[Document]: A list of documents.
|
||||
"""
|
||||
results = self._collection.query(query_embeddings=query_vector, n_results=limit)
|
||||
|
||||
documents = []
|
||||
for result in zip(results["ids"], results["documents"], results["embeddings"]):
|
||||
document = Document(
|
||||
doc_id=result[0][0],
|
||||
text=result[1][0],
|
||||
embedding=result[2][0],
|
||||
)
|
||||
documents.append(document)
|
||||
results = self._collection.query(
|
||||
query_embeddings=query_vector,
|
||||
include=["documents", "metadatas", "embeddings"], # noqa: E501
|
||||
n_results=limit,
|
||||
where=where,
|
||||
where_document=where_document,
|
||||
)
|
||||
print(results)
|
||||
documents: List[Document] = []
|
||||
# early return if no results
|
||||
if results is None or len(results["ids"][0]) == 0:
|
||||
return documents
|
||||
for i in range(len(results["ids"])):
|
||||
for result in zip(
|
||||
results["ids"][i], results["documents"][i], results["embeddings"][i]
|
||||
):
|
||||
document = Document(
|
||||
doc_id=result[0],
|
||||
text=result[1],
|
||||
embedding=result[2],
|
||||
)
|
||||
documents.append(document)
|
||||
|
||||
return documents
|
||||
|
||||
@@ -13,6 +13,7 @@ llama-index>=0.6.9
|
||||
atlassian-python-api
|
||||
html2text
|
||||
olefile
|
||||
chromadb
|
||||
|
||||
# hotfix
|
||||
psutil
|
||||
|
||||
@@ -0,0 +1,172 @@
|
||||
import shutil
|
||||
from typing import Any, Generator
|
||||
import pytest
|
||||
import tempfile
|
||||
from llama_hub.chroma import ChromaReader
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def chroma_persist_dir() -> Generator[str, None, None]:
|
||||
temp_dir = tempfile.mkdtemp()
|
||||
yield temp_dir
|
||||
shutil.rmtree(temp_dir, ignore_errors=True)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def chroma_client(chroma_persist_dir: str) -> Generator[Any, None, None]:
|
||||
import chromadb
|
||||
from chromadb.config import Settings
|
||||
|
||||
# The client settings must align with ChromaReader's settings otherwise
|
||||
# an exception will be raised.
|
||||
client = chromadb.Client(
|
||||
Settings(
|
||||
is_persistent=True,
|
||||
persist_directory=chroma_persist_dir,
|
||||
)
|
||||
)
|
||||
yield client
|
||||
|
||||
|
||||
def test_chroma_with_client(chroma_client: Any) -> None:
|
||||
test_collection = chroma_client.get_or_create_collection("test_collection")
|
||||
test_collection.add(ids=["1"], documents=["test"], embeddings=[[1, 2, 3]])
|
||||
chroma = ChromaReader(
|
||||
collection_name="test_collection",
|
||||
client=chroma_client,
|
||||
)
|
||||
assert chroma is not None
|
||||
docs = chroma.load_data(query_vector=[[1, 2, 3]], limit=5)
|
||||
assert len(docs) == 1
|
||||
|
||||
|
||||
def test_chroma_with_persist_dir(chroma_client: Any, chroma_persist_dir: str) -> None:
|
||||
test_collection = chroma_client.get_or_create_collection("test_collection")
|
||||
test_collection.add(ids=["1"], documents=["test"], embeddings=[[1, 2, 3]])
|
||||
chroma = ChromaReader(
|
||||
collection_name="test_collection", persist_directory=chroma_persist_dir
|
||||
)
|
||||
|
||||
assert chroma is not None
|
||||
docs = chroma.load_data(query_vector=[[1, 2, 3]], limit=5)
|
||||
assert len(docs) == 1
|
||||
|
||||
|
||||
def test_chroma_with_where_filter(chroma_client: Any) -> None:
|
||||
test_collection = chroma_client.get_or_create_collection("test_collection")
|
||||
test_collection.add(
|
||||
ids=["1"],
|
||||
documents=["test"],
|
||||
embeddings=[[1, 2, 3]],
|
||||
metadatas=[{"test": "test"}],
|
||||
)
|
||||
chroma = ChromaReader(
|
||||
collection_name="test_collection",
|
||||
client=chroma_client,
|
||||
)
|
||||
assert chroma is not None
|
||||
docs = chroma.load_data(query_vector=[[1, 2, 3]], limit=5, where={"test": "test"})
|
||||
assert len(docs) == 1
|
||||
|
||||
|
||||
def test_chroma_with_where_filter_no_match(chroma_client: Any) -> None:
|
||||
test_collection = chroma_client.get_or_create_collection("test_collection")
|
||||
test_collection.add(
|
||||
ids=["1"],
|
||||
documents=["test"],
|
||||
embeddings=[[1, 2, 3]],
|
||||
metadatas=[{"test": "test"}],
|
||||
)
|
||||
chroma = ChromaReader(
|
||||
collection_name="test_collection",
|
||||
client=chroma_client,
|
||||
)
|
||||
assert chroma is not None
|
||||
docs = chroma.load_data(query_vector=[[1, 2, 3]], where={"test": "test1"})
|
||||
assert len(docs) == 0
|
||||
|
||||
|
||||
def test_chroma_with_where_document_filter(chroma_client: Any) -> None:
|
||||
test_collection = chroma_client.get_or_create_collection("test_collection")
|
||||
test_collection.add(
|
||||
ids=["1"],
|
||||
documents=["this is my test document"],
|
||||
embeddings=[[1, 2, 3]],
|
||||
metadatas=[{"test": "test"}],
|
||||
)
|
||||
chroma = ChromaReader(
|
||||
collection_name="test_collection",
|
||||
client=chroma_client,
|
||||
)
|
||||
assert chroma is not None
|
||||
docs = chroma.load_data(
|
||||
query_vector=[[1, 2, 3]], limit=5, where_document={"$contains": "test"}
|
||||
)
|
||||
assert len(docs) == 1
|
||||
|
||||
|
||||
def test_chroma_with_where_document_filter_no_match(chroma_client: Any) -> None:
|
||||
test_collection = chroma_client.get_or_create_collection("test_collection")
|
||||
test_collection.add(
|
||||
ids=["1"],
|
||||
documents=["this is my test document"],
|
||||
embeddings=[[1, 2, 3]],
|
||||
metadatas=[{"test": "test"}],
|
||||
)
|
||||
chroma = ChromaReader(
|
||||
collection_name="test_collection",
|
||||
client=chroma_client,
|
||||
)
|
||||
assert chroma is not None
|
||||
docs = chroma.load_data(
|
||||
query_vector=[[1, 2, 3]], limit=5, where_document={"$contains": "test1"}
|
||||
)
|
||||
assert len(docs) == 0
|
||||
|
||||
|
||||
def test_chroma_with_multiple_docs(chroma_client: Any) -> None:
|
||||
test_collection = chroma_client.get_or_create_collection("test_collection")
|
||||
test_collection.add(
|
||||
ids=["1", "2"],
|
||||
documents=["test", "another test doc"],
|
||||
embeddings=[[1, 2, 3], [1, 2, 3]],
|
||||
)
|
||||
chroma = ChromaReader(
|
||||
collection_name="test_collection",
|
||||
client=chroma_client,
|
||||
)
|
||||
assert chroma is not None
|
||||
docs = chroma.load_data(query_vector=[[1, 2, 3]], limit=5)
|
||||
assert len(docs) == 2
|
||||
|
||||
|
||||
def test_chroma_with_multiple_docs_multiple_queries(chroma_client: Any) -> None:
|
||||
test_collection = chroma_client.get_or_create_collection("test_collection")
|
||||
test_collection.add(
|
||||
ids=["1", "2"],
|
||||
documents=["test", "another test doc"],
|
||||
embeddings=[[1, 2, 3], [3, 2, 1]],
|
||||
)
|
||||
chroma = ChromaReader(
|
||||
collection_name="test_collection",
|
||||
client=chroma_client,
|
||||
)
|
||||
assert chroma is not None
|
||||
docs = chroma.load_data(query_vector=[[1, 2, 3], [3, 2, 1]], limit=5)
|
||||
assert len(docs) == 4 # there are duplicates in this result
|
||||
|
||||
|
||||
def test_chroma_with_multiple_docs_with_limit(chroma_client: Any) -> None:
|
||||
test_collection = chroma_client.get_or_create_collection("test_collection")
|
||||
test_collection.add(
|
||||
ids=["1", "2"],
|
||||
documents=["test", "another test doc"],
|
||||
embeddings=[[1, 2, 3], [3, 2, 1]],
|
||||
)
|
||||
chroma = ChromaReader(
|
||||
collection_name="test_collection",
|
||||
client=chroma_client,
|
||||
)
|
||||
assert chroma is not None
|
||||
docs = chroma.load_data(query_vector=[[1, 2, 3]], limit=1)
|
||||
assert len(docs) == 1
|
||||
Reference in New Issue
Block a user