mirror of
https://github.com/langchain-ai/langchain-extract.git
synced 2026-07-01 20:24:03 -04:00
Initial draft
This commit is contained in:
@@ -0,0 +1,91 @@
|
||||
# An action for setting up poetry install with caching.
|
||||
# Using a custom action since the default action does not
|
||||
# take poetry install groups into account.
|
||||
# Action code from:
|
||||
# https://github.com/actions/setup-python/issues/505#issuecomment-1273013236
|
||||
name: poetry-install-with-caching
|
||||
description: Poetry install with support for caching of dependency groups.
|
||||
|
||||
inputs:
|
||||
python-version:
|
||||
description: Python version, supporting MAJOR.MINOR only
|
||||
required: true
|
||||
|
||||
poetry-version:
|
||||
description: Poetry version
|
||||
required: true
|
||||
|
||||
cache-key:
|
||||
description: Cache key to use for manual handling of caching
|
||||
required: true
|
||||
|
||||
working-directory:
|
||||
description: Directory whose poetry.lock file should be cached
|
||||
required: true
|
||||
|
||||
runs:
|
||||
using: composite
|
||||
steps:
|
||||
- uses: actions/setup-python@v4
|
||||
name: Setup python ${{ inputs.python-version }}
|
||||
with:
|
||||
python-version: ${{ inputs.python-version }}
|
||||
|
||||
- uses: actions/cache@v3
|
||||
id: cache-bin-poetry
|
||||
name: Cache Poetry binary - Python ${{ inputs.python-version }}
|
||||
env:
|
||||
SEGMENT_DOWNLOAD_TIMEOUT_MIN: "1"
|
||||
with:
|
||||
path: |
|
||||
/opt/pipx/venvs/poetry
|
||||
# This step caches the poetry installation, so make sure it's keyed on the poetry version as well.
|
||||
key: bin-poetry-${{ runner.os }}-${{ runner.arch }}-py-${{ inputs.python-version }}-${{ inputs.poetry-version }}
|
||||
|
||||
- name: Refresh shell hashtable and fixup softlinks
|
||||
if: steps.cache-bin-poetry.outputs.cache-hit == 'true'
|
||||
shell: bash
|
||||
env:
|
||||
POETRY_VERSION: ${{ inputs.poetry-version }}
|
||||
PYTHON_VERSION: ${{ inputs.python-version }}
|
||||
run: |
|
||||
set -eux
|
||||
|
||||
# Refresh the shell hashtable, to ensure correct `which` output.
|
||||
hash -r
|
||||
|
||||
# `actions/cache@v3` doesn't always seem able to correctly unpack softlinks.
|
||||
# Delete and recreate the softlinks pipx expects to have.
|
||||
rm /opt/pipx/venvs/poetry/bin/python
|
||||
cd /opt/pipx/venvs/poetry/bin
|
||||
ln -s "$(which "python$PYTHON_VERSION")" python
|
||||
chmod +x python
|
||||
cd /opt/pipx_bin/
|
||||
ln -s /opt/pipx/venvs/poetry/bin/poetry poetry
|
||||
chmod +x poetry
|
||||
|
||||
# Ensure everything got set up correctly.
|
||||
/opt/pipx/venvs/poetry/bin/python --version
|
||||
/opt/pipx_bin/poetry --version
|
||||
|
||||
- name: Install poetry
|
||||
if: steps.cache-bin-poetry.outputs.cache-hit != 'true'
|
||||
shell: bash
|
||||
env:
|
||||
POETRY_VERSION: ${{ inputs.poetry-version }}
|
||||
PYTHON_VERSION: ${{ inputs.python-version }}
|
||||
run: pipx install "poetry==$POETRY_VERSION" --python "python$PYTHON_VERSION" --verbose
|
||||
|
||||
- name: Restore pip and poetry cached dependencies
|
||||
uses: actions/cache@v3
|
||||
env:
|
||||
SEGMENT_DOWNLOAD_TIMEOUT_MIN: "4"
|
||||
WORKDIR: ${{ inputs.working-directory == '' && '.' || inputs.working-directory }}
|
||||
with:
|
||||
path: |
|
||||
~/.cache/pip
|
||||
~/.cache/pypoetry/virtualenvs
|
||||
~/.cache/pypoetry/cache
|
||||
~/.cache/pypoetry/artifacts
|
||||
${{ env.WORKDIR }}/.venv
|
||||
key: py-deps-${{ runner.os }}-${{ runner.arch }}-py-${{ inputs.python-version }}-poetry-${{ inputs.poetry-version }}-${{ inputs.cache-key }}-${{ hashFiles(format('{0}/**/poetry.lock', env.WORKDIR)) }}
|
||||
@@ -0,0 +1,83 @@
|
||||
name: lint
|
||||
|
||||
on:
|
||||
workflow_call:
|
||||
inputs:
|
||||
working-directory:
|
||||
required: true
|
||||
type: string
|
||||
description: "From which folder this pipeline executes"
|
||||
|
||||
env:
|
||||
POETRY_VERSION: "1.7.1"
|
||||
WORKDIR: ${{ inputs.working-directory == '' && '.' || inputs.working-directory }}
|
||||
|
||||
jobs:
|
||||
build:
|
||||
runs-on: ubuntu-latest
|
||||
env:
|
||||
# This number is set "by eye": we want it to be big enough
|
||||
# so that it's bigger than the number of commits in any reasonable PR,
|
||||
# and also as small as possible since increasing the number makes
|
||||
# the initial `git fetch` slower.
|
||||
FETCH_DEPTH: 50
|
||||
strategy:
|
||||
matrix:
|
||||
# Only lint on the min and max supported Python versions.
|
||||
# It's extremely unlikely that there's a lint issue on any version in between
|
||||
# that doesn't show up on the min or max versions.
|
||||
#
|
||||
# GitHub rate-limits how many jobs can be running at any one time.
|
||||
# Starting new jobs is also relatively slow,
|
||||
# so linting on fewer versions makes CI faster.
|
||||
python-version:
|
||||
- "3.8"
|
||||
- "3.11"
|
||||
steps:
|
||||
- uses: actions/checkout@v3
|
||||
- name: Set up Python ${{ matrix.python-version }} + Poetry ${{ env.POETRY_VERSION }}
|
||||
uses: "./.github/actions/poetry_setup"
|
||||
with:
|
||||
python-version: ${{ matrix.python-version }}
|
||||
poetry-version: ${{ env.POETRY_VERSION }}
|
||||
working-directory: ${{ inputs.working-directory }}
|
||||
cache-key: lint-with-extras
|
||||
|
||||
- name: Check Poetry File
|
||||
shell: bash
|
||||
working-directory: ${{ inputs.working-directory }}
|
||||
run: |
|
||||
poetry check
|
||||
|
||||
- name: Check lock file
|
||||
shell: bash
|
||||
working-directory: ${{ inputs.working-directory }}
|
||||
run: |
|
||||
poetry lock --check
|
||||
|
||||
- name: Install dependencies
|
||||
# Also installs dev/lint/test/typing dependencies, to ensure we have
|
||||
# type hints for as many of our libraries as possible.
|
||||
# This helps catch errors that require dependencies to be spotted, for example:
|
||||
# https://github.com/langchain-ai/langchain/pull/10249/files#diff-935185cd488d015f026dcd9e19616ff62863e8cde8c0bee70318d3ccbca98341
|
||||
#
|
||||
# If you change this configuration, make sure to change the `cache-key`
|
||||
# in the `poetry_setup` action above to stop using the old cache.
|
||||
# It doesn't matter how you change it, any change will cause a cache-bust.
|
||||
working-directory: ${{ inputs.working-directory }}
|
||||
run: |
|
||||
poetry install --with dev,lint,test,typing
|
||||
|
||||
- name: Get .mypy_cache to speed up mypy
|
||||
uses: actions/cache@v3
|
||||
env:
|
||||
SEGMENT_DOWNLOAD_TIMEOUT_MIN: "2"
|
||||
with:
|
||||
path: |
|
||||
${{ env.WORKDIR }}/.mypy_cache
|
||||
key: mypy-${{ runner.os }}-${{ runner.arch }}-py${{ matrix.python-version }}-${{ inputs.working-directory }}-${{ hashFiles(format('{0}/poetry.lock', env.WORKDIR)) }}
|
||||
|
||||
- name: Analysing the code with our lint
|
||||
working-directory: ${{ inputs.working-directory }}
|
||||
run: |
|
||||
make lint
|
||||
@@ -0,0 +1,57 @@
|
||||
name: test
|
||||
|
||||
on:
|
||||
workflow_call:
|
||||
inputs:
|
||||
working-directory:
|
||||
required: true
|
||||
type: string
|
||||
description: "From which folder this pipeline executes"
|
||||
|
||||
env:
|
||||
POETRY_VERSION: "1.7.1"
|
||||
|
||||
jobs:
|
||||
build:
|
||||
defaults:
|
||||
run:
|
||||
working-directory: ${{ inputs.working-directory }}
|
||||
runs-on: ubuntu-latest
|
||||
strategy:
|
||||
matrix:
|
||||
python-version:
|
||||
- "3.8"
|
||||
- "3.9"
|
||||
- "3.10"
|
||||
- "3.11"
|
||||
name: Python ${{ matrix.python-version }}
|
||||
steps:
|
||||
- uses: actions/checkout@v3
|
||||
|
||||
- name: Set up Python ${{ matrix.python-version }} + Poetry ${{ env.POETRY_VERSION }}
|
||||
uses: "./.github/actions/poetry_setup"
|
||||
with:
|
||||
python-version: ${{ matrix.python-version }}
|
||||
poetry-version: ${{ env.POETRY_VERSION }}
|
||||
working-directory: ${{ inputs.working-directory }}
|
||||
cache-key: core
|
||||
|
||||
- name: Install dependencies
|
||||
shell: bash
|
||||
run: poetry install
|
||||
|
||||
- name: Run core tests
|
||||
shell: bash
|
||||
run: make test
|
||||
|
||||
- name: Ensure the tests did not create any additional files
|
||||
shell: bash
|
||||
run: |
|
||||
set -eu
|
||||
|
||||
STATUS="$(git status)"
|
||||
echo "$STATUS"
|
||||
|
||||
# grep will exit non-zero if the target message isn't found,
|
||||
# and `set -e` above will cause the step to fail.
|
||||
echo "$STATUS" | grep 'nothing to commit, working tree clean'
|
||||
@@ -0,0 +1,108 @@
|
||||
---
|
||||
name: Run CI Tests
|
||||
|
||||
on:
|
||||
push:
|
||||
branches: [ main ]
|
||||
pull_request:
|
||||
workflow_dispatch: # Allows to trigger the workflow manually in GitHub UI
|
||||
|
||||
# If another push to the same PR or branch happens while this workflow is still running,
|
||||
# cancel the earlier run in favor of the next run.
|
||||
#
|
||||
# There's no point in testing an outdated version of the code. GitHub only allows
|
||||
# a limited number of job runners to be active at the same time, so it's better to cancel
|
||||
# pointless jobs early so that more useful jobs can run sooner.
|
||||
concurrency:
|
||||
group: ${{ github.workflow }}-${{ github.ref }}
|
||||
cancel-in-progress: true
|
||||
|
||||
env:
|
||||
POETRY_VERSION: "1.7.1"
|
||||
WORKDIR: "./backend"
|
||||
|
||||
jobs:
|
||||
lint:
|
||||
uses:
|
||||
./.github/workflows/_lint.yml
|
||||
with:
|
||||
working-directory: ./backend
|
||||
secrets: inherit
|
||||
test:
|
||||
timeout-minutes: 5
|
||||
runs-on: ubuntu-latest
|
||||
defaults:
|
||||
run:
|
||||
working-directory: ${{ env.WORKDIR }}
|
||||
services:
|
||||
postgres:
|
||||
# ensure postgres version this stays in sync with prod database
|
||||
# and with postgres version used in docker compose
|
||||
image: postgres:16
|
||||
env:
|
||||
# optional (defaults to `postgres`)
|
||||
POSTGRES_DB: langchain_test
|
||||
# required
|
||||
POSTGRES_PASSWORD: langchain
|
||||
# optional (defaults to `5432`)
|
||||
POSTGRES_PORT: 5432
|
||||
# optional (defaults to `postgres`)
|
||||
POSTGRES_USER: langchain
|
||||
ports:
|
||||
# maps tcp port 5432 on service container to the host
|
||||
- 5432:5432
|
||||
# set health checks to wait until postgres has started
|
||||
options: >-
|
||||
--health-cmd pg_isready
|
||||
--health-interval 3s
|
||||
--health-timeout 5s
|
||||
--health-retries 10
|
||||
strategy:
|
||||
matrix:
|
||||
python-version:
|
||||
- "3.8"
|
||||
- "3.9"
|
||||
- "3.10"
|
||||
- "3.11"
|
||||
name: Python ${{ matrix.python-version }} tests
|
||||
steps:
|
||||
- uses: actions/checkout@v3
|
||||
|
||||
- name: Set up Python ${{ matrix.python-version }} + Poetry ${{ env.POETRY_VERSION }}
|
||||
uses: "./.github/actions/poetry_setup"
|
||||
with:
|
||||
python-version: ${{ matrix.python-version }}
|
||||
poetry-version: ${{ env.POETRY_VERSION }}
|
||||
working-directory: ${{ env.WORKDIR }}
|
||||
cache-key: langchain-extract-all
|
||||
- name: Test database connection
|
||||
run: |
|
||||
# Set up postgresql-client
|
||||
sudo apt-get install -y postgresql-client
|
||||
# Test psql connection
|
||||
psql -h localhost -p 5432 -U langchain -d langchain_test -c "SELECT 1;"
|
||||
env:
|
||||
# postgress password is required; alternatively, you can run:
|
||||
# `PGPASSWORD=postgres_password psql ...`
|
||||
PGPASSWORD: langchain
|
||||
|
||||
- name: Install dependencies
|
||||
shell: bash
|
||||
run: |
|
||||
echo "Running tests, installing dependencies with poetry..."
|
||||
poetry install --with test,lint,typing,docs
|
||||
|
||||
- name: Run tests
|
||||
run: make test
|
||||
|
||||
- name: Ensure the tests did not create any additional files
|
||||
shell: bash
|
||||
run: |
|
||||
set -eu
|
||||
|
||||
STATUS="$(git status)"
|
||||
echo "$STATUS"
|
||||
|
||||
# grep will exit non-zero if the target message isn't found,
|
||||
# and `set -e` above will cause the step to fail.
|
||||
echo "$STATUS" | grep 'nothing to commit, working tree clean'
|
||||
+162
@@ -0,0 +1,162 @@
|
||||
### Python template
|
||||
# Byte-compiled / optimized / DLL files
|
||||
__pycache__/
|
||||
*.py[cod]
|
||||
*$py.class
|
||||
|
||||
# C extensions
|
||||
*.so
|
||||
|
||||
# Distribution / packaging
|
||||
.Python
|
||||
build/
|
||||
develop-eggs/
|
||||
dist/
|
||||
downloads/
|
||||
eggs/
|
||||
.eggs/
|
||||
lib/
|
||||
lib64/
|
||||
parts/
|
||||
sdist/
|
||||
var/
|
||||
wheels/
|
||||
share/python-wheels/
|
||||
*.egg-info/
|
||||
.installed.cfg
|
||||
*.egg
|
||||
MANIFEST
|
||||
|
||||
# PyInstaller
|
||||
# Usually these files are written by a python script from a template
|
||||
# before PyInstaller builds the exe, so as to inject date/other infos into it.
|
||||
*.manifest
|
||||
*.spec
|
||||
|
||||
# Installer logs
|
||||
pip-log.txt
|
||||
pip-delete-this-directory.txt
|
||||
|
||||
# Unit test / coverage reports
|
||||
htmlcov/
|
||||
.tox/
|
||||
.nox/
|
||||
.coverage
|
||||
.coverage.*
|
||||
.cache
|
||||
nosetests.xml
|
||||
coverage.xml
|
||||
*.cover
|
||||
*.py,cover
|
||||
.hypothesis/
|
||||
.pytest_cache/
|
||||
cover/
|
||||
|
||||
# Translations
|
||||
*.mo
|
||||
*.pot
|
||||
|
||||
# Django stuff:
|
||||
*.log
|
||||
local_settings.py
|
||||
db.sqlite3
|
||||
db.sqlite3-journal
|
||||
|
||||
# Flask stuff:
|
||||
instance/
|
||||
.webassets-cache
|
||||
|
||||
# Scrapy stuff:
|
||||
.scrapy
|
||||
|
||||
# Sphinx documentation
|
||||
docs/_build/
|
||||
|
||||
# PyBuilder
|
||||
.pybuilder/
|
||||
target/
|
||||
|
||||
# Jupyter Notebook
|
||||
.ipynb_checkpoints
|
||||
|
||||
# IPython
|
||||
profile_default/
|
||||
ipython_config.py
|
||||
|
||||
# pyenv
|
||||
# For a library or package, you might want to ignore these files since the code is
|
||||
# intended to run in multiple environments; otherwise, check them in:
|
||||
# .python-version
|
||||
|
||||
# pipenv
|
||||
# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
|
||||
# However, in case of collaboration, if having platform-specific dependencies or dependencies
|
||||
# having no cross-platform support, pipenv may install dependencies that don't work, or not
|
||||
# install all needed dependencies.
|
||||
#Pipfile.lock
|
||||
|
||||
# poetry
|
||||
# Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
|
||||
# This is especially recommended for binary packages to ensure reproducibility, and is more
|
||||
# commonly ignored for libraries.
|
||||
# https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
|
||||
#poetry.lock
|
||||
|
||||
# pdm
|
||||
# Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
|
||||
#pdm.lock
|
||||
# pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
|
||||
# in version control.
|
||||
# https://pdm.fming.dev/#use-with-ide
|
||||
.pdm.toml
|
||||
|
||||
# PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
|
||||
__pypackages__/
|
||||
|
||||
# Celery stuff
|
||||
celerybeat-schedule
|
||||
celerybeat.pid
|
||||
|
||||
# SageMath parsed files
|
||||
*.sage.py
|
||||
|
||||
# Environments
|
||||
.env
|
||||
.venv
|
||||
env/
|
||||
venv/
|
||||
ENV/
|
||||
env.bak/
|
||||
venv.bak/
|
||||
|
||||
# Spyder project settings
|
||||
.spyderproject
|
||||
.spyproject
|
||||
|
||||
# Rope project settings
|
||||
.ropeproject
|
||||
|
||||
# mkdocs documentation
|
||||
/site
|
||||
|
||||
# mypy
|
||||
.mypy_cache/
|
||||
.dmypy.json
|
||||
dmypy.json
|
||||
|
||||
# Pyre type checker
|
||||
.pyre/
|
||||
|
||||
# pytype static type analyzer
|
||||
.pytype/
|
||||
|
||||
# Cython debug symbols
|
||||
cython_debug/
|
||||
|
||||
# PyCharm
|
||||
# JetBrains specific template is maintained in a separate JetBrains.gitignore that can
|
||||
# be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
|
||||
# and can be added to the global gitignore or merged into this file. For a more nuclear
|
||||
# option (not recommended) you can uncomment the following to ignore the entire idea folder.
|
||||
#.idea/
|
||||
.DS_Store
|
||||
@@ -0,0 +1,21 @@
|
||||
MIT License
|
||||
|
||||
Copyright (c) 2024-Present Langchain AI
|
||||
|
||||
Permission is hereby granted, free of charge, to any person obtaining a copy
|
||||
of this software and associated documentation files (the "Software"), to deal
|
||||
in the Software without restriction, including without limitation the rights
|
||||
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
||||
copies of the Software, and to permit persons to whom the Software is
|
||||
furnished to do so, subject to the following conditions:
|
||||
|
||||
The above copyright notice and this permission notice shall be included in all
|
||||
copies or substantial portions of the Software.
|
||||
|
||||
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
||||
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
||||
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
||||
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
||||
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
||||
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
||||
SOFTWARE.
|
||||
@@ -0,0 +1,63 @@
|
||||
🚧 Under Active Development 🚧
|
||||
|
||||
Please expect breaking changes!
|
||||
|
||||
# 🦜📝 LangChain Extract
|
||||
|
||||
[](https://github.com/langchain-ai/langchain-extract/actions/workflows/ci.yml)
|
||||
[](https://opensource.org/licenses/MIT)
|
||||
[](https://twitter.com/langchainai)
|
||||
[](https://discord.gg/6adMQxSpJS)
|
||||
[](https://github.com/langchain-ai/langchain-extract/issues)
|
||||
|
||||
|
||||
# Set up
|
||||
|
||||
## Services
|
||||
|
||||
The root folder contains a docker compose file which will a launch a postgres
|
||||
instance.
|
||||
|
||||
```
|
||||
docker compose up
|
||||
```
|
||||
|
||||
At the time of writing, the app wasn't using postgres yet!
|
||||
|
||||
## App
|
||||
|
||||
```sh
|
||||
cd [root]/backend
|
||||
```
|
||||
|
||||
Set up the environment using poetry:
|
||||
|
||||
```sh
|
||||
poetry install --with lint,dev,test
|
||||
```
|
||||
|
||||
Verify that unit tests pass (they probably wont?)
|
||||
|
||||
# Test and format
|
||||
|
||||
Testing and formatting is done using a Makefile inside `[root]/backend`
|
||||
|
||||
```sh
|
||||
make format
|
||||
```
|
||||
|
||||
```sh
|
||||
make test
|
||||
```
|
||||
|
||||
# Launch Server
|
||||
|
||||
From `[root]/backend`:
|
||||
|
||||
```sh
|
||||
python -m server.main
|
||||
```
|
||||
|
||||
# Example client
|
||||
|
||||
See `docs/source/notebooks` for an example client.
|
||||
@@ -0,0 +1,60 @@
|
||||
.PHONY: all lint format test help
|
||||
|
||||
# Default target executed when no arguments are given to make.
|
||||
all: help
|
||||
|
||||
######################
|
||||
# TESTING AND COVERAGE
|
||||
######################
|
||||
|
||||
# Define a variable for the test file path.
|
||||
TEST_FILE ?= tests/unit_tests/
|
||||
|
||||
test:
|
||||
poetry run pytest --disable-socket --allow-unix-socket $(TEST_FILE)
|
||||
|
||||
test_watch:
|
||||
poetry run ptw . -- $(TEST_FILE)
|
||||
|
||||
openapi:
|
||||
OPENAI_API_KEY=placeholder python -c "from server import main; import json; print(json.dumps(main.app.openapi()))" > openapi.json
|
||||
|
||||
|
||||
######################
|
||||
# LINTING AND FORMATTING
|
||||
######################
|
||||
|
||||
# Define a variable for Python and notebook files.
|
||||
lint format: PYTHON_FILES=.
|
||||
lint_diff format_diff: PYTHON_FILES=$(shell git diff --relative=. --name-only --diff-filter=d master | grep -E '\.py$$|\.ipynb$$')
|
||||
|
||||
lint lint_diff:
|
||||
[ "$(PYTHON_FILES)" = "" ] || poetry run ruff format $(PYTHON_FILES) --diff
|
||||
# [ "$(PYTHON_FILES)" = "" ] || poetry run mypy $(PYTHON_FILES)
|
||||
|
||||
format format_diff:
|
||||
[ "$(PYTHON_FILES)" = "" ] || poetry run ruff format $(PYTHON_FILES)
|
||||
[ "$(PYTHON_FILES)" = "" ] || poetry run ruff --select I --fix $(PYTHON_FILES)
|
||||
|
||||
spell_check:
|
||||
poetry run codespell --toml pyproject.toml
|
||||
|
||||
spell_fix:
|
||||
poetry run codespell --toml pyproject.toml -w
|
||||
|
||||
######################
|
||||
# HELP
|
||||
######################
|
||||
|
||||
help:
|
||||
@echo '===================='
|
||||
@echo '-- LINTING --'
|
||||
@echo 'format - run code formatters'
|
||||
@echo 'lint - run linters'
|
||||
@echo 'spell_check - run codespell on the project'
|
||||
@echo 'spell_fix - run codespell on the project and fix the errors'
|
||||
@echo '-- TESTS --'
|
||||
@echo 'coverage - run unit tests and generate coverage report'
|
||||
@echo 'test - run unit tests'
|
||||
@echo 'test TEST_FILE=<test_file> - run all tests in file'
|
||||
@echo '-- DOCUMENTATION tasks are from the top-level Makefile --'
|
||||
@@ -0,0 +1 @@
|
||||
See readme at repo root.
|
||||
@@ -0,0 +1,134 @@
|
||||
import uuid
|
||||
from datetime import datetime
|
||||
from typing import Generator
|
||||
|
||||
from sqlalchemy import Column, DateTime, ForeignKey, String, Text, create_engine
|
||||
from sqlalchemy.dialects.postgresql import JSONB, UUID
|
||||
from sqlalchemy.ext.declarative import declarative_base
|
||||
from sqlalchemy.orm import Session, relationship, sessionmaker
|
||||
from sqlalchemy.sql import func
|
||||
|
||||
from server.settings import get_postgres_url
|
||||
|
||||
ENGINE = create_engine(get_postgres_url())
|
||||
SessionClass = sessionmaker(bind=ENGINE)
|
||||
|
||||
Base = declarative_base()
|
||||
|
||||
|
||||
# TODO(Eugene): Convert to async code
|
||||
def get_session() -> Generator[Session, None, None]:
|
||||
"""Create a new session."""
|
||||
session = SessionClass()
|
||||
|
||||
try:
|
||||
yield session
|
||||
except:
|
||||
session.rollback()
|
||||
raise
|
||||
finally:
|
||||
session.close()
|
||||
|
||||
|
||||
class TimestampedModel(Base):
|
||||
"""An abstract base model that includes the timestamp fields."""
|
||||
|
||||
__abstract__ = True
|
||||
|
||||
created_at = Column(
|
||||
DateTime,
|
||||
default=datetime.utcnow,
|
||||
comment="The time the record was created (UTC).",
|
||||
)
|
||||
updated_at = Column(
|
||||
DateTime,
|
||||
default=datetime.utcnow,
|
||||
onupdate=datetime.utcnow,
|
||||
doc="The time the record was last updated (UTC).",
|
||||
)
|
||||
|
||||
# This is our own uuid assigned to the artifact.
|
||||
# By construction guaranteed to be unique no matter what.
|
||||
uuid = Column(
|
||||
UUID(as_uuid=True),
|
||||
primary_key=True,
|
||||
default=lambda: str(uuid.uuid4()),
|
||||
doc="Unique identifier for this model.",
|
||||
)
|
||||
|
||||
|
||||
class Extractor(TimestampedModel):
|
||||
__tablename__ = "extractors"
|
||||
|
||||
name = Column(
|
||||
String(100),
|
||||
nullable=False,
|
||||
server_default="",
|
||||
comment="The name of the extractor.",
|
||||
)
|
||||
created_at = Column(
|
||||
DateTime(timezone=True),
|
||||
server_default=func.now(),
|
||||
comment="Time when this extracted was originally created.",
|
||||
)
|
||||
modified_at = Column(
|
||||
DateTime(timezone=True),
|
||||
onupdate=func.now(),
|
||||
comment="Last time this was modified.",
|
||||
)
|
||||
schema = Column(
|
||||
JSONB,
|
||||
nullable=False,
|
||||
comment="JSON Schema that describes what content will be extracted from the document",
|
||||
)
|
||||
description = Column(
|
||||
String(100),
|
||||
nullable=False,
|
||||
server_default="",
|
||||
comment="Surfaced via UI to the users.",
|
||||
)
|
||||
instruction = Column(
|
||||
Text, nullable=False, comment="The prompt to the language model."
|
||||
) # TODO: This will need to evolve
|
||||
|
||||
examples = relationship("Example", backref="extractor")
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return f"<Extractor(id={self.uuid}, description={self.description})>"
|
||||
|
||||
|
||||
class Example(TimestampedModel):
|
||||
"""A representation of an example.
|
||||
|
||||
Examples consist of content together with the expected output.
|
||||
|
||||
The output is a JSON object that is expected to be extracted from the content.
|
||||
|
||||
The JSON object should be valid according to the schema of the associated extractor.
|
||||
|
||||
The JSON object is defined by the schema of the associated extractor, so
|
||||
it's perfectly fine for a given example to represent the extraction
|
||||
of multiple instances of some object from the content since
|
||||
the JSON schema can represent a list of objects.
|
||||
"""
|
||||
|
||||
__tablename__ = "examples"
|
||||
|
||||
content = Column(
|
||||
Text,
|
||||
nullable=False,
|
||||
comment="The input portion of the example.",
|
||||
)
|
||||
output = Column(
|
||||
JSONB,
|
||||
comment="The output associated with the example.",
|
||||
)
|
||||
extractor_id = Column(
|
||||
UUID(as_uuid=True),
|
||||
ForeignKey("extractors.uuid", ondelete="CASCADE"),
|
||||
nullable=False,
|
||||
comment="Foreign key referencing the associated extractor.",
|
||||
)
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return f"<Example(uuid={self.uuid}, content={self.content[:20]}>"
|
||||
@@ -0,0 +1,65 @@
|
||||
"""Convert binary input to blobs and parse them using the appropriate parser."""
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import BinaryIO, List
|
||||
|
||||
from langchain.document_loaders.parsers import BS4HTMLParser, PDFMinerParser
|
||||
from langchain.document_loaders.parsers.generic import MimeTypeBasedParser
|
||||
from langchain.document_loaders.parsers.txt import TextParser
|
||||
from langchain_community.document_loaders import Blob
|
||||
from langchain_core.documents import Document
|
||||
|
||||
HANDLERS = {
|
||||
"application/pdf": PDFMinerParser(),
|
||||
"text/plain": TextParser(),
|
||||
"text/html": BS4HTMLParser(),
|
||||
# Disable for now as they rely on unstructured and there's some install
|
||||
# issue with unstructured.
|
||||
# from langchain.document_loaders.parsers.msword import MsWordParser
|
||||
# "application/msword": MsWordParser(),
|
||||
# "application/vnd.openxmlformats-officedocument.wordprocessingml.document": (
|
||||
# MsWordParser()
|
||||
# ),
|
||||
}
|
||||
|
||||
SUPPORTED_MIMETYPES = sorted(HANDLERS.keys())
|
||||
|
||||
|
||||
def _guess_mimetype(file_bytes: bytes) -> str:
|
||||
"""Guess the mime-type of a file."""
|
||||
try:
|
||||
import magic
|
||||
except ImportError as e:
|
||||
raise ImportError(
|
||||
"magic package not found, please install it with `pip install python-magic`"
|
||||
) from e
|
||||
|
||||
mime = magic.Magic(mime=True)
|
||||
mime_type = mime.from_buffer(file_bytes)
|
||||
return mime_type
|
||||
|
||||
|
||||
# PUBLIC API
|
||||
|
||||
MIMETYPE_BASED_PARSER = MimeTypeBasedParser(
|
||||
handlers=HANDLERS,
|
||||
fallback_parser=None,
|
||||
)
|
||||
|
||||
|
||||
def convert_binary_input_to_blob(data: BinaryIO) -> Blob:
|
||||
"""Convert ingestion input to blob."""
|
||||
file_data = data.read()
|
||||
mimetype = _guess_mimetype(file_data)
|
||||
file_name = data.name
|
||||
return Blob.from_data(
|
||||
data=file_data,
|
||||
path=file_name,
|
||||
mime_type=mimetype,
|
||||
)
|
||||
|
||||
|
||||
def parse_binary_input(data: BinaryIO) -> List[Document]:
|
||||
"""Parse binary input."""
|
||||
blob = convert_binary_input_to_blob(data)
|
||||
return MIMETYPE_BASED_PARSER.parse(blob)
|
||||
@@ -0,0 +1,52 @@
|
||||
"""Adapters to convert between different formats."""
|
||||
from __future__ import annotations
|
||||
|
||||
from langchain_core.utils.json_schema import dereference_refs
|
||||
|
||||
|
||||
def _rm_titles(kv: dict) -> dict:
|
||||
"""Remove titles from a dictionary."""
|
||||
new_kv = {}
|
||||
for k, v in kv.items():
|
||||
if k == "title":
|
||||
continue
|
||||
elif isinstance(v, dict):
|
||||
new_kv[k] = _rm_titles(v)
|
||||
else:
|
||||
new_kv[k] = v
|
||||
return new_kv
|
||||
|
||||
|
||||
# PUBLIC API
|
||||
|
||||
|
||||
def convert_json_schema_to_openai_schema(
|
||||
schema: dict,
|
||||
*,
|
||||
rm_titles: bool = True,
|
||||
multi: bool = True,
|
||||
) -> dict:
|
||||
"""Convert JSON schema to a corresponding OpenAI function call."""
|
||||
if multi:
|
||||
# Wrap the schema in an object called "Root" with a property called: "data"
|
||||
# which will be a json array of the original schema.
|
||||
schema_ = {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"data": {
|
||||
"type": "array",
|
||||
"items": dereference_refs(schema),
|
||||
},
|
||||
},
|
||||
"required": ["data"],
|
||||
}
|
||||
else:
|
||||
raise NotImplementedError("Only multi is supported for now.")
|
||||
|
||||
schema_.pop("definitions", None)
|
||||
|
||||
return {
|
||||
"name": "extractor",
|
||||
"description": "Extract information matching the given schema.",
|
||||
"parameters": _rm_titles(schema_) if rm_titles else schema_,
|
||||
}
|
||||
Generated
+4243
File diff suppressed because it is too large
Load Diff
@@ -0,0 +1,91 @@
|
||||
[tool.poetry]
|
||||
name = "langchain-extract"
|
||||
version = "0.0.1"
|
||||
description = "Sample extraction backend."
|
||||
authors = ["LangChain AI"]
|
||||
license = "MIT"
|
||||
readme = "README.md"
|
||||
|
||||
[tool.poetry.dependencies]
|
||||
python = "^3.8.1"
|
||||
langchain = "~0.1"
|
||||
langsmith = ">=0.0.66"
|
||||
fastapi = "^0.109.2"
|
||||
langserve = "^0.0.41"
|
||||
uvicorn = "^0.27.1"
|
||||
pydantic = "^1.10"
|
||||
langchain-openai = "^0.0.6"
|
||||
jsonschema = "^4.21.1"
|
||||
sse-starlette = "^2.0.0"
|
||||
alembic = "^1.13.1"
|
||||
psycopg2 = "^2.9.9"
|
||||
python-magic = "^0.4.27"
|
||||
pdfminer-six = "^20231228"
|
||||
beautifulsoup4 = "^4.12.3"
|
||||
lxml = "^5.1.0"
|
||||
faiss-cpu = "^1.7.4"
|
||||
python-multipart = "^0.0.9"
|
||||
|
||||
[tool.poetry.group.dev.dependencies]
|
||||
jupyterlab = "^3.6.1"
|
||||
|
||||
[tool.poetry.group.typing.dependencies]
|
||||
mypy = "^1.7.0"
|
||||
|
||||
[tool.poetry.group.lint.dependencies]
|
||||
ruff = "^0.1.5"
|
||||
|
||||
[tool.poetry.group.docs.dependencies]
|
||||
|
||||
[tool.poetry.group.test.dependencies]
|
||||
pytest = "^7.2.1"
|
||||
pytest-cov = "^4.0.0"
|
||||
pytest-asyncio = "^0.21.1"
|
||||
pytest-mock = "^3.11.1"
|
||||
pytest-socket = "^0.6.0"
|
||||
pytest-watch = "^4.2.0"
|
||||
pytest-timeout = "^2.2.0"
|
||||
|
||||
|
||||
[tool.ruff]
|
||||
select = [
|
||||
"E", # pycodestyle
|
||||
"F", # pyflakes
|
||||
"I", # isort
|
||||
]
|
||||
extend-include = ["*.ipynb"]
|
||||
|
||||
# Same as Black.
|
||||
line-length = 88
|
||||
|
||||
[tool.mypy]
|
||||
disallow_untyped_defs = "True"
|
||||
ignore_missing_imports = "True"
|
||||
|
||||
[tool.coverage.run]
|
||||
omit = [
|
||||
"tests/*",
|
||||
]
|
||||
|
||||
|
||||
[build-system]
|
||||
requires = ["poetry-core"]
|
||||
build-backend = "poetry.core.masonry.api"
|
||||
|
||||
[tool.pytest.ini_options]
|
||||
# --strict-markers will raise errors on unknown marks.
|
||||
# https://docs.pytest.org/en/7.1.x/how-to/mark.html#raising-errors-on-unknown-marks
|
||||
#
|
||||
# https://docs.pytest.org/en/7.1.x/reference/reference.html
|
||||
# --strict-config any warnings encountered while parsing the `pytest`
|
||||
# section of the configuration file raise errors.
|
||||
addopts = "--strict-markers --strict-config --durations=5 -vv"
|
||||
# Global timeout for all tests. There shuold be a good reason for a test to
|
||||
# take more than 5 second
|
||||
timeout = 5
|
||||
# Registering custom markers.
|
||||
# https://docs.pytest.org/en/7.1.x/example/markers.html#registering-markers
|
||||
markers = [
|
||||
"asyncio: mark tests as requiring asyncio",
|
||||
]
|
||||
asyncio_mode = "auto"
|
||||
Executable
+59
@@ -0,0 +1,59 @@
|
||||
#!/usr/bin/env python
|
||||
"""Run migrations."""
|
||||
import click
|
||||
|
||||
from db.models import ENGINE, Base
|
||||
|
||||
|
||||
@click.group()
|
||||
def cli():
|
||||
"""Database migration commands."""
|
||||
pass
|
||||
|
||||
|
||||
@cli.command()
|
||||
def create():
|
||||
"""Create all tables."""
|
||||
Base.metadata.create_all(ENGINE)
|
||||
click.echo("All tables created successfully.")
|
||||
|
||||
|
||||
@cli.command()
|
||||
@click.confirmation_option(prompt="Are you sure you want to drop all tables?")
|
||||
def drop():
|
||||
"""Drop all tables."""
|
||||
Base.metadata.drop_all(ENGINE)
|
||||
click.echo("All tables dropped successfully.")
|
||||
|
||||
|
||||
@cli.command()
|
||||
def create_test_db():
|
||||
"""Create a test database called langchain_test used for testing purposes."""
|
||||
import psycopg2
|
||||
from psycopg2.errors import DuplicateDatabase
|
||||
|
||||
# establishing the connection
|
||||
conn = psycopg2.connect(
|
||||
database="langchain",
|
||||
user="langchain",
|
||||
password="langchain",
|
||||
host="localhost",
|
||||
port="5432",
|
||||
)
|
||||
conn.autocommit = True
|
||||
|
||||
# Creating a cursor object using the cursor() method
|
||||
with conn.cursor() as cursor:
|
||||
# Preparing query to create a database
|
||||
sql = "CREATE DATABASE langchain_test;"
|
||||
|
||||
# Creating a database
|
||||
try:
|
||||
cursor.execute(sql)
|
||||
print("Database created successfully.")
|
||||
except DuplicateDatabase:
|
||||
print("Database already exists")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
cli()
|
||||
@@ -0,0 +1,74 @@
|
||||
"""Endpoints for managing definition of examples.."""
|
||||
from typing import Any, Dict, List
|
||||
from uuid import UUID
|
||||
|
||||
from fastapi import APIRouter, Depends
|
||||
from sqlalchemy.orm import Session
|
||||
from typing_extensions import Annotated, TypedDict
|
||||
|
||||
from db.models import Example, get_session
|
||||
|
||||
router = APIRouter(
|
||||
prefix="/examples",
|
||||
tags=["example definitions"],
|
||||
responses={404: {"description": "Not found"}},
|
||||
)
|
||||
|
||||
|
||||
class CreateExample(TypedDict):
|
||||
"""A request to create an example."""
|
||||
|
||||
extractor_id: Annotated[UUID, "The extractor ID that this is an example for."]
|
||||
content: Annotated[str, "The input portion of the example."]
|
||||
output: Annotated[
|
||||
List[Any], "JSON object that is expected to be extracted from the content."
|
||||
]
|
||||
|
||||
|
||||
class CreateExampleResponse(TypedDict):
|
||||
"""Response for creating an example."""
|
||||
|
||||
uuid: UUID
|
||||
|
||||
|
||||
@router.post("")
|
||||
def create(
|
||||
create_request: CreateExample,
|
||||
*,
|
||||
session: Session = Depends(get_session),
|
||||
) -> CreateExampleResponse:
|
||||
"""Endpoint to create an example."""
|
||||
instance = Example(
|
||||
extractor_id=create_request["extractor_id"],
|
||||
content=create_request["content"],
|
||||
output=create_request["output"],
|
||||
)
|
||||
session.add(instance)
|
||||
session.commit()
|
||||
return {"uuid": instance.uuid}
|
||||
|
||||
|
||||
@router.get("")
|
||||
def list(
|
||||
extractor_id: UUID,
|
||||
*,
|
||||
limit: int = 10,
|
||||
offset: int = 0,
|
||||
session=Depends(get_session),
|
||||
) -> List[Any]:
|
||||
"""Endpoint to get all examples."""
|
||||
return (
|
||||
session.query(Example)
|
||||
.filter(Example.extractor_id == extractor_id)
|
||||
.order_by(Example.uuid)
|
||||
.limit(limit)
|
||||
.offset(offset)
|
||||
.all()
|
||||
)
|
||||
|
||||
|
||||
@router.delete("/{uuid}")
|
||||
def delete(uuid: UUID, *, session: Session = Depends(get_session)) -> None:
|
||||
"""Endpoint to delete an example."""
|
||||
session.query(Example).filter(Example.uuid == str(uuid)).delete()
|
||||
session.commit()
|
||||
@@ -0,0 +1,56 @@
|
||||
from typing import Literal, Optional
|
||||
from uuid import UUID
|
||||
|
||||
from fastapi import APIRouter, Depends, File, Form, HTTPException, UploadFile
|
||||
from sqlalchemy.orm import Session
|
||||
from typing_extensions import Annotated
|
||||
|
||||
from db.models import Extractor, get_session
|
||||
from extraction.parsing import parse_binary_input
|
||||
from server.extraction_runnable import ExtractResponse, extract_entire_document
|
||||
from server.retrieval import extract_from_content
|
||||
|
||||
router = APIRouter(
|
||||
prefix="/extract",
|
||||
tags=["extract"],
|
||||
responses={404: {"description": "Not found"}},
|
||||
)
|
||||
|
||||
|
||||
@router.post("", response_model=ExtractResponse)
|
||||
async def extract_using_existing_extractor(
|
||||
*,
|
||||
extractor_id: Annotated[UUID, Form()],
|
||||
text: Optional[str] = Form(None),
|
||||
mode: Literal["entire_document", "retrieval"] = Form("entire_document"),
|
||||
file: Optional[UploadFile] = File(None),
|
||||
session: Session = Depends(get_session),
|
||||
) -> ExtractResponse:
|
||||
"""Endpoint that is used with an existing extractor.
|
||||
|
||||
This endpoint will be expanded to support upload of binary files as well as
|
||||
text files.
|
||||
"""
|
||||
if text is None and file is None:
|
||||
raise HTTPException(status_code=422, detail="No text or file provided.")
|
||||
|
||||
extractor = session.query(Extractor).filter(Extractor.uuid == extractor_id).scalar()
|
||||
if extractor is None:
|
||||
raise HTTPException(status_code=404, detail="Extractor not found.")
|
||||
|
||||
if text:
|
||||
text_ = text
|
||||
else:
|
||||
documents = parse_binary_input(file.file)
|
||||
# TODO: Add metadata like location from original file where
|
||||
# the text was extracted from
|
||||
text_ = "\n".join([document.page_content for document in documents])
|
||||
|
||||
if mode == "entire_document":
|
||||
return await extract_entire_document(text_, extractor)
|
||||
elif mode == "retrieval":
|
||||
return await extract_from_content(text_, extractor)
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Invalid mode {mode}. Expected one of 'entire_document', 'retrieval'."
|
||||
)
|
||||
@@ -0,0 +1,93 @@
|
||||
"""Endpoints for managing definition of extractors."""
|
||||
from typing import Any, Dict, List
|
||||
from uuid import UUID
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException
|
||||
from pydantic import BaseModel, Field, validator
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from db.models import Extractor, get_session
|
||||
from server.validators import validate_json_schema
|
||||
|
||||
router = APIRouter(
|
||||
prefix="/extractors",
|
||||
tags=["extractor definitions"],
|
||||
responses={404: {"description": "Not found"}},
|
||||
)
|
||||
|
||||
|
||||
class CreateExtractor(BaseModel):
|
||||
"""A request to create an extractor."""
|
||||
|
||||
name: str = Field(default="", description="The name of the extractor.")
|
||||
|
||||
description: str = Field(
|
||||
default="", description="Short description of the extractor."
|
||||
)
|
||||
json_schema: Dict[str, Any] = Field(
|
||||
..., description="The schema to use for extraction.", alias="schema"
|
||||
)
|
||||
instruction: str = Field(..., description="The instruction to use for extraction.")
|
||||
|
||||
@validator("json_schema")
|
||||
def validate_schema(cls, v: Any) -> Dict[str, Any]:
|
||||
"""Validate the schema."""
|
||||
validate_json_schema(v)
|
||||
return v
|
||||
|
||||
|
||||
class CreateExtractorResponse(BaseModel):
|
||||
"""Response for creating an extractor."""
|
||||
|
||||
uuid: UUID
|
||||
|
||||
|
||||
@router.post("")
|
||||
def create(
|
||||
create_request: CreateExtractor, *, session: Session = Depends(get_session)
|
||||
) -> CreateExtractorResponse:
|
||||
"""Endpoint to create an extractor."""
|
||||
instance = Extractor(
|
||||
name=create_request.name,
|
||||
schema=create_request.json_schema,
|
||||
description=create_request.description,
|
||||
instruction=create_request.instruction,
|
||||
)
|
||||
session.add(instance)
|
||||
session.commit()
|
||||
return CreateExtractorResponse(uuid=instance.uuid)
|
||||
|
||||
|
||||
@router.get("/{uuid}")
|
||||
def get(
|
||||
uuid: UUID, *, session: Session = Depends(get_session)
|
||||
) -> Dict[str, Any]:
|
||||
"""Endpoint to get an extractor."""
|
||||
extractor = session.query(Extractor).filter(Extractor.uuid == str(uuid)).scalar()
|
||||
if extractor is None:
|
||||
raise HTTPException(status_code=404, detail="Extractor not found.")
|
||||
return {
|
||||
"uuid": extractor.uuid,
|
||||
"name": extractor.name,
|
||||
"description": extractor.description,
|
||||
"schema": extractor.schema,
|
||||
"instruction": extractor.instruction,
|
||||
}
|
||||
|
||||
|
||||
@router.get("")
|
||||
def list(
|
||||
*,
|
||||
limit: int = 10,
|
||||
offset: int = 0,
|
||||
session=Depends(get_session),
|
||||
) -> List[Any]:
|
||||
"""Endpoint to get all extractors."""
|
||||
return session.query(Extractor).limit(limit).offset(offset).all()
|
||||
|
||||
|
||||
@router.delete("/{uuid}")
|
||||
def delete(uuid: UUID, *, session: Session = Depends(get_session)) -> None:
|
||||
"""Endpoint to delete an extractor."""
|
||||
session.query(Extractor).filter(Extractor.uuid == str(uuid)).delete()
|
||||
session.commit()
|
||||
@@ -0,0 +1,208 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
from typing import Any, Dict, List, Optional, Sequence
|
||||
|
||||
from fastapi import HTTPException
|
||||
from jsonschema import Draft202012Validator, exceptions
|
||||
from langchain.chains.openai_functions import create_openai_fn_runnable
|
||||
from langchain.text_splitter import TokenTextSplitter
|
||||
from langchain_core.messages import AIMessage, HumanMessage
|
||||
from langchain_core.prompts import ChatPromptTemplate
|
||||
from langchain_core.runnables import chain
|
||||
from langserve import CustomUserType
|
||||
from pydantic import BaseModel, Field, validator
|
||||
from typing_extensions import TypedDict
|
||||
|
||||
from db.models import Example, Extractor
|
||||
from extraction.utils import (
|
||||
convert_json_schema_to_openai_schema,
|
||||
)
|
||||
from server.settings import CHUNK_SIZE, MODEL_NAME, model
|
||||
from server.validators import validate_json_schema
|
||||
|
||||
|
||||
class ExtractionExample(BaseModel):
|
||||
"""An example extraction.
|
||||
|
||||
This example consists of a text and the expected output of the extraction.
|
||||
"""
|
||||
|
||||
text: str = Field(..., description="The input text")
|
||||
output: List[Dict[str, Any]] = Field(
|
||||
..., description="The expected output of the example. A list of objects."
|
||||
)
|
||||
|
||||
|
||||
class ExtractRequest(CustomUserType):
|
||||
"""Request body for the extract endpoint."""
|
||||
|
||||
text: str = Field(..., description="The text to extract from.")
|
||||
json_schema: Dict[str, Any] = Field(
|
||||
...,
|
||||
description="JSON schema that describes what content should be extracted "
|
||||
"from the text.",
|
||||
alias="schema",
|
||||
)
|
||||
instructions: Optional[str] = Field(
|
||||
None, description="Supplemental system instructions."
|
||||
)
|
||||
examples: Optional[List[ExtractionExample]] = Field(
|
||||
None, description="Examples of extractions."
|
||||
)
|
||||
|
||||
@validator("json_schema")
|
||||
def validate_schema(cls, v: Any) -> Dict[str, Any]:
|
||||
"""Validate the schema."""
|
||||
validate_json_schema(v)
|
||||
return v
|
||||
|
||||
|
||||
class ExtractResponse(TypedDict):
|
||||
"""Response body for the extract endpoint."""
|
||||
|
||||
data: List[Any]
|
||||
|
||||
|
||||
def _deduplicate(
|
||||
extract_responses: Sequence[ExtractResponse],
|
||||
) -> ExtractResponse:
|
||||
"""Deduplicate the results.
|
||||
|
||||
The deduplication is done by comparing the serialized JSON of each of the results
|
||||
and only keeping the unique ones.
|
||||
"""
|
||||
unique_extracted = []
|
||||
seen = set()
|
||||
for response in extract_responses:
|
||||
for data_item in response["data"]:
|
||||
# Serialize the data item for comparison purposes
|
||||
serialized = json.dumps(data_item, sort_keys=True)
|
||||
if serialized not in seen:
|
||||
seen.add(serialized)
|
||||
unique_extracted.append(data_item)
|
||||
|
||||
return {
|
||||
"data": unique_extracted,
|
||||
}
|
||||
|
||||
|
||||
def _cast_example_to_dict(example: Example) -> Dict[str, Any]:
|
||||
"""Cast example record to dictionary."""
|
||||
return {
|
||||
"text": example.content,
|
||||
"output": example.output,
|
||||
}
|
||||
|
||||
|
||||
def _make_prompt_template(
|
||||
instructions: Optional[str],
|
||||
examples: Optional[Sequence[ExtractionExample]],
|
||||
function_name: str,
|
||||
) -> ChatPromptTemplate:
|
||||
"""Make a system message from instructions and examples."""
|
||||
prefix = (
|
||||
"You are a top-tier algorithm for extracting information from text. "
|
||||
"Only extract information that is relevant to the provided text. "
|
||||
"If no information is relevant, use the schema and output "
|
||||
"an empty list where appropriate."
|
||||
)
|
||||
if instructions:
|
||||
system_message = ("system", f"{prefix}\n\n{instructions}")
|
||||
else:
|
||||
system_message = ("system", prefix)
|
||||
prompt_components = [system_message]
|
||||
if examples is not None:
|
||||
few_shot_prompt = []
|
||||
for example in examples:
|
||||
# TODO: We'll need to refactor this at some point to
|
||||
# support other encoding strategies. The function calling logic here
|
||||
# has some hard-coded assumptions (e.g., name of parameters like `data`).
|
||||
function_call = {
|
||||
"arguments": json.dumps(
|
||||
{
|
||||
"data": example.output,
|
||||
}
|
||||
),
|
||||
"name": function_name,
|
||||
}
|
||||
few_shot_prompt.extend(
|
||||
[
|
||||
HumanMessage(content=example.text),
|
||||
AIMessage(
|
||||
content="", additional_kwargs={"function_call": function_call}
|
||||
),
|
||||
]
|
||||
)
|
||||
prompt_components.extend(few_shot_prompt)
|
||||
|
||||
prompt_components.append(
|
||||
(
|
||||
"human",
|
||||
"I need to extract information from "
|
||||
"the following text: ```\n{text}\n```\n",
|
||||
),
|
||||
)
|
||||
return ChatPromptTemplate.from_messages(prompt_components)
|
||||
|
||||
|
||||
# PUBLIC API
|
||||
|
||||
|
||||
def get_examples_from_extractor(extractor: Extractor) -> List[Dict[str, Any]]:
|
||||
"""Get examples from an extractor."""
|
||||
return [_cast_example_to_dict(example) for example in extractor.examples]
|
||||
|
||||
|
||||
@chain
|
||||
async def extraction_runnable(extraction_request: ExtractRequest) -> ExtractResponse:
|
||||
"""An end point to extract content from a given text object."""
|
||||
# TODO: Add validation for model context window size
|
||||
schema = extraction_request.json_schema
|
||||
try:
|
||||
Draft202012Validator.check_schema(schema)
|
||||
except exceptions.ValidationError as e:
|
||||
raise HTTPException(status_code=422, detail=f"Invalid schema: {e.message}")
|
||||
|
||||
openai_function = convert_json_schema_to_openai_schema(schema)
|
||||
function_name = openai_function["name"]
|
||||
prompt = _make_prompt_template(
|
||||
extraction_request.instructions,
|
||||
extraction_request.examples,
|
||||
function_name,
|
||||
)
|
||||
runnable = create_openai_fn_runnable(
|
||||
functions=[openai_function], llm=model, prompt=prompt
|
||||
)
|
||||
return await runnable.ainvoke({"text": extraction_request.text})
|
||||
|
||||
|
||||
async def extract_entire_document(
|
||||
content: str,
|
||||
extractor: Extractor,
|
||||
) -> ExtractResponse:
|
||||
"""Extract from entire document."""
|
||||
|
||||
json_schema = extractor.schema
|
||||
examples = get_examples_from_extractor(extractor)
|
||||
text_splitter = TokenTextSplitter(
|
||||
chunk_size=CHUNK_SIZE,
|
||||
chunk_overlap=20,
|
||||
model_name=MODEL_NAME,
|
||||
)
|
||||
texts = text_splitter.split_text(content)
|
||||
extraction_requests = [
|
||||
ExtractRequest(
|
||||
text=text,
|
||||
schema=json_schema,
|
||||
instructions=extractor.instruction, # TODO: consistent naming
|
||||
examples=examples,
|
||||
)
|
||||
for text in texts
|
||||
]
|
||||
# Run extractions which may potentially yield duplicate results
|
||||
extract_responses: List[ExtractResponse] = await extraction_runnable.abatch(
|
||||
extraction_requests, {"max_concurrency": 1}
|
||||
)
|
||||
# Deduplicate the results
|
||||
return _deduplicate(extract_responses)
|
||||
@@ -0,0 +1,65 @@
|
||||
"""Entry point into the server."""
|
||||
|
||||
from fastapi import FastAPI
|
||||
from fastapi.middleware.cors import CORSMiddleware
|
||||
from langserve import add_routes
|
||||
from sqlalchemy.orm import Session
|
||||
from typing_extensions import Annotated
|
||||
|
||||
from server.api import examples, extract, extractors
|
||||
from server.extraction_runnable import (
|
||||
ExtractRequest,
|
||||
ExtractResponse,
|
||||
extraction_runnable,
|
||||
)
|
||||
|
||||
app = FastAPI(
|
||||
title="Extraction Powered by LangChain",
|
||||
description="An extraction service powered by LangChain.",
|
||||
version="0.0.1",
|
||||
openapi_tags=[
|
||||
{
|
||||
"name": "extraction",
|
||||
"description": "Operations related to extracting content from text.",
|
||||
}
|
||||
],
|
||||
)
|
||||
|
||||
|
||||
origins = [
|
||||
"http://localhost:5173",
|
||||
]
|
||||
|
||||
app.add_middleware(
|
||||
CORSMiddleware,
|
||||
allow_origins=origins,
|
||||
allow_credentials=True,
|
||||
allow_methods=["*"],
|
||||
allow_headers=["*"],
|
||||
)
|
||||
|
||||
|
||||
@app.get("/ready")
|
||||
def ready():
|
||||
return "ok"
|
||||
|
||||
|
||||
# Include API endpoints for extractor definitions
|
||||
app.include_router(extractors.router)
|
||||
app.include_router(examples.router)
|
||||
app.include_router(extract.router)
|
||||
|
||||
add_routes(
|
||||
app,
|
||||
extraction_runnable.with_types(
|
||||
input_type=ExtractRequest, output_type=ExtractResponse
|
||||
),
|
||||
path="/extract_text",
|
||||
enabled_endpoints=["invoke", "playground", "stream_log"],
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
import uvicorn
|
||||
|
||||
uvicorn.run(app, host="localhost", port=8000)
|
||||
@@ -0,0 +1,72 @@
|
||||
from operator import itemgetter
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
from langchain.text_splitter import CharacterTextSplitter
|
||||
from langchain_community.vectorstores import FAISS
|
||||
from langchain_core.documents import Document
|
||||
from langchain_core.runnables import RunnableLambda
|
||||
from langchain_openai import OpenAIEmbeddings
|
||||
|
||||
from db.models import Extractor
|
||||
from server.extraction_runnable import (
|
||||
ExtractRequest,
|
||||
ExtractResponse,
|
||||
extraction_runnable,
|
||||
get_examples_from_extractor,
|
||||
)
|
||||
|
||||
|
||||
def _get_top_doc_content(docs: List[Document]) -> str:
|
||||
if docs:
|
||||
return docs[0].page_content
|
||||
else:
|
||||
return ""
|
||||
|
||||
|
||||
def _make_extract_request(input_dict: Dict[str, Any]) -> ExtractRequest:
|
||||
return ExtractRequest(**input_dict)
|
||||
|
||||
|
||||
async def extract_from_content(
|
||||
content: str,
|
||||
extractor: Extractor,
|
||||
*,
|
||||
text_splitter_kwargs: Optional[Dict[str, Any]] = None,
|
||||
multi: bool = True,
|
||||
) -> ExtractResponse:
|
||||
"""Extract from potentially long-form content."""
|
||||
if text_splitter_kwargs is None:
|
||||
text_splitter_kwargs = {
|
||||
"separator": "\n\n",
|
||||
"chunk_size": 1000,
|
||||
"chunk_overlap": 50,
|
||||
}
|
||||
text_splitter = CharacterTextSplitter(**text_splitter_kwargs)
|
||||
docs = text_splitter.create_documents([content])
|
||||
doc_contents = [doc.page_content for doc in docs]
|
||||
|
||||
vectorstore = FAISS.from_texts(doc_contents, embedding=OpenAIEmbeddings())
|
||||
retriever = vectorstore.as_retriever()
|
||||
|
||||
runnable = (
|
||||
{
|
||||
"text": itemgetter("query") | retriever | _get_top_doc_content,
|
||||
"schema": itemgetter("schema"),
|
||||
"instructions": lambda x: x.get("instructions"),
|
||||
"examples": lambda x: x.get("examples"),
|
||||
}
|
||||
| RunnableLambda(_make_extract_request)
|
||||
| extraction_runnable
|
||||
)
|
||||
schema = extractor.schema
|
||||
examples = get_examples_from_extractor(extractor)
|
||||
description = extractor.description # TODO: improve this
|
||||
result = await runnable.ainvoke(
|
||||
{
|
||||
"query": description,
|
||||
"schema": schema,
|
||||
"examples": examples,
|
||||
"instructions": extractor.instruction,
|
||||
}
|
||||
)
|
||||
return ExtractResponse(extracted=[result.extracted])
|
||||
@@ -0,0 +1,24 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from langchain_openai import ChatOpenAI
|
||||
from sqlalchemy.engine import URL
|
||||
|
||||
MODEL_NAME = "gpt-3.5-turbo"
|
||||
CHUNK_SIZE = int(4_096 * 0.8)
|
||||
# Max concurrency for the model.
|
||||
MAX_CONCURRENCY = 1
|
||||
|
||||
|
||||
def get_postgres_url():
|
||||
url = URL.create(
|
||||
drivername="postgresql",
|
||||
username="langchain",
|
||||
password="langchain",
|
||||
host="localhost",
|
||||
database="langchain",
|
||||
port=5432,
|
||||
)
|
||||
return url
|
||||
|
||||
|
||||
model = ChatOpenAI(model=MODEL_NAME, temperature=0)
|
||||
@@ -0,0 +1,15 @@
|
||||
from typing import Any, Dict
|
||||
|
||||
from fastapi import HTTPException
|
||||
from jsonschema import exceptions
|
||||
from jsonschema.validators import Draft202012Validator
|
||||
|
||||
|
||||
def validate_json_schema(schema: Dict[str, Any]) -> None:
|
||||
"""Validate a JSON schema."""
|
||||
try:
|
||||
Draft202012Validator.check_schema(schema)
|
||||
except exceptions.ValidationError as e:
|
||||
raise HTTPException(
|
||||
status_code=422, detail=f"Not a valid JSON schema: {e.message}"
|
||||
)
|
||||
@@ -0,0 +1,51 @@
|
||||
"""Utility code that sets up a test database and client for tests."""
|
||||
from contextlib import asynccontextmanager
|
||||
from typing import Generator
|
||||
|
||||
from httpx import AsyncClient
|
||||
from sqlalchemy import URL, create_engine
|
||||
from sqlalchemy.orm import sessionmaker
|
||||
|
||||
from db.models import Base, get_session
|
||||
from server.main import app
|
||||
|
||||
url = URL.create(
|
||||
drivername="postgresql",
|
||||
username="langchain",
|
||||
password="langchain",
|
||||
host="localhost",
|
||||
database="langchain_test",
|
||||
port=5432,
|
||||
)
|
||||
engine = create_engine(url)
|
||||
TestingSession = sessionmaker(bind=engine)
|
||||
|
||||
|
||||
def override_get_session() -> Generator[TestingSession, None, None]:
|
||||
"""Override the get_session dependency with a test session.
|
||||
|
||||
This fixture also re-creats the database before each test and drops it after to
|
||||
ensure a clean slate for each test.
|
||||
"""
|
||||
try:
|
||||
session = TestingSession()
|
||||
yield session
|
||||
finally:
|
||||
session.close()
|
||||
|
||||
|
||||
app.dependency_overrides[get_session] = override_get_session
|
||||
|
||||
|
||||
@asynccontextmanager
|
||||
async def get_async_client() -> AsyncClient:
|
||||
"""Get an async client."""
|
||||
# Clear the database before each test
|
||||
Base.metadata.drop_all(engine)
|
||||
Base.metadata.create_all(engine)
|
||||
|
||||
async_client = AsyncClient(app=app, base_url="http://test")
|
||||
try:
|
||||
yield async_client
|
||||
finally:
|
||||
await async_client.aclose()
|
||||
@@ -0,0 +1,88 @@
|
||||
"""Makes it easy to run an integration tests using a real chat model."""
|
||||
from contextlib import asynccontextmanager
|
||||
from typing import Optional
|
||||
|
||||
import httpx
|
||||
from fastapi import FastAPI
|
||||
from httpx import AsyncClient
|
||||
from langchain_core.pydantic_v1 import BaseModel
|
||||
|
||||
from server.main import app
|
||||
|
||||
|
||||
@asynccontextmanager
|
||||
async def get_async_test_client(
|
||||
server: FastAPI, *, path: Optional[str] = None, raise_app_exceptions: bool = True
|
||||
) -> AsyncClient:
|
||||
"""Get an async client."""
|
||||
url = "http://localhost:9999/"
|
||||
if path:
|
||||
url += path
|
||||
transport = httpx.ASGITransport(
|
||||
app=server,
|
||||
raise_app_exceptions=raise_app_exceptions,
|
||||
)
|
||||
async_client = AsyncClient(app=server, base_url=url, transport=transport)
|
||||
try:
|
||||
yield async_client
|
||||
finally:
|
||||
await async_client.aclose()
|
||||
|
||||
|
||||
async def test_extraction_api() -> None:
|
||||
"""Test the extraction API endpoint."""
|
||||
|
||||
class Person(BaseModel):
|
||||
age: Optional[int]
|
||||
name: Optional[str]
|
||||
alias: Optional[str]
|
||||
|
||||
async with get_async_test_client(app) as client:
|
||||
text = """
|
||||
My name is Chester. I am young. I love cats. I have two cats. My age
|
||||
is the number of cats I have to the power of 5. (Approximately.)
|
||||
I also have a friend. His name is Neo. He is older than me. He is
|
||||
also a cat lover. He has 3 cats. He is 25 years old.
|
||||
"""
|
||||
result = await client.post(
|
||||
"/extract_text/invoke",
|
||||
json={"input": {"text": text, "schema": Person.schema()}},
|
||||
)
|
||||
assert result.status_code == 200, result.text
|
||||
response_data = result.json()
|
||||
assert response_data == {}
|
||||
assert isinstance(response_data["output"]["data"], list)
|
||||
|
||||
# Test with instructions
|
||||
result = await client.post(
|
||||
"/extract_text/invoke",
|
||||
json={
|
||||
"input": {
|
||||
"text": text,
|
||||
"schema": Person.schema(),
|
||||
"instructions": "Very important: Chester's alias is Neo.",
|
||||
}
|
||||
},
|
||||
)
|
||||
response_data = result.json()
|
||||
assert result.status_code == 200, result.text
|
||||
|
||||
# Test with few shot examples
|
||||
examples = [
|
||||
{
|
||||
"text": "My name is Grung. I am 100.",
|
||||
"output": [Person(age=100, name="######").dict()],
|
||||
},
|
||||
]
|
||||
result = await client.post(
|
||||
"/extract_text/invoke",
|
||||
json={
|
||||
"input": {
|
||||
"text": text,
|
||||
"schema": Person(),
|
||||
"instructions": "Redact all names using the characters `######`",
|
||||
"examples": examples,
|
||||
}
|
||||
},
|
||||
)
|
||||
assert result.status_code == 200, result.text
|
||||
@@ -0,0 +1,72 @@
|
||||
"""Code to test API endpoints."""
|
||||
import uuid
|
||||
|
||||
from tests.db import get_async_client
|
||||
|
||||
|
||||
async def test_extractors_api() -> None:
|
||||
"""This will test a few of the extractors API endpoints."""
|
||||
# First verify that the database is empty
|
||||
async with get_async_client() as client:
|
||||
response = await client.get("/extractors")
|
||||
assert response.status_code == 200
|
||||
assert response.json() == []
|
||||
|
||||
# Verify that we can create an extractor
|
||||
create_request = {
|
||||
"description": "Test Description",
|
||||
"schema": {"type": "object"},
|
||||
"instruction": "Test Instruction",
|
||||
}
|
||||
response = await client.post("/extractors", json=create_request)
|
||||
assert response.status_code == 200
|
||||
|
||||
# Verify that the extractor was created
|
||||
response = await client.get("/extractors")
|
||||
assert response.status_code == 200
|
||||
assert len(response.json()) == 1
|
||||
|
||||
# Verify that we can delete an extractor
|
||||
get_response = response.json()
|
||||
uuid_str = get_response[0]["uuid"]
|
||||
_ = uuid.UUID(uuid_str) # assert valid uuid
|
||||
response = await client.delete(f"/extractors/{uuid_str}")
|
||||
assert response.status_code == 200
|
||||
|
||||
get_response = await client.get("/extractors")
|
||||
assert get_response.status_code == 200
|
||||
assert get_response.json() == []
|
||||
|
||||
# Verify that we can create an extractor
|
||||
create_request = {
|
||||
"description": "Test Description",
|
||||
"schema": {"type": "object"},
|
||||
"instruction": "Test Instruction",
|
||||
}
|
||||
response = await client.post("/extractors", json=create_request)
|
||||
assert response.status_code == 200
|
||||
|
||||
# Verify that the extractor was created
|
||||
response = await client.get("/extractors")
|
||||
assert response.status_code == 200
|
||||
assert len(response.json()) == 1
|
||||
|
||||
# Verify that we can delete an extractor
|
||||
get_response = response.json()
|
||||
uuid_str = get_response[0]["uuid"]
|
||||
_ = uuid.UUID(uuid_str) # assert valid uuid
|
||||
response = await client.delete(f"/extractors/{uuid_str}")
|
||||
assert response.status_code == 200
|
||||
|
||||
get_response = await client.get("/extractors")
|
||||
assert get_response.status_code == 200
|
||||
assert get_response.json() == []
|
||||
|
||||
# Verify that we can create an extractor
|
||||
create_request = {
|
||||
"description": "Test Description",
|
||||
"schema": {"type": "object"},
|
||||
"instruction": "Test Instruction",
|
||||
}
|
||||
response = await client.post("/extractors", json=create_request)
|
||||
assert response.status_code == 200
|
||||
@@ -0,0 +1,76 @@
|
||||
"""Code to test API endpoints."""
|
||||
|
||||
from tests.db import get_async_client
|
||||
|
||||
|
||||
async def _list_extractors() -> list:
|
||||
async with get_async_client() as client:
|
||||
response = await client.get("/extractors")
|
||||
assert response.status_code == 200
|
||||
return response.json()
|
||||
|
||||
|
||||
async def test_examples_api() -> None:
|
||||
"""Runs through a set of API calls to test the examples API."""
|
||||
async with get_async_client() as client:
|
||||
# First create an extractor
|
||||
create_request = {
|
||||
"description": "Test Description",
|
||||
"name": "Test Name",
|
||||
"schema": {"type": "object"},
|
||||
"instruction": "Test Instruction",
|
||||
}
|
||||
response = await client.post("/extractors", json=create_request)
|
||||
assert response.status_code == 200
|
||||
# Get the extractor id
|
||||
extractor_id = response.json()["uuid"]
|
||||
|
||||
# Let's verify that there are no examples
|
||||
response = await client.get("/examples?extractor_id=" + extractor_id)
|
||||
assert response.status_code == 200
|
||||
assert response.json() == []
|
||||
|
||||
# Now let's create an example
|
||||
create_request = {
|
||||
"extractor_id": extractor_id,
|
||||
"content": "Test Content",
|
||||
"output": [
|
||||
{
|
||||
"age": 100,
|
||||
"name": "Grung",
|
||||
}
|
||||
],
|
||||
}
|
||||
response = await client.post("/examples", json=create_request)
|
||||
assert response.status_code == 200
|
||||
example_id = response.json()["uuid"]
|
||||
|
||||
# Verify that the example was created
|
||||
response = await client.get("/examples?extractor_id=" + extractor_id)
|
||||
assert response.status_code == 200
|
||||
assert len(response.json()) == 1
|
||||
|
||||
keys = ["content", "extractor_id", "output", "uuid"]
|
||||
projected_response = {
|
||||
key: record[key] for key in keys for record in response.json()
|
||||
}
|
||||
assert projected_response == {
|
||||
"content": "Test Content",
|
||||
"extractor_id": extractor_id,
|
||||
"output": [
|
||||
{
|
||||
"age": 100,
|
||||
"name": "Grung",
|
||||
}
|
||||
],
|
||||
"uuid": example_id,
|
||||
}
|
||||
|
||||
# Verify that we can delete an example
|
||||
response = await client.delete(f"/examples/{example_id}")
|
||||
assert response.status_code == 200
|
||||
|
||||
# Verify that the example was deleted
|
||||
response = await client.get("/examples?extractor_id=" + extractor_id)
|
||||
assert response.status_code == 200
|
||||
assert response.json() == []
|
||||
@@ -0,0 +1,86 @@
|
||||
"""Code to test API endpoints."""
|
||||
import tempfile
|
||||
from unittest.mock import patch
|
||||
from uuid import UUID
|
||||
|
||||
from langchain.text_splitter import CharacterTextSplitter
|
||||
from langchain_core.runnables import RunnableLambda
|
||||
|
||||
from tests.db import get_async_client
|
||||
|
||||
|
||||
def mock_extraction_runnable(*args, **kwargs):
|
||||
"""Mock the extraction_runnable function."""
|
||||
extract_request = args[0]
|
||||
return {
|
||||
"data": [
|
||||
extract_request.text[:10],
|
||||
]
|
||||
}
|
||||
|
||||
|
||||
def mock_text_splitter(*args, **kwargs):
|
||||
return CharacterTextSplitter()
|
||||
|
||||
|
||||
@patch(
|
||||
"server.extraction_runnable.extraction_runnable",
|
||||
new=RunnableLambda(mock_extraction_runnable),
|
||||
)
|
||||
@patch("server.extraction_runnable.TokenTextSplitter", mock_text_splitter)
|
||||
async def test_extract_from_file() -> None:
|
||||
"""Test extract from file API."""
|
||||
async with get_async_client() as client:
|
||||
# Test with invalid extractor
|
||||
extractor_id = UUID(int=1027) # 1027 is a good number.
|
||||
response = await client.post(
|
||||
"/extract",
|
||||
data={
|
||||
"extractor_id": str(extractor_id),
|
||||
"text": "Test Content",
|
||||
},
|
||||
)
|
||||
assert response.status_code == 404, response.text
|
||||
|
||||
# First create an extractor
|
||||
create_request = {
|
||||
"name": "Test Name",
|
||||
"description": "Test Description",
|
||||
"schema": {"type": "object"},
|
||||
"instruction": "Test Instruction",
|
||||
}
|
||||
response = await client.post("/extractors", json=create_request)
|
||||
assert response.status_code == 200, response.text
|
||||
# Get the extractor id
|
||||
extractor_id = response.json()["uuid"]
|
||||
|
||||
# Run an extraction.
|
||||
# We'll use multi-form data here.
|
||||
response = await client.post(
|
||||
"/extract",
|
||||
data={
|
||||
"extractor_id": extractor_id,
|
||||
"text": "Test Content",
|
||||
"mode": "entire_document",
|
||||
},
|
||||
)
|
||||
assert response.status_code == 200
|
||||
assert response.json() == {"data": ["Test Conte"]}
|
||||
|
||||
# We'll use multi-form data here.
|
||||
# Create a named temporary file
|
||||
with tempfile.NamedTemporaryFile(mode="w+t", delete=False) as f:
|
||||
f.write("This is a named temporary file.")
|
||||
f.seek(0)
|
||||
f.flush()
|
||||
response = await client.post(
|
||||
"/extract",
|
||||
data={
|
||||
"extractor_id": extractor_id,
|
||||
"mode": "entire_document",
|
||||
},
|
||||
files={"file": f},
|
||||
)
|
||||
|
||||
assert response.status_code == 200, response.text
|
||||
assert response.json() == {"data": ["This is a "]}
|
||||
@@ -0,0 +1,3 @@
|
||||
import os
|
||||
|
||||
os.environ["OPENAI_API_KEY"] = "placeholder"
|
||||
@@ -0,0 +1,45 @@
|
||||
"""Fake Chat Model wrapper for testing purposes."""
|
||||
from typing import Any, Iterator, List, Optional
|
||||
|
||||
from langchain_core.callbacks.manager import (
|
||||
CallbackManagerForLLMRun,
|
||||
)
|
||||
from langchain_core.language_models.chat_models import BaseChatModel
|
||||
from langchain_core.messages import (
|
||||
AIMessage,
|
||||
BaseMessage,
|
||||
)
|
||||
from langchain_core.outputs import ChatGeneration, ChatResult
|
||||
|
||||
|
||||
class GenericFakeChatModel(BaseChatModel):
|
||||
"""A generic fake chat model that can be used to test the chat model interface."""
|
||||
|
||||
messages: Iterator[AIMessage]
|
||||
"""Get an iterator over messages.
|
||||
|
||||
This can be expanded to accept other types like Callables / dicts / strings
|
||||
to make the interface more generic if needed.
|
||||
|
||||
Note: if you want to pass a list, you can use `iter` to convert it to an iterator.
|
||||
|
||||
Please note that streaming is not implemented yet. We should try to implement it
|
||||
in the future by delegating to invoke and then breaking the resulting output
|
||||
into message chunks.
|
||||
"""
|
||||
|
||||
def _generate(
|
||||
self,
|
||||
messages: List[BaseMessage],
|
||||
stop: Optional[List[str]] = None,
|
||||
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
||||
**kwargs: Any,
|
||||
) -> ChatResult:
|
||||
"""Top Level call"""
|
||||
message = next(self.messages)
|
||||
generation = ChatGeneration(message=message)
|
||||
return ChatResult(generations=[generation])
|
||||
|
||||
@property
|
||||
def _llm_type(self) -> str:
|
||||
return "generic-fake-chat-model"
|
||||
@@ -0,0 +1,30 @@
|
||||
"""Tests for verifying that testing utility code works as expected."""
|
||||
from itertools import cycle
|
||||
|
||||
from langchain_core.messages import AIMessage
|
||||
|
||||
from tests.unit_tests.fake.chat_model import GenericFakeChatModel
|
||||
|
||||
|
||||
def test_generic_fake_chat_model_invoke() -> None:
|
||||
# Will alternate between responding with hello and goodbye
|
||||
infinite_cycle = cycle([AIMessage(content="hello"), AIMessage(content="goodbye")])
|
||||
model = GenericFakeChatModel(messages=infinite_cycle)
|
||||
response = model.invoke("meow")
|
||||
assert response == AIMessage(content="hello")
|
||||
response = model.invoke("kitty")
|
||||
assert response == AIMessage(content="goodbye")
|
||||
response = model.invoke("meow")
|
||||
assert response == AIMessage(content="hello")
|
||||
|
||||
|
||||
async def test_generic_fake_chat_model_ainvoke() -> None:
|
||||
# Will alternate between responding with hello and goodbye
|
||||
infinite_cycle = cycle([AIMessage(content="hello"), AIMessage(content="goodbye")])
|
||||
model = GenericFakeChatModel(messages=infinite_cycle)
|
||||
response = await model.ainvoke("meow")
|
||||
assert response == AIMessage(content="hello")
|
||||
response = await model.ainvoke("kitty")
|
||||
assert response == AIMessage(content="goodbye")
|
||||
response = await model.ainvoke("meow")
|
||||
assert response == AIMessage(content="hello")
|
||||
@@ -0,0 +1,11 @@
|
||||
from pathlib import Path
|
||||
from typing import List
|
||||
|
||||
HERE = Path(__file__).parent
|
||||
|
||||
# PUBLIC API
|
||||
|
||||
|
||||
def get_sample_paths() -> List[Path]:
|
||||
"""List all fixtures."""
|
||||
return list(HERE.glob("sample.*"))
|
||||
Binary file not shown.
Binary file not shown.
File diff suppressed because one or more lines are too long
Binary file not shown.
Binary file not shown.
File diff suppressed because one or more lines are too long
@@ -0,0 +1,42 @@
|
||||
🦜️ LangChain
|
||||
|
||||
|
||||
|
||||
|
||||
Underline
|
||||
|
||||
|
||||
Bold
|
||||
|
||||
|
||||
Italics
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
Col 1
|
||||
Col 2
|
||||
Row 1
|
||||
1
|
||||
2
|
||||
Row 2
|
||||
3
|
||||
4
|
||||
|
||||
|
||||
|
||||
|
||||
Link: https://www.langchain.com/
|
||||
|
||||
|
||||
|
||||
|
||||
* Item 1
|
||||
* Item 2
|
||||
* Item 3
|
||||
* We also love cats 🐱
|
||||
|
||||
|
||||
Image
|
||||
@@ -0,0 +1,54 @@
|
||||
from server.extraction_runnable import ExtractResponse, _deduplicate
|
||||
|
||||
|
||||
async def test_deduplication_different_resutls() -> None:
|
||||
"""Test deduplication of extraction results."""
|
||||
result = _deduplicate(
|
||||
[
|
||||
{"data": [{"name": "Chester", "age": 42}]},
|
||||
{"data": [{"name": "Jane", "age": 42}]},
|
||||
]
|
||||
)
|
||||
expected = ExtractResponse(
|
||||
data=[
|
||||
{"name": "Chester", "age": 42},
|
||||
{"name": "Jane", "age": 42},
|
||||
]
|
||||
)
|
||||
assert expected == result
|
||||
|
||||
result = _deduplicate(
|
||||
[
|
||||
{
|
||||
"data": [
|
||||
{"field_1": 1, "field_2": "a"},
|
||||
{"field_1": 2, "field_2": "b"},
|
||||
]
|
||||
},
|
||||
{
|
||||
"data": [
|
||||
{"field_1": 1, "field_2": "a"},
|
||||
{"field_1": 2, "field_2": "c"},
|
||||
]
|
||||
},
|
||||
]
|
||||
)
|
||||
|
||||
expected = ExtractResponse(
|
||||
data=[
|
||||
{"field_1": 1, "field_2": "a"},
|
||||
{"field_1": 2, "field_2": "b"},
|
||||
{"field_1": 2, "field_2": "c"},
|
||||
]
|
||||
)
|
||||
assert expected == result
|
||||
|
||||
# Test with data being a list of strings
|
||||
result = _deduplicate([{"data": ["1", "2"]}, {"data": ["1", "3"]}])
|
||||
expected = ExtractResponse(data=["1", "2", "3"])
|
||||
assert expected == result
|
||||
|
||||
# Test with data being a mix of integer and string
|
||||
result = _deduplicate([{"data": [1, "2"]}, {"data": ["1", "3"]}])
|
||||
expected = ExtractResponse(data=[1, "2", "1", "3"])
|
||||
assert expected == result
|
||||
@@ -0,0 +1,46 @@
|
||||
"""Test parsing logic."""
|
||||
import mimetypes
|
||||
|
||||
from langchain.document_loaders import Blob
|
||||
|
||||
from extraction.parsing import (
|
||||
MIMETYPE_BASED_PARSER,
|
||||
SUPPORTED_MIMETYPES,
|
||||
)
|
||||
from tests.unit_tests.fixtures import get_sample_paths
|
||||
|
||||
|
||||
def test_list_of_supported_mimetypes() -> None:
|
||||
"""This list should generally grow! Protecting against typos in mimetypes."""
|
||||
assert SUPPORTED_MIMETYPES == [
|
||||
# Two MS Word mimetypes are disabled for now
|
||||
# Need to install unstructured to enable them
|
||||
# "application/msword",
|
||||
"application/pdf",
|
||||
# "application/vnd.openxmlformats-officedocument.wordprocessingml.document",
|
||||
"text/html",
|
||||
"text/plain",
|
||||
]
|
||||
|
||||
|
||||
def test_attempt_to_parse_each_fixture() -> None:
|
||||
"""Attempt to parse supported fixtures."""
|
||||
seen_mimetypes = set()
|
||||
for path in get_sample_paths():
|
||||
type_, _ = mimetypes.guess_type(path)
|
||||
if type_ not in SUPPORTED_MIMETYPES:
|
||||
continue
|
||||
seen_mimetypes.add(type_)
|
||||
blob = Blob.from_path(path)
|
||||
documents = MIMETYPE_BASED_PARSER.parse(blob)
|
||||
try:
|
||||
assert len(documents) == 1
|
||||
doc = documents[0]
|
||||
assert "source" in doc.metadata
|
||||
assert doc.metadata["source"] == str(path)
|
||||
assert "🦜" in doc.page_content
|
||||
except Exception as e:
|
||||
raise AssertionError(f"Failed to parse {path}") from e
|
||||
|
||||
known_missing = {"application/msword"}
|
||||
assert set(SUPPORTED_MIMETYPES) - known_missing == seen_mimetypes
|
||||
@@ -0,0 +1,22 @@
|
||||
from extraction.parsing import _guess_mimetype
|
||||
from tests.unit_tests.fixtures import get_sample_paths
|
||||
|
||||
|
||||
async def test_mimetype_guessing() -> None:
|
||||
"""Verify mimetype guessing for all fixtures."""
|
||||
name_to_mime = {}
|
||||
for file in sorted(get_sample_paths()):
|
||||
data = file.read_bytes()
|
||||
name_to_mime[file.name] = _guess_mimetype(data)
|
||||
|
||||
assert {
|
||||
"sample.docx": (
|
||||
"application/vnd.openxmlformats-officedocument.wordprocessingml.document"
|
||||
),
|
||||
"sample.epub": "application/epub+zip",
|
||||
"sample.html": "text/html",
|
||||
"sample.odt": "application/vnd.oasis.opendocument.text",
|
||||
"sample.pdf": "application/pdf",
|
||||
"sample.rtf": "text/rtf",
|
||||
"sample.txt": "text/plain",
|
||||
} == name_to_mime
|
||||
@@ -0,0 +1,104 @@
|
||||
from typing import List
|
||||
|
||||
import pytest
|
||||
from langchain.pydantic_v1 import BaseModel, Field
|
||||
|
||||
from extraction.utils import (
|
||||
convert_json_schema_to_openai_schema,
|
||||
)
|
||||
from server.extraction_runnable import ExtractionExample, _make_prompt_template
|
||||
|
||||
|
||||
def test_convert_json_schema_to_openai_schema() -> None:
|
||||
"""Test converting a JSON schema to an OpenAI schema."""
|
||||
|
||||
class Person(BaseModel):
|
||||
name: str = Field(..., description="The name of the person.")
|
||||
age: int = Field(..., description="The age of the person.")
|
||||
|
||||
schema = Person.schema()
|
||||
|
||||
assert schema == {
|
||||
"properties": {
|
||||
"age": {
|
||||
"description": "The age of the person.",
|
||||
"title": "Age",
|
||||
"type": "integer",
|
||||
},
|
||||
"name": {
|
||||
"description": "The name of the person.",
|
||||
"title": "Name",
|
||||
"type": "string",
|
||||
},
|
||||
},
|
||||
"required": ["name", "age"],
|
||||
"title": "Person",
|
||||
"type": "object",
|
||||
}
|
||||
|
||||
openai_schema = convert_json_schema_to_openai_schema(schema)
|
||||
assert openai_schema == {
|
||||
"description": "Extract information matching the given schema.",
|
||||
"name": "extractor",
|
||||
"parameters": {
|
||||
"properties": {
|
||||
"data": {
|
||||
"items": {
|
||||
"properties": {
|
||||
"age": {
|
||||
"description": "The age of the person.",
|
||||
"type": "integer",
|
||||
},
|
||||
"name": {
|
||||
"description": "The name of the person.",
|
||||
"type": "string",
|
||||
},
|
||||
},
|
||||
"required": ["name", "age"],
|
||||
"type": "object",
|
||||
},
|
||||
"type": "array",
|
||||
}
|
||||
},
|
||||
"required": ["data"],
|
||||
"type": "object",
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
def test_make_prompt_template() -> None:
|
||||
"""Test making a system message from instructions and examples."""
|
||||
instructions = "Test instructions."
|
||||
examples = [
|
||||
ExtractionExample(
|
||||
text="Test text.",
|
||||
output=[
|
||||
{"name": "Test Name", "age": 0},
|
||||
{"name": "Test Name 2", "age": 1},
|
||||
],
|
||||
)
|
||||
]
|
||||
prefix = (
|
||||
"You are a top-tier algorithm for extracting information from text. "
|
||||
"Only extract information that is relevant to the provided text. "
|
||||
"If no information is relevant, use the schema and output "
|
||||
"an empty list where appropriate."
|
||||
)
|
||||
prompt = _make_prompt_template(instructions, examples, "name")
|
||||
messages = prompt.messages
|
||||
assert 4 == len(messages)
|
||||
system = messages[0].prompt.template
|
||||
assert system.startswith(prefix)
|
||||
assert system.endswith(instructions)
|
||||
|
||||
example_input = messages[1]
|
||||
assert example_input.content == "Test text."
|
||||
example_output = messages[2]
|
||||
assert "function_call" in example_output.additional_kwargs
|
||||
assert example_output.additional_kwargs["function_call"]["name"] == "name"
|
||||
|
||||
prompt = _make_prompt_template(instructions, None, "name")
|
||||
assert 2 == len(prompt.messages)
|
||||
|
||||
prompt = _make_prompt_template(None, examples, "name")
|
||||
assert 4 == len(prompt.messages)
|
||||
@@ -0,0 +1,16 @@
|
||||
import pytest
|
||||
|
||||
from server.validators import validate_json_schema
|
||||
|
||||
|
||||
def test_validate_json_schema() -> None:
|
||||
"""Test validate_json_schema."""
|
||||
# TODO: Validate more extensively to make sure that it actually validates
|
||||
# the schema as expected.
|
||||
with pytest.raises(Exception):
|
||||
validate_json_schema({"type": "meow"})
|
||||
|
||||
with pytest.raises(Exception):
|
||||
validate_json_schema({"type": "str"})
|
||||
|
||||
validate_json_schema({"type": "string"})
|
||||
@@ -0,0 +1,25 @@
|
||||
from contextlib import asynccontextmanager
|
||||
from typing import Optional
|
||||
|
||||
import httpx
|
||||
from fastapi import FastAPI
|
||||
from httpx import AsyncClient
|
||||
|
||||
|
||||
@asynccontextmanager
|
||||
async def get_async_test_client(
|
||||
server: FastAPI, *, path: Optional[str] = None, raise_app_exceptions: bool = True
|
||||
) -> AsyncClient:
|
||||
"""Get an async client."""
|
||||
url = "http://localhost:9999/"
|
||||
if path:
|
||||
url += path
|
||||
transport = httpx.ASGITransport(
|
||||
app=server,
|
||||
raise_app_exceptions=raise_app_exceptions,
|
||||
)
|
||||
async_client = AsyncClient(app=server, base_url=url, transport=transport)
|
||||
try:
|
||||
yield async_client
|
||||
finally:
|
||||
await async_client.aclose()
|
||||
@@ -0,0 +1,30 @@
|
||||
version: "3"
|
||||
name: langchain-extract
|
||||
|
||||
services:
|
||||
postgres:
|
||||
# Careful if bumping postgres version.
|
||||
# Make sure to keep in sync with CI
|
||||
# version if being tested on CI.
|
||||
image: postgres:16
|
||||
ports:
|
||||
- "5432:5432"
|
||||
environment:
|
||||
POSTGRES_DB: langchain
|
||||
POSTGRES_USER: langchain
|
||||
POSTGRES_PASSWORD: langchain
|
||||
volumes:
|
||||
- postgres_data:/var/lib/postgresql/data
|
||||
|
||||
# For rely on docker compose to spin up postgres
|
||||
# but developer using docker
|
||||
# Add backend when we actually need it
|
||||
# backend:
|
||||
# build: ./backend
|
||||
# ports:
|
||||
# - "8000:8000"
|
||||
# depends_on:
|
||||
# - postgres
|
||||
|
||||
volumes:
|
||||
postgres_data:
|
||||
Reference in New Issue
Block a user