Merge pull request #6 from langchain-ai/infra/configurable-fields

feat: add configurable fields
This commit is contained in:
langchain-infra
2024-06-04 20:22:59 -04:00
committed by GitHub
3 changed files with 45 additions and 21 deletions
+21 -10
View File
@@ -6,6 +6,7 @@ from langchain_core.callbacks import (
from langchain_core.language_models import BaseChatModel
from langchain_core.messages import AIMessageChunk, BaseMessage, AIMessage
from langchain_core.outputs import ChatGeneration, ChatGenerationChunk, ChatResult
from langchain_core.runnables import ConfigurableField, Runnable
class CustomChatModel(BaseChatModel):
@@ -18,11 +19,11 @@ class CustomChatModel(BaseChatModel):
n: int = 5
def _generate(
self,
messages: List[BaseMessage],
stop: Optional[List[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any,
self,
messages: List[BaseMessage],
stop: Optional[List[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> ChatResult:
"""Override the _generate method to implement the chat model logic.
@@ -56,11 +57,11 @@ class CustomChatModel(BaseChatModel):
return ChatResult(generations=[generation])
def _stream(
self,
messages: List[BaseMessage],
stop: Optional[List[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any,
self,
messages: List[BaseMessage],
stop: Optional[List[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> Iterator[ChatGenerationChunk]:
"""Stream the output of the model.
@@ -112,3 +113,13 @@ class CustomChatModel(BaseChatModel):
# costs for the given LLM.)
"model_name": "CustomChatModel",
}
def with_configurable_fields(self) -> Runnable:
"""Expose fields you want to be configurable in the playground. We will automatically expose these to the
playground. If you don't want to expose any fields, you can remove this method."""
return self.configurable_fields(n=ConfigurableField(
id="n",
name="Num Characters",
description="Number of characters to return from the input prompt.",
))
+20 -10
View File
@@ -3,6 +3,7 @@ from typing import Any, Dict, Iterator, List, Optional
from langchain_core.callbacks.manager import CallbackManagerForLLMRun
from langchain_core.language_models.llms import LLM
from langchain_core.outputs import GenerationChunk
from langchain_core.runnables import ConfigurableField, Runnable
class CustomLLM(LLM):
@@ -15,11 +16,11 @@ class CustomLLM(LLM):
n: int = 5
def _call(
self,
prompt: str,
stop: Optional[List[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any,
self,
prompt: str,
stop: Optional[List[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> str:
"""Run the LLM on the given input.
@@ -40,11 +41,11 @@ class CustomLLM(LLM):
return prompt[: self.n]
def _stream(
self,
prompt: str,
stop: Optional[List[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any,
self,
prompt: str,
stop: Optional[List[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> Iterator[GenerationChunk]:
"""Stream the LLM on the given prompt.
@@ -88,3 +89,12 @@ class CustomLLM(LLM):
def _llm_type(self) -> str:
"""Get the type of language model used by this chat model. Used for logging purposes only."""
return "custom"
def with_configurable_fields(self) -> Runnable:
"""Expose fields you want to be configurable in the playground. We will automatically expose these to the
playground. If you don't want to expose any fields, you can remove this method."""
return self.configurable_fields(n=ConfigurableField(
id="n",
name="Num Characters",
description="Number of characters to return from the input prompt.",
))
+4 -1
View File
@@ -5,7 +5,10 @@ from langserve import add_routes
app = FastAPI()
add_routes(app, CustomChatModel(), path="/chat")
configurable_chat_model = CustomChatModel().with_configurable_fields() if hasattr(CustomChatModel, 'with_configurable_fields') else CustomChatModel()
add_routes(app, configurable_chat_model, path="/chat")
configurable_llm = CustomLLM().with_configurable_fields() if hasattr(CustomLLM, 'with_configurable_fields') else CustomLLM()
add_routes(app, CustomLLM())
if __name__ == "__main__":