From 3432954a54d5017571404b199c664a1aa1d7aa28 Mon Sep 17 00:00:00 2001 From: isaac hershenson Date: Thu, 8 Aug 2024 15:49:04 -0700 Subject: [PATCH] refactoring --- agent.py | 102 ------------------ langgraph.json | 4 +- my_agent/__init__.py | 0 my_agent/agent.py | 52 +++++++++ requirements.txt => my_agent/requirements.txt | 0 my_agent/utils/__init__.py | 0 my_agent/utils/nodes.py | 45 ++++++++ my_agent/utils/state.py | 6 ++ my_agent/utils/tools.py | 3 + 9 files changed, 108 insertions(+), 104 deletions(-) delete mode 100644 agent.py create mode 100644 my_agent/__init__.py create mode 100644 my_agent/agent.py rename requirements.txt => my_agent/requirements.txt (100%) create mode 100644 my_agent/utils/__init__.py create mode 100644 my_agent/utils/nodes.py create mode 100644 my_agent/utils/state.py create mode 100644 my_agent/utils/tools.py diff --git a/agent.py b/agent.py deleted file mode 100644 index b7ce159..0000000 --- a/agent.py +++ /dev/null @@ -1,102 +0,0 @@ -from typing import TypedDict, Annotated, Sequence, Literal - -from functools import lru_cache -from langchain_core.messages import BaseMessage -from langchain_anthropic import ChatAnthropic -from langchain_openai import ChatOpenAI -from langchain_community.tools.tavily_search import TavilySearchResults -from langgraph.prebuilt import ToolNode -from langgraph.graph import StateGraph, END, add_messages - -tools = [TavilySearchResults(max_results=1)] - -@lru_cache(maxsize=4) -def _get_model(model_name: str): - if model_name == "openai": - model = ChatOpenAI(temperature=0, model_name="gpt-4o") - elif model_name == "anthropic": - model = ChatAnthropic(temperature=0, model_name="claude-3-sonnet-20240229") - else: - raise ValueError(f"Unsupported model type: {model_name}") - - model = model.bind_tools(tools) - return model - - -class AgentState(TypedDict): - messages: Annotated[Sequence[BaseMessage], add_messages] - - -# Define the function that determines whether to continue or not -def should_continue(state): - messages = state["messages"] - last_message = messages[-1] - # If there are no tool calls, then we finish - if not last_message.tool_calls: - return "end" - # Otherwise if there is, we continue - else: - return "continue" - - -system_prompt = """Be a helpful assistant""" - -# Define the function that calls the model -def call_model(state, config): - messages = state["messages"] - messages = [{"role": "system", "content": system_prompt}] + messages - model_name = config.get('configurable', {}).get("model_name", "anthropic") - model = _get_model(model_name) - response = model.invoke(messages) - # We return a list, because this will get added to the existing list - return {"messages": [response]} - - -# Define the function to execute tools -tool_node = ToolNode(tools) - -# Define the config -class GraphConfig(TypedDict): - model_name: Literal["anthropic", "openai"] - - -# Define a new graph -workflow = StateGraph(AgentState, config_schema=GraphConfig) - -# Define the two nodes we will cycle between -workflow.add_node("agent", call_model) -workflow.add_node("action", tool_node) - -# Set the entrypoint as `agent` -# This means that this node is the first one called -workflow.set_entry_point("agent") - -# We now add a conditional edge -workflow.add_conditional_edges( - # First, we define the start node. We use `agent`. - # This means these are the edges taken after the `agent` node is called. - "agent", - # Next, we pass in the function that will determine which node is called next. - should_continue, - # Finally we pass in a mapping. - # The keys are strings, and the values are other nodes. - # END is a special node marking that the graph should finish. - # What will happen is we will call `should_continue`, and then the output of that - # will be matched against the keys in this mapping. - # Based on which one it matches, that node will then be called. - { - # If `tools`, then we call the tool node. - "continue": "action", - # Otherwise we finish. - "end": END, - }, -) - -# We now add a normal edge from `tools` to `agent`. -# This means that after `tools` is called, `agent` node is called next. -workflow.add_edge("action", "agent") - -# Finally, we compile it! -# This compiles it into a LangChain Runnable, -# meaning you can use it as you would any other runnable -graph = workflow.compile() diff --git a/langgraph.json b/langgraph.json index 6755a10..2d7c855 100644 --- a/langgraph.json +++ b/langgraph.json @@ -1,7 +1,7 @@ { - "dependencies": ["."], + "dependencies": ["./my_agent"], "graphs": { - "agent": "./agent.py:graph" + "agent": "./my_agent/agent.py:graph" }, "env": ".env" } diff --git a/my_agent/__init__.py b/my_agent/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/my_agent/agent.py b/my_agent/agent.py new file mode 100644 index 0000000..e48c215 --- /dev/null +++ b/my_agent/agent.py @@ -0,0 +1,52 @@ +from typing import TypedDict, Literal + +from langgraph.graph import StateGraph, END +from my_agent.utils.nodes import call_model, should_continue, tool_node +from my_agent.utils.state import AgentState + + +# Define the config +class GraphConfig(TypedDict): + model_name: Literal["anthropic", "openai"] + + +# Define a new graph +workflow = StateGraph(AgentState, config_schema=GraphConfig) + +# Define the two nodes we will cycle between +workflow.add_node("agent", call_model) +workflow.add_node("action", tool_node) + +# Set the entrypoint as `agent` +# This means that this node is the first one called +workflow.set_entry_point("agent") + +# We now add a conditional edge +workflow.add_conditional_edges( + # First, we define the start node. We use `agent`. + # This means these are the edges taken after the `agent` node is called. + "agent", + # Next, we pass in the function that will determine which node is called next. + should_continue, + # Finally we pass in a mapping. + # The keys are strings, and the values are other nodes. + # END is a special node marking that the graph should finish. + # What will happen is we will call `should_continue`, and then the output of that + # will be matched against the keys in this mapping. + # Based on which one it matches, that node will then be called. + { + # If `tools`, then we call the tool node. + "continue": "action", + # Otherwise we finish. + "end": END, + }, +) + +# We now add a normal edge from `tools` to `agent`. +# This means that after `tools` is called, `agent` node is called next. +workflow.add_edge("action", "agent") + +# Finally, we compile it! +# This compiles it into a LangChain Runnable, +# meaning you can use it as you would any other runnable +graph = workflow.compile() diff --git a/requirements.txt b/my_agent/requirements.txt similarity index 100% rename from requirements.txt rename to my_agent/requirements.txt diff --git a/my_agent/utils/__init__.py b/my_agent/utils/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/my_agent/utils/nodes.py b/my_agent/utils/nodes.py new file mode 100644 index 0000000..c44ba6e --- /dev/null +++ b/my_agent/utils/nodes.py @@ -0,0 +1,45 @@ +from functools import lru_cache +from langchain_anthropic import ChatAnthropic +from langchain_openai import ChatOpenAI +from my_agent.utils.tools import tools +from langgraph.prebuilt import ToolNode + + +@lru_cache(maxsize=4) +def _get_model(model_name: str): + if model_name == "openai": + model = ChatOpenAI(temperature=0, model_name="gpt-4o") + elif model_name == "anthropic": + model = ChatAnthropic(temperature=0, model_name="claude-3-sonnet-20240229") + else: + raise ValueError(f"Unsupported model type: {model_name}") + + model = model.bind_tools(tools) + return model + +# Define the function that determines whether to continue or not +def should_continue(state): + messages = state["messages"] + last_message = messages[-1] + # If there are no tool calls, then we finish + if not last_message.tool_calls: + return "end" + # Otherwise if there is, we continue + else: + return "continue" + + +system_prompt = """Be a helpful assistant""" + +# Define the function that calls the model +def call_model(state, config): + messages = state["messages"] + messages = [{"role": "system", "content": system_prompt}] + messages + model_name = config.get('configurable', {}).get("model_name", "anthropic") + model = _get_model(model_name) + response = model.invoke(messages) + # We return a list, because this will get added to the existing list + return {"messages": [response]} + +# Define the function to execute tools +tool_node = ToolNode(tools) \ No newline at end of file diff --git a/my_agent/utils/state.py b/my_agent/utils/state.py new file mode 100644 index 0000000..7d21270 --- /dev/null +++ b/my_agent/utils/state.py @@ -0,0 +1,6 @@ +from langgraph.graph import add_messages +from langchain_core.messages import BaseMessage +from typing import TypedDict, Annotated, Sequence + +class AgentState(TypedDict): + messages: Annotated[Sequence[BaseMessage], add_messages] diff --git a/my_agent/utils/tools.py b/my_agent/utils/tools.py new file mode 100644 index 0000000..5850682 --- /dev/null +++ b/my_agent/utils/tools.py @@ -0,0 +1,3 @@ +from langchain_community.tools.tavily_search import TavilySearchResults + +tools = [TavilySearchResults(max_results=1)] \ No newline at end of file