mirror of
https://github.com/run-llama/multi-agent-concierge.git
synced 2026-06-30 21:07:58 -04:00
add auth check to tools
This commit is contained in:
@@ -4,7 +4,6 @@ from llama_index.core.memory import ChatMemoryBuffer
|
||||
from llama_index.core.tools import BaseTool
|
||||
from llama_index.core.workflow import Context
|
||||
from llama_index.llms.openai import OpenAI
|
||||
from llama_index.utils.workflow import draw_all_possible_flows
|
||||
|
||||
from workflow import (
|
||||
AgentConfig,
|
||||
@@ -45,6 +44,12 @@ def get_stock_lookup_tools() -> list[BaseTool]:
|
||||
|
||||
|
||||
def get_authentication_tools() -> list[BaseTool]:
|
||||
async def is_authenticated(ctx: Context) -> bool:
|
||||
"""Checks if the user has a session token."""
|
||||
ctx.write_event_to_stream(ProgressEvent(msg="Checking if authenticated"))
|
||||
user_state = await ctx.get("user_state")
|
||||
return user_state["session_token"] is not None
|
||||
|
||||
async def store_username(ctx: Context, username: str) -> None:
|
||||
"""Adds the username to the user state."""
|
||||
ctx.write_event_to_stream(ProgressEvent(msg="Recording username"))
|
||||
@@ -66,12 +71,6 @@ def get_authentication_tools() -> list[BaseTool]:
|
||||
|
||||
return f"Logged in user {username} with session token {session_token}. They have an account with id {user_state['account_id']} and a balance of ${user_state['account_balance']}."
|
||||
|
||||
async def is_authenticated(ctx: Context) -> bool:
|
||||
"""Checks if the user has a session token."""
|
||||
ctx.write_event_to_stream(ProgressEvent(msg="Checking if authenticated"))
|
||||
user_state = await ctx.get("user_state")
|
||||
return user_state["session_token"] is not None
|
||||
|
||||
return [
|
||||
FunctionToolWithContext.from_defaults(async_fn=store_username),
|
||||
FunctionToolWithContext.from_defaults(async_fn=login),
|
||||
@@ -80,8 +79,18 @@ def get_authentication_tools() -> list[BaseTool]:
|
||||
|
||||
|
||||
def get_account_balance_tools() -> list[BaseTool]:
|
||||
async def is_authenticated(ctx: Context) -> bool:
|
||||
"""Checks if the user has a session token."""
|
||||
ctx.write_event_to_stream(ProgressEvent(msg="Checking if authenticated"))
|
||||
user_state = await ctx.get("user_state")
|
||||
return user_state["session_token"] is not None
|
||||
|
||||
async def get_account_id(ctx: Context, account_name: str) -> str:
|
||||
"""Useful for looking up an account ID."""
|
||||
is_auth = await is_authenticated(ctx)
|
||||
if not is_auth:
|
||||
raise ValueError("User is not authenticated!")
|
||||
|
||||
ctx.write_event_to_stream(
|
||||
ProgressEvent(msg=f"Looking up account ID for {account_name}")
|
||||
)
|
||||
@@ -92,6 +101,10 @@ def get_account_balance_tools() -> list[BaseTool]:
|
||||
|
||||
async def get_account_balance(ctx: Context, account_id: str) -> str:
|
||||
"""Useful for looking up an account balance."""
|
||||
is_auth = await is_authenticated(ctx)
|
||||
if not is_auth:
|
||||
raise ValueError("User is not authenticated!")
|
||||
|
||||
ctx.write_event_to_stream(
|
||||
ProgressEvent(msg=f"Looking up account balance for {account_id}")
|
||||
)
|
||||
@@ -100,12 +113,6 @@ def get_account_balance_tools() -> list[BaseTool]:
|
||||
|
||||
return f"Account {account_id} has a balance of ${account_balance}"
|
||||
|
||||
async def is_authenticated(ctx: Context) -> bool:
|
||||
"""Checks if the user has a session token."""
|
||||
ctx.write_event_to_stream(ProgressEvent(msg="Checking if authenticated"))
|
||||
user_state = await ctx.get("user_state")
|
||||
return user_state["session_token"] is not None
|
||||
|
||||
return [
|
||||
FunctionToolWithContext.from_defaults(async_fn=get_account_id),
|
||||
FunctionToolWithContext.from_defaults(async_fn=get_account_balance),
|
||||
@@ -114,10 +121,20 @@ def get_account_balance_tools() -> list[BaseTool]:
|
||||
|
||||
|
||||
def get_transfer_money_tools() -> list[BaseTool]:
|
||||
def transfer_money(
|
||||
async def is_authenticated(ctx: Context) -> bool:
|
||||
"""Checks if the user has a session token."""
|
||||
ctx.write_event_to_stream(ProgressEvent(msg="Checking if authenticated"))
|
||||
user_state = await ctx.get("user_state")
|
||||
return user_state["session_token"] is not None
|
||||
|
||||
async def transfer_money(
|
||||
ctx: Context, from_account_id: str, to_account_id: str, amount: int
|
||||
) -> str:
|
||||
"""Useful for transferring money between accounts."""
|
||||
is_auth = await is_authenticated(ctx)
|
||||
if not is_auth:
|
||||
raise ValueError("User is not authenticated!")
|
||||
|
||||
ctx.write_event_to_stream(
|
||||
ProgressEvent(
|
||||
msg=f"Transferring {amount} from {from_account_id} to account {to_account_id}"
|
||||
@@ -127,6 +144,10 @@ def get_transfer_money_tools() -> list[BaseTool]:
|
||||
|
||||
async def balance_sufficient(ctx: Context, account_id: str, amount: int) -> bool:
|
||||
"""Useful for checking if an account has enough money to transfer."""
|
||||
is_auth = await is_authenticated(ctx)
|
||||
if not is_auth:
|
||||
raise ValueError("User is not authenticated!")
|
||||
|
||||
ctx.write_event_to_stream(
|
||||
ProgressEvent(msg="Checking if balance is sufficient")
|
||||
)
|
||||
@@ -135,6 +156,10 @@ def get_transfer_money_tools() -> list[BaseTool]:
|
||||
|
||||
async def has_balance(ctx: Context) -> bool:
|
||||
"""Useful for checking if an account has a balance."""
|
||||
is_auth = await is_authenticated(ctx)
|
||||
if not is_auth:
|
||||
raise ValueError("User is not authenticated!")
|
||||
|
||||
ctx.write_event_to_stream(
|
||||
ProgressEvent(msg="Checking if account has a balance")
|
||||
)
|
||||
@@ -144,14 +169,8 @@ def get_transfer_money_tools() -> list[BaseTool]:
|
||||
and user_state["account_balance"] > 0
|
||||
)
|
||||
|
||||
async def is_authenticated(ctx: Context) -> bool:
|
||||
"""Checks if the user has a session token."""
|
||||
ctx.write_event_to_stream(ProgressEvent(msg="Checking if authenticated"))
|
||||
user_state = await ctx.get("user_state")
|
||||
return user_state["session_token"] is not None
|
||||
|
||||
return [
|
||||
FunctionToolWithContext.from_defaults(fn=transfer_money),
|
||||
FunctionToolWithContext.from_defaults(async_fn=transfer_money),
|
||||
FunctionToolWithContext.from_defaults(async_fn=balance_sufficient),
|
||||
FunctionToolWithContext.from_defaults(async_fn=has_balance),
|
||||
FunctionToolWithContext.from_defaults(async_fn=is_authenticated),
|
||||
@@ -222,7 +241,7 @@ async def main():
|
||||
workflow = ConciergeAgent(timeout=None)
|
||||
|
||||
# draw a diagram of the workflow
|
||||
draw_all_possible_flows(workflow, filename="workflow.html")
|
||||
# draw_all_possible_flows(workflow, filename="workflow.html")
|
||||
|
||||
handler = workflow.run(
|
||||
user_msg="Hello!",
|
||||
|
||||
Reference in New Issue
Block a user