Logan/update llama index (#29)

* bump llama agi to use llamaindex v0.6.13

* linting

* bump llama-agi to v0.2.0
This commit is contained in:
Logan
2023-05-29 18:07:45 -06:00
committed by GitHub
parent a288fca2a3
commit 3364c9eae1
18 changed files with 123 additions and 112 deletions
+6 -11
View File
@@ -5,7 +5,7 @@ The goal of this is to simulate conversation between two agents.
"""
from llama_index import (
GPTSimpleVectorIndex, GPTListIndex, Document, ServiceContext
GPTVectorStoreIndex, GPTListIndex, Document, ServiceContext
)
from llama_index.indices.base import BaseGPTIndex
from llama_index.data_structs import Node
@@ -68,7 +68,7 @@ class ConvoAgent(BaseModel):
) -> "ConvoAgent":
name = name or "Agent"
st_memory = st_memory or deque()
lt_memory = lt_memory or GPTSimpleVectorIndex([])
lt_memory = lt_memory or GPTVectorStoreIndex([])
service_context = service_context or ServiceContext.from_defaults()
return cls(
name=name,
@@ -94,12 +94,9 @@ class ConvoAgent(BaseModel):
prev_message = self.st_memory[-1]
st_memory_text = "\n".join([l for l in self.st_memory])
summary_response = self.lt_memory.query(
summary_response = self.lt_memory.as_query_engine(**self.lt_memory_query_kwargs).query(
f"Tell me a bit more about any context that's relevant "
f"to the current messages: \n{st_memory_text}",
# similarity_top_k=10,
response_mode="compact",
**self.lt_memory_query_kwargs
f"to the current messages: \n{st_memory_text}"
)
# add both the long-term memory summary and the short-term conversation
@@ -114,9 +111,7 @@ class ConvoAgent(BaseModel):
)
qa_prompt = QuestionAnswerPrompt(full_qa_prompt_tmpl)
response = list_builder.query(
"Generate the next message in the conversation.",
text_qa_template=qa_prompt,
response_mode="compact"
response = list_builder.as_query_engine(text_qa_template=qa_prompt).query(
"Generate the next message in the conversation."
)
return str(response)
+1 -1
View File
@@ -1 +1 @@
llama-index==0.5.22
llama-index==0.6.13
@@ -15,22 +15,23 @@ class SimpleExecutionAgent(BaseExecutionAgent):
This agent uses an LLM to execute a basic action without tools.
The LlamaAgentPrompts.execution_prompt defines how this execution agent
behaves.
behaves.
Usually, this is used for simple tasks, like generating the initial list of tasks.
The execution template kwargs are automatically extracted and expected to be
The execution template kwargs are automatically extracted and expected to be
specified in execute_task().
Args:
llm (Union[BaseLLM, BaseChatModel]): The langchain LLM class to use.
model_name: (str): The name of the OpenAI model to use, if the LLM is
model_name: (str): The name of the OpenAI model to use, if the LLM is
not provided.
max_tokens: (int): The maximum number of tokens the LLM can generate.
prompts: (LlamaAgentPrompts): The prompt templates used during execution.
The only prompt used byt the SimpleExecutionAgent is
prompts: (LlamaAgentPrompts): The prompt templates used during execution.
The only prompt used byt the SimpleExecutionAgent is
LlamaAgentPrompts.execution_prompt.
"""
def __init__(
self,
llm: Optional[Union[BaseLLM, BaseChatModel]] = None,
@@ -12,30 +12,31 @@ from llama_agi.execution_agent.base import BaseExecutionAgent, LlamaAgentPrompts
class ToolExecutionAgent(BaseExecutionAgent):
"""Tool Execution Agent
This agent is a wrapper around the zero-shot agent from Langchain. Using
a set of tools, the agent is expected to carry out and complete some task
a set of tools, the agent is expected to carry out and complete some task
that will help achieve an overall objective.
The agents overall behavior is controlled by the LlamaAgentPrompts.agent_prefix
The agents overall behavior is controlled by the LlamaAgentPrompts.agent_prefix
and LlamaAgentPrompts.agent_suffix prompt templates.
The execution template kwargs are automatically extracted and expected to be
specified in execute_task().
The execution template kwargs are automatically extracted and expected to be
specified in execute_task().
execute_task() also returns the intermediate steps, for additional debugging and is
used for the streamlit example.
used for the streamlit example.
Args:
llm (Union[BaseLLM, BaseChatModel]): The langchain LLM class to use.
model_name: (str): The name of the OpenAI model to use, if the LLM is
model_name: (str): The name of the OpenAI model to use, if the LLM is
not provided.
max_tokens: (int): The maximum number of tokens the LLM can generate.
prompts: (LlamaAgentPrompts): The prompt templates used during execution.
The Tool Execution Agent uses LlamaAgentPrompts.agent_prefix and
prompts: (LlamaAgentPrompts): The prompt templates used during execution.
The Tool Execution Agent uses LlamaAgentPrompts.agent_prefix and
LlamaAgentPrompts.agent_suffix.
tools: (List[Tool]): The list of langchain tools for the execution agent to use.
"""
def __init__(
self,
llm: Optional[Union[BaseLLM, BaseChatModel]] = None,
@@ -71,7 +72,10 @@ class ToolExecutionAgent(BaseExecutionAgent):
llm_chain=self._llm_chain, tools=self.tools, verbose=True
)
self._execution_chain = AgentExecutor.from_agent_and_tools(
agent=self._agent, tools=self.tools, verbose=True, return_intermediate_steps=True
agent=self._agent,
tools=self.tools,
verbose=True,
return_intermediate_steps=True,
)
def execute_task(self, **prompt_kwargs: Any) -> Dict[str, str]:
@@ -1,7 +1,4 @@
from .SimpleExecutionAgent import SimpleExecutionAgent
from .ToolExecutionAgent import ToolExecutionAgent
__all__ = [
SimpleExecutionAgent,
ToolExecutionAgent
]
__all__ = [SimpleExecutionAgent, ToolExecutionAgent]
+5 -6
View File
@@ -7,9 +7,7 @@ from langchain.llms import OpenAI, BaseLLM
from langchain.chat_models.base import BaseChatModel
from langchain.chat_models import ChatOpenAI
from llama_agi.default_task_prompts import (
LC_PREFIX, LC_SUFFIX, LC_EXECUTION_PROMPT
)
from llama_agi.default_task_prompts import LC_PREFIX, LC_SUFFIX, LC_EXECUTION_PROMPT
@dataclass
@@ -21,16 +19,17 @@ class LlamaAgentPrompts:
class BaseExecutionAgent:
"""Base Execution Agent
Args:
llm (Union[BaseLLM, BaseChatModel]): The langchain LLM class to use.
model_name: (str): The name of the OpenAI model to use, if the LLM is
model_name: (str): The name of the OpenAI model to use, if the LLM is
not provided.
max_tokens: (int): The maximum number of tokens the LLM can generate.
prompts: (LlamaAgentPrompts): The prompt templates used during execution.
tools: (List[Tool]): The list of langchain tools for the execution
tools: (List[Tool]): The list of langchain tools for the execution
agent to use.
"""
def __init__(
self,
llm: Optional[Union[BaseLLM, BaseChatModel]] = None,
+4 -2
View File
@@ -35,7 +35,9 @@ class AutoAGIRunner(BaseAGIRunner):
completed_tasks_summary=initial_completed_tasks_summary,
)
initial_task_list = self.task_manager.parse_task_list(initial_task_list_result['output'])
initial_task_list = self.task_manager.parse_task_list(
initial_task_list_result["output"]
)
# add tasks to the task manager
self.task_manager.add_new_tasks(initial_task_list)
@@ -53,7 +55,7 @@ class AutoAGIRunner(BaseAGIRunner):
objective=objective,
cur_task=cur_task,
completed_tasks_summary=completed_tasks_summary,
)['output']
)["output"]
# store the task and result as completed
self.task_manager.add_completed_task(cur_task, result)
@@ -27,14 +27,14 @@ class AutoStreamlitAGIRunner(BaseAGIRunner):
initial_task: str,
sleep_time: int,
initial_task_list: Optional[List[str]] = None,
max_iterations: Optional[int] = None
max_iterations: Optional[int] = None,
) -> None:
run_initial_task = False
if 'logs' not in st.session_state:
st.session_state['logs'] = []
st.session_state['state_str'] = "No state yet!"
st.session_state['tasks_summary'] = ""
if "logs" not in st.session_state:
st.session_state["logs"] = []
st.session_state["state_str"] = "No state yet!"
st.session_state["tasks_summary"] = ""
run_initial_task = True
logs_col, state_col = st.columns(2)
@@ -42,12 +42,12 @@ class AutoStreamlitAGIRunner(BaseAGIRunner):
with logs_col:
st.subheader("Execution Log")
st_logs = st.empty()
st_logs.write(st.session_state['logs'])
st_logs.write(st.session_state["logs"])
with state_col:
st.subheader("AGI State")
st_state = st.empty()
st_state.write(st.session_state['state_str'])
st_state.write(st.session_state["state_str"])
if run_initial_task:
# get initial list of tasks
@@ -71,7 +71,9 @@ class AutoStreamlitAGIRunner(BaseAGIRunner):
completed_tasks_summary=initial_completed_tasks_summary,
)
initial_task_list = self.task_manager.parse_task_list(initial_task_list_result['output'])
initial_task_list = self.task_manager.parse_task_list(
initial_task_list_result["output"]
)
# add tasks to the task manager
self.task_manager.add_new_tasks(initial_task_list)
@@ -80,12 +82,18 @@ class AutoStreamlitAGIRunner(BaseAGIRunner):
self.task_manager.prioritize_tasks(objective)
tasks_summary = initial_completed_tasks_summary
st.session_state['tasks_summary'] = tasks_summary
st.session_state["tasks_summary"] = tasks_summary
# update streamlit state
st.session_state['state_str'] = log_current_status(initial_task, initial_task_list_result['output'], tasks_summary, self.task_manager.current_tasks, return_str=True)
if st.session_state['state_str']:
st_state.markdown(st.session_state['state_str'].replace("\n", "\n\n"))
st.session_state["state_str"] = log_current_status(
initial_task,
initial_task_list_result["output"],
tasks_summary,
self.task_manager.current_tasks,
return_str=True,
)
if st.session_state["state_str"]:
st_state.markdown(st.session_state["state_str"].replace("\n", "\n\n"))
for _ in range(0, max_iterations):
# Get the next task
@@ -95,14 +103,16 @@ class AutoStreamlitAGIRunner(BaseAGIRunner):
result_dict = self.execution_agent.execute_task(
objective=objective,
cur_task=cur_task,
completed_tasks_summary=st.session_state['tasks_summary'],
completed_tasks_summary=st.session_state["tasks_summary"],
)
result = result_dict['output']
# update logs
log = make_intermediate_steps_pretty(json.dumps(result_dict['intermediate_steps'])) + [result]
st.session_state['logs'].append(log)
st_logs.write(st.session_state['logs'])
result = result_dict["output"]
# update logs
log = make_intermediate_steps_pretty(
json.dumps(result_dict["intermediate_steps"])
) + [result]
st.session_state["logs"].append(log)
st_logs.write(st.session_state["logs"])
# store the task and result as completed
self.task_manager.add_completed_task(cur_task, result)
@@ -112,18 +122,18 @@ class AutoStreamlitAGIRunner(BaseAGIRunner):
# Summarize completed tasks
completed_tasks_summary = self.task_manager.get_completed_tasks_summary()
st.session_state['tasks_summary'] = completed_tasks_summary
st.session_state["tasks_summary"] = completed_tasks_summary
# log state of AGI to streamlit
st.session_state['state_str'] = log_current_status(
st.session_state["state_str"] = log_current_status(
cur_task,
result,
completed_tasks_summary,
self.task_manager.current_tasks,
return_str=True
return_str=True,
)
if st.session_state['state_str'] is not None:
st_state.markdown(st.session_state['state_str'].replace("\n", "\n\n"))
if st.session_state["state_str"] is not None:
st_state.markdown(st.session_state["state_str"].replace("\n", "\n\n"))
# Quit the loop?
if len(self.task_manager.current_tasks) == 0:
@@ -132,4 +142,3 @@ class AutoStreamlitAGIRunner(BaseAGIRunner):
# wait a bit to let you read what's happening
time.sleep(sleep_time)
+1 -4
View File
@@ -1,7 +1,4 @@
from .AutoAGIRunner import AutoAGIRunner
from .AutoStreamlitAGIRunner import AutoStreamlitAGIRunner
__all__ = [
AutoAGIRunner,
AutoStreamlitAGIRunner
]
__all__ = [AutoAGIRunner, AutoStreamlitAGIRunner]
@@ -12,7 +12,7 @@ from llama_agi.default_task_prompts import NO_COMPLETED_TASKS_SUMMARY
class LlamaTaskManager(BaseTaskManager):
"""Llama Task Manager
This task manager uses LlamaIndex to create and prioritize tasks. Using
the LlamaTaskPrompts, the task manager will create tasks that work
towards achieving an overall objective.
@@ -24,12 +24,13 @@ class LlamaTaskManager(BaseTaskManager):
Args:
tasks (List[str]): The initial list of tasks to complete.
prompts: (LlamaTaskPrompts): The prompts to control the task creation
prompts: (LlamaTaskPrompts): The prompts to control the task creation
and prioritization.
tasK_service_context (ServiceContext): The LlamaIndex service context to use
tasK_service_context (ServiceContext): The LlamaIndex service context to use
for task creation and prioritization.
"""
def __init__(
self,
tasks: List[str],
@@ -51,7 +52,9 @@ class LlamaTaskManager(BaseTaskManager):
self.task_create_refine_template = self.prompts.task_create_refine_template
self.task_prioritize_qa_template = self.prompts.task_prioritize_qa_template
self.task_prioritize_refine_template = self.prompts.task_prioritize_refine_template
self.task_prioritize_refine_template = (
self.prompts.task_prioritize_refine_template
)
def _get_task_create_templates(
self, prev_task: str, prev_result: str
@@ -104,19 +107,19 @@ class LlamaTaskManager(BaseTaskManager):
"""Generate a summary of completed tasks."""
if len(self.completed_tasks) == 0:
return NO_COMPLETED_TASKS_SUMMARY
summary = self.completed_tasks_index.query(
"Summarize the current completed tasks", response_mode="tree_summarize"
summary = self.completed_tasks_index.as_query_engine(
response_mode="tree_summarize"
).query(
"Summarize the current completed tasks",
)
return str(summary)
def prioritize_tasks(self, objective: str) -> None:
"""Prioritize the current list of incomplete tasks."""
(text_qa_template, refine_template) = self._get_task_prioritize_templates()
prioritized_tasks = self.current_tasks_index.query(
objective,
text_qa_template=text_qa_template,
refine_template=refine_template,
)
prioritized_tasks = self.current_tasks_index.as_query_engine(
text_qa_template=text_qa_template, refine_template=refine_template
).query(objective)
new_tasks = []
for task in str(prioritized_tasks).split("\n"):
@@ -135,11 +138,9 @@ class LlamaTaskManager(BaseTaskManager):
(text_qa_template, refine_template) = self._get_task_create_templates(
prev_task, prev_result
)
task_list_response = self.completed_tasks_index.query(
objective,
text_qa_template=text_qa_template,
refine_template=refine_template,
)
task_list_response = self.completed_tasks_index.as_query_engine(
text_qa_template=text_qa_template, refine_template=refine_template
).query(objective)
new_tasks = self.parse_task_list(str(task_list_response))
self.add_new_tasks(new_tasks)
+1 -1
View File
@@ -2,4 +2,4 @@ from .LlamaTaskManager import LlamaTaskManager
__all__ = [
LlamaTaskManager,
]
]
+4 -3
View File
@@ -22,15 +22,16 @@ class LlamaTaskPrompts:
class BaseTaskManager:
"""Base Task Manager
Args:
tasks (List[str]): The initial list of tasks to complete.
prompts: (LlamaTaskPrompts): The prompts to control the task creation
prompts: (LlamaTaskPrompts): The prompts to control the task creation
and prioritization.
tasK_service_context (ServiceContext): The LlamaIndex service context to use
tasK_service_context (ServiceContext): The LlamaIndex service context to use
for task creation and prioritization.
"""
def __init__(
self,
tasks: List[str],
+3 -1
View File
@@ -17,5 +17,7 @@ def record_note(note: str) -> str:
def search_notes(query_str: str) -> str:
"""Useful for searching through notes that you previously recorded."""
global note_index
response = note_index.query(query_str, similarity_top_k=3, response_mode="compact")
response = note_index.as_query_engine(
similarity_top_k=3,
).query(query_str)
return str(response)
@@ -20,9 +20,7 @@ def search_webpage(prompt: str) -> str:
documents = loader.load_data(urls=[url])
service_context = ServiceContext.from_defaults(chunk_size_limit=512)
index = initialize_search_index(documents, service_context=service_context)
query_result = index.query(
query_str, similarity_top_k=3, response_mode="compact"
)
query_result = index.as_query_engine(similarity_top_k=3).query(query_str)
return str(query_result)
except ValueError as e:
return str(e)
+1 -5
View File
@@ -1,8 +1,4 @@
from .NoteTakingTools import record_note, search_notes
from .WebpageSearchTool import search_webpage
__all__ = [
record_note,
search_notes,
search_webpage
]
__all__ = [record_note, search_notes, search_webpage]
+7 -3
View File
@@ -1,6 +1,6 @@
from typing import Any, List, Optional
from llama_index import GPTSimpleVectorIndex, GPTListIndex, ServiceContext, Document
from llama_index import GPTVectorStoreIndex, GPTListIndex, ServiceContext, Document
from llama_index.indices.base import BaseGPTIndex
@@ -13,13 +13,17 @@ def initialize_task_list_index(
def initialize_search_index(
documents: List[Document], service_context: Optional[ServiceContext] = None
) -> BaseGPTIndex[Any]:
return GPTSimpleVectorIndex.from_documents(
return GPTVectorStoreIndex.from_documents(
documents, service_context=service_context
)
def log_current_status(
cur_task: str, result: str, completed_tasks_summary: str, task_list: List[Document], return_str: bool = False
cur_task: str,
result: str,
completed_tasks_summary: str,
task_list: List[Document],
return_str: bool = False,
) -> Optional[str]:
status_string = f"""
__________________________________
+7 -5
View File
@@ -1,6 +1,6 @@
[tool.poetry]
name = "llama_agi"
version = "0.1.2"
version = "0.2.0"
description = "Building AGI loops using LlamaIndex and Langchain"
authors = []
license = "MIT"
@@ -13,10 +13,12 @@ keywords = ["LLM", "LlamaIndex", "Langchain", "AGI"]
[tool.poetry.dependencies]
python = ">=3.8.1,<4.0"
langchain = "==0.0.141"
llama-index = "==0.5.16"
streamlit = ">=1.21.0"
altair = "==4.2.2"
langchain = "==0.0.154"
llama-index = "==0.6.13"
streamlit = "==1.21.0"
transformers = ">=0.4.29"
google-api-python-client = ">=2.87.0"
[tool.poetry.group.lint.dependencies]
ruff = "^0.0.249"
+5 -2
View File
@@ -1,3 +1,6 @@
langchain==0.0.141
llama-index==0.5.16
altair==4.2.2
google-api-python-client>=2.87.0
langchain==0.0.154
llama-index==0.6.13
streamlit==1.21.0
transformers>=4.29.2