mirror of
https://github.com/langchain-ai/langserve.git
synced 2026-07-01 20:14:01 -04:00
Compare commits
3 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
| 1bdcb70e73 | |||
| 1ad5619598 | |||
| 588e7975c2 |
@@ -13,14 +13,16 @@ See:
|
||||
* https://fastapi.tiangolo.com/tutorial/security/
|
||||
"""
|
||||
|
||||
from fastapi import Depends, FastAPI, Header, HTTPException
|
||||
from fastapi import Depends, FastAPI, HTTPException
|
||||
from fastapi.security import APIKeyHeader
|
||||
from langchain_core.runnables import RunnableLambda
|
||||
from typing_extensions import Annotated
|
||||
|
||||
from langserve import add_routes
|
||||
|
||||
XToken = APIKeyHeader(name="x-token")
|
||||
|
||||
async def verify_token(x_token: Annotated[str, Header()]) -> None:
|
||||
|
||||
async def verify_token(x_token: str = Depends(XToken)) -> None:
|
||||
"""Verify the token is valid."""
|
||||
# Replace this with your actual authentication logic
|
||||
if x_token != "secret-token":
|
||||
|
||||
@@ -13,14 +13,16 @@ To implement proper auth, please see the FastAPI docs:
|
||||
* https://fastapi.tiangolo.com/tutorial/security/
|
||||
""" # noqa: E501
|
||||
|
||||
from fastapi import Depends, FastAPI, Header, HTTPException
|
||||
from fastapi import Depends, FastAPI, HTTPException
|
||||
from fastapi.security import APIKeyHeader
|
||||
from langchain_core.runnables import RunnableLambda
|
||||
from typing_extensions import Annotated
|
||||
|
||||
from langserve import add_routes
|
||||
|
||||
XToken = APIKeyHeader(name="x-token")
|
||||
|
||||
async def verify_token(x_token: Annotated[str, Header()]) -> None:
|
||||
|
||||
async def verify_token(x_token: str = Depends(XToken)) -> None:
|
||||
"""Verify the token is valid."""
|
||||
# Replace this with your actual authentication logic
|
||||
if x_token != "secret-token":
|
||||
|
||||
@@ -36,7 +36,7 @@ from starlette.responses import JSONResponse, Response
|
||||
|
||||
from langserve.callbacks import AsyncEventAggregatorCallback, CallbackEventDict
|
||||
from langserve.lzstring import LZString
|
||||
from langserve.playground import serve_playground
|
||||
from langserve.playground import PlaygroundConfig, serve_playground
|
||||
from langserve.pydantic_v1 import BaseModel, Field, ValidationError, create_model
|
||||
from langserve.schema import (
|
||||
BatchResponseMetadata,
|
||||
@@ -468,6 +468,7 @@ class APIHandler:
|
||||
per_req_config_modifier: Optional[PerRequestConfigModifier] = None,
|
||||
stream_log_name_allow_list: Optional[Sequence[str]] = None,
|
||||
playground_type: Literal["default", "chat"] = "default",
|
||||
playground_config: Optional[PlaygroundConfig] = None,
|
||||
) -> None:
|
||||
"""Create an API handler for the given runnable.
|
||||
|
||||
@@ -561,6 +562,7 @@ class APIHandler:
|
||||
self._enable_feedback_endpoint = enable_feedback_endpoint
|
||||
self._enable_public_trace_link_endpoint = enable_public_trace_link_endpoint
|
||||
self._names_in_stream_allow_list = stream_log_name_allow_list
|
||||
self._playground_config = playground_config
|
||||
|
||||
# Client is patched using mock.patch, if changing the names
|
||||
# remember to make relevant updates in the unit tests.
|
||||
@@ -1366,7 +1368,8 @@ class APIHandler:
|
||||
file_path,
|
||||
feedback_enabled,
|
||||
public_trace_link_enabled,
|
||||
playground_type=self.playground_type,
|
||||
self.playground_type,
|
||||
playground_config=self._playground_config,
|
||||
)
|
||||
|
||||
async def create_feedback(
|
||||
|
||||
+27
-1
@@ -2,10 +2,12 @@ import json
|
||||
import mimetypes
|
||||
import os
|
||||
from string import Template
|
||||
from typing import Literal, Sequence, Type
|
||||
from typing import Literal, Optional, Sequence, Type, Union
|
||||
|
||||
from fastapi.responses import Response
|
||||
from fastapi.security import APIKeyCookie, APIKeyHeader, APIKeyQuery
|
||||
from langchain.schema.runnable import Runnable
|
||||
from typing_extensions import TypedDict
|
||||
|
||||
from langserve.pydantic_v1 import BaseModel
|
||||
|
||||
@@ -47,6 +49,15 @@ def _get_mimetype(path: str) -> str:
|
||||
return mime_type
|
||||
|
||||
|
||||
SupportedSecurityScheme = Union[APIKeyHeader, APIKeyQuery, APIKeyCookie]
|
||||
|
||||
|
||||
class PlaygroundConfig(TypedDict, total=False):
|
||||
"""Configuration for the playground."""
|
||||
|
||||
security_scheme: Optional[SupportedSecurityScheme]
|
||||
|
||||
|
||||
async def serve_playground(
|
||||
runnable: Runnable,
|
||||
input_schema: Type[BaseModel],
|
||||
@@ -56,8 +67,20 @@ async def serve_playground(
|
||||
feedback_enabled: bool,
|
||||
public_trace_link_enabled: bool,
|
||||
playground_type: Literal["default", "chat"],
|
||||
*,
|
||||
playground_config: Optional[PlaygroundConfig] = None,
|
||||
) -> Response:
|
||||
"""Serve the playground."""
|
||||
security_scheme = (
|
||||
playground_config.get("security_scheme") if playground_config else None
|
||||
)
|
||||
if not isinstance(
|
||||
security_scheme, (APIKeyHeader, APIKeyQuery, APIKeyCookie, type(None))
|
||||
):
|
||||
raise NotImplementedError(
|
||||
"Only APIKeyHeader, APIKeyQuery, APIKeyCookie, and None are supported."
|
||||
)
|
||||
|
||||
if playground_type == "default":
|
||||
path_to_dist = "./playground/dist"
|
||||
elif playground_type == "chat":
|
||||
@@ -98,6 +121,9 @@ async def serve_playground(
|
||||
LANGSERVE_PUBLIC_TRACE_LINK_ENABLED=json.dumps(
|
||||
"true" if public_trace_link_enabled else "false"
|
||||
),
|
||||
SECURITY_SCHEME=security_scheme.model.json()
|
||||
if security_scheme
|
||||
else json.dumps({}),
|
||||
)
|
||||
else:
|
||||
response = f.buffer.read()
|
||||
|
||||
Reference in New Issue
Block a user