Initial draft

This commit is contained in:
Eugene Yurtsev
2024-02-29 14:27:26 -05:00
commit f5acd6fbe2
56 changed files with 6859 additions and 0 deletions
+91
View File
@@ -0,0 +1,91 @@
# An action for setting up poetry install with caching.
# Using a custom action since the default action does not
# take poetry install groups into account.
# Action code from:
# https://github.com/actions/setup-python/issues/505#issuecomment-1273013236
name: poetry-install-with-caching
description: Poetry install with support for caching of dependency groups.
inputs:
python-version:
description: Python version, supporting MAJOR.MINOR only
required: true
poetry-version:
description: Poetry version
required: true
cache-key:
description: Cache key to use for manual handling of caching
required: true
working-directory:
description: Directory whose poetry.lock file should be cached
required: true
runs:
using: composite
steps:
- uses: actions/setup-python@v4
name: Setup python ${{ inputs.python-version }}
with:
python-version: ${{ inputs.python-version }}
- uses: actions/cache@v3
id: cache-bin-poetry
name: Cache Poetry binary - Python ${{ inputs.python-version }}
env:
SEGMENT_DOWNLOAD_TIMEOUT_MIN: "1"
with:
path: |
/opt/pipx/venvs/poetry
# This step caches the poetry installation, so make sure it's keyed on the poetry version as well.
key: bin-poetry-${{ runner.os }}-${{ runner.arch }}-py-${{ inputs.python-version }}-${{ inputs.poetry-version }}
- name: Refresh shell hashtable and fixup softlinks
if: steps.cache-bin-poetry.outputs.cache-hit == 'true'
shell: bash
env:
POETRY_VERSION: ${{ inputs.poetry-version }}
PYTHON_VERSION: ${{ inputs.python-version }}
run: |
set -eux
# Refresh the shell hashtable, to ensure correct `which` output.
hash -r
# `actions/cache@v3` doesn't always seem able to correctly unpack softlinks.
# Delete and recreate the softlinks pipx expects to have.
rm /opt/pipx/venvs/poetry/bin/python
cd /opt/pipx/venvs/poetry/bin
ln -s "$(which "python$PYTHON_VERSION")" python
chmod +x python
cd /opt/pipx_bin/
ln -s /opt/pipx/venvs/poetry/bin/poetry poetry
chmod +x poetry
# Ensure everything got set up correctly.
/opt/pipx/venvs/poetry/bin/python --version
/opt/pipx_bin/poetry --version
- name: Install poetry
if: steps.cache-bin-poetry.outputs.cache-hit != 'true'
shell: bash
env:
POETRY_VERSION: ${{ inputs.poetry-version }}
PYTHON_VERSION: ${{ inputs.python-version }}
run: pipx install "poetry==$POETRY_VERSION" --python "python$PYTHON_VERSION" --verbose
- name: Restore pip and poetry cached dependencies
uses: actions/cache@v3
env:
SEGMENT_DOWNLOAD_TIMEOUT_MIN: "4"
WORKDIR: ${{ inputs.working-directory == '' && '.' || inputs.working-directory }}
with:
path: |
~/.cache/pip
~/.cache/pypoetry/virtualenvs
~/.cache/pypoetry/cache
~/.cache/pypoetry/artifacts
${{ env.WORKDIR }}/.venv
key: py-deps-${{ runner.os }}-${{ runner.arch }}-py-${{ inputs.python-version }}-poetry-${{ inputs.poetry-version }}-${{ inputs.cache-key }}-${{ hashFiles(format('{0}/**/poetry.lock', env.WORKDIR)) }}
+83
View File
@@ -0,0 +1,83 @@
name: lint
on:
workflow_call:
inputs:
working-directory:
required: true
type: string
description: "From which folder this pipeline executes"
env:
POETRY_VERSION: "1.7.1"
WORKDIR: ${{ inputs.working-directory == '' && '.' || inputs.working-directory }}
jobs:
build:
runs-on: ubuntu-latest
env:
# This number is set "by eye": we want it to be big enough
# so that it's bigger than the number of commits in any reasonable PR,
# and also as small as possible since increasing the number makes
# the initial `git fetch` slower.
FETCH_DEPTH: 50
strategy:
matrix:
# Only lint on the min and max supported Python versions.
# It's extremely unlikely that there's a lint issue on any version in between
# that doesn't show up on the min or max versions.
#
# GitHub rate-limits how many jobs can be running at any one time.
# Starting new jobs is also relatively slow,
# so linting on fewer versions makes CI faster.
python-version:
- "3.8"
- "3.11"
steps:
- uses: actions/checkout@v3
- name: Set up Python ${{ matrix.python-version }} + Poetry ${{ env.POETRY_VERSION }}
uses: "./.github/actions/poetry_setup"
with:
python-version: ${{ matrix.python-version }}
poetry-version: ${{ env.POETRY_VERSION }}
working-directory: ${{ inputs.working-directory }}
cache-key: lint-with-extras
- name: Check Poetry File
shell: bash
working-directory: ${{ inputs.working-directory }}
run: |
poetry check
- name: Check lock file
shell: bash
working-directory: ${{ inputs.working-directory }}
run: |
poetry lock --check
- name: Install dependencies
# Also installs dev/lint/test/typing dependencies, to ensure we have
# type hints for as many of our libraries as possible.
# This helps catch errors that require dependencies to be spotted, for example:
# https://github.com/langchain-ai/langchain/pull/10249/files#diff-935185cd488d015f026dcd9e19616ff62863e8cde8c0bee70318d3ccbca98341
#
# If you change this configuration, make sure to change the `cache-key`
# in the `poetry_setup` action above to stop using the old cache.
# It doesn't matter how you change it, any change will cause a cache-bust.
working-directory: ${{ inputs.working-directory }}
run: |
poetry install --with dev,lint,test,typing
- name: Get .mypy_cache to speed up mypy
uses: actions/cache@v3
env:
SEGMENT_DOWNLOAD_TIMEOUT_MIN: "2"
with:
path: |
${{ env.WORKDIR }}/.mypy_cache
key: mypy-${{ runner.os }}-${{ runner.arch }}-py${{ matrix.python-version }}-${{ inputs.working-directory }}-${{ hashFiles(format('{0}/poetry.lock', env.WORKDIR)) }}
- name: Analysing the code with our lint
working-directory: ${{ inputs.working-directory }}
run: |
make lint
+57
View File
@@ -0,0 +1,57 @@
name: test
on:
workflow_call:
inputs:
working-directory:
required: true
type: string
description: "From which folder this pipeline executes"
env:
POETRY_VERSION: "1.7.1"
jobs:
build:
defaults:
run:
working-directory: ${{ inputs.working-directory }}
runs-on: ubuntu-latest
strategy:
matrix:
python-version:
- "3.8"
- "3.9"
- "3.10"
- "3.11"
name: Python ${{ matrix.python-version }}
steps:
- uses: actions/checkout@v3
- name: Set up Python ${{ matrix.python-version }} + Poetry ${{ env.POETRY_VERSION }}
uses: "./.github/actions/poetry_setup"
with:
python-version: ${{ matrix.python-version }}
poetry-version: ${{ env.POETRY_VERSION }}
working-directory: ${{ inputs.working-directory }}
cache-key: core
- name: Install dependencies
shell: bash
run: poetry install
- name: Run core tests
shell: bash
run: make test
- name: Ensure the tests did not create any additional files
shell: bash
run: |
set -eu
STATUS="$(git status)"
echo "$STATUS"
# grep will exit non-zero if the target message isn't found,
# and `set -e` above will cause the step to fail.
echo "$STATUS" | grep 'nothing to commit, working tree clean'
+108
View File
@@ -0,0 +1,108 @@
---
name: Run CI Tests
on:
push:
branches: [ main ]
pull_request:
workflow_dispatch: # Allows to trigger the workflow manually in GitHub UI
# If another push to the same PR or branch happens while this workflow is still running,
# cancel the earlier run in favor of the next run.
#
# There's no point in testing an outdated version of the code. GitHub only allows
# a limited number of job runners to be active at the same time, so it's better to cancel
# pointless jobs early so that more useful jobs can run sooner.
concurrency:
group: ${{ github.workflow }}-${{ github.ref }}
cancel-in-progress: true
env:
POETRY_VERSION: "1.7.1"
WORKDIR: "./backend"
jobs:
lint:
uses:
./.github/workflows/_lint.yml
with:
working-directory: ./backend
secrets: inherit
test:
timeout-minutes: 5
runs-on: ubuntu-latest
defaults:
run:
working-directory: ${{ env.WORKDIR }}
services:
postgres:
# ensure postgres version this stays in sync with prod database
# and with postgres version used in docker compose
image: postgres:16
env:
# optional (defaults to `postgres`)
POSTGRES_DB: langchain_test
# required
POSTGRES_PASSWORD: langchain
# optional (defaults to `5432`)
POSTGRES_PORT: 5432
# optional (defaults to `postgres`)
POSTGRES_USER: langchain
ports:
# maps tcp port 5432 on service container to the host
- 5432:5432
# set health checks to wait until postgres has started
options: >-
--health-cmd pg_isready
--health-interval 3s
--health-timeout 5s
--health-retries 10
strategy:
matrix:
python-version:
- "3.8"
- "3.9"
- "3.10"
- "3.11"
name: Python ${{ matrix.python-version }} tests
steps:
- uses: actions/checkout@v3
- name: Set up Python ${{ matrix.python-version }} + Poetry ${{ env.POETRY_VERSION }}
uses: "./.github/actions/poetry_setup"
with:
python-version: ${{ matrix.python-version }}
poetry-version: ${{ env.POETRY_VERSION }}
working-directory: ${{ env.WORKDIR }}
cache-key: langchain-extract-all
- name: Test database connection
run: |
# Set up postgresql-client
sudo apt-get install -y postgresql-client
# Test psql connection
psql -h localhost -p 5432 -U langchain -d langchain_test -c "SELECT 1;"
env:
# postgress password is required; alternatively, you can run:
# `PGPASSWORD=postgres_password psql ...`
PGPASSWORD: langchain
- name: Install dependencies
shell: bash
run: |
echo "Running tests, installing dependencies with poetry..."
poetry install --with test,lint,typing,docs
- name: Run tests
run: make test
- name: Ensure the tests did not create any additional files
shell: bash
run: |
set -eu
STATUS="$(git status)"
echo "$STATUS"
# grep will exit non-zero if the target message isn't found,
# and `set -e` above will cause the step to fail.
echo "$STATUS" | grep 'nothing to commit, working tree clean'
+162
View File
@@ -0,0 +1,162 @@
### Python template
# Byte-compiled / optimized / DLL files
__pycache__/
*.py[cod]
*$py.class
# C extensions
*.so
# Distribution / packaging
.Python
build/
develop-eggs/
dist/
downloads/
eggs/
.eggs/
lib/
lib64/
parts/
sdist/
var/
wheels/
share/python-wheels/
*.egg-info/
.installed.cfg
*.egg
MANIFEST
# PyInstaller
# Usually these files are written by a python script from a template
# before PyInstaller builds the exe, so as to inject date/other infos into it.
*.manifest
*.spec
# Installer logs
pip-log.txt
pip-delete-this-directory.txt
# Unit test / coverage reports
htmlcov/
.tox/
.nox/
.coverage
.coverage.*
.cache
nosetests.xml
coverage.xml
*.cover
*.py,cover
.hypothesis/
.pytest_cache/
cover/
# Translations
*.mo
*.pot
# Django stuff:
*.log
local_settings.py
db.sqlite3
db.sqlite3-journal
# Flask stuff:
instance/
.webassets-cache
# Scrapy stuff:
.scrapy
# Sphinx documentation
docs/_build/
# PyBuilder
.pybuilder/
target/
# Jupyter Notebook
.ipynb_checkpoints
# IPython
profile_default/
ipython_config.py
# pyenv
# For a library or package, you might want to ignore these files since the code is
# intended to run in multiple environments; otherwise, check them in:
# .python-version
# pipenv
# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
# However, in case of collaboration, if having platform-specific dependencies or dependencies
# having no cross-platform support, pipenv may install dependencies that don't work, or not
# install all needed dependencies.
#Pipfile.lock
# poetry
# Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
# This is especially recommended for binary packages to ensure reproducibility, and is more
# commonly ignored for libraries.
# https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
#poetry.lock
# pdm
# Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
#pdm.lock
# pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
# in version control.
# https://pdm.fming.dev/#use-with-ide
.pdm.toml
# PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
__pypackages__/
# Celery stuff
celerybeat-schedule
celerybeat.pid
# SageMath parsed files
*.sage.py
# Environments
.env
.venv
env/
venv/
ENV/
env.bak/
venv.bak/
# Spyder project settings
.spyderproject
.spyproject
# Rope project settings
.ropeproject
# mkdocs documentation
/site
# mypy
.mypy_cache/
.dmypy.json
dmypy.json
# Pyre type checker
.pyre/
# pytype static type analyzer
.pytype/
# Cython debug symbols
cython_debug/
# PyCharm
# JetBrains specific template is maintained in a separate JetBrains.gitignore that can
# be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
# and can be added to the global gitignore or merged into this file. For a more nuclear
# option (not recommended) you can uncomment the following to ignore the entire idea folder.
#.idea/
.DS_Store
+21
View File
@@ -0,0 +1,21 @@
MIT License
Copyright (c) 2024-Present Langchain AI
Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:
The above copyright notice and this permission notice shall be included in all
copies or substantial portions of the Software.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
SOFTWARE.
+63
View File
@@ -0,0 +1,63 @@
🚧 Under Active Development 🚧
Please expect breaking changes!
# 🦜📝 LangChain Extract
[![CI](https://github.com/langchain-ai/langchain-extract/actions/workflows/ci.yml/badge.svg)](https://github.com/langchain-ai/langchain-extract/actions/workflows/ci.yml)
[![License: MIT](https://img.shields.io/badge/License-MIT-yellow.svg)](https://opensource.org/licenses/MIT)
[![Twitter](https://img.shields.io/twitter/url/https/twitter.com/langchainai.svg?style=social&label=Follow%20%40LangChainAI)](https://twitter.com/langchainai)
[![](https://dcbadge.vercel.app/api/server/6adMQxSpJS?compact=true&style=flat)](https://discord.gg/6adMQxSpJS)
[![Open Issues](https://img.shields.io/github/issues-raw/langchain-ai/langchain-extract)](https://github.com/langchain-ai/langchain-extract/issues)
# Set up
## Services
The root folder contains a docker compose file which will a launch a postgres
instance.
```
docker compose up
```
At the time of writing, the app wasn't using postgres yet!
## App
```sh
cd [root]/backend
```
Set up the environment using poetry:
```sh
poetry install --with lint,dev,test
```
Verify that unit tests pass (they probably wont?)
# Test and format
Testing and formatting is done using a Makefile inside `[root]/backend`
```sh
make format
```
```sh
make test
```
# Launch Server
From `[root]/backend`:
```sh
python -m server.main
```
# Example client
See `docs/source/notebooks` for an example client.
+60
View File
@@ -0,0 +1,60 @@
.PHONY: all lint format test help
# Default target executed when no arguments are given to make.
all: help
######################
# TESTING AND COVERAGE
######################
# Define a variable for the test file path.
TEST_FILE ?= tests/unit_tests/
test:
poetry run pytest --disable-socket --allow-unix-socket $(TEST_FILE)
test_watch:
poetry run ptw . -- $(TEST_FILE)
openapi:
OPENAI_API_KEY=placeholder python -c "from server import main; import json; print(json.dumps(main.app.openapi()))" > openapi.json
######################
# LINTING AND FORMATTING
######################
# Define a variable for Python and notebook files.
lint format: PYTHON_FILES=.
lint_diff format_diff: PYTHON_FILES=$(shell git diff --relative=. --name-only --diff-filter=d master | grep -E '\.py$$|\.ipynb$$')
lint lint_diff:
[ "$(PYTHON_FILES)" = "" ] || poetry run ruff format $(PYTHON_FILES) --diff
# [ "$(PYTHON_FILES)" = "" ] || poetry run mypy $(PYTHON_FILES)
format format_diff:
[ "$(PYTHON_FILES)" = "" ] || poetry run ruff format $(PYTHON_FILES)
[ "$(PYTHON_FILES)" = "" ] || poetry run ruff --select I --fix $(PYTHON_FILES)
spell_check:
poetry run codespell --toml pyproject.toml
spell_fix:
poetry run codespell --toml pyproject.toml -w
######################
# HELP
######################
help:
@echo '===================='
@echo '-- LINTING --'
@echo 'format - run code formatters'
@echo 'lint - run linters'
@echo 'spell_check - run codespell on the project'
@echo 'spell_fix - run codespell on the project and fix the errors'
@echo '-- TESTS --'
@echo 'coverage - run unit tests and generate coverage report'
@echo 'test - run unit tests'
@echo 'test TEST_FILE=<test_file> - run all tests in file'
@echo '-- DOCUMENTATION tasks are from the top-level Makefile --'
+1
View File
@@ -0,0 +1 @@
See readme at repo root.
View File
+134
View File
@@ -0,0 +1,134 @@
import uuid
from datetime import datetime
from typing import Generator
from sqlalchemy import Column, DateTime, ForeignKey, String, Text, create_engine
from sqlalchemy.dialects.postgresql import JSONB, UUID
from sqlalchemy.ext.declarative import declarative_base
from sqlalchemy.orm import Session, relationship, sessionmaker
from sqlalchemy.sql import func
from server.settings import get_postgres_url
ENGINE = create_engine(get_postgres_url())
SessionClass = sessionmaker(bind=ENGINE)
Base = declarative_base()
# TODO(Eugene): Convert to async code
def get_session() -> Generator[Session, None, None]:
"""Create a new session."""
session = SessionClass()
try:
yield session
except:
session.rollback()
raise
finally:
session.close()
class TimestampedModel(Base):
"""An abstract base model that includes the timestamp fields."""
__abstract__ = True
created_at = Column(
DateTime,
default=datetime.utcnow,
comment="The time the record was created (UTC).",
)
updated_at = Column(
DateTime,
default=datetime.utcnow,
onupdate=datetime.utcnow,
doc="The time the record was last updated (UTC).",
)
# This is our own uuid assigned to the artifact.
# By construction guaranteed to be unique no matter what.
uuid = Column(
UUID(as_uuid=True),
primary_key=True,
default=lambda: str(uuid.uuid4()),
doc="Unique identifier for this model.",
)
class Extractor(TimestampedModel):
__tablename__ = "extractors"
name = Column(
String(100),
nullable=False,
server_default="",
comment="The name of the extractor.",
)
created_at = Column(
DateTime(timezone=True),
server_default=func.now(),
comment="Time when this extracted was originally created.",
)
modified_at = Column(
DateTime(timezone=True),
onupdate=func.now(),
comment="Last time this was modified.",
)
schema = Column(
JSONB,
nullable=False,
comment="JSON Schema that describes what content will be extracted from the document",
)
description = Column(
String(100),
nullable=False,
server_default="",
comment="Surfaced via UI to the users.",
)
instruction = Column(
Text, nullable=False, comment="The prompt to the language model."
) # TODO: This will need to evolve
examples = relationship("Example", backref="extractor")
def __repr__(self) -> str:
return f"<Extractor(id={self.uuid}, description={self.description})>"
class Example(TimestampedModel):
"""A representation of an example.
Examples consist of content together with the expected output.
The output is a JSON object that is expected to be extracted from the content.
The JSON object should be valid according to the schema of the associated extractor.
The JSON object is defined by the schema of the associated extractor, so
it's perfectly fine for a given example to represent the extraction
of multiple instances of some object from the content since
the JSON schema can represent a list of objects.
"""
__tablename__ = "examples"
content = Column(
Text,
nullable=False,
comment="The input portion of the example.",
)
output = Column(
JSONB,
comment="The output associated with the example.",
)
extractor_id = Column(
UUID(as_uuid=True),
ForeignKey("extractors.uuid", ondelete="CASCADE"),
nullable=False,
comment="Foreign key referencing the associated extractor.",
)
def __repr__(self) -> str:
return f"<Example(uuid={self.uuid}, content={self.content[:20]}>"
View File
+65
View File
@@ -0,0 +1,65 @@
"""Convert binary input to blobs and parse them using the appropriate parser."""
from __future__ import annotations
from typing import BinaryIO, List
from langchain.document_loaders.parsers import BS4HTMLParser, PDFMinerParser
from langchain.document_loaders.parsers.generic import MimeTypeBasedParser
from langchain.document_loaders.parsers.txt import TextParser
from langchain_community.document_loaders import Blob
from langchain_core.documents import Document
HANDLERS = {
"application/pdf": PDFMinerParser(),
"text/plain": TextParser(),
"text/html": BS4HTMLParser(),
# Disable for now as they rely on unstructured and there's some install
# issue with unstructured.
# from langchain.document_loaders.parsers.msword import MsWordParser
# "application/msword": MsWordParser(),
# "application/vnd.openxmlformats-officedocument.wordprocessingml.document": (
# MsWordParser()
# ),
}
SUPPORTED_MIMETYPES = sorted(HANDLERS.keys())
def _guess_mimetype(file_bytes: bytes) -> str:
"""Guess the mime-type of a file."""
try:
import magic
except ImportError as e:
raise ImportError(
"magic package not found, please install it with `pip install python-magic`"
) from e
mime = magic.Magic(mime=True)
mime_type = mime.from_buffer(file_bytes)
return mime_type
# PUBLIC API
MIMETYPE_BASED_PARSER = MimeTypeBasedParser(
handlers=HANDLERS,
fallback_parser=None,
)
def convert_binary_input_to_blob(data: BinaryIO) -> Blob:
"""Convert ingestion input to blob."""
file_data = data.read()
mimetype = _guess_mimetype(file_data)
file_name = data.name
return Blob.from_data(
data=file_data,
path=file_name,
mime_type=mimetype,
)
def parse_binary_input(data: BinaryIO) -> List[Document]:
"""Parse binary input."""
blob = convert_binary_input_to_blob(data)
return MIMETYPE_BASED_PARSER.parse(blob)
+52
View File
@@ -0,0 +1,52 @@
"""Adapters to convert between different formats."""
from __future__ import annotations
from langchain_core.utils.json_schema import dereference_refs
def _rm_titles(kv: dict) -> dict:
"""Remove titles from a dictionary."""
new_kv = {}
for k, v in kv.items():
if k == "title":
continue
elif isinstance(v, dict):
new_kv[k] = _rm_titles(v)
else:
new_kv[k] = v
return new_kv
# PUBLIC API
def convert_json_schema_to_openai_schema(
schema: dict,
*,
rm_titles: bool = True,
multi: bool = True,
) -> dict:
"""Convert JSON schema to a corresponding OpenAI function call."""
if multi:
# Wrap the schema in an object called "Root" with a property called: "data"
# which will be a json array of the original schema.
schema_ = {
"type": "object",
"properties": {
"data": {
"type": "array",
"items": dereference_refs(schema),
},
},
"required": ["data"],
}
else:
raise NotImplementedError("Only multi is supported for now.")
schema_.pop("definitions", None)
return {
"name": "extractor",
"description": "Extract information matching the given schema.",
"parameters": _rm_titles(schema_) if rm_titles else schema_,
}
+4243
View File
File diff suppressed because it is too large Load Diff
+91
View File
@@ -0,0 +1,91 @@
[tool.poetry]
name = "langchain-extract"
version = "0.0.1"
description = "Sample extraction backend."
authors = ["LangChain AI"]
license = "MIT"
readme = "README.md"
[tool.poetry.dependencies]
python = "^3.8.1"
langchain = "~0.1"
langsmith = ">=0.0.66"
fastapi = "^0.109.2"
langserve = "^0.0.41"
uvicorn = "^0.27.1"
pydantic = "^1.10"
langchain-openai = "^0.0.6"
jsonschema = "^4.21.1"
sse-starlette = "^2.0.0"
alembic = "^1.13.1"
psycopg2 = "^2.9.9"
python-magic = "^0.4.27"
pdfminer-six = "^20231228"
beautifulsoup4 = "^4.12.3"
lxml = "^5.1.0"
faiss-cpu = "^1.7.4"
python-multipart = "^0.0.9"
[tool.poetry.group.dev.dependencies]
jupyterlab = "^3.6.1"
[tool.poetry.group.typing.dependencies]
mypy = "^1.7.0"
[tool.poetry.group.lint.dependencies]
ruff = "^0.1.5"
[tool.poetry.group.docs.dependencies]
[tool.poetry.group.test.dependencies]
pytest = "^7.2.1"
pytest-cov = "^4.0.0"
pytest-asyncio = "^0.21.1"
pytest-mock = "^3.11.1"
pytest-socket = "^0.6.0"
pytest-watch = "^4.2.0"
pytest-timeout = "^2.2.0"
[tool.ruff]
select = [
"E", # pycodestyle
"F", # pyflakes
"I", # isort
]
extend-include = ["*.ipynb"]
# Same as Black.
line-length = 88
[tool.mypy]
disallow_untyped_defs = "True"
ignore_missing_imports = "True"
[tool.coverage.run]
omit = [
"tests/*",
]
[build-system]
requires = ["poetry-core"]
build-backend = "poetry.core.masonry.api"
[tool.pytest.ini_options]
# --strict-markers will raise errors on unknown marks.
# https://docs.pytest.org/en/7.1.x/how-to/mark.html#raising-errors-on-unknown-marks
#
# https://docs.pytest.org/en/7.1.x/reference/reference.html
# --strict-config any warnings encountered while parsing the `pytest`
# section of the configuration file raise errors.
addopts = "--strict-markers --strict-config --durations=5 -vv"
# Global timeout for all tests. There shuold be a good reason for a test to
# take more than 5 second
timeout = 5
# Registering custom markers.
# https://docs.pytest.org/en/7.1.x/example/markers.html#registering-markers
markers = [
"asyncio: mark tests as requiring asyncio",
]
asyncio_mode = "auto"
View File
+59
View File
@@ -0,0 +1,59 @@
#!/usr/bin/env python
"""Run migrations."""
import click
from db.models import ENGINE, Base
@click.group()
def cli():
"""Database migration commands."""
pass
@cli.command()
def create():
"""Create all tables."""
Base.metadata.create_all(ENGINE)
click.echo("All tables created successfully.")
@cli.command()
@click.confirmation_option(prompt="Are you sure you want to drop all tables?")
def drop():
"""Drop all tables."""
Base.metadata.drop_all(ENGINE)
click.echo("All tables dropped successfully.")
@cli.command()
def create_test_db():
"""Create a test database called langchain_test used for testing purposes."""
import psycopg2
from psycopg2.errors import DuplicateDatabase
# establishing the connection
conn = psycopg2.connect(
database="langchain",
user="langchain",
password="langchain",
host="localhost",
port="5432",
)
conn.autocommit = True
# Creating a cursor object using the cursor() method
with conn.cursor() as cursor:
# Preparing query to create a database
sql = "CREATE DATABASE langchain_test;"
# Creating a database
try:
cursor.execute(sql)
print("Database created successfully.")
except DuplicateDatabase:
print("Database already exists")
if __name__ == "__main__":
cli()
View File
View File
+74
View File
@@ -0,0 +1,74 @@
"""Endpoints for managing definition of examples.."""
from typing import Any, Dict, List
from uuid import UUID
from fastapi import APIRouter, Depends
from sqlalchemy.orm import Session
from typing_extensions import Annotated, TypedDict
from db.models import Example, get_session
router = APIRouter(
prefix="/examples",
tags=["example definitions"],
responses={404: {"description": "Not found"}},
)
class CreateExample(TypedDict):
"""A request to create an example."""
extractor_id: Annotated[UUID, "The extractor ID that this is an example for."]
content: Annotated[str, "The input portion of the example."]
output: Annotated[
List[Any], "JSON object that is expected to be extracted from the content."
]
class CreateExampleResponse(TypedDict):
"""Response for creating an example."""
uuid: UUID
@router.post("")
def create(
create_request: CreateExample,
*,
session: Session = Depends(get_session),
) -> CreateExampleResponse:
"""Endpoint to create an example."""
instance = Example(
extractor_id=create_request["extractor_id"],
content=create_request["content"],
output=create_request["output"],
)
session.add(instance)
session.commit()
return {"uuid": instance.uuid}
@router.get("")
def list(
extractor_id: UUID,
*,
limit: int = 10,
offset: int = 0,
session=Depends(get_session),
) -> List[Any]:
"""Endpoint to get all examples."""
return (
session.query(Example)
.filter(Example.extractor_id == extractor_id)
.order_by(Example.uuid)
.limit(limit)
.offset(offset)
.all()
)
@router.delete("/{uuid}")
def delete(uuid: UUID, *, session: Session = Depends(get_session)) -> None:
"""Endpoint to delete an example."""
session.query(Example).filter(Example.uuid == str(uuid)).delete()
session.commit()
+56
View File
@@ -0,0 +1,56 @@
from typing import Literal, Optional
from uuid import UUID
from fastapi import APIRouter, Depends, File, Form, HTTPException, UploadFile
from sqlalchemy.orm import Session
from typing_extensions import Annotated
from db.models import Extractor, get_session
from extraction.parsing import parse_binary_input
from server.extraction_runnable import ExtractResponse, extract_entire_document
from server.retrieval import extract_from_content
router = APIRouter(
prefix="/extract",
tags=["extract"],
responses={404: {"description": "Not found"}},
)
@router.post("", response_model=ExtractResponse)
async def extract_using_existing_extractor(
*,
extractor_id: Annotated[UUID, Form()],
text: Optional[str] = Form(None),
mode: Literal["entire_document", "retrieval"] = Form("entire_document"),
file: Optional[UploadFile] = File(None),
session: Session = Depends(get_session),
) -> ExtractResponse:
"""Endpoint that is used with an existing extractor.
This endpoint will be expanded to support upload of binary files as well as
text files.
"""
if text is None and file is None:
raise HTTPException(status_code=422, detail="No text or file provided.")
extractor = session.query(Extractor).filter(Extractor.uuid == extractor_id).scalar()
if extractor is None:
raise HTTPException(status_code=404, detail="Extractor not found.")
if text:
text_ = text
else:
documents = parse_binary_input(file.file)
# TODO: Add metadata like location from original file where
# the text was extracted from
text_ = "\n".join([document.page_content for document in documents])
if mode == "entire_document":
return await extract_entire_document(text_, extractor)
elif mode == "retrieval":
return await extract_from_content(text_, extractor)
else:
raise ValueError(
f"Invalid mode {mode}. Expected one of 'entire_document', 'retrieval'."
)
+93
View File
@@ -0,0 +1,93 @@
"""Endpoints for managing definition of extractors."""
from typing import Any, Dict, List
from uuid import UUID
from fastapi import APIRouter, Depends, HTTPException
from pydantic import BaseModel, Field, validator
from sqlalchemy.orm import Session
from db.models import Extractor, get_session
from server.validators import validate_json_schema
router = APIRouter(
prefix="/extractors",
tags=["extractor definitions"],
responses={404: {"description": "Not found"}},
)
class CreateExtractor(BaseModel):
"""A request to create an extractor."""
name: str = Field(default="", description="The name of the extractor.")
description: str = Field(
default="", description="Short description of the extractor."
)
json_schema: Dict[str, Any] = Field(
..., description="The schema to use for extraction.", alias="schema"
)
instruction: str = Field(..., description="The instruction to use for extraction.")
@validator("json_schema")
def validate_schema(cls, v: Any) -> Dict[str, Any]:
"""Validate the schema."""
validate_json_schema(v)
return v
class CreateExtractorResponse(BaseModel):
"""Response for creating an extractor."""
uuid: UUID
@router.post("")
def create(
create_request: CreateExtractor, *, session: Session = Depends(get_session)
) -> CreateExtractorResponse:
"""Endpoint to create an extractor."""
instance = Extractor(
name=create_request.name,
schema=create_request.json_schema,
description=create_request.description,
instruction=create_request.instruction,
)
session.add(instance)
session.commit()
return CreateExtractorResponse(uuid=instance.uuid)
@router.get("/{uuid}")
def get(
uuid: UUID, *, session: Session = Depends(get_session)
) -> Dict[str, Any]:
"""Endpoint to get an extractor."""
extractor = session.query(Extractor).filter(Extractor.uuid == str(uuid)).scalar()
if extractor is None:
raise HTTPException(status_code=404, detail="Extractor not found.")
return {
"uuid": extractor.uuid,
"name": extractor.name,
"description": extractor.description,
"schema": extractor.schema,
"instruction": extractor.instruction,
}
@router.get("")
def list(
*,
limit: int = 10,
offset: int = 0,
session=Depends(get_session),
) -> List[Any]:
"""Endpoint to get all extractors."""
return session.query(Extractor).limit(limit).offset(offset).all()
@router.delete("/{uuid}")
def delete(uuid: UUID, *, session: Session = Depends(get_session)) -> None:
"""Endpoint to delete an extractor."""
session.query(Extractor).filter(Extractor.uuid == str(uuid)).delete()
session.commit()
+208
View File
@@ -0,0 +1,208 @@
from __future__ import annotations
import json
from typing import Any, Dict, List, Optional, Sequence
from fastapi import HTTPException
from jsonschema import Draft202012Validator, exceptions
from langchain.chains.openai_functions import create_openai_fn_runnable
from langchain.text_splitter import TokenTextSplitter
from langchain_core.messages import AIMessage, HumanMessage
from langchain_core.prompts import ChatPromptTemplate
from langchain_core.runnables import chain
from langserve import CustomUserType
from pydantic import BaseModel, Field, validator
from typing_extensions import TypedDict
from db.models import Example, Extractor
from extraction.utils import (
convert_json_schema_to_openai_schema,
)
from server.settings import CHUNK_SIZE, MODEL_NAME, model
from server.validators import validate_json_schema
class ExtractionExample(BaseModel):
"""An example extraction.
This example consists of a text and the expected output of the extraction.
"""
text: str = Field(..., description="The input text")
output: List[Dict[str, Any]] = Field(
..., description="The expected output of the example. A list of objects."
)
class ExtractRequest(CustomUserType):
"""Request body for the extract endpoint."""
text: str = Field(..., description="The text to extract from.")
json_schema: Dict[str, Any] = Field(
...,
description="JSON schema that describes what content should be extracted "
"from the text.",
alias="schema",
)
instructions: Optional[str] = Field(
None, description="Supplemental system instructions."
)
examples: Optional[List[ExtractionExample]] = Field(
None, description="Examples of extractions."
)
@validator("json_schema")
def validate_schema(cls, v: Any) -> Dict[str, Any]:
"""Validate the schema."""
validate_json_schema(v)
return v
class ExtractResponse(TypedDict):
"""Response body for the extract endpoint."""
data: List[Any]
def _deduplicate(
extract_responses: Sequence[ExtractResponse],
) -> ExtractResponse:
"""Deduplicate the results.
The deduplication is done by comparing the serialized JSON of each of the results
and only keeping the unique ones.
"""
unique_extracted = []
seen = set()
for response in extract_responses:
for data_item in response["data"]:
# Serialize the data item for comparison purposes
serialized = json.dumps(data_item, sort_keys=True)
if serialized not in seen:
seen.add(serialized)
unique_extracted.append(data_item)
return {
"data": unique_extracted,
}
def _cast_example_to_dict(example: Example) -> Dict[str, Any]:
"""Cast example record to dictionary."""
return {
"text": example.content,
"output": example.output,
}
def _make_prompt_template(
instructions: Optional[str],
examples: Optional[Sequence[ExtractionExample]],
function_name: str,
) -> ChatPromptTemplate:
"""Make a system message from instructions and examples."""
prefix = (
"You are a top-tier algorithm for extracting information from text. "
"Only extract information that is relevant to the provided text. "
"If no information is relevant, use the schema and output "
"an empty list where appropriate."
)
if instructions:
system_message = ("system", f"{prefix}\n\n{instructions}")
else:
system_message = ("system", prefix)
prompt_components = [system_message]
if examples is not None:
few_shot_prompt = []
for example in examples:
# TODO: We'll need to refactor this at some point to
# support other encoding strategies. The function calling logic here
# has some hard-coded assumptions (e.g., name of parameters like `data`).
function_call = {
"arguments": json.dumps(
{
"data": example.output,
}
),
"name": function_name,
}
few_shot_prompt.extend(
[
HumanMessage(content=example.text),
AIMessage(
content="", additional_kwargs={"function_call": function_call}
),
]
)
prompt_components.extend(few_shot_prompt)
prompt_components.append(
(
"human",
"I need to extract information from "
"the following text: ```\n{text}\n```\n",
),
)
return ChatPromptTemplate.from_messages(prompt_components)
# PUBLIC API
def get_examples_from_extractor(extractor: Extractor) -> List[Dict[str, Any]]:
"""Get examples from an extractor."""
return [_cast_example_to_dict(example) for example in extractor.examples]
@chain
async def extraction_runnable(extraction_request: ExtractRequest) -> ExtractResponse:
"""An end point to extract content from a given text object."""
# TODO: Add validation for model context window size
schema = extraction_request.json_schema
try:
Draft202012Validator.check_schema(schema)
except exceptions.ValidationError as e:
raise HTTPException(status_code=422, detail=f"Invalid schema: {e.message}")
openai_function = convert_json_schema_to_openai_schema(schema)
function_name = openai_function["name"]
prompt = _make_prompt_template(
extraction_request.instructions,
extraction_request.examples,
function_name,
)
runnable = create_openai_fn_runnable(
functions=[openai_function], llm=model, prompt=prompt
)
return await runnable.ainvoke({"text": extraction_request.text})
async def extract_entire_document(
content: str,
extractor: Extractor,
) -> ExtractResponse:
"""Extract from entire document."""
json_schema = extractor.schema
examples = get_examples_from_extractor(extractor)
text_splitter = TokenTextSplitter(
chunk_size=CHUNK_SIZE,
chunk_overlap=20,
model_name=MODEL_NAME,
)
texts = text_splitter.split_text(content)
extraction_requests = [
ExtractRequest(
text=text,
schema=json_schema,
instructions=extractor.instruction, # TODO: consistent naming
examples=examples,
)
for text in texts
]
# Run extractions which may potentially yield duplicate results
extract_responses: List[ExtractResponse] = await extraction_runnable.abatch(
extraction_requests, {"max_concurrency": 1}
)
# Deduplicate the results
return _deduplicate(extract_responses)
+65
View File
@@ -0,0 +1,65 @@
"""Entry point into the server."""
from fastapi import FastAPI
from fastapi.middleware.cors import CORSMiddleware
from langserve import add_routes
from sqlalchemy.orm import Session
from typing_extensions import Annotated
from server.api import examples, extract, extractors
from server.extraction_runnable import (
ExtractRequest,
ExtractResponse,
extraction_runnable,
)
app = FastAPI(
title="Extraction Powered by LangChain",
description="An extraction service powered by LangChain.",
version="0.0.1",
openapi_tags=[
{
"name": "extraction",
"description": "Operations related to extracting content from text.",
}
],
)
origins = [
"http://localhost:5173",
]
app.add_middleware(
CORSMiddleware,
allow_origins=origins,
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
@app.get("/ready")
def ready():
return "ok"
# Include API endpoints for extractor definitions
app.include_router(extractors.router)
app.include_router(examples.router)
app.include_router(extract.router)
add_routes(
app,
extraction_runnable.with_types(
input_type=ExtractRequest, output_type=ExtractResponse
),
path="/extract_text",
enabled_endpoints=["invoke", "playground", "stream_log"],
)
if __name__ == "__main__":
import uvicorn
uvicorn.run(app, host="localhost", port=8000)
+72
View File
@@ -0,0 +1,72 @@
from operator import itemgetter
from typing import Any, Dict, List, Optional
from langchain.text_splitter import CharacterTextSplitter
from langchain_community.vectorstores import FAISS
from langchain_core.documents import Document
from langchain_core.runnables import RunnableLambda
from langchain_openai import OpenAIEmbeddings
from db.models import Extractor
from server.extraction_runnable import (
ExtractRequest,
ExtractResponse,
extraction_runnable,
get_examples_from_extractor,
)
def _get_top_doc_content(docs: List[Document]) -> str:
if docs:
return docs[0].page_content
else:
return ""
def _make_extract_request(input_dict: Dict[str, Any]) -> ExtractRequest:
return ExtractRequest(**input_dict)
async def extract_from_content(
content: str,
extractor: Extractor,
*,
text_splitter_kwargs: Optional[Dict[str, Any]] = None,
multi: bool = True,
) -> ExtractResponse:
"""Extract from potentially long-form content."""
if text_splitter_kwargs is None:
text_splitter_kwargs = {
"separator": "\n\n",
"chunk_size": 1000,
"chunk_overlap": 50,
}
text_splitter = CharacterTextSplitter(**text_splitter_kwargs)
docs = text_splitter.create_documents([content])
doc_contents = [doc.page_content for doc in docs]
vectorstore = FAISS.from_texts(doc_contents, embedding=OpenAIEmbeddings())
retriever = vectorstore.as_retriever()
runnable = (
{
"text": itemgetter("query") | retriever | _get_top_doc_content,
"schema": itemgetter("schema"),
"instructions": lambda x: x.get("instructions"),
"examples": lambda x: x.get("examples"),
}
| RunnableLambda(_make_extract_request)
| extraction_runnable
)
schema = extractor.schema
examples = get_examples_from_extractor(extractor)
description = extractor.description # TODO: improve this
result = await runnable.ainvoke(
{
"query": description,
"schema": schema,
"examples": examples,
"instructions": extractor.instruction,
}
)
return ExtractResponse(extracted=[result.extracted])
+24
View File
@@ -0,0 +1,24 @@
from __future__ import annotations
from langchain_openai import ChatOpenAI
from sqlalchemy.engine import URL
MODEL_NAME = "gpt-3.5-turbo"
CHUNK_SIZE = int(4_096 * 0.8)
# Max concurrency for the model.
MAX_CONCURRENCY = 1
def get_postgres_url():
url = URL.create(
drivername="postgresql",
username="langchain",
password="langchain",
host="localhost",
database="langchain",
port=5432,
)
return url
model = ChatOpenAI(model=MODEL_NAME, temperature=0)
+15
View File
@@ -0,0 +1,15 @@
from typing import Any, Dict
from fastapi import HTTPException
from jsonschema import exceptions
from jsonschema.validators import Draft202012Validator
def validate_json_schema(schema: Dict[str, Any]) -> None:
"""Validate a JSON schema."""
try:
Draft202012Validator.check_schema(schema)
except exceptions.ValidationError as e:
raise HTTPException(
status_code=422, detail=f"Not a valid JSON schema: {e.message}"
)
View File
+51
View File
@@ -0,0 +1,51 @@
"""Utility code that sets up a test database and client for tests."""
from contextlib import asynccontextmanager
from typing import Generator
from httpx import AsyncClient
from sqlalchemy import URL, create_engine
from sqlalchemy.orm import sessionmaker
from db.models import Base, get_session
from server.main import app
url = URL.create(
drivername="postgresql",
username="langchain",
password="langchain",
host="localhost",
database="langchain_test",
port=5432,
)
engine = create_engine(url)
TestingSession = sessionmaker(bind=engine)
def override_get_session() -> Generator[TestingSession, None, None]:
"""Override the get_session dependency with a test session.
This fixture also re-creats the database before each test and drops it after to
ensure a clean slate for each test.
"""
try:
session = TestingSession()
yield session
finally:
session.close()
app.dependency_overrides[get_session] = override_get_session
@asynccontextmanager
async def get_async_client() -> AsyncClient:
"""Get an async client."""
# Clear the database before each test
Base.metadata.drop_all(engine)
Base.metadata.create_all(engine)
async_client = AsyncClient(app=app, base_url="http://test")
try:
yield async_client
finally:
await async_client.aclose()
@@ -0,0 +1,88 @@
"""Makes it easy to run an integration tests using a real chat model."""
from contextlib import asynccontextmanager
from typing import Optional
import httpx
from fastapi import FastAPI
from httpx import AsyncClient
from langchain_core.pydantic_v1 import BaseModel
from server.main import app
@asynccontextmanager
async def get_async_test_client(
server: FastAPI, *, path: Optional[str] = None, raise_app_exceptions: bool = True
) -> AsyncClient:
"""Get an async client."""
url = "http://localhost:9999/"
if path:
url += path
transport = httpx.ASGITransport(
app=server,
raise_app_exceptions=raise_app_exceptions,
)
async_client = AsyncClient(app=server, base_url=url, transport=transport)
try:
yield async_client
finally:
await async_client.aclose()
async def test_extraction_api() -> None:
"""Test the extraction API endpoint."""
class Person(BaseModel):
age: Optional[int]
name: Optional[str]
alias: Optional[str]
async with get_async_test_client(app) as client:
text = """
My name is Chester. I am young. I love cats. I have two cats. My age
is the number of cats I have to the power of 5. (Approximately.)
I also have a friend. His name is Neo. He is older than me. He is
also a cat lover. He has 3 cats. He is 25 years old.
"""
result = await client.post(
"/extract_text/invoke",
json={"input": {"text": text, "schema": Person.schema()}},
)
assert result.status_code == 200, result.text
response_data = result.json()
assert response_data == {}
assert isinstance(response_data["output"]["data"], list)
# Test with instructions
result = await client.post(
"/extract_text/invoke",
json={
"input": {
"text": text,
"schema": Person.schema(),
"instructions": "Very important: Chester's alias is Neo.",
}
},
)
response_data = result.json()
assert result.status_code == 200, result.text
# Test with few shot examples
examples = [
{
"text": "My name is Grung. I am 100.",
"output": [Person(age=100, name="######").dict()],
},
]
result = await client.post(
"/extract_text/invoke",
json={
"input": {
"text": text,
"schema": Person(),
"instructions": "Redact all names using the characters `######`",
"examples": examples,
}
},
)
assert result.status_code == 200, result.text
@@ -0,0 +1,72 @@
"""Code to test API endpoints."""
import uuid
from tests.db import get_async_client
async def test_extractors_api() -> None:
"""This will test a few of the extractors API endpoints."""
# First verify that the database is empty
async with get_async_client() as client:
response = await client.get("/extractors")
assert response.status_code == 200
assert response.json() == []
# Verify that we can create an extractor
create_request = {
"description": "Test Description",
"schema": {"type": "object"},
"instruction": "Test Instruction",
}
response = await client.post("/extractors", json=create_request)
assert response.status_code == 200
# Verify that the extractor was created
response = await client.get("/extractors")
assert response.status_code == 200
assert len(response.json()) == 1
# Verify that we can delete an extractor
get_response = response.json()
uuid_str = get_response[0]["uuid"]
_ = uuid.UUID(uuid_str) # assert valid uuid
response = await client.delete(f"/extractors/{uuid_str}")
assert response.status_code == 200
get_response = await client.get("/extractors")
assert get_response.status_code == 200
assert get_response.json() == []
# Verify that we can create an extractor
create_request = {
"description": "Test Description",
"schema": {"type": "object"},
"instruction": "Test Instruction",
}
response = await client.post("/extractors", json=create_request)
assert response.status_code == 200
# Verify that the extractor was created
response = await client.get("/extractors")
assert response.status_code == 200
assert len(response.json()) == 1
# Verify that we can delete an extractor
get_response = response.json()
uuid_str = get_response[0]["uuid"]
_ = uuid.UUID(uuid_str) # assert valid uuid
response = await client.delete(f"/extractors/{uuid_str}")
assert response.status_code == 200
get_response = await client.get("/extractors")
assert get_response.status_code == 200
assert get_response.json() == []
# Verify that we can create an extractor
create_request = {
"description": "Test Description",
"schema": {"type": "object"},
"instruction": "Test Instruction",
}
response = await client.post("/extractors", json=create_request)
assert response.status_code == 200
@@ -0,0 +1,76 @@
"""Code to test API endpoints."""
from tests.db import get_async_client
async def _list_extractors() -> list:
async with get_async_client() as client:
response = await client.get("/extractors")
assert response.status_code == 200
return response.json()
async def test_examples_api() -> None:
"""Runs through a set of API calls to test the examples API."""
async with get_async_client() as client:
# First create an extractor
create_request = {
"description": "Test Description",
"name": "Test Name",
"schema": {"type": "object"},
"instruction": "Test Instruction",
}
response = await client.post("/extractors", json=create_request)
assert response.status_code == 200
# Get the extractor id
extractor_id = response.json()["uuid"]
# Let's verify that there are no examples
response = await client.get("/examples?extractor_id=" + extractor_id)
assert response.status_code == 200
assert response.json() == []
# Now let's create an example
create_request = {
"extractor_id": extractor_id,
"content": "Test Content",
"output": [
{
"age": 100,
"name": "Grung",
}
],
}
response = await client.post("/examples", json=create_request)
assert response.status_code == 200
example_id = response.json()["uuid"]
# Verify that the example was created
response = await client.get("/examples?extractor_id=" + extractor_id)
assert response.status_code == 200
assert len(response.json()) == 1
keys = ["content", "extractor_id", "output", "uuid"]
projected_response = {
key: record[key] for key in keys for record in response.json()
}
assert projected_response == {
"content": "Test Content",
"extractor_id": extractor_id,
"output": [
{
"age": 100,
"name": "Grung",
}
],
"uuid": example_id,
}
# Verify that we can delete an example
response = await client.delete(f"/examples/{example_id}")
assert response.status_code == 200
# Verify that the example was deleted
response = await client.get("/examples?extractor_id=" + extractor_id)
assert response.status_code == 200
assert response.json() == []
@@ -0,0 +1,86 @@
"""Code to test API endpoints."""
import tempfile
from unittest.mock import patch
from uuid import UUID
from langchain.text_splitter import CharacterTextSplitter
from langchain_core.runnables import RunnableLambda
from tests.db import get_async_client
def mock_extraction_runnable(*args, **kwargs):
"""Mock the extraction_runnable function."""
extract_request = args[0]
return {
"data": [
extract_request.text[:10],
]
}
def mock_text_splitter(*args, **kwargs):
return CharacterTextSplitter()
@patch(
"server.extraction_runnable.extraction_runnable",
new=RunnableLambda(mock_extraction_runnable),
)
@patch("server.extraction_runnable.TokenTextSplitter", mock_text_splitter)
async def test_extract_from_file() -> None:
"""Test extract from file API."""
async with get_async_client() as client:
# Test with invalid extractor
extractor_id = UUID(int=1027) # 1027 is a good number.
response = await client.post(
"/extract",
data={
"extractor_id": str(extractor_id),
"text": "Test Content",
},
)
assert response.status_code == 404, response.text
# First create an extractor
create_request = {
"name": "Test Name",
"description": "Test Description",
"schema": {"type": "object"},
"instruction": "Test Instruction",
}
response = await client.post("/extractors", json=create_request)
assert response.status_code == 200, response.text
# Get the extractor id
extractor_id = response.json()["uuid"]
# Run an extraction.
# We'll use multi-form data here.
response = await client.post(
"/extract",
data={
"extractor_id": extractor_id,
"text": "Test Content",
"mode": "entire_document",
},
)
assert response.status_code == 200
assert response.json() == {"data": ["Test Conte"]}
# We'll use multi-form data here.
# Create a named temporary file
with tempfile.NamedTemporaryFile(mode="w+t", delete=False) as f:
f.write("This is a named temporary file.")
f.seek(0)
f.flush()
response = await client.post(
"/extract",
data={
"extractor_id": extractor_id,
"mode": "entire_document",
},
files={"file": f},
)
assert response.status_code == 200, response.text
assert response.json() == {"data": ["This is a "]}
+3
View File
@@ -0,0 +1,3 @@
import os
os.environ["OPENAI_API_KEY"] = "placeholder"
@@ -0,0 +1,45 @@
"""Fake Chat Model wrapper for testing purposes."""
from typing import Any, Iterator, List, Optional
from langchain_core.callbacks.manager import (
CallbackManagerForLLMRun,
)
from langchain_core.language_models.chat_models import BaseChatModel
from langchain_core.messages import (
AIMessage,
BaseMessage,
)
from langchain_core.outputs import ChatGeneration, ChatResult
class GenericFakeChatModel(BaseChatModel):
"""A generic fake chat model that can be used to test the chat model interface."""
messages: Iterator[AIMessage]
"""Get an iterator over messages.
This can be expanded to accept other types like Callables / dicts / strings
to make the interface more generic if needed.
Note: if you want to pass a list, you can use `iter` to convert it to an iterator.
Please note that streaming is not implemented yet. We should try to implement it
in the future by delegating to invoke and then breaking the resulting output
into message chunks.
"""
def _generate(
self,
messages: List[BaseMessage],
stop: Optional[List[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> ChatResult:
"""Top Level call"""
message = next(self.messages)
generation = ChatGeneration(message=message)
return ChatResult(generations=[generation])
@property
def _llm_type(self) -> str:
return "generic-fake-chat-model"
@@ -0,0 +1,30 @@
"""Tests for verifying that testing utility code works as expected."""
from itertools import cycle
from langchain_core.messages import AIMessage
from tests.unit_tests.fake.chat_model import GenericFakeChatModel
def test_generic_fake_chat_model_invoke() -> None:
# Will alternate between responding with hello and goodbye
infinite_cycle = cycle([AIMessage(content="hello"), AIMessage(content="goodbye")])
model = GenericFakeChatModel(messages=infinite_cycle)
response = model.invoke("meow")
assert response == AIMessage(content="hello")
response = model.invoke("kitty")
assert response == AIMessage(content="goodbye")
response = model.invoke("meow")
assert response == AIMessage(content="hello")
async def test_generic_fake_chat_model_ainvoke() -> None:
# Will alternate between responding with hello and goodbye
infinite_cycle = cycle([AIMessage(content="hello"), AIMessage(content="goodbye")])
model = GenericFakeChatModel(messages=infinite_cycle)
response = await model.ainvoke("meow")
assert response == AIMessage(content="hello")
response = await model.ainvoke("kitty")
assert response == AIMessage(content="goodbye")
response = await model.ainvoke("meow")
assert response == AIMessage(content="hello")
@@ -0,0 +1,11 @@
from pathlib import Path
from typing import List
HERE = Path(__file__).parent
# PUBLIC API
def get_sample_paths() -> List[Path]:
"""List all fixtures."""
return list(HERE.glob("sample.*"))
Binary file not shown.
Binary file not shown.
File diff suppressed because one or more lines are too long
Binary file not shown.
Binary file not shown.
File diff suppressed because one or more lines are too long
@@ -0,0 +1,42 @@
🦜️ LangChain
Underline
Bold
Italics
Col 1
Col 2
Row 1
1
2
Row 2
3
4
Link: https://www.langchain.com/
* Item 1
* Item 2
* Item 3
* We also love cats 🐱
Image
@@ -0,0 +1,54 @@
from server.extraction_runnable import ExtractResponse, _deduplicate
async def test_deduplication_different_resutls() -> None:
"""Test deduplication of extraction results."""
result = _deduplicate(
[
{"data": [{"name": "Chester", "age": 42}]},
{"data": [{"name": "Jane", "age": 42}]},
]
)
expected = ExtractResponse(
data=[
{"name": "Chester", "age": 42},
{"name": "Jane", "age": 42},
]
)
assert expected == result
result = _deduplicate(
[
{
"data": [
{"field_1": 1, "field_2": "a"},
{"field_1": 2, "field_2": "b"},
]
},
{
"data": [
{"field_1": 1, "field_2": "a"},
{"field_1": 2, "field_2": "c"},
]
},
]
)
expected = ExtractResponse(
data=[
{"field_1": 1, "field_2": "a"},
{"field_1": 2, "field_2": "b"},
{"field_1": 2, "field_2": "c"},
]
)
assert expected == result
# Test with data being a list of strings
result = _deduplicate([{"data": ["1", "2"]}, {"data": ["1", "3"]}])
expected = ExtractResponse(data=["1", "2", "3"])
assert expected == result
# Test with data being a mix of integer and string
result = _deduplicate([{"data": [1, "2"]}, {"data": ["1", "3"]}])
expected = ExtractResponse(data=[1, "2", "1", "3"])
assert expected == result
+46
View File
@@ -0,0 +1,46 @@
"""Test parsing logic."""
import mimetypes
from langchain.document_loaders import Blob
from extraction.parsing import (
MIMETYPE_BASED_PARSER,
SUPPORTED_MIMETYPES,
)
from tests.unit_tests.fixtures import get_sample_paths
def test_list_of_supported_mimetypes() -> None:
"""This list should generally grow! Protecting against typos in mimetypes."""
assert SUPPORTED_MIMETYPES == [
# Two MS Word mimetypes are disabled for now
# Need to install unstructured to enable them
# "application/msword",
"application/pdf",
# "application/vnd.openxmlformats-officedocument.wordprocessingml.document",
"text/html",
"text/plain",
]
def test_attempt_to_parse_each_fixture() -> None:
"""Attempt to parse supported fixtures."""
seen_mimetypes = set()
for path in get_sample_paths():
type_, _ = mimetypes.guess_type(path)
if type_ not in SUPPORTED_MIMETYPES:
continue
seen_mimetypes.add(type_)
blob = Blob.from_path(path)
documents = MIMETYPE_BASED_PARSER.parse(blob)
try:
assert len(documents) == 1
doc = documents[0]
assert "source" in doc.metadata
assert doc.metadata["source"] == str(path)
assert "🦜" in doc.page_content
except Exception as e:
raise AssertionError(f"Failed to parse {path}") from e
known_missing = {"application/msword"}
assert set(SUPPORTED_MIMETYPES) - known_missing == seen_mimetypes
+22
View File
@@ -0,0 +1,22 @@
from extraction.parsing import _guess_mimetype
from tests.unit_tests.fixtures import get_sample_paths
async def test_mimetype_guessing() -> None:
"""Verify mimetype guessing for all fixtures."""
name_to_mime = {}
for file in sorted(get_sample_paths()):
data = file.read_bytes()
name_to_mime[file.name] = _guess_mimetype(data)
assert {
"sample.docx": (
"application/vnd.openxmlformats-officedocument.wordprocessingml.document"
),
"sample.epub": "application/epub+zip",
"sample.html": "text/html",
"sample.odt": "application/vnd.oasis.opendocument.text",
"sample.pdf": "application/pdf",
"sample.rtf": "text/rtf",
"sample.txt": "text/plain",
} == name_to_mime
+104
View File
@@ -0,0 +1,104 @@
from typing import List
import pytest
from langchain.pydantic_v1 import BaseModel, Field
from extraction.utils import (
convert_json_schema_to_openai_schema,
)
from server.extraction_runnable import ExtractionExample, _make_prompt_template
def test_convert_json_schema_to_openai_schema() -> None:
"""Test converting a JSON schema to an OpenAI schema."""
class Person(BaseModel):
name: str = Field(..., description="The name of the person.")
age: int = Field(..., description="The age of the person.")
schema = Person.schema()
assert schema == {
"properties": {
"age": {
"description": "The age of the person.",
"title": "Age",
"type": "integer",
},
"name": {
"description": "The name of the person.",
"title": "Name",
"type": "string",
},
},
"required": ["name", "age"],
"title": "Person",
"type": "object",
}
openai_schema = convert_json_schema_to_openai_schema(schema)
assert openai_schema == {
"description": "Extract information matching the given schema.",
"name": "extractor",
"parameters": {
"properties": {
"data": {
"items": {
"properties": {
"age": {
"description": "The age of the person.",
"type": "integer",
},
"name": {
"description": "The name of the person.",
"type": "string",
},
},
"required": ["name", "age"],
"type": "object",
},
"type": "array",
}
},
"required": ["data"],
"type": "object",
},
}
def test_make_prompt_template() -> None:
"""Test making a system message from instructions and examples."""
instructions = "Test instructions."
examples = [
ExtractionExample(
text="Test text.",
output=[
{"name": "Test Name", "age": 0},
{"name": "Test Name 2", "age": 1},
],
)
]
prefix = (
"You are a top-tier algorithm for extracting information from text. "
"Only extract information that is relevant to the provided text. "
"If no information is relevant, use the schema and output "
"an empty list where appropriate."
)
prompt = _make_prompt_template(instructions, examples, "name")
messages = prompt.messages
assert 4 == len(messages)
system = messages[0].prompt.template
assert system.startswith(prefix)
assert system.endswith(instructions)
example_input = messages[1]
assert example_input.content == "Test text."
example_output = messages[2]
assert "function_call" in example_output.additional_kwargs
assert example_output.additional_kwargs["function_call"]["name"] == "name"
prompt = _make_prompt_template(instructions, None, "name")
assert 2 == len(prompt.messages)
prompt = _make_prompt_template(None, examples, "name")
assert 4 == len(prompt.messages)
@@ -0,0 +1,16 @@
import pytest
from server.validators import validate_json_schema
def test_validate_json_schema() -> None:
"""Test validate_json_schema."""
# TODO: Validate more extensively to make sure that it actually validates
# the schema as expected.
with pytest.raises(Exception):
validate_json_schema({"type": "meow"})
with pytest.raises(Exception):
validate_json_schema({"type": "str"})
validate_json_schema({"type": "string"})
+25
View File
@@ -0,0 +1,25 @@
from contextlib import asynccontextmanager
from typing import Optional
import httpx
from fastapi import FastAPI
from httpx import AsyncClient
@asynccontextmanager
async def get_async_test_client(
server: FastAPI, *, path: Optional[str] = None, raise_app_exceptions: bool = True
) -> AsyncClient:
"""Get an async client."""
url = "http://localhost:9999/"
if path:
url += path
transport = httpx.ASGITransport(
app=server,
raise_app_exceptions=raise_app_exceptions,
)
async_client = AsyncClient(app=server, base_url=url, transport=transport)
try:
yield async_client
finally:
await async_client.aclose()
+30
View File
@@ -0,0 +1,30 @@
version: "3"
name: langchain-extract
services:
postgres:
# Careful if bumping postgres version.
# Make sure to keep in sync with CI
# version if being tested on CI.
image: postgres:16
ports:
- "5432:5432"
environment:
POSTGRES_DB: langchain
POSTGRES_USER: langchain
POSTGRES_PASSWORD: langchain
volumes:
- postgres_data:/var/lib/postgresql/data
# For rely on docker compose to spin up postgres
# but developer using docker
# Add backend when we actually need it
# backend:
# build: ./backend
# ports:
# - "8000:8000"
# depends_on:
# - postgres
volumes:
postgres_data: