Compare commits

...

53 Commits

Author SHA1 Message Date
SN dbcb7f932a update 2025-10-02 17:58:38 -07:00
SN ca512ab494 fix again again 2025-10-02 17:55:16 -07:00
SN a0cbcb303e fix again 2025-10-02 17:53:40 -07:00
SN 7737d33370 fix tests 2025-10-02 17:49:35 -07:00
SN 377dc5aaa3 chore: bump langchain version 2025-10-02 17:46:11 -07:00
joaquin-borggio-lc f0d86287ab ci: drop python 3.8 (#825) 2025-09-17 15:38:01 -04:00
joaquin-borggio-lc 168c9ff90e chore: bump patch version (#823)
bump patched version so i can release
2025-09-17 15:31:09 -04:00
joaquin-borggio-lc a87125d0b8 chore: bump some package versions (#821)
Co-authored-by: Eugene Yurtsev <eyurtsev@gmail.com>
2025-09-17 15:03:09 -04:00
Eugene Yurtsev 1de17fa799 Add new contributor to .clabot configuration (#822) 2025-09-17 15:01:18 -04:00
ResearchAI 1a08f2740c Update Chroma import to avoid deprecation warning. (#741) 2025-07-09 17:19:18 -04:00
Yash Jivani 06c3c3691e Update README.md Grammatical Errors (#733) 2025-07-09 17:15:01 -04:00
Eugene Yurtsev d20eab45f2 ci: add workflow dispatch to _release (#805) 2024-12-26 21:37:14 -05:00
Eugene Yurtsev 321b7aa3b1 Release 0.3.1 (#801) 2024-12-19 10:33:17 -05:00
Eugene Yurtsev aa4aea4a81 lint (#803) 2024-12-19 10:19:16 -05:00
Eugene Yurtsev 1aaec1189c Update _lint.yml (#802) 2024-12-18 22:10:43 -05:00
アリス 0d7601781d update ssl verifying part to work with httpx 0.28.0 (#798)
As the new httpx 0.28.0 removes VerifyTypes, we should update to follow
their change
See [issue](https://github.com/langchain-ai/langserve/issues/796)

---------

Co-authored-by: Eugene Yurtsev <eyurtsev@gmail.com>
2024-12-18 22:08:58 -05:00
Eugene Yurtsev 0ad075fb67 Update .clabot (#800) 2024-12-18 22:04:39 -05:00
Mingqi Hu b007300b06 docs: Remove templates in readme (#792)
docs: Remove templates in readme

Issue: 404 when you click the [LangChain
Templates](https://github.com/langchain-ai/langchain/blob/master/templates/README.md)
in readme. As templates dir has been removed on latest langchain master.
<img width="368" alt="image"
src="https://github.com/user-attachments/assets/8dfa7e82-c6df-458c-ba42-2b9f0fd67d8d">


Fix: Just remove it to avoid 404 as example dirs is enough, look like
(my local vs code)
<img width="423" alt="image"
src="https://github.com/user-attachments/assets/b65fc053-ed62-47e4-8404-fbfbcb4b17d5">

Signed-off-by: Mingqi <mingqi.hu@intel.com>
Co-authored-by: Eugene Yurtsev <eyurtsev@gmail.com>
2024-11-22 10:21:02 -05:00
Eugene Yurtsev 9df7e88bb9 Update .clabot (#793) 2024-11-22 10:20:11 -05:00
Eugene Yurtsev c626cde08c Update MIGRATION.md (#789) 2024-11-18 11:06:37 -05:00
ccurme c4a8925b00 docs: update local link in readme (#788)
LangChain doc builds are currently
[broken](https://vercel.com/langchain/langchain/ExxE3a6SSZVa7RSRUDB3mzHDpsvV?filter=errors).

- LangChain docs page hosts LangServe readme:
https://python.langchain.com/docs/langserve/;
- The readme is [downloaded from
github](https://github.com/langchain-ai/langchain/blob/22a8652eccd9023cd41ea0f6924575ce87c28756/docs/Makefile#L49)
when docs are built;
- A
[script](https://github.com/langchain-ai/langchain/blob/master/docs/scripts/resolve_local_links.py)
resolves local links in the readme to point to the LangServe github.

This script assumes local links are prefixed with `./`
2024-11-18 11:05:19 -05:00
Eugene Yurtsev 80f949b62e migration guide (#787) 2024-11-16 22:36:31 -05:00
Eugene Yurtsev c27923a7d1 Update README.md (#773) 2024-09-18 13:33:52 -04:00
Eugene Yurtsev f3b9c43106 Update README.md (#771) 2024-09-16 13:22:26 -04:00
Eugene Yurtsev 6b1a0d97ef Update README (#769) 2024-09-16 11:56:41 -04:00
Eugene Yurtsev dc04672537 examples: update clients (#768) 2024-09-14 14:32:41 -04:00
Eugene Yurtsev 42b61a664b v0.3 release (#767) 2024-09-14 14:14:45 -04:00
Eugene Yurtsev 1e24edce08 Add ability to specify custom serializer (#764)
Allow users to define a custom serializer
2024-09-14 14:06:08 -04:00
Eugene Yurtsev c747e20c1e Fix issue with callback events sent from server (#765)
This properly propagates the name of the runs from the server to the client if one enables callbacks.
2024-09-14 13:46:10 -04:00
Eugene Yurtsev 8b4d8dff6c propagate astream events from add_route (#763)
Propagate the parameter to APIHandler
2024-09-12 17:02:35 -04:00
Eugene Yurtsev 2c957bdd78 mark include callback events as a beta api (#761) 2024-09-12 14:58:33 -04:00
Eugene Yurtsev 36f945e494 improve error messages when models fail to hash (#762) 2024-09-12 14:58:25 -04:00
Eugene Yurtsev 43683b3671 add ability to control version of astream events API (#760)
add ability to control version of astream events API server side
2024-09-12 14:33:21 -04:00
Eugene Yurtsev 59b3c81189 Release 0.30dev2 (#758) 2024-09-11 21:08:56 -04:00
Eugene Yurtsev 36e9919c17 serialization fix for defaults (#757)
Fix serialization issue when there are defaults
2024-09-11 21:07:30 -04:00
Eugene Yurtsev ff94f96dc8 update serialization to work with older pydantic versions (#756)
* Serialization to work with pydantic 2.7, 2.8
* Add constraint on min pydantic version for langserve to be 2.7 for now
2024-09-11 12:55:53 -04:00
Eugene Yurtsev 8c852935e5 Update doc-string, remove unused function (#754) 2024-09-10 14:05:51 -04:00
Eugene Yurtsev 04236b0cf2 more run time warnings (#753)
This PR resolves most of the remaining run time warnings
2024-09-10 11:41:03 -04:00
Eugene Yurtsev 54eee64faf update claude model in examples (#750) 2024-09-10 10:34:31 -04:00
Eugene Yurtsev 72c200ff81 first pass at removing deprecated usaged (#751)
Reduces a significant number of run time warnings when running unit
tests (down to ~350)
2024-09-10 10:34:16 -04:00
Eugene Yurtsev 21c2e3da2a resolve more warnings in unit tests (#752)
Migrated using gritql 

`AsyncClient(app=$y, $x)` => `AsyncClient($x, transport=
httpx.ASGITransport(app=$y))`
2024-09-09 18:40:04 -04:00
Eugene Yurtsev 6e7a9ee3f5 0.3.0.dev1 release (#748) 2024-09-09 16:09:34 -04:00
Eugene Yurtsev 81c0285af2 update default playground to handle oneOf type (#747)
Update the default playground to handle the oneOf type which appears as
the input into language models.

This change stems from the upgrade to pydantic 2.
2024-09-09 16:06:36 -04:00
Eugene Yurtsev b62b925825 chat-playground: update to handle oneOf (resulted from pydantic upgrade) (#746)
Update the chat playground to handle the oneOf type which appears as the
input into language models.

This change stems from the upgrade to pydantic 2.
2024-09-09 16:06:29 -04:00
Eugene Yurtsev b528955b60 migrate examples to pydantic 2 (#745)
Migrate examples to pydantic 2
2024-09-09 11:03:26 -04:00
Eugene Yurtsev 5aedbf7083 upgrade to pydantic 2 (#744)
This PR upgrades langserve to pydantic 2.

* Added a failing unit test that has 2 known failures (that need to be
fixed in langchain-core)
* Deprecation warnings will be resolved separately.
2024-09-09 10:23:16 -04:00
Eugene Yurtsev 41a9d798aa Update unit tests to catch up with langchain-core (#743)
Updating the unit tests to catch up with langchain-core. No adjustements
should be necessary to user code, the issues manifest themselves only with the
given test set up (e.g., snapshots). langchain-core changes were either
additive or self-consistent.
2024-09-06 13:55:23 -04:00
Erick Friis d4704c2b45 infra: release permissions (#738) 2024-09-01 22:09:28 -04:00
Eugene Yurtsev c259ec3e4d Release 0.2.3 (#737) 2024-09-01 21:55:23 -04:00
William FH 62e648a2bf Support correction when creating feedback with token (#736)
Closes https://github.com/langchain-ai/langserve/issues/735

---------

Co-authored-by: Eugene Yurtsev <eyurtsev@gmail.com>
2024-09-01 21:43:50 -04:00
Eugene Yurtsev a74e072486 Add langgraph compatibility section (#718) 2024-07-26 14:29:38 -04:00
Erick Friis 1487bf1ce5 Add instructions to address pydantic v2 incompatibility (#713) 2024-07-22 14:40:03 -04:00
ccurme 050a0cc674 update readme (#697) 2024-06-28 11:02:32 -04:00
58 changed files with 3551 additions and 2947 deletions
+1 -1
View File
@@ -1,4 +1,4 @@
{
"contributors": ["eyurtsev", "hwchase17", "nfcampos", "efriis", "jacoblee93", "dqbd", "kreneskyp", "adarsh-jha-dev", "harris", "baskaryan", "hinthornw", "bracesproul", "jakerachleff", "craigsdennis", "anhi", "169", "LarchLiu", "PaulLockett", "RCMatthias", "jwynia", "majiayu000", "mpskex", "shivachittamuru", "sinashaloudegi", "sowsan", "akira", "lucianotonet", "JGalego", "nat-n", "dirien", "donbr", "rahilvora", "WarrenTheRabbit", "StreetLamb", "ccurme", "dennisrall"],
"contributors": ["eyurtsev", "hwchase17", "nfcampos", "efriis", "jacoblee93", "dqbd", "kreneskyp", "adarsh-jha-dev", "harris", "baskaryan", "hinthornw", "bracesproul", "jakerachleff", "craigsdennis", "anhi", "169", "LarchLiu", "PaulLockett", "RCMatthias", "jwynia", "majiayu000", "mpskex", "shivachittamuru", "sinashaloudegi", "sowsan", "akira", "lucianotonet", "JGalego", "nat-n", "dirien", "donbr", "rahilvora", "WarrenTheRabbit", "StreetLamb", "ccurme", "dennisrall", "Mingqi2", "xxsl", "joaquin-borggio-lc"],
"message": "Thank you for your pull request and welcome to our community. We require contributors to sign our Contributor License Agreement, and we don't seem to have the username {{usersWithoutCLA}} on file. In order for us to review and merge your code, please complete the Individual Contributor License Agreement here https://forms.gle/AQFbtkWRoHXUgipM6 .\n\nThis process is done manually on our side, so after signing the form one of the maintainers will add you to the contributors list.\n\nFor more details about why we have a CLA and other contribution guidelines please see: https://github.com/langchain-ai/langserve/blob/main/CONTRIBUTING.md."
}
+1 -1
View File
@@ -31,7 +31,7 @@ jobs:
# Starting new jobs is also relatively slow,
# so linting on fewer versions makes CI faster.
python-version:
- "3.8"
- "3.10"
- "3.11"
steps:
- uses: actions/checkout@v3
@@ -1,94 +0,0 @@
name: pydantic v1/v2 compatibility
on:
workflow_call:
inputs:
working-directory:
required: true
type: string
description: "From which folder this pipeline executes"
env:
POETRY_VERSION: "1.5.1"
jobs:
build:
timeout-minutes: 10
defaults:
run:
working-directory: ${{ inputs.working-directory }}
runs-on: ubuntu-latest
strategy:
matrix:
python-version:
- "3.8"
- "3.9"
- "3.10"
- "3.11"
name: Pydantic v1/v2 compatibility - Python ${{ matrix.python-version }}
steps:
- uses: actions/checkout@v3
- name: Set up Python ${{ matrix.python-version }} + Poetry ${{ env.POETRY_VERSION }}
uses: "./.github/actions/poetry_setup"
with:
python-version: ${{ matrix.python-version }}
poetry-version: ${{ env.POETRY_VERSION }}
working-directory: ${{ inputs.working-directory }}
cache-key: pydantic-cross-compat
- name: Install dependencies
shell: bash
run: poetry install
- name: Install the opposite major version of pydantic
# If normal tests use pydantic v1, here we'll use v2, and vice versa.
shell: bash
run: |
# Determine the major part of pydantic version
REGULAR_VERSION=$(poetry run python -c "import pydantic; print(pydantic.__version__)" | cut -d. -f1)
if [[ "$REGULAR_VERSION" == "1" ]]; then
PYDANTIC_DEP=">=2.1,<3"
TEST_WITH_VERSION="2"
elif [[ "$REGULAR_VERSION" == "2" ]]; then
PYDANTIC_DEP="<2"
TEST_WITH_VERSION="1"
else
echo "Unexpected pydantic major version '$REGULAR_VERSION', cannot determine which version to use for cross-compatibility test."
exit 1
fi
# Install via `pip` instead of `poetry add` to avoid changing lockfile,
# which would prevent caching from working: the cache would get saved
# to a different key than where it gets loaded from.
poetry run pip install "pydantic${PYDANTIC_DEP}"
# Ensure that the correct pydantic is installed now.
echo "Checking pydantic version... Expecting ${TEST_WITH_VERSION}"
# Determine the major part of pydantic version
CURRENT_VERSION=$(poetry run python -c "import pydantic; print(pydantic.__version__)" | cut -d. -f1)
# Check that the major part of pydantic version is as expected, if not
# raise an error
if [[ "$CURRENT_VERSION" != "$TEST_WITH_VERSION" ]]; then
echo "Error: expected pydantic version ${CURRENT_VERSION} to have been installed, but found: ${TEST_WITH_VERSION}"
exit 1
fi
echo "Found pydantic version ${CURRENT_VERSION}, as expected"
- name: Run pydantic compatibility tests
shell: bash
run: make test
- name: Ensure the tests did not create any additional files
shell: bash
run: |
set -eu
STATUS="$(git status)"
echo "$STATUS"
# grep will exit non-zero if the target message isn't found,
# and `set -e` above will cause the step to fail.
echo "$STATUS" | grep 'nothing to commit, working tree clean'
+6
View File
@@ -7,6 +7,12 @@ on:
required: true
type: string
description: "From which folder this pipeline executes"
workflow_dispatch: # Allows to trigger the workflow manually in GitHub UI
inputs:
working-directory:
required: true
type: string
description: "From which folder this pipeline executes"
env:
POETRY_VERSION: "1.5.1"
-1
View File
@@ -20,7 +20,6 @@ jobs:
strategy:
matrix:
python-version:
- "3.8"
- "3.9"
- "3.10"
- "3.11"
-9
View File
@@ -39,13 +39,6 @@ jobs:
with:
working-directory: .
secrets: inherit
pydantic-compatibility:
uses:
./.github/workflows/_pydantic_compatibility.yml
with:
working-directory: .
secrets: inherit
test:
timeout-minutes: 10
runs-on: ubuntu-latest
@@ -55,8 +48,6 @@ jobs:
strategy:
matrix:
python-version:
- "3.8"
- "3.9"
- "3.10"
- "3.11"
name: Python ${{ matrix.python-version }} tests
+1
View File
@@ -10,4 +10,5 @@ jobs:
./.github/workflows/_release.yml
with:
working-directory: .
permissions: write-all
secrets: inherit
+222
View File
@@ -0,0 +1,222 @@
# LangGraph Platform Migration Guide
We have [recently announced](https://blog.langchain.dev/langgraph-platform-announce/) LangGraph Platform, a ***significantly*** enhanced solution for deploying agentic applications at scale.
LangGraph Platform incorporates [key design patterns and capabilities](https://langchain-ai.github.io/langgraph/concepts/langgraph_platform/#option-2-leveraging-langgraph-platform-for-complex-deployments) essential for production-level deployment of large language model (LLM) applications.
In contrast to LangServe, LangGraph Platform provides comprehensive, out-of-the-box support for [persistence](https://langchain-ai.github.io/langgraph/concepts/application_structure/), [memory](https://langchain-ai.github.io/langgraph/concepts/assistants/), [double-texting handling](https://langchain-ai.github.io/langgraph/concepts/double_texting/), [human-in-the-loop workflows](https://langchain-ai.github.io/langgraph/concepts/assistants/), [cron job scheduling](https://langchain-ai.github.io/langgraph/concepts/langgraph_server/#cron-jobs), [webhooks](https://langchain-ai.github.io/langgraph/concepts/langgraph_server/#webhooks), high-load management, advanced streaming, support for long-running tasks, background task processing, and much more.
The LangGraph Platform ecosystem includes the following components:
- [LangGraph Server](https://langchain-ai.github.io/langgraph/concepts/langgraph_server/): Provides an [Assistants API](https://langchain-ai.github.io/langgraph/cloud/reference/api/api_ref.html) for LLM applications (graphs) built with [LangGraph](https://langchain-ai.github.io/langgraph/). Available in both Python and JavaScript/TypeScript.
- [LangGraph Studio](https://langchain-ai.github.io/langgraph/concepts/langgraph_studio/): A specialized IDE for real-time visualization, debugging, and interaction via a graphical interface. Available as a web application or macOS desktop app, it's a substantial improvement over LangServe's playground.
- [SDK](https://langchain-ai.github.io/langgraph/concepts/sdk/): Enables programmatic interaction with the server, available in Python and JavaScript/TypeScript.
- [RemoteGraph](https://langchain-ai.github.io/langgraph/how-tos/use-remote-graph/): Allows interaction with a remote graph as if it were running locally, serving as LangGraph's equivalent to LangServe's RemoteRunnable. Available in both Python and JavaScript/TypeScript.
## Context
LangServe was built as a deployment solution for LangChain Runnables created using the [LangChain Expression Language (LCEL)](https://python.langchain.com/docs/concepts/lcel). In LangServe, the LCEL was the orchestration layer that managed the execution of the Runnable.
[LangGraph](https://langchain-ai.github.io/langgraph/) is an open source library created by the LangChain team that provides a more flexible orchestration layer that's better suited for creating more complex LLM applications. LangGraph Platform
is the deployment solution for LangGraph applications.
## LangServe Support
We recommend using LangGraph Platform rather than LangServe for new projects.
We will continue to accept bug fixes for LangServe from the community; however, we will not be accepting new feature contributions.
## Migration
If you would like to migrate an existing LangServe application to LangGraph Platform, you have two options:
1. You can wrap the existing `Runnable` that you expose in the LangServe application via `add_routes` in a `LangGraph` node. This is the quickest way to migrate your application to LangGraph Platform.
2. You can do a larger refactor to break up the existing LCEL into appropriate `LangGraph` nodes. This is recommended if you want to take advantage of more advanced features in LangGraph Platform.
### Option 1: Wrap Runnable in LangGraph Node
This option is the quickest way to migrate your application to LangGraph Platform. You can wrap the existing `Runnable` that you expose in the LangServe application via `add_routes` in a `LangGraph` node.
Original LangServe code:
```python
from langserve import add_routes
app = FastAPI()
# Some input schema
class Input(BaseModel):
input: str
foo: Optional[str]
# Some output schema
class Output(BaseModel):
output: Any
runnable = .... # Your existing Runnable
runnable_with_types = runnable.with_types(input_type=Input, output_type=Output)
# Adds routes to the app for using the chain under:
add_routes(
app,
runnable_with_types,
)
```
Migrated LangGraph Platform code:
```python
@dataclass
class InputState: # Equivalent to Input in the original code
"""Defines the input state, representing a narrower interface to the outside world.
This class is used to define the initial state and structure of incoming data.
See: https://langchain-ai.github.io/langgraph/concepts/low_level/#state
for more information.
"""
input: str
foo: Optional[str] = None
@dataclass
class OutputState: # Equivalent to Output in the original code
"""Defines the output state, representing a narrower interface to the outside world.
https://langchain-ai.github.io/langgraph/concepts/low_level/#state
"""
output: Any
@dataclass
class SharedState:
"""The full graph state.
https://langchain-ai.github.io/langgraph/concepts/low_level/#state
"""
input: str
foo: Optional[str] = None
output: Any
runnable = ... # Same code as before
async def my_node(state: InputState, config: RunnableConfig) -> OutputState:
"""Each node does work."""
return await runnable.ainvoke({"input": state.input, "foo": state.foo})
# Define a new graph
builder = StateGraph(
SharedState, config_schema=Configuration, input=InputState, output=OutputState
)
# Add the node to the graph
builder.add_node("my_node", my_node)
# Set the entrypoint as `call_model`
builder.add_edge("__start__", "my_node")
# Compile the workflow into an executable graph
graph = builder.compile()
graph.name = "New Graph" # This defines the custom name in LangSmith
```
### 2. Refactor LCEL into LangGraph Nodes
This option is recommended if you want to take advantage of more advanced features in LangGraph Platform.
#### Memory (alternative to `RunnableWithMessageHistory`)
For example, LangGraph comes with built-in persistence that is more general than LangChain's `RunnableWithMessageHistory`.
Please refer to the guide on [upgrading to LangGraph memory](https://python.langchain.com/docs/versions/migrating_memory/) for more details.
#### Agents
If you're relying on legacy LangChain agents, you can migrate them into the pre-built
LangGraph agents. Please refer to the guide on [migrating agents](https://python.langchain.com/docs/how_to/migrate_agent/) for more details.
#### Custom Chains
If you created a custom chain and used LCEL to orchestrate it, you will usually be able to refactor it into a LangGraph without too much difficulty.
There isn't a one-size-fits-all guide for this, but generally speaking, consider creating
a separate node for any long-running step in your LCEL chain or any step that you would
want to be able to monitor or debug separately.
For example, if you have a simple Retrieval Augmented Generation (RAG) pipeline, you might have a node for the retrieval step and a node for the generation step.
Original LCEL code:
```python
...
rag_chain = (
{"context": retriever | format_docs, "question": RunnablePassthrough()}
| prompt
| llm
| StrOutputParser()
)
rag_chain.with_types(input_type=Input, output_type=Output)
```
Using LangGraph for the same pipeline:
```python
@dataclass
class InputState: # Equivalent to Input in the original code
"""Input question from the user."""
question: str
@dataclass
class OutputState: # Equivalent to Output in the original code
"""The output from the graph."""
answer: str
@dataclass
class SharedState:
question: str
docs: List[str]
response: str
async def retriever_node(state: InputState) -> SharedState:
"""Rettrieve documents based on the user's question."""
documents = await retriever.ainvoke({"context": state.question})
return {
"docs": documents
}
async def generator_node(state: SharedState) -> OutputState:
"""Generate an answer using an LLM based on the retrieved documents and question."""
context = " -- DOCUMENT -- ".join(state.docs)
prompt = [
SystemMessage(
content=(
"Answer the user's question based on the list of documents "
"that were retrieved. Here are the documents: \n\n"
f"{context}"
)
),
HumanMessage(content=state.question),
]
ai_message = await llm.ainvoke(prompt)
return {"answer": ai_message.content}
# Define a new graph
builder = StateGraph(
SharedState, config_schema=Configuration, input=InputState, output=OutputState
)
builder.add_node("retriever", retriever_node)
builder.add_node("generator", generator_node)
builder.add_edge("__start__", "retriever")
builder.add_edge("retriever", "generator")
graph = builder.compile()
graph.name = "RAG Graph"
```
Please see the [LangGraph tutorials](https://langchain-ai.github.io/langgraph/tutorials/)
for tutorials and examples that will help you get started with LangGraph
and LangGraph Platform.
+27 -25
View File
@@ -5,9 +5,14 @@
[![Open Issues](https://img.shields.io/github/issues-raw/langchain-ai/langserve)](https://github.com/langchain-ai/langserve/issues)
[![](https://dcbadge.vercel.app/api/server/6adMQxSpJS?compact=true&style=flat)](https://discord.com/channels/1038097195422978059/1170024642245832774)
🚩 We will be releasing a hosted version of LangServe for one-click deployments of
LangChain applications. [Sign up here](https://forms.gle/KC13Nzn76UeLaghK7)
to get on the waitlist.
> [!WARNING]
> We recommend using LangGraph Platform rather than LangServe for new projects.
>
> Please see the [LangGraph Platform Migration Guide](./MIGRATION.md) for more information.
>
> We will continue to accept bug fixes for LangServe from the community; however, we
> will not be accepting new feature contributions.
## Overview
@@ -42,19 +47,18 @@ in [LangChain.js](https://js.langchain.com/docs/ecosystem/langserve).
locally (or call the HTTP API directly)
- [LangServe Hub](https://github.com/langchain-ai/langchain/blob/master/templates/README.md)
## ⚠️ LangGraph Compatibility
LangServe is designed to primarily deploy simple Runnables and work with well-known primitives in langchain-core.
If you need a deployment option for LangGraph, you should instead be looking at [LangGraph Cloud (beta)](https://langchain-ai.github.io/langgraph/cloud/) which will
be better suited for deploying LangGraph applications.
## Limitations
- Client callbacks are not yet supported for events that originate on the server
- OpenAPI docs will not be generated when using Pydantic V2. Fast API does not
support [mixing pydantic v1 and v2 namespaces](https://github.com/tiangolo/fastapi/issues/10360).
See section below for more details.
## Hosted LangServe
We will be releasing a hosted version of LangServe for one-click deployments of
LangChain
applications. [Sign up here](https://forms.gle/KC13Nzn76UeLaghK7)
to get on the waitlist.
- Versions of LangServe <= 0.2.0, will not generate OpenAPI docs properly when using Pydantic V2 as Fast API does not support [mixing pydantic v1 and v2 namespaces](https://github.com/tiangolo/fastapi/issues/10360).
See section below for more details. Either upgrade to LangServe>=0.3.0 or downgrade Pydantic to pydantic 1.
## Security
@@ -96,7 +100,7 @@ langchain app new my-app
add_routes(app. NotImplemented)
```
### 3. Use `poetry` to add 3rd party packages (e.g., langchain-openai, langchain-anthropic, langchain-mistral etc).
### 3. Use `poetry` to add 3rd party packages (e.g., langchain-openai, langchain-anthropic, langchain-mistral, etc).
```sh
poetry add [package-name] // e.g `poetry add langchain-openai`
@@ -116,12 +120,7 @@ poetry run langchain serve --port=8100
## Examples
Get your LangServe instance started quickly with
[LangChain Templates](https://github.com/langchain-ai/langchain/blob/master/templates/README.md).
For more examples, see the templates
[index](https://github.com/langchain-ai/langchain/blob/master/templates/docs/INDEX.md)
or the [examples](https://github.com/langchain-ai/langserve/tree/main/examples)
Get your LangServe instances started quickly with the [examples](https://github.com/langchain-ai/langserve/tree/main/examples)
directory.
| Description | Links |
@@ -212,8 +211,9 @@ app.add_middleware(
If you've deployed the server above, you can view the generated OpenAPI docs using:
> ⚠️ If using pydantic v2, docs will not be generated for _invoke_, _batch_, _stream_,
> ⚠️ If using LangServe <= 0.2.0 and pydantic v2, docs will not be generated for _invoke_, _batch_, _stream_,
> _stream_log_. See [Pydantic](#pydantic) section below for more details.
> To resolve please upgrade to LangServe 0.3.0.
```sh
curl localhost:8000/docs
@@ -384,7 +384,7 @@ prompt = ChatPromptTemplate.from_messages(
]
)
chain = prompt | ChatAnthropic(model="claude-2")
chain = prompt | ChatAnthropic(model="claude-2.1")
class InputChat(BaseModel):
@@ -476,10 +476,12 @@ gcloud run deploy [your-service-name] --source . --port 8001 --allow-unauthentic
## Pydantic
LangServe provides support for Pydantic 2 with some limitations.
LangServe>=0.3 fully supports Pydantic 2.
If you're using an earlier version of LangServe (<= 0.2), then please note that support for Pydantic 2 has the following limitations:
1. OpenAPI docs will not be generated for invoke/batch/stream/stream_log when using
Pydantic V2. Fast API does not support [mixing pydantic v1 and v2 namespaces].
Pydantic V2. Fast API does not support [mixing pydantic v1 and v2 namespaces]. To fix this, use `pip install pydantic==1.10.17`.
2. LangChain uses the v1 namespace in Pydantic v2. Please read
the [following guidelines to ensure compatibility with LangChain](https://github.com/langchain-ai/langchain/discussions/9337)
@@ -776,7 +778,7 @@ prompt = ChatPromptTemplate.from_messages(
]
)
chain = prompt | ChatAnthropic(model="claude-2")
chain = prompt | ChatAnthropic(model="claude-2.1")
class MessageListInput(BaseModel):
+1 -1
View File
@@ -23,12 +23,12 @@ from fastapi import FastAPI
from langchain.agents import AgentExecutor
from langchain.agents.format_scratchpad import format_to_openai_functions
from langchain.agents.output_parsers import OpenAIFunctionsAgentOutputParser
from langchain.pydantic_v1 import BaseModel
from langchain_community.vectorstores import FAISS
from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder
from langchain_core.tools import tool
from langchain_core.utils.function_calling import format_tool_to_openai_function
from langchain_openai import ChatOpenAI, OpenAIEmbeddings
from pydantic import BaseModel
from langserve import add_routes
+1 -1
View File
@@ -57,9 +57,9 @@ from langchain_core.runnables import RunnableLambda
from langchain_core.tools import tool
from langchain_core.utils.function_calling import format_tool_to_openai_tool
from langchain_openai import ChatOpenAI
from pydantic import BaseModel
from langserve import add_routes
from langserve.pydantic_v1 import BaseModel
prompt = ChatPromptTemplate.from_messages(
[
+1 -1
View File
@@ -36,9 +36,9 @@ from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder
from langchain_core.tools import tool
from langchain_core.utils.function_calling import format_tool_to_openai_tool
from langchain_openai import ChatOpenAI
from pydantic import BaseModel, Field
from langserve import add_routes
from langserve.pydantic_v1 import BaseModel, Field
prompt = ChatPromptTemplate.from_messages(
[
+5 -6
View File
@@ -35,7 +35,7 @@ from typing import Any, List, Optional, Union
from fastapi import Depends, FastAPI, HTTPException, Request, Response, status
from fastapi.security import OAuth2PasswordBearer, OAuth2PasswordRequestForm
from langchain_community.vectorstores.chroma import Chroma
from langchain_chroma import Chroma
from langchain_core.documents import Document
from langchain_core.runnables import (
ConfigurableField,
@@ -44,10 +44,10 @@ from langchain_core.runnables import (
)
from langchain_core.vectorstores import VectorStore
from langchain_openai import OpenAIEmbeddings
from pydantic import BaseModel, ConfigDict
from typing_extensions import Annotated
from langserve import APIHandler
from langserve.pydantic_v1 import BaseModel
class User(BaseModel):
@@ -150,10 +150,9 @@ class PerUserVectorstore(RunnableSerializable):
user_id: Optional[str]
vectorstore: VectorStore
class Config:
# Allow arbitrary types since VectorStore is an abstract interface
# and not a pydantic model
arbitrary_types_allowed = True
model_config = ConfigDict(
arbitrary_types_allowed=True,
)
def _invoke(
self, input: str, config: Optional[RunnableConfig] = None, **kwargs: Any
@@ -36,7 +36,7 @@ from typing import Any, Dict, List, Optional, Union
from fastapi import Depends, FastAPI, HTTPException, Request, status
from fastapi.security import OAuth2PasswordBearer, OAuth2PasswordRequestForm
from langchain_community.vectorstores.chroma import Chroma
from langchain_chroma import Chroma
from langchain_core.documents import Document
from langchain_core.runnables import (
ConfigurableField,
@@ -45,10 +45,10 @@ from langchain_core.runnables import (
)
from langchain_core.vectorstores import VectorStore
from langchain_openai import OpenAIEmbeddings
from pydantic import BaseModel, ConfigDict
from typing_extensions import Annotated
from langserve import add_routes
from langserve.pydantic_v1 import BaseModel
class User(BaseModel):
@@ -147,10 +147,9 @@ class PerUserVectorstore(RunnableSerializable):
user_id: Optional[str]
vectorstore: VectorStore
class Config:
# Allow arbitrary types since VectorStore is an abstract interface
# and not a pydantic model
arbitrary_types_allowed = True
model_config = ConfigDict(
arbitrary_types_allowed=True,
)
def _invoke(
self, input: str, config: Optional[RunnableConfig] = None, **kwargs: Any
@@ -8,9 +8,9 @@ from fastapi import FastAPI
from langchain_anthropic import ChatAnthropic
from langchain_core.messages import AIMessage, HumanMessage, SystemMessage
from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder
from pydantic import BaseModel, Field
from langserve import add_routes
from langserve.pydantic_v1 import BaseModel, Field
app = FastAPI(
title="LangChain Server",
@@ -28,7 +28,7 @@ prompt = ChatPromptTemplate.from_messages(
]
)
chain = prompt | ChatAnthropic(model="claude-2")
chain = prompt | ChatAnthropic(model="claude-2.1")
class InputChat(BaseModel):
+1 -1
View File
@@ -8,9 +8,9 @@ from fastapi import FastAPI
from langchain_anthropic.chat_models import ChatAnthropic
from langchain_core.messages import AIMessage, HumanMessage, SystemMessage
from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder
from pydantic import BaseModel, Field
from langserve import add_routes
from langserve.pydantic_v1 import BaseModel, Field
app = FastAPI(
title="LangChain Server",
+2 -2
View File
@@ -18,9 +18,9 @@ from langchain_community.chat_message_histories import FileChatMessageHistory
from langchain_core.chat_history import BaseChatMessageHistory
from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder
from langchain_core.runnables.history import RunnableWithMessageHistory
from pydantic import BaseModel, Field
from langserve import add_routes
from langserve.pydantic_v1 import BaseModel, Field
def _is_valid_identifier(value: str) -> bool:
@@ -76,7 +76,7 @@ prompt = ChatPromptTemplate.from_messages(
]
)
chain = prompt | ChatAnthropic(model="claude-2")
chain = prompt | ChatAnthropic(model="claude-2.1")
class InputChat(BaseModel):
@@ -20,7 +20,6 @@ from fastapi import FastAPI
from langchain.agents import AgentExecutor
from langchain.agents.format_scratchpad import format_to_openai_functions
from langchain.agents.output_parsers import OpenAIFunctionsAgentOutputParser
from langchain.pydantic_v1 import BaseModel
from langchain_community.vectorstores import FAISS
from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder
from langchain_core.runnables import (
@@ -33,6 +32,7 @@ from langchain_core.runnables.utils import Input, Output
from langchain_core.tools import tool
from langchain_core.utils.function_calling import format_tool_to_openai_function
from langchain_openai import ChatOpenAI, OpenAIEmbeddings
from pydantic import BaseModel
from langserve import add_routes
+9 -53
View File
@@ -23,14 +23,7 @@
"tags": []
},
"outputs": [],
"source": [
"import requests\n",
"\n",
"inputs = {\"input\": {\"topic\": \"sports\"}}\n",
"response = requests.post(\"http://localhost:8000/configurable_temp/invoke\", json=inputs)\n",
"\n",
"response.json()"
]
"source": ["import requests\n\ninputs = {\"input\": {\"topic\": \"sports\"}}\nresponse = requests.post(\"http://localhost:8000/configurable_temp/invoke\", json=inputs)\n\nresponse.json()"]
},
{
"cell_type": "markdown",
@@ -46,11 +39,7 @@
"tags": []
},
"outputs": [],
"source": [
"from langserve import RemoteRunnable\n",
"\n",
"remote_runnable = RemoteRunnable(\"http://localhost:8000/configurable_temp\")"
]
"source": ["from langserve import RemoteRunnable\n\nremote_runnable = RemoteRunnable(\"http://localhost:8000/configurable_temp\")"]
},
{
"cell_type": "markdown",
@@ -66,9 +55,7 @@
"tags": []
},
"outputs": [],
"source": [
"response = await remote_runnable.ainvoke({\"topic\": \"sports\"})"
]
"source": ["response = await remote_runnable.ainvoke({\"topic\": \"sports\"})"]
},
{
"cell_type": "markdown",
@@ -84,11 +71,7 @@
"tags": []
},
"outputs": [],
"source": [
"from langchain.schema.runnable.config import RunnableConfig\n",
"\n",
"remote_runnable.batch([{\"topic\": \"sports\"}, {\"topic\": \"cars\"}])"
]
"source": ["from langchain_core.runnables import RunnableConfig\n\nremote_runnable.batch([{\"topic\": \"sports\"}, {\"topic\": \"cars\"}])"]
},
{
"cell_type": "markdown",
@@ -104,10 +87,7 @@
"tags": []
},
"outputs": [],
"source": [
"async for chunk in remote_runnable.astream({\"topic\": \"bears, but a bit verbose\"}):\n",
" print(chunk, end=\"\", flush=True)"
]
"source": ["async for chunk in remote_runnable.astream({\"topic\": \"bears, but a bit verbose\"}):\n print(chunk, end=\"\", flush=True)"]
},
{
"cell_type": "markdown",
@@ -157,14 +137,7 @@
"tags": []
},
"outputs": [],
"source": [
"await remote_runnable.ainvoke(\n",
" {\"topic\": \"sports\"},\n",
" config={\n",
" \"configurable\": {\"prompt\": \"how to say {topic} in french\", \"llm\": \"low_temp\"}\n",
" },\n",
")"
]
"source": ["await remote_runnable.ainvoke(\n {\"topic\": \"sports\"},\n config={\n \"configurable\": {\"prompt\": \"how to say {topic} in french\", \"llm\": \"low_temp\"}\n },\n)"]
},
{
"cell_type": "markdown",
@@ -221,13 +194,7 @@
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# The model will fail with an auth error\n",
"unauthenticated_response = requests.post(\n",
" \"http://localhost:8000/auth_from_header/invoke\", json={\"input\": \"hello\"}\n",
")\n",
"unauthenticated_response.json()"
]
"source": ["# The model will fail with an auth error\nunauthenticated_response = requests.post(\n \"http://localhost:8000/auth_from_header/invoke\", json={\"input\": \"hello\"}\n)\nunauthenticated_response.json()"]
},
{
"cell_type": "markdown",
@@ -244,25 +211,14 @@
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# The model will succeed as long as the above shell script is run previously\n",
"import os\n",
"\n",
"test_key = os.environ[\"TEST_API_KEY\"]\n",
"authenticated_response = requests.post(\n",
" \"http://localhost:8000/auth_from_header/invoke\",\n",
" json={\"input\": \"hello\"},\n",
" headers={\"x-api-key\": test_key},\n",
")\n",
"authenticated_response.json()"
]
"source": ["# The model will succeed as long as the above shell script is run previously\nimport os\n\ntest_key = os.environ[\"TEST_API_KEY\"]\nauthenticated_response = requests.post(\n \"http://localhost:8000/auth_from_header/invoke\",\n json={\"input\": \"hello\"},\n headers={\"x-api-key\": test_key},\n)\nauthenticated_response.json()"]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
"source": [""]
}
],
"metadata": {
+1 -1
View File
@@ -15,9 +15,9 @@ from langchain_core.runnables import (
)
from langchain_core.vectorstores import VectorStore
from langchain_openai import OpenAIEmbeddings
from pydantic import BaseModel, Field
from langserve import add_routes
from langserve.pydantic_v1 import BaseModel, Field
vectorstore1 = FAISS.from_texts(
["cats like fish", "dogs like sticks"], embedding=OpenAIEmbeddings()
@@ -18,9 +18,9 @@ from langchain_core.output_parsers import StrOutputParser
from langchain_core.prompts import ChatPromptTemplate, PromptTemplate, format_document
from langchain_core.runnables import RunnableMap, RunnablePassthrough
from langchain_openai import ChatOpenAI, OpenAIEmbeddings
from pydantic import BaseModel, Field
from langserve import add_routes
from langserve.pydantic_v1 import BaseModel, Field
_TEMPLATE = """Given the following conversation and a follow up question, rephrase the
follow up question to be a standalone question, in its original language.
+1 -1
View File
@@ -15,10 +15,10 @@ allowing one to upload a binary file using the langserve playground UI.
import base64
from fastapi import FastAPI
from langchain.pydantic_v1 import Field
from langchain_community.document_loaders.parsers.pdf import PDFMinerParser
from langchain_core.document_loaders import Blob
from langchain_core.runnables import RunnableLambda
from pydantic import Field
from langserve import CustomUserType, add_routes
+13 -81
View File
@@ -16,9 +16,7 @@
"tags": []
},
"outputs": [],
"source": [
"from langchain.prompts.chat import ChatPromptTemplate"
]
"source": ["from langchain_core.prompts import ChatPromptTemplate"]
},
{
"cell_type": "code",
@@ -27,12 +25,7 @@
"tags": []
},
"outputs": [],
"source": [
"from langserve import RemoteRunnable\n",
"\n",
"openai_llm = RemoteRunnable(\"http://localhost:8000/openai/\")\n",
"anthropic = RemoteRunnable(\"http://localhost:8000/anthropic/\")"
]
"source": ["from langserve import RemoteRunnable\n\nopenai_llm = RemoteRunnable(\"http://localhost:8000/openai/\")\nanthropic = RemoteRunnable(\"http://localhost:8000/anthropic/\")"]
},
{
"cell_type": "markdown",
@@ -48,18 +41,7 @@
"tags": []
},
"outputs": [],
"source": [
"prompt = ChatPromptTemplate.from_messages(\n",
" [\n",
" (\n",
" \"system\",\n",
" \"You are a highly educated person who loves to use big words. \"\n",
" + \"You are also concise. Never answer in more than three sentences.\",\n",
" ),\n",
" (\"human\", \"Tell me about your favorite novel\"),\n",
" ]\n",
").format_messages()"
]
"source": ["prompt = ChatPromptTemplate.from_messages(\n [\n (\n \"system\",\n \"You are a highly educated person who loves to use big words. \"\n + \"You are also concise. Never answer in more than three sentences.\",\n ),\n (\"human\", \"Tell me about your favorite novel\"),\n ]\n).format_messages()"]
},
{
"cell_type": "markdown",
@@ -86,9 +68,7 @@
"output_type": "execute_result"
}
],
"source": [
"anthropic.invoke(prompt)"
]
"source": ["anthropic.invoke(prompt)"]
},
{
"cell_type": "code",
@@ -97,9 +77,7 @@
"tags": []
},
"outputs": [],
"source": [
"openai_llm.invoke(prompt)"
]
"source": ["openai_llm.invoke(prompt)"]
},
{
"cell_type": "markdown",
@@ -126,9 +104,7 @@
"output_type": "execute_result"
}
],
"source": [
"await openai_llm.ainvoke(prompt)"
]
"source": ["await openai_llm.ainvoke(prompt)"]
},
{
"cell_type": "code",
@@ -149,9 +125,7 @@
"output_type": "execute_result"
}
],
"source": [
"anthropic.batch([prompt, prompt])"
]
"source": ["anthropic.batch([prompt, prompt])"]
},
{
"cell_type": "code",
@@ -172,9 +146,7 @@
"output_type": "execute_result"
}
],
"source": [
"await anthropic.abatch([prompt, prompt])"
]
"source": ["await anthropic.abatch([prompt, prompt])"]
},
{
"cell_type": "markdown",
@@ -198,10 +170,7 @@
]
}
],
"source": [
"for chunk in anthropic.stream(prompt):\n",
" print(chunk.content, end=\"\", flush=True)"
]
"source": ["for chunk in anthropic.stream(prompt):\n print(chunk.content, end=\"\", flush=True)"]
},
{
"cell_type": "code",
@@ -218,19 +187,14 @@
]
}
],
"source": [
"async for chunk in anthropic.astream(prompt):\n",
" print(chunk.content, end=\"\", flush=True)"
]
"source": ["async for chunk in anthropic.astream(prompt):\n print(chunk.content, end=\"\", flush=True)"]
},
{
"cell_type": "code",
"execution_count": 13,
"metadata": {},
"outputs": [],
"source": [
"from langchain.schema.runnable import RunnablePassthrough"
]
"source": ["from langchain_core.runnables import RunnablePassthrough"]
},
{
"cell_type": "code",
@@ -239,37 +203,7 @@
"tags": []
},
"outputs": [],
"source": [
"comedian_chain = (\n",
" ChatPromptTemplate.from_messages(\n",
" [\n",
" (\n",
" \"system\",\n",
" \"You are a comedian that sometimes tells funny jokes and other times you just state facts that are not funny. Please either tell a joke or state fact now but only output one.\",\n",
" ),\n",
" ]\n",
" )\n",
" | openai_llm\n",
")\n",
"\n",
"joke_classifier_chain = (\n",
" ChatPromptTemplate.from_messages(\n",
" [\n",
" (\n",
" \"system\",\n",
" \"Please determine if the joke is funny. Say `funny` if it's funny and `not funny` if not funny. Then repeat the first five words of the joke for reference...\",\n",
" ),\n",
" (\"human\", \"{joke}\"),\n",
" ]\n",
" )\n",
" | anthropic\n",
")\n",
"\n",
"\n",
"chain = {\"joke\": comedian_chain} | RunnablePassthrough.assign(\n",
" classification=joke_classifier_chain\n",
")"
]
"source": ["comedian_chain = (\n ChatPromptTemplate.from_messages(\n [\n (\n \"system\",\n \"You are a comedian that sometimes tells funny jokes and other times you just state facts that are not funny. Please either tell a joke or state fact now but only output one.\",\n ),\n ]\n )\n | openai_llm\n)\n\njoke_classifier_chain = (\n ChatPromptTemplate.from_messages(\n [\n (\n \"system\",\n \"Please determine if the joke is funny. Say `funny` if it's funny and `not funny` if not funny. Then repeat the first five words of the joke for reference...\",\n ),\n (\"human\", \"{joke}\"),\n ]\n )\n | anthropic\n)\n\n\nchain = {\"joke\": comedian_chain} | RunnablePassthrough.assign(\n classification=joke_classifier_chain\n)"]
},
{
"cell_type": "code",
@@ -290,9 +224,7 @@
"output_type": "execute_result"
}
],
"source": [
"chain.invoke({})"
]
"source": ["chain.invoke({})"]
}
],
"metadata": {
+11 -46
View File
@@ -18,9 +18,7 @@
"tags": []
},
"outputs": [],
"source": [
"from langchain.prompts.chat import ChatPromptTemplate"
]
"source": ["from langchain_core.prompts import ChatPromptTemplate"]
},
{
"cell_type": "code",
@@ -29,11 +27,7 @@
"tags": []
},
"outputs": [],
"source": [
"from langserve import RemoteRunnable\n",
"\n",
"model = RemoteRunnable(\"http://localhost:8000/ollama/\")"
]
"source": ["from langserve import RemoteRunnable\n\nmodel = RemoteRunnable(\"http://localhost:8000/ollama/\")"]
},
{
"cell_type": "markdown",
@@ -49,9 +43,7 @@
"tags": []
},
"outputs": [],
"source": [
"prompt = \"Tell me a 3 sentence story about a cat.\""
]
"source": ["prompt = \"Tell me a 3 sentence story about a cat.\""]
},
{
"cell_type": "code",
@@ -71,9 +63,7 @@
"output_type": "execute_result"
}
],
"source": [
"model.invoke(prompt)"
]
"source": ["model.invoke(prompt)"]
},
{
"cell_type": "code",
@@ -93,9 +83,7 @@
"output_type": "execute_result"
}
],
"source": [
"await model.ainvoke(prompt)"
]
"source": ["await model.ainvoke(prompt)"]
},
{
"cell_type": "markdown",
@@ -131,10 +119,7 @@
"output_type": "execute_result"
}
],
"source": [
"%%time\n",
"model.batch([prompt, prompt])"
]
"source": ["%%time\nmodel.batch([prompt, prompt])"]
},
{
"cell_type": "code",
@@ -152,11 +137,7 @@
]
}
],
"source": [
"%%time\n",
"for _ in range(2):\n",
" model.invoke(prompt)"
]
"source": ["%%time\nfor _ in range(2):\n model.invoke(prompt)"]
},
{
"cell_type": "code",
@@ -177,9 +158,7 @@
"output_type": "execute_result"
}
],
"source": [
"await model.abatch([prompt, prompt])"
]
"source": ["await model.abatch([prompt, prompt])"]
},
{
"cell_type": "markdown",
@@ -206,10 +185,7 @@
]
}
],
"source": [
"for chunk in model.stream(prompt):\n",
" print(chunk.content, end=\"|\", flush=True)"
]
"source": ["for chunk in model.stream(prompt):\n print(chunk.content, end=\"|\", flush=True)"]
},
{
"cell_type": "code",
@@ -227,10 +203,7 @@
]
}
],
"source": [
"async for chunk in model.astream(prompt):\n",
" print(chunk.content, end=\"|\", flush=True)"
]
"source": ["async for chunk in model.astream(prompt):\n print(chunk.content, end=\"|\", flush=True)"]
},
{
"cell_type": "markdown",
@@ -266,15 +239,7 @@
]
}
],
"source": [
"i = 0\n",
"async for event in model.astream_events(prompt, version='v1'):\n",
" print(event)\n",
" if i > 10:\n",
" print('...')\n",
" break\n",
" i += 1"
]
"source": ["i = 0\nasync for event in model.astream_events(prompt, version='v1'):\n print(event)\n if i > 10:\n print('...')\n break\n i += 1"]
}
],
"metadata": {
+6 -23
View File
@@ -16,9 +16,7 @@
"tags": []
},
"outputs": [],
"source": [
"from langchain.prompts.chat import ChatPromptTemplate"
]
"source": ["from langchain_core.prompts import ChatPromptTemplate"]
},
{
"cell_type": "code",
@@ -27,11 +25,7 @@
"tags": []
},
"outputs": [],
"source": [
"from langserve import RemoteRunnable\n",
"\n",
"chain = RemoteRunnable(\"http://localhost:8000/v1/\")"
]
"source": ["from langserve import RemoteRunnable\n\nchain = RemoteRunnable(\"http://localhost:8000/v1/\")"]
},
{
"cell_type": "markdown",
@@ -59,9 +53,7 @@
"output_type": "execute_result"
}
],
"source": [
"chain.invoke({'thing': 'apple', 'language': 'italian', 'info': {\"user_id\": 42, \"user_info\": {\"address\": 42}}})"
]
"source": ["chain.invoke({'thing': 'apple', 'language': 'italian', 'info': {\"user_id\": 42, \"user_info\": {\"address\": 42}}})"]
},
{
"cell_type": "code",
@@ -82,10 +74,7 @@
]
}
],
"source": [
"for chunk in chain.stream({'thing': 'apple', 'language': 'italian', 'info': {\"user_id\": 42, \"user_info\": {\"address\": 42}}}):\n",
" print(chunk)"
]
"source": ["for chunk in chain.stream({'thing': 'apple', 'language': 'italian', 'info': {\"user_id\": 42, \"user_info\": {\"address\": 42}}}):\n print(chunk)"]
},
{
"cell_type": "code",
@@ -94,11 +83,7 @@
"tags": []
},
"outputs": [],
"source": [
"from langserve import RemoteRunnable\n",
"\n",
"chain = RemoteRunnable(\"http://localhost:8000/v2/\")"
]
"source": ["from langserve import RemoteRunnable\n\nchain = RemoteRunnable(\"http://localhost:8000/v2/\")"]
},
{
"cell_type": "code",
@@ -119,9 +104,7 @@
"output_type": "execute_result"
}
],
"source": [
"chain.invoke({'thing': 'apple', 'language': 'italian', 'info': {\"user_id\": 42, \"user_info\": {\"address\": 42}}})"
]
"source": ["chain.invoke({'thing': 'apple', 'language': 'italian', 'info': {\"user_id\": 42, \"user_info\": {\"address\": 42}}})"]
}
],
"metadata": {
+2 -2
View File
@@ -10,9 +10,9 @@ from langchain_anthropic import ChatAnthropic
from langchain_core.messages import AIMessage, HumanMessage, SystemMessage
from langchain_core.output_parsers import StrOutputParser
from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder
from pydantic import BaseModel, Field
from langserve import add_routes
from langserve.pydantic_v1 import BaseModel, Field
app = FastAPI(
title="LangChain Server",
@@ -40,7 +40,7 @@ prompt = ChatPromptTemplate.from_messages(
]
)
chain = prompt | ChatAnthropic(model="claude-2") | StrOutputParser()
chain = prompt | ChatAnthropic(model="claude-2.1") | StrOutputParser()
class InputChat(BaseModel):
+1 -1
View File
@@ -6,7 +6,6 @@ from typing import Any, Dict, List, Tuple
from fastapi import FastAPI
from fastapi.middleware.cors import CORSMiddleware
from langchain.pydantic_v1 import BaseModel, Field
from langchain_community.document_loaders.parsers.pdf import PDFMinerParser
from langchain_core.document_loaders import Blob
from langchain_core.messages import (
@@ -17,6 +16,7 @@ from langchain_core.messages import (
)
from langchain_core.runnables import RunnableLambda, RunnableParallel
from langchain_openai import ChatOpenAI
from pydantic import BaseModel, Field
from langserve import CustomUserType
from langserve.server import add_routes
+53
View File
@@ -0,0 +1,53 @@
from typing import Any, Dict, Type, cast
from pydantic import BaseModel, ConfigDict, RootModel
from pydantic.json_schema import (
DEFAULT_REF_TEMPLATE,
GenerateJsonSchema,
JsonSchemaMode,
)
def _create_root_model(name: str, type_: Any) -> Type[RootModel]:
"""Create a base class."""
def schema(
cls: Type[BaseModel],
by_alias: bool = True,
ref_template: str = DEFAULT_REF_TEMPLATE,
) -> Dict[str, Any]:
# Complains about schema not being defined in superclass
schema_ = super(cls, cls).schema( # type: ignore[misc]
by_alias=by_alias, ref_template=ref_template
)
schema_["title"] = name
return schema_
def model_json_schema(
cls: Type[BaseModel],
by_alias: bool = True,
ref_template: str = DEFAULT_REF_TEMPLATE,
schema_generator: type[GenerateJsonSchema] = GenerateJsonSchema,
mode: JsonSchemaMode = "validation",
) -> Dict[str, Any]:
# Complains about model_json_schema not being defined in superclass
schema_ = super(cls, cls).model_json_schema( # type: ignore[misc]
by_alias=by_alias,
ref_template=ref_template,
schema_generator=schema_generator,
mode=mode,
)
schema_["title"] = name
return schema_
base_class_attributes = {
"__annotations__": {"root": type_},
"model_config": ConfigDict(arbitrary_types_allowed=True),
"schema": classmethod(schema),
"model_json_schema": classmethod(model_json_schema),
# Should replace __module__ with caller based on stack frame.
"__module__": "langserve._pydantic",
}
custom_root_type = type(name, (RootModel,), base_class_attributes)
return cast(Type[RootModel], custom_root_type)
+119 -55
View File
@@ -29,6 +29,7 @@ from fastapi.encoders import jsonable_encoder
from fastapi.exceptions import RequestValidationError
from langchain_core._api.beta_decorator import warn_beta
from langchain_core.callbacks.base import AsyncCallbackHandler
from langchain_core.callbacks.manager import BaseCallbackManager
from langchain_core.load.serializable import Serializable
from langchain_core.runnables import Runnable, RunnableConfig
from langchain_core.runnables.config import (
@@ -40,14 +41,16 @@ from langchain_core.tracers import RunLogPatch
from langsmith import client as ls_client
from langsmith.schemas import FeedbackIngestToken
from langsmith.utils import tracing_is_enabled
from pydantic import BaseModel, Field, RootModel, ValidationError, create_model
from pydantic.v1 import BaseModel as BaseModelV1
from starlette.requests import Request
from starlette.responses import JSONResponse, Response
from typing_extensions import TypedDict
from langserve._pydantic import _create_root_model
from langserve.callbacks import AsyncEventAggregatorCallback, CallbackEventDict
from langserve.lzstring import LZString
from langserve.playground import serve_playground
from langserve.pydantic_v1 import BaseModel, Field, ValidationError, create_model
from langserve.schema import (
BatchResponseMetadata,
CustomUserType,
@@ -59,7 +62,7 @@ from langserve.schema import (
PublicTraceLink,
PublicTraceLinkCreateRequest,
)
from langserve.serialization import WellKnownLCSerializer
from langserve.serialization import Serializer, WellKnownLCSerializer
from langserve.validation import (
BatchBaseResponse,
BatchRequestShallowValidator,
@@ -184,11 +187,11 @@ async def _unpack_request_config(
config_dicts = []
for config in client_sent_configs:
if isinstance(config, str):
config_dicts.append(model(**_config_from_hash(config)).dict())
config_dicts.append(model(**_config_from_hash(config)).model_dump())
elif isinstance(config, BaseModel):
config_dicts.append(config.dict())
config_dicts.append(config.model_dump())
elif isinstance(config, Mapping):
config_dicts.append(model(**config).dict())
config_dicts.append(model(**config).model_dump())
else:
raise TypeError(f"Expected a string, dict or BaseModel got {type(config)}")
config = merge_configs(*config_dicts)
@@ -255,10 +258,12 @@ def _update_config_with_defaults(
}
metadata.update(hosted_metadata)
non_overridable_default_config = RunnableConfig(
run_name=run_name,
metadata=metadata,
)
non_overridable_default_config: RunnableConfig = {
"metadata": metadata,
}
if run_name:
non_overridable_default_config["run_name"] = run_name
# merge_configs is last-writer-wins, so we specifically pass in the
# overridable configs first, then the user provided configs, then
@@ -279,8 +284,8 @@ def _update_config_with_defaults(
def _unpack_input(validated_model: BaseModel) -> Any:
"""Unpack the decoded input from the validated model."""
if hasattr(validated_model, "__root__"):
model = validated_model.__root__
if isinstance(validated_model, RootModel):
model = validated_model.root
else:
model = validated_model
@@ -294,7 +299,7 @@ def _unpack_input(validated_model: BaseModel) -> Any:
# This logic should be applied recursively to nested models.
return {
fieldname: _unpack_input(getattr(model, fieldname))
for fieldname in model.__fields__.keys()
for fieldname in model.model_fields.keys()
}
return model
@@ -304,7 +309,7 @@ def _rename_pydantic_model(model: Type[BaseModel], prefix: str) -> Type[BaseMode
"""Rename the given pydantic model to the given name."""
return create_model(
prefix + model.__name__,
__config__=model.__config__,
__config__=model.model_config,
**{
fieldname: (
_rename_pydantic_model(field.annotation, prefix)
@@ -313,10 +318,10 @@ def _rename_pydantic_model(model: Type[BaseModel], prefix: str) -> Type[BaseMode
Field(
field.default,
title=fieldname,
description=field.field_info.description,
description=field.description,
),
)
for fieldname, field in model.__fields__.items()
for fieldname, field in model.model_fields.items()
},
)
@@ -326,6 +331,11 @@ def _replace_non_alphanumeric_with_underscores(s: str) -> str:
return re.sub(r"[^a-zA-Z0-9]", "_", s)
def _schema_json(model: Type[BaseModel]) -> str:
"""Return the JSON representation of the model schema."""
return json.dumps(model.model_json_schema(), sort_keys=True, indent=False)
def _resolve_model(
type_: Union[Type, BaseModel], default_name: str, namespace: str
) -> Type[BaseModel]:
@@ -333,15 +343,15 @@ def _resolve_model(
if isclass(type_) and issubclass(type_, BaseModel):
model = type_
else:
model = create_model(default_name, __root__=(type_, ...))
model = _create_root_model(default_name, type_)
hash_ = model.schema_json()
hash_ = _schema_json(model)
if model.__name__ in _SEEN_NAMES and hash_ not in _MODEL_REGISTRY:
# If the model name has been seen before, but the model itself is different
# generate a new name for the model.
model_to_use = _rename_pydantic_model(model, namespace)
hash_ = model_to_use.schema_json()
hash_ = _schema_json(model_to_use)
else:
model_to_use = model
@@ -366,11 +376,7 @@ def _add_namespace_to_model(namespace: str, model: Type[BaseModel]) -> Type[Base
A new model with name prepended with the given namespace.
"""
model_with_unique_name = _rename_pydantic_model(model, namespace)
if "run_id" in model_with_unique_name.__annotations__:
# Help resolve reference by providing namespace references
model_with_unique_name.update_forward_refs(uuid=uuid)
else:
model_with_unique_name.update_forward_refs()
model_with_unique_name.model_rebuild()
return model_with_unique_name
@@ -403,7 +409,7 @@ def _with_validation_error_translation() -> Generator[None, None, None]:
try:
yield
except ValidationError as e:
raise RequestValidationError(e.errors(), body=e.model)
raise RequestValidationError(e.errors())
def _json_encode_response(model: BaseModel) -> JSONResponse:
@@ -423,27 +429,27 @@ def _json_encode_response(model: BaseModel) -> JSONResponse:
if isinstance(model, InvokeBaseResponse):
# Invoke Response
# Collapse '__root__' from output field if it exists. This is done
# Collapse 'root' from output field if it exists. This is done
# automatically by fastapi when annotating request and response with
# We need to do this manually since we're using vanilla JSONResponse
if isinstance(obj["output"], dict) and "__root__" in obj["output"]:
obj["output"] = obj["output"]["__root__"]
if isinstance(obj["output"], dict) and "root" in obj["output"]:
obj["output"] = obj["output"]["root"]
if "callback_events" in obj:
for idx, callback_event in enumerate(obj["callback_events"]):
if isinstance(callback_event, dict) and "__root__" in callback_event:
obj["callback_events"][idx] = callback_event["__root__"]
if isinstance(callback_event, dict) and "root" in callback_event:
obj["callback_events"][idx] = callback_event["root"]
elif isinstance(model, BatchBaseResponse):
if not isinstance(obj["output"], list):
raise AssertionError("Expected output to be a list")
# Collapse '__root__' from output field if it exists. This is done
# Collapse 'root' from output field if it exists. This is done
# automatically by fastapi when annotating request and response with
# We need to do this manually since we're using vanilla JSONResponse
outputs = obj["output"]
for idx, output in enumerate(outputs):
if isinstance(output, dict) and "__root__" in output:
outputs[idx] = output["__root__"]
if isinstance(output, dict) and "root" in output:
outputs[idx] = output["root"]
if "callback_events" in obj:
if not isinstance(obj["callback_events"], list):
@@ -451,11 +457,8 @@ def _json_encode_response(model: BaseModel) -> JSONResponse:
for callback_events in obj["callback_events"]:
for idx, callback_event in enumerate(callback_events):
if (
isinstance(callback_event, dict)
and "__root__" in callback_event
):
callback_events[idx] = callback_event["__root__"]
if isinstance(callback_event, dict) and "root" in callback_event:
callback_events[idx] = callback_event["root"]
else:
raise AssertionError(
f"Expected a InvokeBaseResponse or BatchBaseResponse got: {type(model)}"
@@ -470,7 +473,12 @@ def _add_callbacks(
"""Add the callback aggregator to the config."""
if "callbacks" not in config:
config["callbacks"] = []
config["callbacks"].extend(callbacks)
if "callbacks" in config:
if isinstance(config["callbacks"], list):
config["callbacks"].extend(callbacks)
elif isinstance(config["callbacks"], BaseCallbackManager):
for callback in callbacks:
config["callbacks"].add_handler(callback, inherit=True)
_MODEL_REGISTRY = {}
@@ -527,6 +535,8 @@ class APIHandler:
per_req_config_modifier: Optional[PerRequestConfigModifier] = None,
stream_log_name_allow_list: Optional[Sequence[str]] = None,
playground_type: Literal["default", "chat"] = "default",
astream_events_version: Literal["v1", "v2"] = "v2",
serializer: Optional[Serializer] = None,
) -> None:
"""Create an API handler for the given runnable.
@@ -561,6 +571,7 @@ class APIHandler:
If true, the client will be able to show trace information
including events that occurred on the server side.
Be sure not to include any sensitive information in the callback events.
This is a **beta** API.
enable_feedback_endpoint: Whether to enable an endpoint for logging feedback
to LangSmith. Disabled by default. If this flag is disabled or LangSmith
tracing is not enabled for the runnable, then 4xx errors will be thrown
@@ -588,6 +599,10 @@ class APIHandler:
If not provided, then all logs will be allowed to be streamed.
Use to also limit the events that can be streamed by the stream_events.
TODO: Introduce deprecation for this parameter to rename it
astream_events_version: version of the stream events endpoint to use.
By default "v2".
serializer: optional serializer to use for serializing the output.
If not provided, the default serializer will be used.
"""
if importlib.util.find_spec("sse_starlette") is None:
raise ImportError(
@@ -619,12 +634,18 @@ class APIHandler:
# and when tracing information is logged, we'll be able to see
# traces for the path /foo/bar.
self._run_name = self._base_url
if include_callback_events:
warn_beta(
message="Including callback events in the response is in beta. "
"This API may change in the future."
)
self._include_callback_events = include_callback_events
self._per_req_config_modifier = per_req_config_modifier
self._serializer = WellKnownLCSerializer()
self._serializer = serializer or WellKnownLCSerializer()
self._enable_feedback_endpoint = enable_feedback_endpoint
self._enable_public_trace_link_endpoint = enable_public_trace_link_endpoint
self._names_in_stream_allow_list = stream_log_name_allow_list
self._astream_events_version = astream_events_version
if token_feedback_config:
if len(token_feedback_config["key_configs"]) != 1:
@@ -664,15 +685,57 @@ class APIHandler:
model_namespace = _replace_non_alphanumeric_with_underscores(path.strip("/"))
input_type_ = _resolve_model(
runnable.get_input_schema(), "Input", model_namespace
)
try:
input_type_ = _resolve_model(
runnable.get_input_schema(), "Input", model_namespace
)
except Exception as e:
# Attempt to surface a more informative user facing error
raise_original_error = True
try:
if isinstance(runnable.get_input_schema(), BaseModelV1):
raise_original_error = False
raise ValueError(
"Found an input type which is a pydantic v1 model."
"Please use pydantic.BaseModel rather than "
"pydantic.v1.BaseModel."
)
finally: # noqa
if raise_original_error:
print(
"Encountered an error while resolving the inputs of "
"the Runnable. Try specifying the input type explicitly "
"using the `with_types` method on the runnable.\n"
"See https://api.python.langchain.com/en/latest/runnables/langchain_core.runnables.base.Runnable.html " # noqa: E501
)
raise e
output_type_ = _resolve_model(
runnable.get_output_schema(),
"Output",
model_namespace,
)
try:
output_type_ = _resolve_model(
runnable.get_output_schema(),
"Output",
model_namespace,
)
except Exception as e:
# Attempt to surface a more informative user facing error
raise_original_error = True
try:
if isinstance(runnable.get_output_schema(), BaseModelV1):
raise_original_error = False
raise ValueError(
"Found an output type which is a pydantic v1 model."
"Please use pydantic.BaseModel rather than "
"pydantic.v1.BaseModel."
)
finally: # noqa
if raise_original_error:
print(
"Encountered an error while resolving the inputs of "
"the Runnable. Try specifying the output type explicitly "
"using the `with_types` method on the runnable.\n"
"See https://api.python.langchain.com/en/latest/runnables/langchain_core.runnables.base.Runnable.html " # noqa: E501
)
raise e
self._ConfigPayload = _add_namespace_to_model(
model_namespace, runnable.config_schema(include=config_keys)
@@ -753,7 +816,7 @@ class APIHandler:
except json.JSONDecodeError:
raise RequestValidationError(errors=["Invalid JSON body"])
try:
body = InvokeRequestShallowValidator.validate(body)
body = InvokeRequestShallowValidator.model_validate(body)
# Merge the config from the path with the config from the body.
user_provided_config = await _unpack_request_config(
@@ -775,7 +838,7 @@ class APIHandler:
# This takes into account changes in the input type when
# using configuration.
schema = self._runnable.with_config(config).input_schema
input_ = schema.validate(body.input)
input_ = schema.model_validate(body.input)
return config, _unpack_input(input_)
except ValidationError as e:
raise RequestValidationError(e.errors(), body=body)
@@ -885,7 +948,7 @@ class APIHandler:
raise RequestValidationError(errors=["Invalid JSON body"])
with _with_validation_error_translation():
body = BatchRequestShallowValidator.validate(body)
body = BatchRequestShallowValidator.model_validate(body)
config = body.config
# First unpack the config
@@ -936,7 +999,7 @@ class APIHandler:
inputs = [
_unpack_input(
self._runnable.with_config(config_).input_schema.validate(input_)
self._runnable.with_config(config_).input_schema.model_validate(input_)
)
for config_, input_ in zip(configs_, inputs_)
]
@@ -1336,7 +1399,7 @@ class APIHandler:
exclude_names=stream_events_request.exclude_names,
exclude_types=stream_events_request.exclude_types,
exclude_tags=stream_events_request.exclude_tags,
version="v1",
version=self._astream_events_version,
):
if (
self._names_in_stream_allow_list is None
@@ -1405,7 +1468,7 @@ class APIHandler:
self._run_name, user_provided_config, request
)
return self._runnable.get_input_schema(config).schema()
return self._runnable.get_input_schema(config).model_json_schema()
async def output_schema(
self,
@@ -1432,7 +1495,7 @@ class APIHandler:
config = _update_config_with_defaults(
self._run_name, user_provided_config, request
)
return self._runnable.get_output_schema(config).schema()
return self._runnable.get_output_schema(config).model_json_schema()
async def config_schema(
self,
@@ -1462,7 +1525,7 @@ class APIHandler:
return (
self._runnable.with_config(config)
.config_schema(include=self._config_keys)
.schema()
.model_json_schema()
)
async def playground(
@@ -1590,6 +1653,7 @@ class APIHandler:
score=create_request.score,
value=create_request.value,
comment=create_request.comment,
correction=create_request.correction,
metadata=metadata,
)
+21 -7
View File
@@ -48,7 +48,7 @@ class AsyncEventAggregatorCallback(AsyncCallbackHandler):
async def on_chat_model_start(
self,
serialized: Dict[str, Any],
serialized: Optional[Dict[str, Any]],
messages: List[List[BaseMessage]],
*,
run_id: UUID,
@@ -73,7 +73,7 @@ class AsyncEventAggregatorCallback(AsyncCallbackHandler):
async def on_chain_start(
self,
serialized: Dict[str, Any],
serialized: Optional[Dict[str, Any]],
inputs: Dict[str, Any],
*,
run_id: UUID,
@@ -138,7 +138,7 @@ class AsyncEventAggregatorCallback(AsyncCallbackHandler):
async def on_retriever_start(
self,
serialized: Dict[str, Any],
serialized: Optional[Dict[str, Any]],
query: str,
*,
run_id: UUID,
@@ -202,7 +202,7 @@ class AsyncEventAggregatorCallback(AsyncCallbackHandler):
async def on_tool_start(
self,
serialized: Dict[str, Any],
serialized: Optional[Dict[str, Any]],
input_str: str,
*,
run_id: UUID,
@@ -306,7 +306,7 @@ class AsyncEventAggregatorCallback(AsyncCallbackHandler):
async def on_llm_start(
self,
serialized: Dict[str, Any],
serialized: Optional[Dict[str, Any]],
prompts: List[str],
*,
run_id: UUID,
@@ -445,7 +445,14 @@ async def ahandle_callbacks(
if event["parent_run_id"] is None: # How do we make sure it's None!?
event["parent_run_id"] = callback_manager.run_id
event_data = {key: value for key, value in event.items() if key != "type"}
event_data = {
key: value
for key, value in event.items()
if key != "type" and key != "kwargs"
}
if "kwargs" in event:
event_data.update(event["kwargs"])
await ahandle_event(
# Unpacking like this may not work
@@ -467,7 +474,14 @@ def handle_callbacks(
if event["parent_run_id"] is None: # How do we make sure it's None!?
event["parent_run_id"] = callback_manager.run_id
event_data = {key: value for key, value in event.items() if key != "type"}
event_data = {
key: value
for key, value in event.items()
if key != "type" and key != "kwargs"
}
if "kwargs" in event:
event_data.update(event["kwargs"])
handle_event(
# Unpacking like this may not work
File diff suppressed because one or more lines are too long
+1 -1
View File
@@ -5,7 +5,7 @@
<link rel="icon" href="/____LANGSERVE_BASE_URL/favicon.ico" />
<meta name="viewport" content="width=device-width, initial-scale=1.0" />
<title>Chat Playground</title>
<script type="module" crossorigin src="/____LANGSERVE_BASE_URL/assets/index-86d4d9c0.js"></script>
<script type="module" crossorigin src="/____LANGSERVE_BASE_URL/assets/index-53ad47d4.js"></script>
<link rel="stylesheet" href="/____LANGSERVE_BASE_URL/assets/index-434ff580.css">
</head>
<body>
+6 -1
View File
@@ -46,7 +46,12 @@
"postcss": "^8.4.31",
"tailwindcss": "^3.3.3",
"typescript": "^5.0.2",
"vite": "^4.4.5",
"vite": "^4.5.2",
"vite-plugin-svgr": "^4.1.0"
},
"resolutions": {
"braces": "^3.0.3",
"cross-spawn": "^7.0.5",
"rollup": "^3.29.5"
}
}
+1
View File
@@ -30,6 +30,7 @@ export function App() {
);
const outputSchemaSupported = (
outputDataSchema?.anyOf?.find((option) => option.properties?.type?.enum?.includes("ai")) ||
outputDataSchema?.oneOf?.find((option) => option.properties?.type?.enum?.includes("ai")) ||
outputDataSchema?.type === "string"
);
const isSupported = isLoading || (inputSchemaSupported && outputSchemaSupported);
+21 -21
View File
@@ -1311,12 +1311,12 @@ brace-expansion@^1.1.7:
balanced-match "^1.0.0"
concat-map "0.0.1"
braces@^3.0.2, braces@~3.0.2:
version "3.0.2"
resolved "https://registry.yarnpkg.com/braces/-/braces-3.0.2.tgz#3454e1a462ee8d599e236df336cd9ea4f8afe107"
integrity sha512-b8um+L1RzM3WDSzvhm6gIz1yfTbBt6YTlcEKAvsmqCZZFw46z626lVj9j1yEPW33H5H+lBQpZMP1k8l+78Ha0A==
braces@^3.0.2, braces@^3.0.3, braces@~3.0.2:
version "3.0.3"
resolved "https://registry.yarnpkg.com/braces/-/braces-3.0.3.tgz#490332f40919452272d55a8480adc0c441358789"
integrity sha512-yQbXgO/OSZVD2IsiLlro+7Hf6Q18EJrKSEsdoMzKePKXct3gvD8oLcOQdIzGupr5Fj+EDe8gO/lxc1BzfMpxvA==
dependencies:
fill-range "^7.0.1"
fill-range "^7.1.1"
browserslist@^4.21.10, browserslist@^4.21.9:
version "4.22.1"
@@ -1460,10 +1460,10 @@ cosmiconfig@^8.1.3:
parse-json "^5.2.0"
path-type "^4.0.0"
cross-spawn@^7.0.2:
version "7.0.3"
resolved "https://registry.yarnpkg.com/cross-spawn/-/cross-spawn-7.0.3.tgz#f73a85b9d5d41d045551c177e2882d4ac85728a6"
integrity sha512-iRDPJKUPVEND7dHPO8rkbOnPpyDygcDFtWjpeWNCgy8WP2rXcxXL8TskReQl6OrB2G7+UJrags1q15Fudc7G6w==
cross-spawn@^7.0.2, cross-spawn@^7.0.5:
version "7.0.6"
resolved "https://registry.yarnpkg.com/cross-spawn/-/cross-spawn-7.0.6.tgz#8a58fe78f00dcd70c370451759dfbfaf03e8ee9f"
integrity sha512-uV2QOWP2nWzsy2aMp8aRibhi9dlzF5Hgh5SHaB9OiTGEyDTiJJyx0uy51QXdyWbtAHNua4XJzUKca3OzKUd3vA==
dependencies:
path-key "^3.1.0"
shebang-command "^2.0.0"
@@ -1750,10 +1750,10 @@ file-entry-cache@^6.0.1:
dependencies:
flat-cache "^3.0.4"
fill-range@^7.0.1:
version "7.0.1"
resolved "https://registry.yarnpkg.com/fill-range/-/fill-range-7.0.1.tgz#1919a6a7c75fe38b2c7c77e5198535da9acdda40"
integrity sha512-qOo9F+dMUmC2Lcb4BbVvnKJxTPjCm+RRpe4gDuGrzkL7mEVl/djYSu2OdQ2Pa302N4oqkSg9ir6jaLWJ2USVpQ==
fill-range@^7.1.1:
version "7.1.1"
resolved "https://registry.yarnpkg.com/fill-range/-/fill-range-7.1.1.tgz#44265d3cac07e3ea7dc247516380643754a05292"
integrity sha512-YsGpe3WHLK8ZYi4tWDg2Jy3ebRz2rXowDxnld4bkQB00cc/1Zw9AWnC0i9ztDJitivtQvaI9KaLyKrc+hBW0yg==
dependencies:
to-regex-range "^5.0.1"
@@ -2483,10 +2483,10 @@ rimraf@^3.0.2:
dependencies:
glob "^7.1.3"
rollup@^3.27.1:
version "3.29.4"
resolved "https://registry.yarnpkg.com/rollup/-/rollup-3.29.4.tgz#4d70c0f9834146df8705bfb69a9a19c9e1109981"
integrity sha512-oWzmBZwvYrU0iJHtDmhsm662rC15FRXmcjCk1xD771dFDx5jJ02ufAQQTn0etB2emNk4J9EZg/yWKpsn9BWGRw==
rollup@^3.27.1, rollup@^3.29.5:
version "3.29.5"
resolved "https://registry.yarnpkg.com/rollup/-/rollup-3.29.5.tgz#8a2e477a758b520fb78daf04bca4c522c1da8a54"
integrity sha512-GVsDdsbJzzy4S/v3dqWPJ7EfvZJfCHiDqe80IyrF59LYuP+e6U1LJoUqeuqRbwAWoMNoXivMNeNAOf5E22VA1w==
optionalDependencies:
fsevents "~2.3.2"
@@ -2770,10 +2770,10 @@ vite-plugin-svgr@^4.1.0:
"@svgr/core" "^8.1.0"
"@svgr/plugin-jsx" "^8.1.0"
vite@^4.4.5:
version "4.5.0"
resolved "https://registry.yarnpkg.com/vite/-/vite-4.5.0.tgz#ec406295b4167ac3bc23e26f9c8ff559287cff26"
integrity sha512-ulr8rNLA6rkyFAlVWw2q5YJ91v098AFQ2R0PRFwPzREXOUJQPtFUG0t+/ZikhaOCDqFoDhN6/v8Sq0o4araFAw==
vite@^4.5.2:
version "4.5.14"
resolved "https://registry.yarnpkg.com/vite/-/vite-4.5.14.tgz#2e652bc1d898265d987d6543ce866ecd65fa4086"
integrity sha512-+v57oAaoYNnO3hIu5Z/tJRZjq5aHM2zDve9YZ8HngVHbhk66RStobhb1sqPMIPEleV6cNKYK4eGrAbE9Ulbl2g==
dependencies:
esbuild "^0.18.10"
postcss "^8.4.27"
+34 -10
View File
@@ -8,6 +8,7 @@ import weakref
from concurrent.futures import ThreadPoolExecutor
from functools import lru_cache
from typing import (
TYPE_CHECKING,
Any,
AsyncIterator,
Dict,
@@ -21,7 +22,7 @@ from typing import (
from urllib.parse import urljoin
import httpx
from httpx._types import AuthTypes, CertTypes, CookieTypes, HeaderTypes, VerifyTypes
from httpx._types import AuthTypes, CertTypes, CookieTypes, HeaderTypes
from langchain_core.callbacks import (
AsyncCallbackManagerForChainRun,
CallbackManagerForChainRun,
@@ -49,6 +50,10 @@ from langserve.server_sent_events import aconnect_sse, connect_sse
logger = logging.getLogger(__name__)
if TYPE_CHECKING:
# For type checking httpx types
import ssl
def _is_json_serializable(obj: Any) -> bool:
"""Return True if the object is json serializable."""
@@ -120,6 +125,12 @@ def _log_error_message_once(error_message: str) -> None:
logger.error(error_message)
@lru_cache(maxsize=1_000) # Will accommodate up to 1_000 different error messages
def _log_info_message_once(error_message: str) -> None:
"""Log an error message once."""
logger.info(error_message)
def _sanitize_request(request: httpx.Request) -> httpx.Request:
"""Remove sensitive headers from the request."""
accept_headers = {
@@ -275,10 +286,11 @@ class RemoteRunnable(Runnable[Input, Output]):
auth: Optional[AuthTypes] = None,
headers: Optional[HeaderTypes] = None,
cookies: Optional[CookieTypes] = None,
verify: VerifyTypes = True,
verify: ssl.SSLContext | str | bool = True,
cert: Optional[CertTypes] = None,
client_kwargs: Optional[Dict[str, Any]] = None,
use_server_callback_events: bool = True,
serializer: Optional[Serializer] = None,
) -> None:
"""Initialize the client.
@@ -294,6 +306,8 @@ class RemoteRunnable(Runnable[Input, Output]):
and async httpx clients
use_server_callback_events: Whether to invoke callbacks on any
callback events returned by the server.
serializer: The serializer to use for serializing and deserializing
data. If not provided, a default serializer will be used.
"""
_client_kwargs = client_kwargs or {}
# Enforce trailing slash
@@ -321,7 +335,7 @@ class RemoteRunnable(Runnable[Input, Output]):
# Register cleanup handler once RemoteRunnable is garbage collected
weakref.finalize(self, _close_clients, self.sync_client, self.async_client)
self._lc_serializer = WellKnownLCSerializer()
self._lc_serializer = serializer or WellKnownLCSerializer()
self._use_server_callback_events = use_server_callback_events
def _invoke(
@@ -431,11 +445,15 @@ class RemoteRunnable(Runnable[Input, Output]):
self,
inputs: List[Input],
config: Optional[RunnableConfig] = None,
*,
return_exceptions: bool = False,
**kwargs: Any,
) -> List[Output]:
if kwargs:
raise NotImplementedError("kwargs not implemented yet.")
return self._batch_with_config(self._batch, inputs, config)
raise NotImplementedError(f"kwargs not implemented yet. Got {kwargs}")
return self._batch_with_config(
self._batch, inputs, config, return_exceptions=return_exceptions
)
async def _abatch(
self,
@@ -748,7 +766,7 @@ class RemoteRunnable(Runnable[Input, Output]):
input: Any,
config: Optional[RunnableConfig] = None,
*,
version: Literal["v1"],
version: Literal["v1", "v2", None] = None,
include_names: Optional[Sequence[str]] = None,
include_types: Optional[Sequence[str]] = None,
include_tags: Optional[Sequence[str]] = None,
@@ -771,7 +789,8 @@ class RemoteRunnable(Runnable[Input, Output]):
input: The input to the runnable
config: The config to use for the runnable
version: The version of the astream_events to use.
Currently only "v1" is supported.
Currently, this input is IGNORED on the client.
The server will return whatever format it's configured with.
include_names: The names of the events to include
include_types: The types of the events to include
include_tags: The tags of the events to include
@@ -779,13 +798,18 @@ class RemoteRunnable(Runnable[Input, Output]):
exclude_types: The types of the events to exclude
exclude_tags: The tags of the events to exclude
"""
if version != "v1":
raise ValueError(f"Unsupported version: {version}. Use 'v1'")
# Create a stream handler that will emit Log objects
config = ensure_config(config)
callback_manager = get_async_callback_manager_for_config(config)
if version is not None:
_log_info_message_once(
"Versioning of the astream_events API is not supported on the client "
"side currently. The server will return events in whatever format "
"it was configured with in add_routes or APIHandler. "
"To stop seeing this message, remove the `version` argument."
)
events = []
run_manager = await callback_manager.on_chain_start(
+6 -5
View File
@@ -6,8 +6,7 @@ from typing import Literal, Sequence, Type
from fastapi.responses import Response
from langchain_core.runnables import Runnable
from langserve.pydantic_v1 import BaseModel
from pydantic import BaseModel
class PlaygroundTemplate(Template):
@@ -90,10 +89,12 @@ async def serve_playground(
if base_url.startswith("/")
else base_url,
LANGSERVE_CONFIG_SCHEMA=json.dumps(
runnable.config_schema(include=config_keys).schema()
runnable.config_schema(include=config_keys).model_json_schema()
),
LANGSERVE_INPUT_SCHEMA=json.dumps(input_schema.model_json_schema()),
LANGSERVE_OUTPUT_SCHEMA=json.dumps(
output_schema.model_json_schema()
),
LANGSERVE_INPUT_SCHEMA=json.dumps(input_schema.schema()),
LANGSERVE_OUTPUT_SCHEMA=json.dumps(output_schema.schema()),
LANGSERVE_FEEDBACK_ENABLED=json.dumps(
"true" if feedback_enabled else "false"
),
File diff suppressed because one or more lines are too long
+1 -1
View File
@@ -5,7 +5,7 @@
<link rel="icon" href="/____LANGSERVE_BASE_URL/favicon.ico" />
<meta name="viewport" content="width=device-width, initial-scale=1.0" />
<title>Playground</title>
<script type="module" crossorigin src="/____LANGSERVE_BASE_URL/assets/index-dbc96538.js"></script>
<script type="module" crossorigin src="/____LANGSERVE_BASE_URL/assets/index-400979f0.js"></script>
<link rel="stylesheet" href="/____LANGSERVE_BASE_URL/assets/index-52e8ab2f.css">
</head>
<body>
+6 -1
View File
@@ -48,7 +48,12 @@
"postcss": "^8.4.31",
"tailwindcss": "^3.3.3",
"typescript": "^5.0.2",
"vite": "^4.4.5",
"vite": "^4.5.2",
"vite-plugin-svgr": "^4.1.0"
},
"resolutions": {
"braces": "^3.0.3",
"cross-spawn": "^7.0.5",
"rollup": "^3.29.5"
}
}
@@ -6,12 +6,30 @@ import {
schemaMatches,
Paths,
isControl,
JsonSchema,
} from "@jsonforms/core";
import { useStreamCallback } from "../useStreamCallback";
import { isJsonSchemaExtra } from "../utils/schema";
import { MessageFields, ChatMessageInput } from "./ChatMessageInput";
import { useEffect } from "react";
function checkItemSchema(schema: JsonSchema) {
const isObjectMessage =
schema.type === "object" &&
(schema.title?.endsWith("Message") ||
schema.title?.endsWith("MessageChunk"));
const isTupleMessage =
schema.type === "array" &&
schema.minItems === 2 &&
schema.maxItems === 2 &&
Array.isArray(schema.items) &&
schema.items.length === 2 &&
schema.items.every((schema) => schema.type === "string");
return isObjectMessage || isTupleMessage;
}
export const chatMessagesTester = rankWith(
12,
and(
@@ -34,22 +52,11 @@ export const chatMessagesTester = rankWith(
}
if ("anyOf" in schema.items && schema.items.anyOf != null) {
return schema.items.anyOf.every((schema) => {
const isObjectMessage =
schema.type === "object" &&
(schema.title?.endsWith("Message") ||
schema.title?.endsWith("MessageChunk"));
return schema.items.anyOf.every(checkItemSchema);
}
const isTupleMessage =
schema.type === "array" &&
schema.minItems === 2 &&
schema.maxItems === 2 &&
Array.isArray(schema.items) &&
schema.items.length === 2 &&
schema.items.every((schema) => schema.type === "string");
return isObjectMessage || isTupleMessage;
});
if ("oneOf" in schema.items && schema.items.oneOf != null) {
return schema.items.oneOf.every(checkItemSchema);
}
return false;
@@ -64,10 +71,14 @@ export const ChatMessagesControlRenderer = withJsonFormsControlProps(
useEffect(() => {
if (!isJsonSchemaExtra(props.schema)) return;
if (props.schema.extra.widget.type !== "chat") return;
setTimeout(() => props.handleChange(props.path, [
...data,
{ content: "", type: "human" },
]), 10);
setTimeout(
() =>
props.handleChange(props.path, [
...data,
{ content: "", type: "human" },
]),
10
);
}, []);
useStreamCallback("onStart", () => {
@@ -81,7 +92,10 @@ export const ChatMessagesControlRenderer = withJsonFormsControlProps(
if (props.schema.extra.widget.type !== "chat") return;
if (aggregatedState?.final_output !== undefined) {
const msgPath = Paths.compose(props.path, `${data.length - 1}`);
if ((aggregatedState.final_output as MessageFields)?.type === "AIMessageChunk") {
if (
(aggregatedState.final_output as MessageFields)?.type ===
"AIMessageChunk"
) {
props.handleChange(
Paths.compose(msgPath, "content"),
(aggregatedState.final_output as MessageFields)?.content
@@ -140,7 +154,7 @@ export const ChatMessagesControlRenderer = withJsonFormsControlProps(
props.path,
data.filter((_, i) => i !== index)
);
}
};
return (
<ChatMessageInput
message={message}
@@ -148,7 +162,7 @@ export const ChatMessagesControlRenderer = withJsonFormsControlProps(
handleRemoval={handleChatMessageRemoval}
path={props.path}
key={index}
></ChatMessageInput>
></ChatMessageInput>
);
})}
</div>
+21 -21
View File
@@ -1338,12 +1338,12 @@ brace-expansion@^1.1.7:
balanced-match "^1.0.0"
concat-map "0.0.1"
braces@^3.0.2, braces@~3.0.2:
version "3.0.2"
resolved "https://registry.yarnpkg.com/braces/-/braces-3.0.2.tgz#3454e1a462ee8d599e236df336cd9ea4f8afe107"
integrity sha512-b8um+L1RzM3WDSzvhm6gIz1yfTbBt6YTlcEKAvsmqCZZFw46z626lVj9j1yEPW33H5H+lBQpZMP1k8l+78Ha0A==
braces@^3.0.2, braces@^3.0.3, braces@~3.0.2:
version "3.0.3"
resolved "https://registry.yarnpkg.com/braces/-/braces-3.0.3.tgz#490332f40919452272d55a8480adc0c441358789"
integrity sha512-yQbXgO/OSZVD2IsiLlro+7Hf6Q18EJrKSEsdoMzKePKXct3gvD8oLcOQdIzGupr5Fj+EDe8gO/lxc1BzfMpxvA==
dependencies:
fill-range "^7.0.1"
fill-range "^7.1.1"
browserslist@^4.21.10, browserslist@^4.21.9:
version "4.22.1"
@@ -1482,10 +1482,10 @@ cosmiconfig@^8.1.3:
parse-json "^5.2.0"
path-type "^4.0.0"
cross-spawn@^7.0.2:
version "7.0.3"
resolved "https://registry.yarnpkg.com/cross-spawn/-/cross-spawn-7.0.3.tgz#f73a85b9d5d41d045551c177e2882d4ac85728a6"
integrity sha512-iRDPJKUPVEND7dHPO8rkbOnPpyDygcDFtWjpeWNCgy8WP2rXcxXL8TskReQl6OrB2G7+UJrags1q15Fudc7G6w==
cross-spawn@^7.0.2, cross-spawn@^7.0.5:
version "7.0.6"
resolved "https://registry.yarnpkg.com/cross-spawn/-/cross-spawn-7.0.6.tgz#8a58fe78f00dcd70c370451759dfbfaf03e8ee9f"
integrity sha512-uV2QOWP2nWzsy2aMp8aRibhi9dlzF5Hgh5SHaB9OiTGEyDTiJJyx0uy51QXdyWbtAHNua4XJzUKca3OzKUd3vA==
dependencies:
path-key "^3.1.0"
shebang-command "^2.0.0"
@@ -1777,10 +1777,10 @@ file-entry-cache@^6.0.1:
dependencies:
flat-cache "^3.0.4"
fill-range@^7.0.1:
version "7.0.1"
resolved "https://registry.yarnpkg.com/fill-range/-/fill-range-7.0.1.tgz#1919a6a7c75fe38b2c7c77e5198535da9acdda40"
integrity sha512-qOo9F+dMUmC2Lcb4BbVvnKJxTPjCm+RRpe4gDuGrzkL7mEVl/djYSu2OdQ2Pa302N4oqkSg9ir6jaLWJ2USVpQ==
fill-range@^7.1.1:
version "7.1.1"
resolved "https://registry.yarnpkg.com/fill-range/-/fill-range-7.1.1.tgz#44265d3cac07e3ea7dc247516380643754a05292"
integrity sha512-YsGpe3WHLK8ZYi4tWDg2Jy3ebRz2rXowDxnld4bkQB00cc/1Zw9AWnC0i9ztDJitivtQvaI9KaLyKrc+hBW0yg==
dependencies:
to-regex-range "^5.0.1"
@@ -2503,10 +2503,10 @@ rimraf@^3.0.2:
dependencies:
glob "^7.1.3"
rollup@^3.27.1:
version "3.29.4"
resolved "https://registry.yarnpkg.com/rollup/-/rollup-3.29.4.tgz#4d70c0f9834146df8705bfb69a9a19c9e1109981"
integrity sha512-oWzmBZwvYrU0iJHtDmhsm662rC15FRXmcjCk1xD771dFDx5jJ02ufAQQTn0etB2emNk4J9EZg/yWKpsn9BWGRw==
rollup@^3.27.1, rollup@^3.29.5:
version "3.29.5"
resolved "https://registry.yarnpkg.com/rollup/-/rollup-3.29.5.tgz#8a2e477a758b520fb78daf04bca4c522c1da8a54"
integrity sha512-GVsDdsbJzzy4S/v3dqWPJ7EfvZJfCHiDqe80IyrF59LYuP+e6U1LJoUqeuqRbwAWoMNoXivMNeNAOf5E22VA1w==
optionalDependencies:
fsevents "~2.3.2"
@@ -2790,10 +2790,10 @@ vite-plugin-svgr@^4.1.0:
"@svgr/core" "^8.1.0"
"@svgr/plugin-jsx" "^8.1.0"
vite@^4.4.5:
version "4.5.0"
resolved "https://registry.yarnpkg.com/vite/-/vite-4.5.0.tgz#ec406295b4167ac3bc23e26f9c8ff559287cff26"
integrity sha512-ulr8rNLA6rkyFAlVWw2q5YJ91v098AFQ2R0PRFwPzREXOUJQPtFUG0t+/ZikhaOCDqFoDhN6/v8Sq0o4araFAw==
vite@^4.5.2:
version "4.5.14"
resolved "https://registry.yarnpkg.com/vite/-/vite-4.5.14.tgz#2e652bc1d898265d987d6543ce866ecd65fa4086"
integrity sha512-+v57oAaoYNnO3hIu5Z/tJRZjq5aHM2zDve9YZ8HngVHbhk66RStobhb1sqPMIPEleV6cNKYK4eGrAbE9Ulbl2g==
dependencies:
esbuild "^0.18.10"
postcss "^8.4.27"
-33
View File
@@ -1,33 +0,0 @@
from importlib import metadata
## Create namespaces for pydantic v1 and v2.
# This code must stay at the top of the file before other modules may
# attempt to import pydantic since it adds pydantic_v1 and pydantic_v2 to sys.modules.
#
# This hack is done for the following reasons:
# * Langchain will attempt to remain compatible with both pydantic v1 and v2 since
# both dependencies and dependents may be stuck on either version of v1 or v2.
# * Creating namespaces for pydantic v1 and v2 should allow us to write code that
# unambiguously uses either v1 or v2 API.
# * This change is easier to roll out and roll back.
try:
# F401: imported but unused
from pydantic.v1 import ( # noqa: F401
BaseModel,
Field,
ValidationError,
create_model,
)
except ImportError:
from pydantic import BaseModel, Field, ValidationError, create_model # noqa: F401
# This is not a pydantic v1 thing, but it feels too small to create a new module for.
PYDANTIC_VERSION = metadata.version("pydantic")
try:
_PYDANTIC_MAJOR_VERSION: int = int(PYDANTIC_VERSION.split(".")[0])
except metadata.PackageNotFoundError:
_PYDANTIC_MAJOR_VERSION = -1
+5 -4
View File
@@ -2,10 +2,11 @@ from datetime import datetime
from typing import Dict, List, Optional, Union
from uuid import UUID
from pydantic import BaseModel # Floats between v1 and v2
from langserve.pydantic_v1 import BaseModel as BaseModelV1
from langserve.pydantic_v1 import Field
from pydantic import (
BaseModel,
Field,
)
from pydantic import BaseModel as BaseModelV1
class CustomUserType(BaseModelV1):
+94 -68
View File
@@ -13,7 +13,7 @@ sensitive information from the server to the client.
import abc
import logging
from functools import lru_cache
from typing import Any, Dict, List, Union
from typing import Annotated, Any, Dict, List, Union
import orjson
from langchain_core.agents import AgentAction, AgentActionMessageLog, AgentFinish
@@ -29,6 +29,8 @@ from langchain_core.messages import (
HumanMessageChunk,
SystemMessage,
SystemMessageChunk,
ToolMessage,
ToolMessageChunk,
)
from langchain_core.outputs import (
ChatGeneration,
@@ -38,8 +40,8 @@ from langchain_core.outputs import (
)
from langchain_core.prompt_values import ChatPromptValueConcrete
from langchain_core.prompts.base import StringPromptValue
from pydantic import BaseModel, Discriminator, Field, RootModel, Tag, ValidationError
from langserve.pydantic_v1 import BaseModel, ValidationError
from langserve.validation import CallbackEvent
logger = logging.getLogger(__name__)
@@ -51,42 +53,58 @@ def _log_error_message_once(error_message: str) -> None:
logger.error(error_message)
class WellKnownLCObject(BaseModel):
"""A well known LangChain object.
def _get_type(v: Any) -> str:
"""Get the type associated with the object for serialization purposes."""
if isinstance(v, dict) and "type" in v:
return v["type"]
elif hasattr(v, "type"):
return v.type
else:
raise TypeError(
f"Expected either a dictionary with a 'type' key or an object "
f"with a 'type' attribute. Instead got type {type(v)}."
)
A pydantic model that defines what constitutes a well known LangChain object.
All well-known objects are allowed to be serialized and de-serialized.
"""
# A well known LangChain object.
# A pydantic model that defines what constitutes a well known LangChain object.
# All well-known objects are allowed to be serialized and de-serialized.
__root__: Union[
Document,
HumanMessage,
SystemMessage,
ChatMessage,
FunctionMessage,
AIMessage,
HumanMessageChunk,
SystemMessageChunk,
ChatMessageChunk,
FunctionMessageChunk,
AIMessageChunk,
StringPromptValue,
ChatPromptValueConcrete,
AgentAction,
AgentFinish,
AgentActionMessageLog,
LLMResult,
ChatGeneration,
Generation,
ChatGenerationChunk,
WellKnownLCObject = RootModel[
Annotated[
Union[
Annotated[AIMessage, Tag(tag="ai")],
Annotated[HumanMessage, Tag(tag="human")],
Annotated[ChatMessage, Tag(tag="chat")],
Annotated[SystemMessage, Tag(tag="system")],
Annotated[FunctionMessage, Tag(tag="function")],
Annotated[ToolMessage, Tag(tag="tool")],
Annotated[AIMessageChunk, Tag(tag="AIMessageChunk")],
Annotated[HumanMessageChunk, Tag(tag="HumanMessageChunk")],
Annotated[ChatMessageChunk, Tag(tag="ChatMessageChunk")],
Annotated[SystemMessageChunk, Tag(tag="SystemMessageChunk")],
Annotated[FunctionMessageChunk, Tag(tag="FunctionMessageChunk")],
Annotated[ToolMessageChunk, Tag(tag="ToolMessageChunk")],
Annotated[Document, Tag(tag="Document")],
Annotated[StringPromptValue, Tag(tag="StringPromptValue")],
Annotated[ChatPromptValueConcrete, Tag(tag="ChatPromptValueConcrete")],
Annotated[AgentAction, Tag(tag="AgentAction")],
Annotated[AgentFinish, Tag(tag="AgentFinish")],
Annotated[AgentActionMessageLog, Tag(tag="AgentActionMessageLog")],
Annotated[ChatGeneration, Tag(tag="ChatGeneration")],
Annotated[Generation, Tag(tag="Generation")],
Annotated[ChatGenerationChunk, Tag(tag="ChatGenerationChunk")],
Annotated[LLMResult, Tag(tag="LLMResult")],
],
Field(discriminator=Discriminator(_get_type)),
]
]
def default(obj) -> Any:
"""Default serialization for well known objects."""
if isinstance(obj, BaseModel):
return obj.dict()
return obj.model_dump()
return super().default(obj)
@@ -96,12 +114,10 @@ def _decode_lc_objects(value: Any) -> Any:
v = {key: _decode_lc_objects(v) for key, v in value.items()}
try:
obj = WellKnownLCObject.parse_obj(v)
parsed = obj.__root__
if set(parsed.dict()) != set(value):
raise ValueError("Invalid object")
obj = WellKnownLCObject.model_validate(v)
parsed = obj.root
return parsed
except (ValidationError, ValueError):
except (ValidationError, ValueError, TypeError):
return v
elif isinstance(value, list):
return [_decode_lc_objects(item) for item in value]
@@ -123,12 +139,12 @@ def _decode_event_data(value: Any) -> Any:
"""Decode the event data from a JSON object representation."""
if isinstance(value, dict):
try:
obj = CallbackEvent.parse_obj(value)
return obj.__root__
obj = CallbackEvent.model_validate(value)
return obj.root
except ValidationError:
try:
obj = WellKnownLCObject.parse_obj(value)
return obj.__root__
obj = WellKnownLCObject.model_validate(value)
return obj.root
except ValidationError:
return {key: _decode_event_data(v) for key, v in value.items()}
elif isinstance(value, list):
@@ -141,44 +157,54 @@ def _decode_event_data(value: Any) -> Any:
class Serializer(abc.ABC):
@abc.abstractmethod
def dumpd(self, obj: Any) -> Any:
"""Convert the given object to a JSON serializable object."""
@abc.abstractmethod
def dumps(self, obj: Any) -> bytes:
"""Dump the given object as a JSON string."""
@abc.abstractmethod
def loads(self, s: bytes) -> Any:
"""Load the given JSON string."""
@abc.abstractmethod
def loadd(self, obj: Any) -> Any:
"""Load the given object."""
class WellKnownLCSerializer(Serializer):
def dumpd(self, obj: Any) -> Any:
"""Convert the given object to a JSON serializable object."""
return orjson.loads(orjson.dumps(obj, default=default))
def dumps(self, obj: Any) -> bytes:
"""Dump the given object as a JSON string."""
return orjson.dumps(obj, default=default)
def loadd(self, obj: Any) -> Any:
"""Load the given object."""
return _decode_lc_objects(obj)
return orjson.loads(self.dumps(obj))
def loads(self, s: bytes) -> Any:
"""Load the given JSON string."""
return self.loadd(orjson.loads(s))
@abc.abstractmethod
def dumps(self, obj: Any) -> bytes:
"""Dump the given object to a JSON byte string."""
@abc.abstractmethod
def loadd(self, s: bytes) -> Any:
"""Given a python object, load it into a well known object.
The obj represents content that was json loaded from a string, but
not yet validated or converted into a well known object.
"""
class WellKnownLCSerializer(Serializer):
"""A pre-defined serializer for well known LangChain objects.
This is the default serialized used by LangServe for serializing and
de-serializing well known LangChain objects.
If you need to extend the serialization capabilities for your own application,
feel free to create a new instance of the Serializer class and implement
the abstract methods dumps and loadd.
"""
def dumps(self, obj: Any) -> bytes:
"""Dump the given object to a JSON byte string."""
return orjson.dumps(obj, default=default)
def loadd(self, obj: Any) -> Any:
"""Given a python object, load it into a well known object.
The obj represents content that was json loaded from a string, but
not yet validated or converted into a well known object.
"""
return _decode_lc_objects(obj)
def _project_top_level(model: BaseModel) -> Dict[str, Any]:
"""Project the top level of the model as dict."""
return {key: getattr(model, key) for key in model.__fields__}
return {key: getattr(model, key) for key in model.model_fields}
def load_events(events: Any) -> List[Dict[str, Any]]:
@@ -209,7 +235,7 @@ def load_events(events: Any) -> List[Dict[str, Any]]:
# Then validate the event
try:
full_event = CallbackEvent.parse_obj(decoded_event_data)
full_event = CallbackEvent.model_validate(decoded_event_data)
except ValidationError as e:
msg = f"Encountered an invalid event: {e}"
if "type" in decoded_event_data:
@@ -217,7 +243,7 @@ def load_events(events: Any) -> List[Dict[str, Any]]:
_log_error_message_once(msg)
continue
decoded_event_data = _project_top_level(full_event.__root__)
decoded_event_data = _project_top_level(full_event.root)
if decoded_event_data["type"].endswith("_error"):
# Data is validated by this point, so we can assume that the shape
+310 -340
View File
@@ -5,6 +5,7 @@ This code contains integration for langchain runnables with FastAPI.
The main entry point is the `add_routes` function which adds the routes to an existing
FastAPI app or APIRouter.
"""
import warnings
import weakref
from typing import (
Any,
@@ -16,6 +17,7 @@ from typing import (
)
from langchain_core.runnables import Runnable
from pydantic import BaseModel
from typing_extensions import Annotated
from langserve.api_handler import (
@@ -24,11 +26,7 @@ from langserve.api_handler import (
TokenFeedbackConfig,
_is_hosted,
)
from langserve.pydantic_v1 import (
_PYDANTIC_MAJOR_VERSION,
PYDANTIC_VERSION,
BaseModel,
)
from langserve.serialization import Serializer
try:
from fastapi import APIRouter, Depends, FastAPI, Request, Response
@@ -205,51 +203,43 @@ def _register_path_for_app(
def _setup_global_app_handlers(
app: Union[FastAPI, APIRouter], endpoint_configuration: _EndpointConfiguration
) -> None:
@app.on_event("startup")
async def startup_event():
LANGSERVE = r"""
__ ___ .__ __. _______ _______. _______ .______ ____ ____ _______
| | / \ | \ | | / _____| / || ____|| _ \ \ \ / / | ____|
| | / ^ \ | \| | | | __ | (----`| |__ | |_) | \ \/ / | |__
| | / /_\ \ | . ` | | | |_ | \ \ | __| | / \ / | __|
| `----./ _____ \ | |\ | | |__| | .----) | | |____ | |\ \----. \ / | |____
|_______/__/ \__\ |__| \__| \______| |_______/ |_______|| _| `._____| \__/ |_______|
""" # noqa: E501
with warnings.catch_warnings():
# We are using deprecated functionality here.
# We should re-write to use lifetime events at some point, and yielding
# an APIRouter instance to the caller.
warnings.filterwarnings(
"ignore",
"[\\s.]*on_event is deprecated[\\s.]*",
category=DeprecationWarning,
)
def green(text: str) -> str:
"""Return the given text in green."""
return "\x1b[1;32;40m" + text + "\x1b[0m"
@app.on_event("startup")
async def startup_event():
LANGSERVE = r"""
__ ___ .__ __. _______ _______. _______ .______ ____ ____ _______
| | / \ | \ | | / _____| / || ____|| _ \ \ \ / / | ____|
| | / ^ \ | \| | | | __ | (----`| |__ | |_) | \ \/ / | |__
| | / /_\ \ | . ` | | | |_ | \ \ | __| | / \ / | __|
| `----./ _____ \ | |\ | | |__| | .----) | | |____ | |\ \----. \ / | |____
|_______/__/ \__\ |__| \__| \______| |_______/ |_______|| _| `._____| \__/ |_______|
""" # noqa: E501
def orange(text: str) -> str:
"""Return the given text in orange."""
return "\x1b[1;31;40m" + text + "\x1b[0m"
def green(text: str) -> str:
"""Return the given text in green."""
return "\x1b[1;32;40m" + text + "\x1b[0m"
paths = _APP_TO_PATHS[app]
print(LANGSERVE)
for path in paths:
if endpoint_configuration.is_playground_enabled:
print(
f'{green("LANGSERVE:")} Playground for chain "{path or ""}/" is '
f"live at:"
)
print(f'{green("LANGSERVE:")}')
print(f'{green("LANGSERVE:")} └──> {path}/playground/')
print(f'{green("LANGSERVE:")}')
print(f'{green("LANGSERVE:")} See all available routes at {app.docs_url}/')
if _PYDANTIC_MAJOR_VERSION == 2:
print()
print(f'{orange("LANGSERVE:")} ', end="")
print(
f"⚠️ Using pydantic {PYDANTIC_VERSION}. "
f"OpenAPI docs for invoke, batch, stream, stream_log "
f"endpoints will not be generated. API endpoints and playground "
f"should work as expected. "
f"If you need to see the docs, you can downgrade to pydantic 1. "
"For example, `pip install pydantic==1.10.13`. "
f"See https://github.com/tiangolo/fastapi/issues/10360 for details."
)
print()
paths = _APP_TO_PATHS[app]
print(LANGSERVE)
for path in paths:
if endpoint_configuration.is_playground_enabled:
print(
f'{green("LANGSERVE:")} Playground for chain "{path or ""}/" '
f'is live at:'
)
print(f'{green("LANGSERVE:")}')
print(f'{green("LANGSERVE:")} └──> {path}/playground/')
print(f'{green("LANGSERVE:")}')
print(f'{green("LANGSERVE:")} See all available routes at {app.docs_url}/')
# PUBLIC API
@@ -273,6 +263,8 @@ def add_routes(
enabled_endpoints: Optional[Sequence[EndpointName]] = None,
dependencies: Optional[Sequence[Depends]] = None,
playground_type: Literal["default", "chat"] = "default",
astream_events_version: Literal["v1", "v2"] = "v2",
serializer: Optional[Serializer] = None,
) -> None:
"""Register the routes on the given FastAPI app or APIRouter.
@@ -391,14 +383,18 @@ def add_routes(
- chat: UX is optimized for chat-like interactions. Please review
the README in langserve for more details about constraints (e.g.,
which message types are supported etc.)
astream_events_version: version of the stream events endpoint to use.
By default "v2".
serializer: The serializer to use for serializing the output. If not provided,
the default serializer will be used.
""" # noqa: E501
if not isinstance(runnable, Runnable):
raise TypeError(
f"Expected a Runnable, got {type(runnable)}. "
f"The second argument to add_routes should be a Runnable instance."
f"add_route(app, runnable, ...) is the correct usage."
f"Please make sure that you are using a runnable which is an instance of "
f"langchain_core.runnables.Runnable."
"The second argument to add_routes should be a Runnable instance."
"add_route(app, runnable, ...) is the correct usage."
"Please make sure that you are using a runnable which is an instance of "
"langchain_core.runnables.Runnable."
)
endpoint_configuration = _EndpointConfiguration(
@@ -454,7 +450,10 @@ def add_routes(
per_req_config_modifier=per_req_config_modifier,
stream_log_name_allow_list=stream_log_name_allow_list,
playground_type=playground_type,
astream_events_version=astream_events_version,
serializer=serializer,
)
namespace = path or ""
route_tags = [path.strip("/")] if path else None
@@ -473,35 +472,9 @@ def add_routes(
if hasattr(app, "openapi_tags") and (path or (app not in _APP_SEEN)):
if not path:
_APP_SEEN.add(app)
if _PYDANTIC_MAJOR_VERSION == 1:
# Documentation for the default endpoints
default_endpoint_tags = {
"name": route_tags[0] if route_tags else "default",
}
elif _PYDANTIC_MAJOR_VERSION == 2:
# When using pydantic v2, we cannot generate openapi docs for
# the invoke/batch/stream/stream_log endpoints since the underlying
# models are from the pydantic.v1 namespace and cannot be supported
# by FastAPI's.
# https://github.com/tiangolo/fastapi/issues/10360
default_endpoint_tags = {
"name": route_tags[0] if route_tags else "default",
"description": (
f"⚠️ Using pydantic {PYDANTIC_VERSION}. "
f"OpenAPI docs for `invoke`, `batch`, `stream`, `stream_log` "
f"endpoints will not be generated. API endpoints and playground "
f"should work as expected. "
f"If you need to see the docs, you can downgrade to pydantic 1. "
"For example, `pip install pydantic==1.10.13`"
f"See https://github.com/tiangolo/fastapi/issues/10360 for details."
),
}
else:
raise AssertionError(
f"Expected pydantic major version 1 or 2, got {_PYDANTIC_MAJOR_VERSION}"
)
default_endpoint_tags = {
"name": route_tags[0] if route_tags else "default",
}
if endpoint_configuration.is_config_hash_enabled:
app.openapi_tags = [
*(getattr(app, "openapi_tags", []) or []),
@@ -778,329 +751,326 @@ def add_routes(
# Documentation variants of end points.
#######################################
# At the moment, we only support pydantic 1.x for documentation
if _PYDANTIC_MAJOR_VERSION == 1:
InvokeRequest = api_handler.InvokeRequest
InvokeResponse = api_handler.InvokeResponse
BatchRequest = api_handler.BatchRequest
BatchResponse = api_handler.BatchResponse
StreamRequest = api_handler.StreamRequest
StreamLogRequest = api_handler.StreamLogRequest
StreamEventsRequest = api_handler.StreamEventsRequest
InvokeRequest = api_handler.InvokeRequest
InvokeResponse = api_handler.InvokeResponse
BatchRequest = api_handler.BatchRequest
BatchResponse = api_handler.BatchResponse
StreamRequest = api_handler.StreamRequest
StreamLogRequest = api_handler.StreamLogRequest
StreamEventsRequest = api_handler.StreamEventsRequest
if endpoint_configuration.is_invoke_enabled:
if endpoint_configuration.is_invoke_enabled:
async def _invoke_docs(
invoke_request: Annotated[InvokeRequest, InvokeRequest],
config_hash: str = "",
) -> InvokeResponse:
"""Invoke the runnable with the given input and config."""
raise AssertionError("This endpoint should not be reachable.")
async def _invoke_docs(
invoke_request: Annotated[InvokeRequest, InvokeRequest],
config_hash: str = "",
) -> InvokeResponse:
"""Invoke the runnable with the given input and config."""
raise AssertionError("This endpoint should not be reachable.")
invoke_docs = app.post(
f"{namespace}/invoke",
invoke_docs = app.post(
f"{namespace}/invoke",
response_model=api_handler.InvokeResponse,
tags=route_tags,
name=_route_name("invoke"),
dependencies=dependencies,
)(_invoke_docs)
if endpoint_configuration.is_config_hash_enabled:
app.post(
namespace + "/c/{config_hash}/invoke",
response_model=api_handler.InvokeResponse,
tags=route_tags,
name=_route_name("invoke"),
tags=route_tags_with_config,
name=_route_name_with_config("invoke"),
dependencies=dependencies,
)(_invoke_docs)
description=(
"This endpoint is to be used with share links generated by the "
"LangServe playground. "
"The hash is an LZString compressed JSON string. "
"For regular use cases, use the /invoke endpoint without "
"the `c/{config_hash}` path parameter."
),
)(invoke_docs)
if endpoint_configuration.is_config_hash_enabled:
app.post(
namespace + "/c/{config_hash}/invoke",
response_model=api_handler.InvokeResponse,
tags=route_tags_with_config,
name=_route_name_with_config("invoke"),
dependencies=dependencies,
description=(
"This endpoint is to be used with share links generated by the "
"LangServe playground. "
"The hash is an LZString compressed JSON string. "
"For regular use cases, use the /invoke endpoint without "
"the `c/{config_hash}` path parameter."
),
)(invoke_docs)
if endpoint_configuration.is_batch_enabled:
if endpoint_configuration.is_batch_enabled:
async def _batch_docs(
batch_request: Annotated[BatchRequest, BatchRequest],
config_hash: str = "",
) -> BatchResponse:
"""Batch invoke the runnable with the given inputs and config."""
raise AssertionError("This endpoint should not be reachable.")
async def _batch_docs(
batch_request: Annotated[BatchRequest, BatchRequest],
config_hash: str = "",
) -> BatchResponse:
"""Batch invoke the runnable with the given inputs and config."""
raise AssertionError("This endpoint should not be reachable.")
batch_docs = app.post(
f"{namespace}/batch",
response_model=BatchResponse,
tags=route_tags,
name=_route_name("batch"),
dependencies=dependencies,
)(_batch_docs)
batch_docs = app.post(
f"{namespace}/batch",
if endpoint_configuration.is_config_hash_enabled:
app.post(
namespace + "/c/{config_hash}/batch",
response_model=BatchResponse,
tags=route_tags,
name=_route_name("batch"),
tags=route_tags_with_config,
name=_route_name_with_config("batch"),
dependencies=dependencies,
)(_batch_docs)
description=(
"This endpoint is to be used with share links generated by the "
"LangServe playground. "
"The hash is an LZString compressed JSON string. "
"For regular use cases, use the /batch endpoint without "
"the `c/{config_hash}` path parameter."
),
)(batch_docs)
if endpoint_configuration.is_config_hash_enabled:
app.post(
namespace + "/c/{config_hash}/batch",
response_model=BatchResponse,
tags=route_tags_with_config,
name=_route_name_with_config("batch"),
dependencies=dependencies,
description=(
"This endpoint is to be used with share links generated by the "
"LangServe playground. "
"The hash is an LZString compressed JSON string. "
"For regular use cases, use the /batch endpoint without "
"the `c/{config_hash}` path parameter."
),
)(batch_docs)
if endpoint_configuration.is_stream_enabled:
if endpoint_configuration.is_stream_enabled:
async def _stream_docs(
stream_request: Annotated[StreamRequest, StreamRequest],
config_hash: str = "",
) -> EventSourceResponse:
"""Invoke the runnable stream the output.
async def _stream_docs(
stream_request: Annotated[StreamRequest, StreamRequest],
config_hash: str = "",
) -> EventSourceResponse:
"""Invoke the runnable stream the output.
This endpoint allows to stream the output of the runnable.
This endpoint allows to stream the output of the runnable.
The endpoint uses a server sent event stream to stream the output.
The endpoint uses a server sent event stream to stream the output.
https://developer.mozilla.org/en-US/docs/Web/API/Server-sent_events
https://developer.mozilla.org/en-US/docs/Web/API/Server-sent_events
Important: Set the "text/event-stream" media type for request headers if
not using an existing SDK.
Important: Set the "text/event-stream" media type for request headers if
not using an existing SDK.
The events that the endpoint uses are the following:
* "data" -- used for streaming the output of the runnale
* "error" -- signaling an error while streaming and ends the stream.
* "end" -- used for signaling the end of the stream
* "metadata" -- used for sending metadata about the run; e.g., run id.
The events that the endpoint uses are the following:
* "data" -- used for streaming the output of the runnale
* "error" -- signaling an error while streaming and ends the stream.
* "end" -- used for signaling the end of the stream
* "metadata" -- used for sending metadata about the run; e.g., run id.
The event type is in the "event" field of the event.
The payload associated with the event is in the "data" field
of the event, and it is JSON encoded.
The event type is in the "event" field of the event.
The payload associated with the event is in the "data" field
of the event, and it is JSON encoded.
Here are some examples of events that the endpoint can send:
Here are some examples of events that the endpoint can send:
Regular streaming event:
{
"event": "data",
"data": {
...
}
}
Internal server error:
{
"event": "error",
"data": {
"status_code": 500,
"message": "Internal Server Error"
}
}
Streaming ended so client should stop listening for events:
{
"event": "end",
}
"""
raise AssertionError("This endpoint should not be reachable.")
stream_docs = app.post(
f"{namespace}/stream",
include_in_schema=True,
tags=route_tags,
name=_route_name("stream"),
dependencies=dependencies,
description=(
"This endpoint allows to stream the output of the runnable. "
"The endpoint uses a server sent event stream to stream the "
"output. "
"https://developer.mozilla.org/en-US/docs/Web/API/Server-sent_events"
),
)(_stream_docs)
if endpoint_configuration.is_config_hash_enabled:
app.post(
namespace + "/c/{config_hash}/stream",
include_in_schema=True,
tags=route_tags_with_config,
name=_route_name_with_config("stream"),
dependencies=dependencies,
description=(
"This endpoint is to be used with share links generated by the "
"LangServe playground. "
"The hash is an LZString compressed JSON string. "
"For regular use cases, use the /stream endpoint without "
"the `c/{config_hash}` path parameter."
),
)(stream_docs)
if endpoint_configuration.is_stream_log_enabled:
async def _stream_log_docs(
stream_log_request: Annotated[StreamLogRequest, StreamLogRequest],
config_hash: str = "",
) -> EventSourceResponse:
"""Invoke the runnable stream_log the output.
This endpoint allows to stream the output of the runnable, including
the output of all intermediate steps.
The endpoint uses a server sent event stream to stream the output.
https://developer.mozilla.org/en-US/docs/Web/API/Server-sent_events
Important: Set the "text/event-stream" media type for request headers if
not using an existing SDK.
This endpoint uses two different types of events:
* data - for streaming the output of the runnable
Regular streaming event:
{
"event": "data",
"data": {
...
...
}
}
Internal server error:
{
"event": "error",
"data": {
"status_code": 500,
"message": "Internal Server Error"
}
}
* error - for signaling an error in the stream, also ends the stream.
{
"event": "error",
"data": {
"status_code": 500,
"message": "Internal Server Error"
}
}
* end - for signaling the end of the stream.
This helps the client to know when to stop listening for events and
know that the streaming has ended successfully.
Streaming ended so client should stop listening for events:
{
"event": "end",
}
"""
raise AssertionError("This endpoint should not be reachable.")
"""
raise AssertionError("This endpoint should not be reachable.")
stream_docs = app.post(
f"{namespace}/stream",
include_in_schema=True,
tags=route_tags,
name=_route_name("stream"),
dependencies=dependencies,
description=(
"This endpoint allows to stream the output of the runnable. "
"The endpoint uses a server sent event stream to stream the "
"output. "
"https://developer.mozilla.org/en-US/docs/Web/API/Server-sent_events"
),
)(_stream_docs)
if endpoint_configuration.is_config_hash_enabled:
app.post(
namespace + "/c/{config_hash}/stream",
include_in_schema=True,
tags=route_tags_with_config,
name=_route_name_with_config("stream"),
dependencies=dependencies,
description=(
"This endpoint is to be used with share links generated by the "
"LangServe playground. "
"The hash is an LZString compressed JSON string. "
"For regular use cases, use the /stream endpoint without "
"the `c/{config_hash}` path parameter."
),
)(stream_docs)
if endpoint_configuration.is_stream_log_enabled:
async def _stream_log_docs(
stream_log_request: Annotated[StreamLogRequest, StreamLogRequest],
config_hash: str = "",
) -> EventSourceResponse:
"""Invoke the runnable stream_log the output.
This endpoint allows to stream the output of the runnable, including
the output of all intermediate steps.
The endpoint uses a server sent event stream to stream the output.
https://developer.mozilla.org/en-US/docs/Web/API/Server-sent_events
Important: Set the "text/event-stream" media type for request headers if
not using an existing SDK.
This endpoint uses two different types of events:
* data - for streaming the output of the runnable
{
"event": "data",
"data": {
...
}
}
* error - for signaling an error in the stream, also ends the stream.
{
"event": "error",
"data": {
"status_code": 500,
"message": "Internal Server Error"
}
}
* end - for signaling the end of the stream.
This helps the client to know when to stop listening for events and
know that the streaming has ended successfully.
{
"event": "end",
}
"""
raise AssertionError("This endpoint should not be reachable.")
app.post(
f"{namespace}/stream_log",
include_in_schema=True,
tags=route_tags,
name=_route_name("stream_log"),
dependencies=dependencies,
)(_stream_log_docs)
if endpoint_configuration.is_config_hash_enabled:
app.post(
f"{namespace}/stream_log",
namespace + "/c/{config_hash}/stream_log",
include_in_schema=True,
tags=route_tags,
name=_route_name("stream_log"),
tags=route_tags_with_config,
name=_route_name_with_config("stream_log"),
description=(
"This endpoint is to be used with share links generated by the "
"LangServe playground. "
"The hash is an LZString compressed JSON string. "
"For regular use cases, use the /stream_log endpoint without "
"the `c/{config_hash}` path parameter."
),
dependencies=dependencies,
)(_stream_log_docs)
if endpoint_configuration.is_config_hash_enabled:
app.post(
namespace + "/c/{config_hash}/stream_log",
include_in_schema=True,
tags=route_tags_with_config,
name=_route_name_with_config("stream_log"),
description=(
"This endpoint is to be used with share links generated by the "
"LangServe playground. "
"The hash is an LZString compressed JSON string. "
"For regular use cases, use the /stream_log endpoint without "
"the `c/{config_hash}` path parameter."
),
dependencies=dependencies,
)(_stream_log_docs)
if has_astream_events and endpoint_configuration.is_stream_events_enabled:
if has_astream_events and endpoint_configuration.is_stream_events_enabled:
async def _stream_events_docs(
stream_events_request: Annotated[StreamEventsRequest, StreamEventsRequest],
config_hash: str = "",
) -> EventSourceResponse:
"""Stream events from the given runnable.
async def _stream_events_docs(
stream_events_request: Annotated[
StreamEventsRequest, StreamEventsRequest
],
config_hash: str = "",
) -> EventSourceResponse:
"""Stream events from the given runnable.
This endpoint allows to stream events from the runnable, including
events from all intermediate steps.
This endpoint allows to stream events from the runnable, including
events from all intermediate steps.
**Attention**
**Attention**
This is a new endpoint that only works for langchain-core >= 0.1.14.
This is a new endpoint that only works for langchain-core >= 0.1.14.
It belongs to a Beta API that may change in the future.
It belongs to a Beta API that may change in the future.
**Important**
Specify filters to the events you want to receive by setting
the appropriate filters in the request body.
**Important**
Specify filters to the events you want to receive by setting
the appropriate filters in the request body.
This will help avoid sending too much data over the network.
This will help avoid sending too much data over the network.
It will also prevent serialization issues with
any unsupported types since it won't need to serialize events
that aren't transmitted.
It will also prevent serialization issues with
any unsupported types since it won't need to serialize events
that aren't transmitted.
The endpoint uses a server sent event stream to stream the output.
The endpoint uses a server sent event stream to stream the output.
https://developer.mozilla.org/en-US/docs/Web/API/Server-sent_events
https://developer.mozilla.org/en-US/docs/Web/API/Server-sent_events
The encoding of events follows the following format:
The encoding of events follows the following format:
* data - for streaming the output of the runnable
{
"event": "data",
"data": {
...
}
}
* error - for signaling an error in the stream, also ends the stream.
* data - for streaming the output of the runnable
{
"event": "error",
"event": "data",
"data": {
"status_code": 500,
"message": "Internal Server Error"
...
}
}
* end - for signaling the end of the stream.
* error - for signaling an error in the stream, also ends the stream.
This helps the client to know when to stop listening for events and
know that the streaming has ended successfully.
{
"event": "error",
"data": {
"status_code": 500,
"message": "Internal Server Error"
}
}
{
"event": "end",
}
* end - for signaling the end of the stream.
`data` for the `data` event is a JSON object that corresponds
to a serialized representation of a StreamEvent.
This helps the client to know when to stop listening for events and
know that the streaming has ended successfully.
See LangChain documentation for more information about astream_events.
"""
raise AssertionError("This endpoint should not be reachable.")
{
"event": "end",
}
`data` for the `data` event is a JSON object that corresponds
to a serialized representation of a StreamEvent.
See LangChain documentation for more information about astream_events.
"""
raise AssertionError("This endpoint should not be reachable.")
app.post(
f"{namespace}/stream_events",
include_in_schema=True,
tags=route_tags,
name=_route_name("stream_events"),
dependencies=dependencies,
)(_stream_events_docs)
if endpoint_configuration.is_config_hash_enabled:
app.post(
f"{namespace}/stream_events",
namespace + "/c/{config_hash}/stream_events",
include_in_schema=True,
tags=route_tags,
name=_route_name("stream_events"),
tags=route_tags_with_config,
name=_route_name_with_config("stream_events"),
description=(
"This endpoint is to be used with share links generated by the "
"LangServe playground. "
"The hash is an LZString compressed JSON string. "
"For regular use cases, use the /stream_events endpoint "
"without the `c/{config_hash}` path parameter."
),
dependencies=dependencies,
)(_stream_events_docs)
if endpoint_configuration.is_config_hash_enabled:
app.post(
namespace + "/c/{config_hash}/stream_events",
include_in_schema=True,
tags=route_tags_with_config,
name=_route_name_with_config("stream_events"),
description=(
"This endpoint is to be used with share links generated by the "
"LangServe playground. "
"The hash is an LZString compressed JSON string. "
"For regular use cases, use the /stream_events endpoint "
"without the `c/{config_hash}` path parameter."
),
dependencies=dependencies,
)(_stream_events_docs)
+2 -1
View File
@@ -1,8 +1,9 @@
"""Adapted from https://github.com/florimondmanca/httpx-sse"""
from contextlib import asynccontextmanager, contextmanager
from typing import Any, AsyncIterator, Iterator, List, Optional, TypedDict
from typing import Any, AsyncIterator, Iterator, List, Optional
import httpx
from typing_extensions import TypedDict
class ServerSentEvent(TypedDict):
+46 -68
View File
@@ -22,16 +22,11 @@ from uuid import UUID
from langchain_core.documents import Document
from langchain_core.messages import BaseMessage
from langchain_core.outputs import ChatGeneration, Generation, RunInfo
from pydantic import BaseModel, Field, RootModel, create_model
from typing_extensions import Type
from langserve.schema import BatchResponseMetadata, InvokeResponseMetadata
try:
from pydantic.v1 import BaseModel, Field, create_model
except ImportError:
from pydantic import BaseModel, Field, create_model
# Type that is either a python annotation or a pydantic model that can be
# used to validate the input or output of a runnable.
Validator = Union[Type[BaseModel], type]
@@ -66,7 +61,7 @@ def create_invoke_request_model(
),
),
)
invoke_request_type.update_forward_refs()
invoke_request_type.model_rebuild()
return invoke_request_type
@@ -97,7 +92,7 @@ def create_stream_request_model(
),
),
)
stream_request_model.update_forward_refs()
stream_request_model.model_rebuild()
return stream_request_model
@@ -129,7 +124,7 @@ def create_batch_request_model(
),
),
)
batch_request_type.update_forward_refs()
batch_request_type.model_rebuild()
return batch_request_type
@@ -187,7 +182,7 @@ def create_stream_log_request_model(
),
kwargs=(dict, Field(default_factory=dict)),
)
stream_log_request.update_forward_refs()
stream_log_request.model_rebuild()
return stream_log_request
@@ -245,7 +240,7 @@ def create_stream_events_request_model(
),
kwargs=(dict, Field(default_factory=dict)),
)
stream_events_request.update_forward_refs()
stream_events_request.model_rebuild()
return stream_events_request
@@ -297,7 +292,7 @@ def create_invoke_response_model(
__base__=InvokeBaseResponse,
**fields,
)
invoke_response_type.update_forward_refs()
invoke_response_type.model_rebuild()
return invoke_response_type
@@ -347,7 +342,7 @@ def create_batch_response_model(
__base__=BatchBaseResponse,
**fields,
)
batch_response_type.update_forward_refs()
batch_response_type.model_rebuild()
return batch_response_type
@@ -401,27 +396,29 @@ class StreamEventsParameters(BaseModel):
# status code and a message.
class OnChainStart(BaseModel):
"""On Chain Start Callback Event."""
class BaseCallback(BaseModel):
"""Base class for all callback events."""
serialized: Dict[str, Any]
inputs: Any
run_id: UUID
parent_run_id: Optional[UUID] = None
tags: Optional[List[str]] = None
metadata: Optional[Dict[str, Any]] = None
kwargs: Any = None
class OnChainStart(BaseCallback):
"""On Chain Start Callback Event."""
serialized: Optional[Dict[str, Any]] = None
inputs: Any
kwargs: Optional[Dict[str, Any]] = None
type: Literal["on_chain_start"] = "on_chain_start"
class OnChainEnd(BaseModel):
class OnChainEnd(BaseCallback):
"""On Chain End Callback Event."""
outputs: Any
run_id: UUID
parent_run_id: Optional[UUID] = None
tags: Optional[List[str]] = None
kwargs: Any = None
kwargs: Optional[Dict[str, Any]] = None
type: Literal["on_chain_end"] = "on_chain_end"
@@ -433,38 +430,35 @@ class Error(BaseModel):
type: Literal["error"] = "error"
class OnChainError(BaseModel):
class OnChainError(BaseCallback):
"""On Chain Error Callback Event."""
error: Error
run_id: UUID
parent_run_id: Optional[UUID] = None
tags: Optional[List[str]] = None
kwargs: Any = None
kwargs: Optional[Dict[str, Any]] = None
type: Literal["on_chain_error"] = "on_chain_error"
class OnToolStart(BaseModel):
class OnToolStart(BaseCallback):
"""On Tool Start Callback Event."""
serialized: Dict[str, Any]
serialized: Optional[Dict[str, Any]] = None
input_str: str
run_id: UUID
parent_run_id: Optional[UUID] = None
tags: Optional[List[str]] = None
metadata: Optional[Dict[str, Any]] = None
kwargs: Any = None
kwargs: Optional[Dict[str, Any]] = None
type: Literal["on_tool_start"] = "on_tool_start"
class OnToolEnd(BaseModel):
class OnToolEnd(BaseCallback):
"""On Tool End Callback Event."""
output: str
run_id: UUID
parent_run_id: Optional[UUID] = None
tags: Optional[List[str]] = None
kwargs: Any = None
kwargs: Optional[Dict[str, Any]] = None
type: Literal["on_tool_end"] = "on_tool_end"
@@ -472,36 +466,29 @@ class OnToolError(BaseModel):
"""On Tool Error Callback Event."""
error: Error
run_id: UUID
parent_run_id: Optional[UUID] = None
tags: Optional[List[str]] = None
kwargs: Any = None
kwargs: Optional[Dict[str, Any]] = None
type: Literal["on_tool_error"] = "on_tool_error"
class OnChatModelStart(BaseModel):
class OnChatModelStart(BaseCallback):
"""On Chat Model Start Callback Event."""
serialized: Dict[str, Any]
serialized: Optional[Dict[str, Any]] = None
messages: List[List[BaseMessage]]
run_id: UUID
parent_run_id: Optional[UUID] = None
tags: Optional[List[str]] = None
metadata: Optional[Dict[str, Any]] = None
kwargs: Any = None
kwargs: Optional[Dict[str, Any]] = None
type: Literal["on_chat_model_start"] = "on_chat_model_start"
class OnLLMStart(BaseModel):
class OnLLMStart(BaseCallback):
"""On LLM Start Callback Event."""
serialized: Dict[str, Any]
serialized: Optional[Dict[str, Any]] = None
prompts: List[str]
run_id: UUID
parent_run_id: Optional[UUID] = None
tags: Optional[List[str]] = None
metadata: Optional[Dict[str, Any]] = None
kwargs: Any = None
kwargs: Optional[Dict[str, Any]] = None
type: Literal["on_llm_start"] = "on_llm_start"
@@ -520,54 +507,44 @@ class LLMResult(BaseModel):
"""List of metadata info for model call for each input."""
class OnLLMEnd(BaseModel):
class OnLLMEnd(BaseCallback):
"""On LLM End Callback Event."""
response: LLMResult
run_id: UUID
parent_run_id: Optional[UUID] = None
tags: Optional[List[str]] = None
kwargs: Any = None
kwargs: Optional[Dict[str, Any]] = None
type: Literal["on_llm_end"] = "on_llm_end"
class OnRetrieverStart(BaseModel):
class OnRetrieverStart(BaseCallback):
"""On Retriever Start Callback Event."""
serialized: Dict[str, Any]
serialized: Optional[Dict[str, Any]] = None
query: str
run_id: UUID
parent_run_id: Optional[UUID] = None
tags: Optional[List[str]] = None
metadata: Optional[Dict[str, Any]] = None
kwargs: Any = None
kwargs: Optional[Dict[str, Any]] = None
type: Literal["on_retriever_start"] = "on_retriever_start"
class OnRetrieverError(BaseModel):
class OnRetrieverError(BaseCallback):
"""On Retriever Error Callback Event."""
error: Error
run_id: UUID
parent_run_id: Optional[UUID] = None
tags: Optional[List[str]] = None
kwargs: Any = None
kwargs: Optional[Dict[str, Any]] = None
type: Literal["on_retriever_error"] = "on_retriever_error"
class OnRetrieverEnd(BaseModel):
class OnRetrieverEnd(BaseCallback):
"""On Retriever End Callback Event."""
documents: Sequence[Document]
run_id: UUID
parent_run_id: Optional[UUID] = None
tags: Optional[List[str]] = None
kwargs: Any = None
kwargs: Optional[Dict[str, Any]] = None
type: Literal["on_retriever_end"] = "on_retriever_end"
class CallbackEvent(BaseModel):
__root__: Union[
CallbackEvent = RootModel[
Union[
OnChainStart,
OnChainEnd,
OnChainError,
@@ -581,3 +558,4 @@ class CallbackEvent(BaseModel):
OnRetrieverEnd,
OnRetrieverError,
]
]
Generated
+1617 -1600
View File
File diff suppressed because it is too large Load Diff
+10 -7
View File
@@ -1,6 +1,6 @@
[tool.poetry]
name = "langserve"
version = "0.2.2"
version = "0.3.2"
description = ""
readme = "README.md"
authors = ["LangChain"]
@@ -10,14 +10,13 @@ exclude = ["langserve/playground,langserve/chat_playground"]
include = ["langserve/playground/dist/**/*", "langserve/chat_playground/dist/**/*"]
[tool.poetry.dependencies]
python = "^3.8.1"
httpx = ">=0.23.0" # May be able to decrease this version
python = "^3.10"
httpx = ">=0.23.0,<1.0"
fastapi = {version = ">=0.90.1,<1", optional = true}
sse-starlette = {version = "^1.3.0", optional = true}
pydantic = ">=1"
langchain-core = ">=0.1,<0.3"
orjson = ">=2"
pyproject-toml = "^0.0.10"
langchain-core = ">=0.3,<2"
orjson = ">=2,<4"
pydantic = "^2.7"
[tool.poetry.group.dev.dependencies]
jupyterlab = "^3.6.1"
@@ -95,3 +94,7 @@ addopts = "--strict-markers --strict-config --durations=5 -vv"
# take more than 5 seconds
timeout = 5
asyncio_mode = "auto"
filterwarnings = [
"ignore::langchain_core._api.beta_decorator.LangChainBetaWarning",
]
+13 -4
View File
@@ -1,5 +1,6 @@
"""Test the playground API."""
import httpx
from fastapi import APIRouter, FastAPI
from httpx import AsyncClient
from langchain_core.runnables import RunnableLambda
@@ -15,7 +16,9 @@ async def test_serve_playground() -> None:
RunnableLambda(lambda foo: "hello"),
)
async with AsyncClient(app=app, base_url="http://localhost:9999") as client:
async with AsyncClient(
base_url="http://localhost:9999", transport=httpx.ASGITransport(app=app)
) as client:
response = await client.get("/playground/index.html")
assert response.status_code == 200
# Test that we can't access files that do not exist.
@@ -42,7 +45,9 @@ async def test_serve_playground_with_api_router() -> None:
app.include_router(router)
async with AsyncClient(app=app, base_url="http://localhost:9999") as client:
async with AsyncClient(
base_url="http://localhost:9999", transport=httpx.ASGITransport(app=app)
) as client:
response = await client.get("/langserve_runnables/chat/playground/index.html")
assert response.status_code == 200
@@ -64,7 +69,9 @@ async def test_serve_playground_with_api_router_in_api_router() -> None:
# Now add parent router to the app
app.include_router(parent_router)
async with AsyncClient(app=app, base_url="http://localhost:9999") as client:
async with AsyncClient(
base_url="http://localhost:9999", transport=httpx.ASGITransport(app=app)
) as client:
response = await client.get("/parent/bar/foo/playground/index.html")
assert response.status_code == 200
@@ -88,7 +95,9 @@ async def test_root_path_on_playground() -> None:
)
app.include_router(router)
async_client = AsyncClient(app=app, base_url="http://localhost:9999")
async_client = AsyncClient(
base_url="http://localhost:9999", transport=httpx.ASGITransport(app=app)
)
response = await async_client.get("/chat/playground/index.html")
assert response.status_code == 200
+16 -6
View File
@@ -4,15 +4,23 @@ from enum import Enum
from typing import Any
import pytest
from langchain_core.documents.base import Document
from langchain_core.messages import HumanMessage, HumanMessageChunk, SystemMessage
from langchain_core.outputs import ChatGeneration
from pydantic import BaseModel
try:
from pydantic.v1 import BaseModel
except ImportError:
from pydantic import BaseModel
from langserve.serialization import (
WellKnownLCObject,
WellKnownLCSerializer,
load_events,
)
from langserve.serialization import WellKnownLCSerializer, load_events
def test_document_serialization() -> None:
"""Simple test. Exhaustive tests follow below."""
doc = Document(page_content="hello")
d = doc.model_dump()
WellKnownLCObject.model_validate(d)
@pytest.mark.parametrize(
@@ -23,6 +31,8 @@ from langserve.serialization import WellKnownLCSerializer, load_events
[],
{},
{"a": 1},
Document(page_content="Hello"),
[Document(page_content="Hello")],
{"output": [HumanMessage(content="hello")]},
# Test with a single message (HumanMessage)
HumanMessage(content="Hello"),
@@ -77,7 +87,7 @@ def _get_full_representation(data: Any) -> Any:
elif isinstance(data, list):
return [_get_full_representation(value) for value in data]
elif isinstance(data, BaseModel):
return data.schema()
return data.model_json_schema()
else:
return data
File diff suppressed because it is too large Load Diff
+3 -8
View File
@@ -5,14 +5,9 @@ import pytest
from fastapi import Request
from langchain_core.prompts import PromptTemplate
from langchain_core.runnables import ConfigurableField
from pydantic import BaseModel, ValidationError
from langserve.api_handler import _unpack_request_config
try:
from pydantic.v1 import BaseModel, ValidationError
except ImportError:
from pydantic import BaseModel, ValidationError
from langserve.validation import (
create_batch_request_model,
create_invoke_request_model,
@@ -175,11 +170,11 @@ async def test_invoke_request_with_runnables() -> None:
"configurable": {"template": "goodbye {name}"},
},
)
assert request.input == {"name": "bob"}
assert dict(request.input) == {"name": "bob"}
assert request.config.tags == ["hello"]
assert request.config.run_name == "run"
assert isinstance(request.config.configurable, BaseModel)
assert request.config.configurable.dict() == {
assert request.config.configurable.model_dump() == {
"template": "goodbye {name}",
}
+27
View File
@@ -0,0 +1,27 @@
from typing import Any
def recursive_dump(obj: Any) -> Any:
"""Recursively dump the object if encountering any pydantic models."""
if isinstance(obj, dict):
return {
k: recursive_dump(v)
for k, v in obj.items()
if k != "id" # Remove the id field for testing purposes
}
if isinstance(obj, list):
return [recursive_dump(v) for v in obj]
if hasattr(obj, "model_dump"):
# if the object contains an ID field, we'll remove it for testing purposes
d = obj.model_dump()
if "id" in d:
d.pop("id")
return recursive_dump(d)
if hasattr(obj, "dict"):
# if the object contains an ID field, we'll remove it for testing purposes
if hasattr(obj, "id"):
d = obj.dict()
d.pop("id")
return recursive_dump(d)
return recursive_dump(obj.dict())
return obj
+16
View File
@@ -1,6 +1,22 @@
from typing import Any
from langchain_core.messages import AIMessage, AIMessageChunk
class AnyStr(str):
def __eq__(self, other: Any) -> bool:
return isinstance(other, str)
def _AnyIdAIMessage(**kwargs: Any) -> AIMessage:
"""Create ai message with an any id field."""
message = AIMessage(**kwargs)
message.id = AnyStr()
return message
def _AnyIdAIMessageChunk(**kwargs: Any) -> AIMessageChunk:
"""Create ai message with an any id field."""
message = AIMessageChunk(**kwargs)
message.id = AnyStr()
return message
+28 -35
View File
@@ -4,11 +4,11 @@ from typing import Any, Dict, List, Optional, Union
from uuid import UUID
from langchain_core.callbacks.base import AsyncCallbackHandler
from langchain_core.messages import AIMessage, AIMessageChunk, BaseMessage
from langchain_core.messages import AIMessage, BaseMessage
from langchain_core.outputs import ChatGenerationChunk, GenerationChunk
from tests.unit_tests.utils.llms import GenericFakeChatModel
from tests.unit_tests.utils.stubs import AnyStr
from tests.unit_tests.utils.stubs import _AnyIdAIMessage, _AnyIdAIMessageChunk
def test_generic_fake_chat_model_invoke() -> None:
@@ -16,11 +16,11 @@ def test_generic_fake_chat_model_invoke() -> None:
infinite_cycle = cycle([AIMessage(content="hello"), AIMessage(content="goodbye")])
model = GenericFakeChatModel(messages=infinite_cycle)
response = model.invoke("meow")
assert response == AIMessage(content="hello", id=AnyStr())
assert response == _AnyIdAIMessage(content="hello")
response = model.invoke("kitty")
assert response == AIMessage(content="goodbye", id=AnyStr())
assert response == _AnyIdAIMessage(content="goodbye")
response = model.invoke("meow")
assert response == AIMessage(content="hello", id=AnyStr())
assert response == _AnyIdAIMessage(content="hello")
async def test_generic_fake_chat_model_ainvoke() -> None:
@@ -28,11 +28,11 @@ async def test_generic_fake_chat_model_ainvoke() -> None:
infinite_cycle = cycle([AIMessage(content="hello"), AIMessage(content="goodbye")])
model = GenericFakeChatModel(messages=infinite_cycle)
response = await model.ainvoke("meow")
assert response == AIMessage(content="hello", id=AnyStr())
assert response == _AnyIdAIMessage(content="hello")
response = await model.ainvoke("kitty")
assert response == AIMessage(content="goodbye", id=AnyStr())
assert response == _AnyIdAIMessage(content="goodbye")
response = await model.ainvoke("meow")
assert response == AIMessage(content="hello", id=AnyStr())
assert response == _AnyIdAIMessage(content="hello")
async def test_generic_fake_chat_model_stream() -> None:
@@ -45,28 +45,26 @@ async def test_generic_fake_chat_model_stream() -> None:
model = GenericFakeChatModel(messages=infinite_cycle)
chunks = [chunk async for chunk in model.astream("meow")]
assert chunks == [
AIMessageChunk(content="hello", id=AnyStr()),
AIMessageChunk(content=" ", id=AnyStr()),
AIMessageChunk(content="goodbye", id=AnyStr()),
_AnyIdAIMessageChunk(content="hello"),
_AnyIdAIMessageChunk(content=" "),
_AnyIdAIMessageChunk(content="goodbye"),
]
chunks = [chunk for chunk in model.stream("meow")]
assert chunks == [
AIMessageChunk(content="hello", id=AnyStr()),
AIMessageChunk(content=" ", id=AnyStr()),
AIMessageChunk(content="goodbye", id=AnyStr()),
_AnyIdAIMessageChunk(content="hello"),
_AnyIdAIMessageChunk(content=" "),
_AnyIdAIMessageChunk(content="goodbye"),
]
# Test streaming of additional kwargs.
# Relying on insertion order of the additional kwargs dict
message = AIMessage(
content="", additional_kwargs={"foo": 42, "bar": 24}, id=AnyStr()
)
message = AIMessage(content="", additional_kwargs={"foo": 42, "bar": 24}, id="1")
model = GenericFakeChatModel(messages=cycle([message]))
chunks = [chunk async for chunk in model.astream("meow")]
assert chunks == [
AIMessageChunk(content="", additional_kwargs={"foo": 42}, id=AnyStr()),
AIMessageChunk(content="", additional_kwargs={"bar": 24}, id=AnyStr()),
_AnyIdAIMessageChunk(content="", additional_kwargs={"foo": 42}),
_AnyIdAIMessageChunk(content="", additional_kwargs={"bar": 24}),
]
message = AIMessage(
@@ -83,29 +81,25 @@ async def test_generic_fake_chat_model_stream() -> None:
chunks = [chunk async for chunk in model.astream("meow")]
assert chunks == [
AIMessageChunk(
_AnyIdAIMessageChunk(
content="",
additional_kwargs={"function_call": {"name": "move_file"}},
id=AnyStr(),
),
AIMessageChunk(
_AnyIdAIMessageChunk(
content="",
additional_kwargs={
"function_call": {"arguments": '{\n "source_path": "foo"'}
},
id=AnyStr(),
),
AIMessageChunk(
_AnyIdAIMessageChunk(
content="",
additional_kwargs={"function_call": {"arguments": ","}},
id=AnyStr(),
),
AIMessageChunk(
_AnyIdAIMessageChunk(
content="",
additional_kwargs={
"function_call": {"arguments": '\n "destination_path": "bar"\n}'}
},
id=AnyStr(),
),
]
@@ -116,7 +110,7 @@ async def test_generic_fake_chat_model_stream() -> None:
else:
accumulate_chunks += chunk
assert accumulate_chunks == AIMessageChunk(
assert accumulate_chunks == _AnyIdAIMessageChunk(
content="",
additional_kwargs={
"function_call": {
@@ -125,7 +119,6 @@ async def test_generic_fake_chat_model_stream() -> None:
'destination_path": "bar"\n}',
}
},
id=AnyStr(),
)
@@ -138,9 +131,9 @@ async def test_generic_fake_chat_model_astream_log() -> None:
]
final = log_patches[-1]
assert final.state["streamed_output"] == [
AIMessageChunk(content="hello", id=AnyStr()),
AIMessageChunk(content=" ", id=AnyStr()),
AIMessageChunk(content="goodbye", id=AnyStr()),
_AnyIdAIMessageChunk(content="hello"),
_AnyIdAIMessageChunk(content=" "),
_AnyIdAIMessageChunk(content="goodbye"),
]
@@ -188,8 +181,8 @@ async def test_callback_handlers() -> None:
# New model
results = list(model.stream("meow", {"callbacks": [MyCustomAsyncHandler(tokens)]}))
assert results == [
AIMessageChunk(content="hello", id=AnyStr()),
AIMessageChunk(content=" ", id=AnyStr()),
AIMessageChunk(content="goodbye", id=AnyStr()),
_AnyIdAIMessageChunk(content="hello"),
_AnyIdAIMessageChunk(content=" "),
_AnyIdAIMessageChunk(content="goodbye"),
]
assert tokens == ["hello", " ", "goodbye"]
+29 -1
View File
@@ -1,4 +1,4 @@
from typing import Dict, List
from typing import Any, Dict, List, Optional
from uuid import UUID
from langchain_core.tracers import BaseTracer
@@ -39,6 +39,34 @@ class FakeTracer(BaseTracer):
}
)
def _create_chain_run(
self,
serialized: Dict[str, Any],
inputs: Dict[str, Any],
run_id: UUID,
tags: Optional[List[str]] = None,
parent_run_id: Optional[UUID] = None,
metadata: Optional[Dict[str, Any]] = None,
run_type: Optional[str] = None,
name: Optional[str] = None,
**kwargs: Any,
) -> Run:
if name is None:
# can't raise an exception from here, but can get a breakpoint
# import pdb; pdb.set_trace()
pass
return super()._create_chain_run(
serialized,
inputs,
run_id,
tags,
parent_run_id,
metadata,
run_type,
name,
**kwargs,
)
def _persist_run(self, run: Run) -> None:
"""Persist a run."""
self.runs.append(self._copy_run(run))