add auth check to tools

This commit is contained in:
Logan Markewich
2024-10-21 17:46:33 -06:00
parent 4aeeb3c24c
commit 687d76f7e2
+41 -22
View File
@@ -4,7 +4,6 @@ from llama_index.core.memory import ChatMemoryBuffer
from llama_index.core.tools import BaseTool
from llama_index.core.workflow import Context
from llama_index.llms.openai import OpenAI
from llama_index.utils.workflow import draw_all_possible_flows
from workflow import (
AgentConfig,
@@ -45,6 +44,12 @@ def get_stock_lookup_tools() -> list[BaseTool]:
def get_authentication_tools() -> list[BaseTool]:
async def is_authenticated(ctx: Context) -> bool:
"""Checks if the user has a session token."""
ctx.write_event_to_stream(ProgressEvent(msg="Checking if authenticated"))
user_state = await ctx.get("user_state")
return user_state["session_token"] is not None
async def store_username(ctx: Context, username: str) -> None:
"""Adds the username to the user state."""
ctx.write_event_to_stream(ProgressEvent(msg="Recording username"))
@@ -66,12 +71,6 @@ def get_authentication_tools() -> list[BaseTool]:
return f"Logged in user {username} with session token {session_token}. They have an account with id {user_state['account_id']} and a balance of ${user_state['account_balance']}."
async def is_authenticated(ctx: Context) -> bool:
"""Checks if the user has a session token."""
ctx.write_event_to_stream(ProgressEvent(msg="Checking if authenticated"))
user_state = await ctx.get("user_state")
return user_state["session_token"] is not None
return [
FunctionToolWithContext.from_defaults(async_fn=store_username),
FunctionToolWithContext.from_defaults(async_fn=login),
@@ -80,8 +79,18 @@ def get_authentication_tools() -> list[BaseTool]:
def get_account_balance_tools() -> list[BaseTool]:
async def is_authenticated(ctx: Context) -> bool:
"""Checks if the user has a session token."""
ctx.write_event_to_stream(ProgressEvent(msg="Checking if authenticated"))
user_state = await ctx.get("user_state")
return user_state["session_token"] is not None
async def get_account_id(ctx: Context, account_name: str) -> str:
"""Useful for looking up an account ID."""
is_auth = await is_authenticated(ctx)
if not is_auth:
raise ValueError("User is not authenticated!")
ctx.write_event_to_stream(
ProgressEvent(msg=f"Looking up account ID for {account_name}")
)
@@ -92,6 +101,10 @@ def get_account_balance_tools() -> list[BaseTool]:
async def get_account_balance(ctx: Context, account_id: str) -> str:
"""Useful for looking up an account balance."""
is_auth = await is_authenticated(ctx)
if not is_auth:
raise ValueError("User is not authenticated!")
ctx.write_event_to_stream(
ProgressEvent(msg=f"Looking up account balance for {account_id}")
)
@@ -100,12 +113,6 @@ def get_account_balance_tools() -> list[BaseTool]:
return f"Account {account_id} has a balance of ${account_balance}"
async def is_authenticated(ctx: Context) -> bool:
"""Checks if the user has a session token."""
ctx.write_event_to_stream(ProgressEvent(msg="Checking if authenticated"))
user_state = await ctx.get("user_state")
return user_state["session_token"] is not None
return [
FunctionToolWithContext.from_defaults(async_fn=get_account_id),
FunctionToolWithContext.from_defaults(async_fn=get_account_balance),
@@ -114,10 +121,20 @@ def get_account_balance_tools() -> list[BaseTool]:
def get_transfer_money_tools() -> list[BaseTool]:
def transfer_money(
async def is_authenticated(ctx: Context) -> bool:
"""Checks if the user has a session token."""
ctx.write_event_to_stream(ProgressEvent(msg="Checking if authenticated"))
user_state = await ctx.get("user_state")
return user_state["session_token"] is not None
async def transfer_money(
ctx: Context, from_account_id: str, to_account_id: str, amount: int
) -> str:
"""Useful for transferring money between accounts."""
is_auth = await is_authenticated(ctx)
if not is_auth:
raise ValueError("User is not authenticated!")
ctx.write_event_to_stream(
ProgressEvent(
msg=f"Transferring {amount} from {from_account_id} to account {to_account_id}"
@@ -127,6 +144,10 @@ def get_transfer_money_tools() -> list[BaseTool]:
async def balance_sufficient(ctx: Context, account_id: str, amount: int) -> bool:
"""Useful for checking if an account has enough money to transfer."""
is_auth = await is_authenticated(ctx)
if not is_auth:
raise ValueError("User is not authenticated!")
ctx.write_event_to_stream(
ProgressEvent(msg="Checking if balance is sufficient")
)
@@ -135,6 +156,10 @@ def get_transfer_money_tools() -> list[BaseTool]:
async def has_balance(ctx: Context) -> bool:
"""Useful for checking if an account has a balance."""
is_auth = await is_authenticated(ctx)
if not is_auth:
raise ValueError("User is not authenticated!")
ctx.write_event_to_stream(
ProgressEvent(msg="Checking if account has a balance")
)
@@ -144,14 +169,8 @@ def get_transfer_money_tools() -> list[BaseTool]:
and user_state["account_balance"] > 0
)
async def is_authenticated(ctx: Context) -> bool:
"""Checks if the user has a session token."""
ctx.write_event_to_stream(ProgressEvent(msg="Checking if authenticated"))
user_state = await ctx.get("user_state")
return user_state["session_token"] is not None
return [
FunctionToolWithContext.from_defaults(fn=transfer_money),
FunctionToolWithContext.from_defaults(async_fn=transfer_money),
FunctionToolWithContext.from_defaults(async_fn=balance_sufficient),
FunctionToolWithContext.from_defaults(async_fn=has_balance),
FunctionToolWithContext.from_defaults(async_fn=is_authenticated),
@@ -222,7 +241,7 @@ async def main():
workflow = ConciergeAgent(timeout=None)
# draw a diagram of the workflow
draw_all_possible_flows(workflow, filename="workflow.html")
# draw_all_possible_flows(workflow, filename="workflow.html")
handler = workflow.run(
user_msg="Hello!",