Compare commits

...

3 Commits

Author SHA1 Message Date
Eugene Yurtsev 8b20a37353 Merge branch 'main' into eugene/add_interpreptors 2023-10-06 15:31:38 -04:00
Eugene Yurtsev 1c4fc3171c x 2023-10-06 15:14:10 -04:00
Eugene Yurtsev e481e8429b x 2023-10-06 15:06:16 -04:00
3 changed files with 133 additions and 33 deletions
+85 -29
View File
@@ -2,15 +2,18 @@ from inspect import isclass
from typing import (
Any,
AsyncIterator,
Callable,
Dict,
List,
Literal,
Mapping,
Optional,
Sequence,
Type,
Union,
)
from fastapi import Request
from langchain.callbacks.tracers.log_stream import RunLog, RunLogPatch
from langchain.load.serializable import Serializable
from langchain.schema.runnable import Runnable
@@ -35,11 +38,36 @@ except ImportError:
# [server] extra not installed
APIRouter = FastAPI = Any
# A function that that takes a config and a raw request
# and updates the config based on the request.
ConfigUpdater = Callable[[Dict[str, Any], Request], Dict[str, Any]]
def _unpack_config(d: Union[BaseModel, Mapping], keys: Sequence[str]) -> Dict[str, Any]:
"""Project the given keys from the given dict."""
_d = d.dict() if isinstance(d, BaseModel) else d
return {k: _d[k] for k in keys if k in _d}
def _unpack_config(
raw_config: Union[BaseModel, Mapping],
keys: Sequence[str],
request: Request,
*,
config_updater: Optional[ConfigUpdater] = None,
) -> Dict[str, Any]:
"""Project the given keys from the given dict.
Args:
raw_config: The raw config a pydantic model
keys: The keys to project from the config.
request: The raw fast api request.
config_updater: optional function that can be used to update the config
based on the raw request.
Returns:
A finalized config.
"""
_d = raw_config.dict() if isinstance(raw_config, BaseModel) else raw_config
projected_config = {k: _d[k] for k in keys if k in _d}
if config_updater is None:
return projected_config
return config_updater(projected_config, request)
class InvokeResponse(BaseModel):
@@ -145,6 +173,7 @@ def add_routes(
path: str = "",
input_type: Union[Type, Literal["auto"], BaseModel] = "auto",
config_keys: Sequence[str] = (),
config_updater: Optional[ConfigUpdater] = None,
) -> None:
"""Register the routes on the given FastAPI app or APIRouter.
@@ -157,6 +186,8 @@ def add_routes(
User is free to provide a custom type annotation.
config_keys: list of config keys that will be accepted, by default
no config keys are accepted.
config_updater: optional function that can be used to update the config
based on the raw request before it is passed to the runnable.
"""
try:
from sse_starlette import EventSourceResponse
@@ -185,54 +216,73 @@ def add_routes(
StreamLogRequest = create_stream_log_request_model(
model_namespace, input_type_, config
)
from fastapi import Request
@app.post(
f"{namespace}/invoke",
response_model=InvokeResponse,
)
async def invoke(
request: Annotated[InvokeRequest, InvokeRequest]
invoke_request: Annotated[InvokeRequest, InvokeRequest],
request: Request,
) -> InvokeResponse:
"""Invoke the runnable with the given input and config."""
# Request is first validated using InvokeRequest which takes into account
# config_keys as well as input_type.
config = _unpack_config(request.config, config_keys)
config = _unpack_config(
invoke_request.config, config_keys, request, config_updater=config_updater
)
output = await runnable.ainvoke(
_unpack_input(request.input), config=config, **request.kwargs
_unpack_input(invoke_request.input), config=config, **invoke_request.kwargs
)
return InvokeResponse(output=simple_dumpd(output))
#
@app.post(f"{namespace}/batch", response_model=BatchResponse)
async def batch(request: Annotated[BatchRequest, BatchRequest]) -> BatchResponse:
async def batch(
batch_request: Annotated[BatchRequest, BatchRequest], request: Request
) -> BatchResponse:
"""Invoke the runnable with the given inputs and config."""
if isinstance(request.config, list):
config = [_unpack_config(config, config_keys) for config in request.config]
if isinstance(batch_request.config, list):
config = [
_unpack_config(
config, config_keys, request, config_updater=config_updater
)
for config in batch_request.config
]
else:
config = _unpack_config(request.config, config_keys)
inputs = [_unpack_input(input_) for input_ in request.inputs]
output = await runnable.abatch(inputs, config=config, **request.kwargs)
config = _unpack_config(
batch_request.config,
config_keys,
request,
config_updater=config_updater,
)
inputs = [_unpack_input(input_) for input_ in batch_request.inputs]
output = await runnable.abatch(inputs, config=config, **batch_request.kwargs)
return BatchResponse(output=simple_dumpd(output))
@app.post(f"{namespace}/stream")
async def stream(
request: Annotated[StreamRequest, StreamRequest],
stream_request: Annotated[StreamRequest, StreamRequest],
request: Request,
) -> EventSourceResponse:
"""Invoke the runnable stream the output."""
# Request is first validated using InvokeRequest which takes into account
# config_keys as well as input_type.
# After validation, the input is loaded using LangChain's load function.
input_ = _unpack_input(request.input)
config = _unpack_config(request.config, config_keys)
input_ = _unpack_input(stream_request.input)
config = _unpack_config(
stream_request.config, config_keys, request, config_updater=config_updater
)
async def _stream() -> AsyncIterator[dict]:
"""Stream the output of the runnable."""
async for chunk in runnable.astream(
input_,
config=config,
**request.kwargs,
**stream_request.kwargs,
):
yield {"data": simple_dumps(chunk), "event": "data"}
yield {"event": "end"}
@@ -241,30 +291,36 @@ def add_routes(
@app.post(f"{namespace}/stream_log")
async def stream_log(
request: Annotated[StreamLogRequest, StreamLogRequest],
stream_log_request: Annotated[StreamLogRequest, StreamLogRequest],
request: Request,
) -> EventSourceResponse:
"""Invoke the runnable stream the output."""
# Request is first validated using InvokeRequest which takes into account
# config_keys as well as input_type.
# After validation, the input is loaded using LangChain's load function.
input_ = _unpack_input(request.input)
config = _unpack_config(request.config, config_keys)
input_ = _unpack_input(stream_log_request.input)
config = _unpack_config(
stream_log_request.config,
config_keys,
request,
config_updater=config_updater,
)
async def _stream_log() -> AsyncIterator[dict]:
"""Stream the output of the runnable."""
async for chunk in runnable.astream_log(
input_,
config=config,
diff=request.diff,
include_names=request.include_names,
include_types=request.include_types,
include_tags=request.include_tags,
exclude_names=request.exclude_names,
exclude_types=request.exclude_types,
exclude_tags=request.exclude_tags,
**request.kwargs,
diff=stream_log_request.diff,
include_names=stream_log_request.include_names,
include_types=stream_log_request.include_types,
include_tags=stream_log_request.include_tags,
exclude_names=stream_log_request.exclude_names,
exclude_types=stream_log_request.exclude_types,
exclude_tags=stream_log_request.exclude_tags,
**stream_log_request.kwargs,
):
if request.diff: # Run log patch
if stream_log_request.diff: # Run log patch
if not isinstance(chunk, RunLogPatch):
raise AssertionError(
f"Expected a RunLog instance got {type(chunk)}"
+41 -3
View File
@@ -2,12 +2,12 @@
import asyncio
from asyncio import AbstractEventLoop
from contextlib import asynccontextmanager
from typing import List, Optional, Union
from typing import Any, Dict, List, Optional, Union
import httpx
import pytest
import pytest_asyncio
from fastapi import FastAPI
from fastapi import FastAPI, Request
from fastapi.testclient import TestClient
from httpx import AsyncClient
from langchain.callbacks.tracers.log_stream import RunLog, RunLogPatch
@@ -405,6 +405,44 @@ async def test_multiple_runnables(event_loop: AbstractEventLoop) -> None:
assert await composite_runnable_2.ainvoke(3) == 10
@pytest.mark.asyncio
async def test_config_updater(
event_loop: AbstractEventLoop, mocker: MockerFixture
) -> None:
"""Test updating the config based on the raw request object."""
async def add_one(x: int) -> int:
"""Add one to simulate a valid function"""
return x + 1
app = FastAPI()
def config_updater(config: Dict[str, Any], request: Request) -> Dict[str, Any]:
"""Update the config"""
config = config.copy()
if "metadata" in config:
config["metadata"] = config["metadata"].copy()
else:
config["metadata"] = {}
config["metadata"]["headers"] = request.headers
return config
server_runnable = RunnableLambda(add_one)
add_routes(app, server_runnable, path="/add_one", config_updater=config_updater)
invoke_spy_1 = mocker.spy(server_runnable, "ainvoke")
# Verify config is handled correctly
async with get_async_client(app, path="/add_one") as runnable1:
# Verify that can be invoked with valid input
# Config ignored for runnable1
assert await runnable1.ainvoke(1, config={}) == 2
assert (
invoke_spy_1.call_args[1]["config"]["metadata"]["headers"]["content-type"]
== "application/json"
)
@pytest.mark.asyncio
async def test_input_validation(
event_loop: AbstractEventLoop, mocker: MockerFixture
@@ -541,7 +579,7 @@ async def test_async_client_close() -> None:
@pytest.mark.asyncio
async def test_openapi_docs_with_identical_runnables(
event_loop: AbstractEventLoop, mocker: MockerFixture
event_loop: AbstractEventLoop,
) -> None:
"""Test client side and server side exceptions."""
+7 -1
View File
@@ -156,6 +156,7 @@ def test_invoke_request_with_runnables() -> None:
input={"name": "bob"},
).config,
[],
None, # type: ignore # This is the raw request (not used for test)
)
== {}
)
@@ -177,6 +178,11 @@ def test_invoke_request_with_runnables() -> None:
"template": "goodbye {name}",
}
assert _unpack_config(request.config, ["configurable"]) == {
raw_request = None # Not used for test
config = _unpack_config(
request.config, ["configurable"], raw_request # type: ignore
)
assert config == {
"configurable": {"template": "goodbye {name}"},
}