mirror of
https://github.com/langchain-ai/langserve.git
synced 2026-07-01 20:14:01 -04:00
Compare commits
3 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
| 8b20a37353 | |||
| 1c4fc3171c | |||
| e481e8429b |
+85
-29
@@ -2,15 +2,18 @@ from inspect import isclass
|
||||
from typing import (
|
||||
Any,
|
||||
AsyncIterator,
|
||||
Callable,
|
||||
Dict,
|
||||
List,
|
||||
Literal,
|
||||
Mapping,
|
||||
Optional,
|
||||
Sequence,
|
||||
Type,
|
||||
Union,
|
||||
)
|
||||
|
||||
from fastapi import Request
|
||||
from langchain.callbacks.tracers.log_stream import RunLog, RunLogPatch
|
||||
from langchain.load.serializable import Serializable
|
||||
from langchain.schema.runnable import Runnable
|
||||
@@ -35,11 +38,36 @@ except ImportError:
|
||||
# [server] extra not installed
|
||||
APIRouter = FastAPI = Any
|
||||
|
||||
# A function that that takes a config and a raw request
|
||||
# and updates the config based on the request.
|
||||
ConfigUpdater = Callable[[Dict[str, Any], Request], Dict[str, Any]]
|
||||
|
||||
def _unpack_config(d: Union[BaseModel, Mapping], keys: Sequence[str]) -> Dict[str, Any]:
|
||||
"""Project the given keys from the given dict."""
|
||||
_d = d.dict() if isinstance(d, BaseModel) else d
|
||||
return {k: _d[k] for k in keys if k in _d}
|
||||
|
||||
def _unpack_config(
|
||||
raw_config: Union[BaseModel, Mapping],
|
||||
keys: Sequence[str],
|
||||
request: Request,
|
||||
*,
|
||||
config_updater: Optional[ConfigUpdater] = None,
|
||||
) -> Dict[str, Any]:
|
||||
"""Project the given keys from the given dict.
|
||||
|
||||
Args:
|
||||
raw_config: The raw config a pydantic model
|
||||
keys: The keys to project from the config.
|
||||
request: The raw fast api request.
|
||||
config_updater: optional function that can be used to update the config
|
||||
based on the raw request.
|
||||
|
||||
Returns:
|
||||
A finalized config.
|
||||
"""
|
||||
_d = raw_config.dict() if isinstance(raw_config, BaseModel) else raw_config
|
||||
projected_config = {k: _d[k] for k in keys if k in _d}
|
||||
if config_updater is None:
|
||||
return projected_config
|
||||
|
||||
return config_updater(projected_config, request)
|
||||
|
||||
|
||||
class InvokeResponse(BaseModel):
|
||||
@@ -145,6 +173,7 @@ def add_routes(
|
||||
path: str = "",
|
||||
input_type: Union[Type, Literal["auto"], BaseModel] = "auto",
|
||||
config_keys: Sequence[str] = (),
|
||||
config_updater: Optional[ConfigUpdater] = None,
|
||||
) -> None:
|
||||
"""Register the routes on the given FastAPI app or APIRouter.
|
||||
|
||||
@@ -157,6 +186,8 @@ def add_routes(
|
||||
User is free to provide a custom type annotation.
|
||||
config_keys: list of config keys that will be accepted, by default
|
||||
no config keys are accepted.
|
||||
config_updater: optional function that can be used to update the config
|
||||
based on the raw request before it is passed to the runnable.
|
||||
"""
|
||||
try:
|
||||
from sse_starlette import EventSourceResponse
|
||||
@@ -185,54 +216,73 @@ def add_routes(
|
||||
StreamLogRequest = create_stream_log_request_model(
|
||||
model_namespace, input_type_, config
|
||||
)
|
||||
from fastapi import Request
|
||||
|
||||
@app.post(
|
||||
f"{namespace}/invoke",
|
||||
response_model=InvokeResponse,
|
||||
)
|
||||
async def invoke(
|
||||
request: Annotated[InvokeRequest, InvokeRequest]
|
||||
invoke_request: Annotated[InvokeRequest, InvokeRequest],
|
||||
request: Request,
|
||||
) -> InvokeResponse:
|
||||
"""Invoke the runnable with the given input and config."""
|
||||
# Request is first validated using InvokeRequest which takes into account
|
||||
# config_keys as well as input_type.
|
||||
config = _unpack_config(request.config, config_keys)
|
||||
config = _unpack_config(
|
||||
invoke_request.config, config_keys, request, config_updater=config_updater
|
||||
)
|
||||
output = await runnable.ainvoke(
|
||||
_unpack_input(request.input), config=config, **request.kwargs
|
||||
_unpack_input(invoke_request.input), config=config, **invoke_request.kwargs
|
||||
)
|
||||
|
||||
return InvokeResponse(output=simple_dumpd(output))
|
||||
|
||||
#
|
||||
@app.post(f"{namespace}/batch", response_model=BatchResponse)
|
||||
async def batch(request: Annotated[BatchRequest, BatchRequest]) -> BatchResponse:
|
||||
async def batch(
|
||||
batch_request: Annotated[BatchRequest, BatchRequest], request: Request
|
||||
) -> BatchResponse:
|
||||
"""Invoke the runnable with the given inputs and config."""
|
||||
if isinstance(request.config, list):
|
||||
config = [_unpack_config(config, config_keys) for config in request.config]
|
||||
if isinstance(batch_request.config, list):
|
||||
config = [
|
||||
_unpack_config(
|
||||
config, config_keys, request, config_updater=config_updater
|
||||
)
|
||||
for config in batch_request.config
|
||||
]
|
||||
else:
|
||||
config = _unpack_config(request.config, config_keys)
|
||||
inputs = [_unpack_input(input_) for input_ in request.inputs]
|
||||
output = await runnable.abatch(inputs, config=config, **request.kwargs)
|
||||
config = _unpack_config(
|
||||
batch_request.config,
|
||||
config_keys,
|
||||
request,
|
||||
config_updater=config_updater,
|
||||
)
|
||||
inputs = [_unpack_input(input_) for input_ in batch_request.inputs]
|
||||
output = await runnable.abatch(inputs, config=config, **batch_request.kwargs)
|
||||
|
||||
return BatchResponse(output=simple_dumpd(output))
|
||||
|
||||
@app.post(f"{namespace}/stream")
|
||||
async def stream(
|
||||
request: Annotated[StreamRequest, StreamRequest],
|
||||
stream_request: Annotated[StreamRequest, StreamRequest],
|
||||
request: Request,
|
||||
) -> EventSourceResponse:
|
||||
"""Invoke the runnable stream the output."""
|
||||
# Request is first validated using InvokeRequest which takes into account
|
||||
# config_keys as well as input_type.
|
||||
# After validation, the input is loaded using LangChain's load function.
|
||||
input_ = _unpack_input(request.input)
|
||||
config = _unpack_config(request.config, config_keys)
|
||||
input_ = _unpack_input(stream_request.input)
|
||||
config = _unpack_config(
|
||||
stream_request.config, config_keys, request, config_updater=config_updater
|
||||
)
|
||||
|
||||
async def _stream() -> AsyncIterator[dict]:
|
||||
"""Stream the output of the runnable."""
|
||||
async for chunk in runnable.astream(
|
||||
input_,
|
||||
config=config,
|
||||
**request.kwargs,
|
||||
**stream_request.kwargs,
|
||||
):
|
||||
yield {"data": simple_dumps(chunk), "event": "data"}
|
||||
yield {"event": "end"}
|
||||
@@ -241,30 +291,36 @@ def add_routes(
|
||||
|
||||
@app.post(f"{namespace}/stream_log")
|
||||
async def stream_log(
|
||||
request: Annotated[StreamLogRequest, StreamLogRequest],
|
||||
stream_log_request: Annotated[StreamLogRequest, StreamLogRequest],
|
||||
request: Request,
|
||||
) -> EventSourceResponse:
|
||||
"""Invoke the runnable stream the output."""
|
||||
# Request is first validated using InvokeRequest which takes into account
|
||||
# config_keys as well as input_type.
|
||||
# After validation, the input is loaded using LangChain's load function.
|
||||
input_ = _unpack_input(request.input)
|
||||
config = _unpack_config(request.config, config_keys)
|
||||
input_ = _unpack_input(stream_log_request.input)
|
||||
config = _unpack_config(
|
||||
stream_log_request.config,
|
||||
config_keys,
|
||||
request,
|
||||
config_updater=config_updater,
|
||||
)
|
||||
|
||||
async def _stream_log() -> AsyncIterator[dict]:
|
||||
"""Stream the output of the runnable."""
|
||||
async for chunk in runnable.astream_log(
|
||||
input_,
|
||||
config=config,
|
||||
diff=request.diff,
|
||||
include_names=request.include_names,
|
||||
include_types=request.include_types,
|
||||
include_tags=request.include_tags,
|
||||
exclude_names=request.exclude_names,
|
||||
exclude_types=request.exclude_types,
|
||||
exclude_tags=request.exclude_tags,
|
||||
**request.kwargs,
|
||||
diff=stream_log_request.diff,
|
||||
include_names=stream_log_request.include_names,
|
||||
include_types=stream_log_request.include_types,
|
||||
include_tags=stream_log_request.include_tags,
|
||||
exclude_names=stream_log_request.exclude_names,
|
||||
exclude_types=stream_log_request.exclude_types,
|
||||
exclude_tags=stream_log_request.exclude_tags,
|
||||
**stream_log_request.kwargs,
|
||||
):
|
||||
if request.diff: # Run log patch
|
||||
if stream_log_request.diff: # Run log patch
|
||||
if not isinstance(chunk, RunLogPatch):
|
||||
raise AssertionError(
|
||||
f"Expected a RunLog instance got {type(chunk)}"
|
||||
|
||||
@@ -2,12 +2,12 @@
|
||||
import asyncio
|
||||
from asyncio import AbstractEventLoop
|
||||
from contextlib import asynccontextmanager
|
||||
from typing import List, Optional, Union
|
||||
from typing import Any, Dict, List, Optional, Union
|
||||
|
||||
import httpx
|
||||
import pytest
|
||||
import pytest_asyncio
|
||||
from fastapi import FastAPI
|
||||
from fastapi import FastAPI, Request
|
||||
from fastapi.testclient import TestClient
|
||||
from httpx import AsyncClient
|
||||
from langchain.callbacks.tracers.log_stream import RunLog, RunLogPatch
|
||||
@@ -405,6 +405,44 @@ async def test_multiple_runnables(event_loop: AbstractEventLoop) -> None:
|
||||
assert await composite_runnable_2.ainvoke(3) == 10
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_config_updater(
|
||||
event_loop: AbstractEventLoop, mocker: MockerFixture
|
||||
) -> None:
|
||||
"""Test updating the config based on the raw request object."""
|
||||
|
||||
async def add_one(x: int) -> int:
|
||||
"""Add one to simulate a valid function"""
|
||||
return x + 1
|
||||
|
||||
app = FastAPI()
|
||||
|
||||
def config_updater(config: Dict[str, Any], request: Request) -> Dict[str, Any]:
|
||||
"""Update the config"""
|
||||
config = config.copy()
|
||||
if "metadata" in config:
|
||||
config["metadata"] = config["metadata"].copy()
|
||||
else:
|
||||
config["metadata"] = {}
|
||||
config["metadata"]["headers"] = request.headers
|
||||
return config
|
||||
|
||||
server_runnable = RunnableLambda(add_one)
|
||||
|
||||
add_routes(app, server_runnable, path="/add_one", config_updater=config_updater)
|
||||
|
||||
invoke_spy_1 = mocker.spy(server_runnable, "ainvoke")
|
||||
# Verify config is handled correctly
|
||||
async with get_async_client(app, path="/add_one") as runnable1:
|
||||
# Verify that can be invoked with valid input
|
||||
# Config ignored for runnable1
|
||||
assert await runnable1.ainvoke(1, config={}) == 2
|
||||
assert (
|
||||
invoke_spy_1.call_args[1]["config"]["metadata"]["headers"]["content-type"]
|
||||
== "application/json"
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_input_validation(
|
||||
event_loop: AbstractEventLoop, mocker: MockerFixture
|
||||
@@ -541,7 +579,7 @@ async def test_async_client_close() -> None:
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_openapi_docs_with_identical_runnables(
|
||||
event_loop: AbstractEventLoop, mocker: MockerFixture
|
||||
event_loop: AbstractEventLoop,
|
||||
) -> None:
|
||||
"""Test client side and server side exceptions."""
|
||||
|
||||
|
||||
@@ -156,6 +156,7 @@ def test_invoke_request_with_runnables() -> None:
|
||||
input={"name": "bob"},
|
||||
).config,
|
||||
[],
|
||||
None, # type: ignore # This is the raw request (not used for test)
|
||||
)
|
||||
== {}
|
||||
)
|
||||
@@ -177,6 +178,11 @@ def test_invoke_request_with_runnables() -> None:
|
||||
"template": "goodbye {name}",
|
||||
}
|
||||
|
||||
assert _unpack_config(request.config, ["configurable"]) == {
|
||||
raw_request = None # Not used for test
|
||||
config = _unpack_config(
|
||||
request.config, ["configurable"], raw_request # type: ignore
|
||||
)
|
||||
|
||||
assert config == {
|
||||
"configurable": {"template": "goodbye {name}"},
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user