Compare commits

...

3 Commits

Author SHA1 Message Date
Eugene Yurtsev 1bdcb70e73 x 2024-03-14 13:30:28 -04:00
Eugene Yurtsev 1ad5619598 x 2024-03-14 13:30:15 -04:00
Eugene Yurtsev 588e7975c2 x 2024-03-14 13:19:37 -04:00
4 changed files with 42 additions and 9 deletions
+5 -3
View File
@@ -13,14 +13,16 @@ See:
* https://fastapi.tiangolo.com/tutorial/security/
"""
from fastapi import Depends, FastAPI, Header, HTTPException
from fastapi import Depends, FastAPI, HTTPException
from fastapi.security import APIKeyHeader
from langchain_core.runnables import RunnableLambda
from typing_extensions import Annotated
from langserve import add_routes
XToken = APIKeyHeader(name="x-token")
async def verify_token(x_token: Annotated[str, Header()]) -> None:
async def verify_token(x_token: str = Depends(XToken)) -> None:
"""Verify the token is valid."""
# Replace this with your actual authentication logic
if x_token != "secret-token":
+5 -3
View File
@@ -13,14 +13,16 @@ To implement proper auth, please see the FastAPI docs:
* https://fastapi.tiangolo.com/tutorial/security/
""" # noqa: E501
from fastapi import Depends, FastAPI, Header, HTTPException
from fastapi import Depends, FastAPI, HTTPException
from fastapi.security import APIKeyHeader
from langchain_core.runnables import RunnableLambda
from typing_extensions import Annotated
from langserve import add_routes
XToken = APIKeyHeader(name="x-token")
async def verify_token(x_token: Annotated[str, Header()]) -> None:
async def verify_token(x_token: str = Depends(XToken)) -> None:
"""Verify the token is valid."""
# Replace this with your actual authentication logic
if x_token != "secret-token":
+5 -2
View File
@@ -36,7 +36,7 @@ from starlette.responses import JSONResponse, Response
from langserve.callbacks import AsyncEventAggregatorCallback, CallbackEventDict
from langserve.lzstring import LZString
from langserve.playground import serve_playground
from langserve.playground import PlaygroundConfig, serve_playground
from langserve.pydantic_v1 import BaseModel, Field, ValidationError, create_model
from langserve.schema import (
BatchResponseMetadata,
@@ -468,6 +468,7 @@ class APIHandler:
per_req_config_modifier: Optional[PerRequestConfigModifier] = None,
stream_log_name_allow_list: Optional[Sequence[str]] = None,
playground_type: Literal["default", "chat"] = "default",
playground_config: Optional[PlaygroundConfig] = None,
) -> None:
"""Create an API handler for the given runnable.
@@ -561,6 +562,7 @@ class APIHandler:
self._enable_feedback_endpoint = enable_feedback_endpoint
self._enable_public_trace_link_endpoint = enable_public_trace_link_endpoint
self._names_in_stream_allow_list = stream_log_name_allow_list
self._playground_config = playground_config
# Client is patched using mock.patch, if changing the names
# remember to make relevant updates in the unit tests.
@@ -1366,7 +1368,8 @@ class APIHandler:
file_path,
feedback_enabled,
public_trace_link_enabled,
playground_type=self.playground_type,
self.playground_type,
playground_config=self._playground_config,
)
async def create_feedback(
+27 -1
View File
@@ -2,10 +2,12 @@ import json
import mimetypes
import os
from string import Template
from typing import Literal, Sequence, Type
from typing import Literal, Optional, Sequence, Type, Union
from fastapi.responses import Response
from fastapi.security import APIKeyCookie, APIKeyHeader, APIKeyQuery
from langchain.schema.runnable import Runnable
from typing_extensions import TypedDict
from langserve.pydantic_v1 import BaseModel
@@ -47,6 +49,15 @@ def _get_mimetype(path: str) -> str:
return mime_type
SupportedSecurityScheme = Union[APIKeyHeader, APIKeyQuery, APIKeyCookie]
class PlaygroundConfig(TypedDict, total=False):
"""Configuration for the playground."""
security_scheme: Optional[SupportedSecurityScheme]
async def serve_playground(
runnable: Runnable,
input_schema: Type[BaseModel],
@@ -56,8 +67,20 @@ async def serve_playground(
feedback_enabled: bool,
public_trace_link_enabled: bool,
playground_type: Literal["default", "chat"],
*,
playground_config: Optional[PlaygroundConfig] = None,
) -> Response:
"""Serve the playground."""
security_scheme = (
playground_config.get("security_scheme") if playground_config else None
)
if not isinstance(
security_scheme, (APIKeyHeader, APIKeyQuery, APIKeyCookie, type(None))
):
raise NotImplementedError(
"Only APIKeyHeader, APIKeyQuery, APIKeyCookie, and None are supported."
)
if playground_type == "default":
path_to_dist = "./playground/dist"
elif playground_type == "chat":
@@ -98,6 +121,9 @@ async def serve_playground(
LANGSERVE_PUBLIC_TRACE_LINK_ENABLED=json.dumps(
"true" if public_trace_link_enabled else "false"
),
SECURITY_SCHEME=security_scheme.model.json()
if security_scheme
else json.dumps({}),
)
else:
response = f.buffer.read()