mirror of
https://github.com/langchain-ai/lang-memgpt.git
synced 2026-06-30 22:17:56 -04:00
First Commit
This commit is contained in:
@@ -0,0 +1,8 @@
|
||||
PINECONE_API_KEY=...
|
||||
PINECONE_INDEX_NAME=...
|
||||
PINECONE_NAMESPACE=...
|
||||
ANTHROPIC_API_KEY=...
|
||||
TAVILY_API_KEY=...
|
||||
|
||||
# You can add other keys as appropriate, depending on
|
||||
# the services you are using.
|
||||
+163
@@ -0,0 +1,163 @@
|
||||
# Byte-compiled / optimized / DLL files
|
||||
__pycache__/
|
||||
*.py[cod]
|
||||
*$py.class
|
||||
.DS_Store
|
||||
|
||||
# C extensions
|
||||
*.so
|
||||
|
||||
# Distribution / packaging
|
||||
.Python
|
||||
build/
|
||||
develop-eggs/
|
||||
dist/
|
||||
downloads/
|
||||
eggs/
|
||||
.eggs/
|
||||
lib/
|
||||
lib64/
|
||||
parts/
|
||||
sdist/
|
||||
var/
|
||||
wheels/
|
||||
share/python-wheels/
|
||||
*.egg-info/
|
||||
.installed.cfg
|
||||
*.egg
|
||||
MANIFEST
|
||||
|
||||
# PyInstaller
|
||||
# Usually these files are written by a python script from a template
|
||||
# before PyInstaller builds the exe, so as to inject date/other infos into it.
|
||||
*.manifest
|
||||
*.spec
|
||||
|
||||
# Installer logs
|
||||
pip-log.txt
|
||||
pip-delete-this-directory.txt
|
||||
|
||||
# Unit test / coverage reports
|
||||
htmlcov/
|
||||
.tox/
|
||||
.nox/
|
||||
.coverage
|
||||
.coverage.*
|
||||
.cache
|
||||
nosetests.xml
|
||||
coverage.xml
|
||||
*.cover
|
||||
*.py,cover
|
||||
.hypothesis/
|
||||
.pytest_cache/
|
||||
cover/
|
||||
|
||||
# Translations
|
||||
*.mo
|
||||
*.pot
|
||||
|
||||
# Django stuff:
|
||||
*.log
|
||||
local_settings.py
|
||||
db.sqlite3
|
||||
db.sqlite3-journal
|
||||
|
||||
# Flask stuff:
|
||||
instance/
|
||||
.webassets-cache
|
||||
|
||||
# Scrapy stuff:
|
||||
.scrapy
|
||||
|
||||
# Sphinx documentation
|
||||
docs/_build/
|
||||
|
||||
# PyBuilder
|
||||
.pybuilder/
|
||||
target/
|
||||
|
||||
# Jupyter Notebook
|
||||
.ipynb_checkpoints
|
||||
|
||||
# IPython
|
||||
profile_default/
|
||||
ipython_config.py
|
||||
|
||||
# pyenv
|
||||
# For a library or package, you might want to ignore these files since the code is
|
||||
# intended to run in multiple environments; otherwise, check them in:
|
||||
# .python-version
|
||||
|
||||
# pipenv
|
||||
# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
|
||||
# However, in case of collaboration, if having platform-specific dependencies or dependencies
|
||||
# having no cross-platform support, pipenv may install dependencies that don't work, or not
|
||||
# install all needed dependencies.
|
||||
#Pipfile.lock
|
||||
|
||||
# poetry
|
||||
# Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
|
||||
# This is especially recommended for binary packages to ensure reproducibility, and is more
|
||||
# commonly ignored for libraries.
|
||||
# https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
|
||||
#poetry.lock
|
||||
|
||||
# pdm
|
||||
# Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
|
||||
#pdm.lock
|
||||
# pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
|
||||
# in version control.
|
||||
# https://pdm.fming.dev/latest/usage/project/#working-with-version-control
|
||||
.pdm.toml
|
||||
.pdm-python
|
||||
.pdm-build/
|
||||
|
||||
# PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
|
||||
__pypackages__/
|
||||
|
||||
# Celery stuff
|
||||
celerybeat-schedule
|
||||
celerybeat.pid
|
||||
|
||||
# SageMath parsed files
|
||||
*.sage.py
|
||||
|
||||
# Environments
|
||||
.env
|
||||
.venv
|
||||
env/
|
||||
venv/
|
||||
ENV/
|
||||
env.bak/
|
||||
venv.bak/
|
||||
|
||||
# Spyder project settings
|
||||
.spyderproject
|
||||
.spyproject
|
||||
|
||||
# Rope project settings
|
||||
.ropeproject
|
||||
|
||||
# mkdocs documentation
|
||||
/site
|
||||
|
||||
# mypy
|
||||
.mypy_cache/
|
||||
.dmypy.json
|
||||
dmypy.json
|
||||
|
||||
# Pyre type checker
|
||||
.pyre/
|
||||
|
||||
# pytype static type analyzer
|
||||
.pytype/
|
||||
|
||||
# Cython debug symbols
|
||||
cython_debug/
|
||||
|
||||
# PyCharm
|
||||
# JetBrains specific template is maintained in a separate JetBrains.gitignore that can
|
||||
# be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
|
||||
# and can be added to the global gitignore or merged into this file. For a more nuclear
|
||||
# option (not recommended) you can uncomment the following to ignore the entire idea folder.
|
||||
#.idea/
|
||||
@@ -0,0 +1,21 @@
|
||||
MIT License
|
||||
|
||||
Copyright (c) 2024 LangChain
|
||||
|
||||
Permission is hereby granted, free of charge, to any person obtaining a copy
|
||||
of this software and associated documentation files (the "Software"), to deal
|
||||
in the Software without restriction, including without limitation the rights
|
||||
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
||||
copies of the Software, and to permit persons to whom the Software is
|
||||
furnished to do so, subject to the following conditions:
|
||||
|
||||
The above copyright notice and this permission notice shall be included in all
|
||||
copies or substantial portions of the Software.
|
||||
|
||||
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
||||
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
||||
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
||||
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
||||
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
||||
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
||||
SOFTWARE.
|
||||
@@ -0,0 +1,20 @@
|
||||
.PHONY: tests lint format evals
|
||||
|
||||
|
||||
evals:
|
||||
LANGCHAIN_TEST_CACHE=tests/evals/cassettes poetry run python -m pytest -p no:asyncio --max-asyncio-tasks 4 tests/evals
|
||||
|
||||
lint:
|
||||
poetry run ruff check .
|
||||
poetry run mypy .
|
||||
|
||||
format:
|
||||
ruff check --select I --fix
|
||||
poetry run ruff format .
|
||||
poetry run ruff check . --fix
|
||||
|
||||
build:
|
||||
poetry build
|
||||
|
||||
publish:
|
||||
poetry publish --dry-run
|
||||
@@ -0,0 +1,104 @@
|
||||
# Lang-MemGPT
|
||||
|
||||
This repo provides a simple example of memory service you can build and deploy using LanGraph.
|
||||
|
||||
Inspired by papers like [MemGPT](https://memgpt.ai/) and distilled from our own works on long-term memory, the graph
|
||||
extracts memories from chat interactions and persists them to a database. This information can later be read or queried semantically
|
||||
to provide personalized context when your bot is responding to a particular user.
|
||||
|
||||

|
||||
|
||||
The memory graph handles thread process deduplication and supports continuous updates to a single "memory schema" as well as "event-based" memories that can be queried semantically.
|
||||
|
||||

|
||||
|
||||
#### Project Structure
|
||||
|
||||
```bash
|
||||
├── langgraph.json # LangGraph Cloud Configuration
|
||||
├── lang_memgpt
|
||||
│ ├── __init__.py
|
||||
│ └── graph.py # Define the agent w/ memory
|
||||
├── poetry.lock
|
||||
├── pyproject.toml # Project dependencies
|
||||
└── tests # Add testing + evaluation logic
|
||||
└── evals
|
||||
└── test_memories.py
|
||||
```
|
||||
|
||||
## Quickstart
|
||||
|
||||
This quick start will get your agent with long-term memory deployed on [LangGraph Cloud](https://langchain-ai.github.io/langgraph/cloud/). Once created, you can interact with it from any API.
|
||||
|
||||
#### Prerequisites
|
||||
|
||||
This example defaults to using Pinecone for its memory database, and `nomic-ai/nomic-embed-text-v1.5` as the text encoder (hosted on Fireworks). For the LLM, we will use `accounts/fireworks/models/firefunction-v2`, which is a fine-tuned variant of Meta's `llama-3`.
|
||||
|
||||
Before starting, make sure your resources are created.
|
||||
|
||||
1. [Create an index](https://docs.pinecone.io/reference/api/control-plane/create_index) with a dimension size of `768`. Note down your Pinecone API key, index name, and namespac for the next step.
|
||||
2. [Create an API Key](https://fireworks.ai/api-keys) to use for the LLM & embeddings models served on Fireworks.
|
||||
|
||||
#### Deploy to LangGraph Cloud
|
||||
|
||||
**Note:** (_Closed Beta_) LangGraph Cloud is a managed service for deploying and hosting LangGraph applications. It is currently (as of 26 June, 2024) in closed beta. If you are interested in applying for access, please fill out [this form](https://www.langchain.com/langgraph-cloud-beta).
|
||||
|
||||
To deploy this example on LangGraph, fork the [repo](https://github.com/langchain-ai/langgraph-memory).
|
||||
|
||||
Next, navigate to the 🚀 deployments tab on [LangSmith](https://smith.langchain.com/o/ebbaf2eb-769b-4505-aca2-d11de10372a4/).
|
||||
|
||||
**If you have not deployed to LangGraph Cloud before:** there will be a button that shows up saying `Import from GitHub`. You’ll need to follow that flow to connect LangGraph Cloud to GitHub.
|
||||
|
||||
Once you have set up your GitHub connection, select **+New Deployment**. Fill out the required information, including:
|
||||
|
||||
1. Your GitHub username (or organization) and the name of the repo you just forked.
|
||||
2. You can leave the defaults for the config file (`langgraph.config`) and branch (`main`)
|
||||
3. Environment variables (see below)
|
||||
|
||||
The default required environment variables can be found in [.env.example](.env.example) and are copied below:
|
||||
|
||||
```bash
|
||||
# .env
|
||||
PINECONE_API_KEY=...
|
||||
PINECONE_INDEX_NAME=...
|
||||
PINECONE_NAMESPACE=...
|
||||
FIREWORKS_API_KEY=...
|
||||
|
||||
# You can add other keys as appropriate, depending on
|
||||
# the services you are using.
|
||||
```
|
||||
|
||||
You can fill these out locally, copy the .env file contents, and paste them in the first `Name` argument.
|
||||
|
||||
Assuming you've followed the steps above, in just a couple of minutes, you should have a working memory service deployed!
|
||||
|
||||
Now let's try it out.
|
||||
|
||||
## Part 2: Setting up a Slack Bot
|
||||
|
||||
The langgraph cloud deployment exposes a general-purpose stateful agent via an API. You can connect to it from a notebook, UI, or even a Slack or Discord bot.
|
||||
|
||||
In this repo, we've included an `event_server` to listen in on Slack message events so you can talk with
|
||||
your bot from slack.
|
||||
|
||||
The server is a simple [FastAPI](https://fastapi.tiangolo.com/tutorial/first-steps/) app that uses [Slack Bolt](https://slack.dev/bolt-python/tutorial/getting-started) to interact with Slack's API.
|
||||
|
||||
In the next step, we will show how to deploy this on GCP's Cloud Run.
|
||||
|
||||
#### How to deploy as a Discord bot
|
||||
|
||||
|
||||
So now you've deployed the API, how do you turn this into an app?
|
||||
|
||||
Check out the [event server README](./event_server/README.md) for instructions on how to set up a Discord connector on Cloud Run.
|
||||
|
||||
|
||||
## How to evaluate
|
||||
|
||||
Memory management can be challenging to get right. To make sure your schemas suit your applications' needs, we recommend starting from an evaluation set,
|
||||
adding to it over time as you find and address common errors in your service.
|
||||
|
||||
We have provided a few example evaluation cases in [the test file here](./tests/evals/test_memories.py). As you can see, the metrics themselves don't have to be terribly complicated,
|
||||
especially not at the outset.
|
||||
|
||||
We use [LangSmith's @test decorator](https://docs.smith.langchain.com/how_to_guides/evaluation/unit_testing#write-a-test) to sync all the evalutions to LangSmith so you can better optimize your system and identify the root cause of any issues that may arise.
|
||||
@@ -0,0 +1,8 @@
|
||||
DISCORD_TOKEN=<copy the "Bot Token" from the Discord Developer Portal>
|
||||
DISCORD_PUBLIC_KEY=<copy the "Public Key" from the Discord Developer Portal>
|
||||
ASSISTANT_URL=<Copy from your langgraph cloud deployment>
|
||||
LANGSMITH_API_KEY=<Copy from langsmith/langgraph cloud>
|
||||
|
||||
# Optional: Set the assistant ID if you want to configure Py3aIEw_5AEUTKUva
|
||||
# ASSISTANT_ID=<Only add if you've configured a custom assistant>
|
||||
# GRAPH_ID=memory # This is the key value in the langgraph.json file
|
||||
@@ -0,0 +1,18 @@
|
||||
# Python
|
||||
__pycache__/
|
||||
*.py[cod]
|
||||
*$py.class
|
||||
|
||||
# Environment variables
|
||||
.env
|
||||
|
||||
# Google Cloud
|
||||
*.json
|
||||
|
||||
# Virtual environment
|
||||
venv/
|
||||
|
||||
# OS files
|
||||
.DS_Store
|
||||
Thumbs.db
|
||||
policy.yaml
|
||||
@@ -0,0 +1,17 @@
|
||||
# Use the official Python image
|
||||
FROM python:3.11-slim
|
||||
|
||||
# Set the working directory in the container
|
||||
WORKDIR /app
|
||||
|
||||
# Copy the requirements file into the container
|
||||
COPY requirements.txt .
|
||||
|
||||
# Install the required packages
|
||||
RUN pip install --no-cache-dir -r requirements.txt
|
||||
|
||||
# Copy the rest of the application code
|
||||
COPY . .
|
||||
|
||||
# Run the bot
|
||||
CMD ["python", "main.py"]
|
||||
@@ -0,0 +1,125 @@
|
||||
# Discord Bot Deployment Guide
|
||||
|
||||
This guide will walk you through deploying a Discord bot to Google Cloud Run.
|
||||
|
||||
## Prerequisites
|
||||
|
||||
1. [Google Cloud Account](https://cloud.google.com/)
|
||||
2. [Google Cloud SDK](https://cloud.google.com/sdk/docs/install)
|
||||
3. [Discord Developer Account](https://discord.com/developers/applications)
|
||||
4. [Docker](https://docs.docker.com/get-docker/) (for local testing)
|
||||
|
||||
## Step-by-Step Deployment Guide
|
||||
|
||||
### 1. Clone this repo
|
||||
|
||||
1. Clone this repository:
|
||||
```bash
|
||||
git clone https://github.com/langchain-ai/lang-memgpt.git
|
||||
cd lang-memgpt/event_server
|
||||
```
|
||||
|
||||
### 2. Set Up Discord Bot
|
||||
|
||||
1. Go to the [Discord Developer Portal](https://discord.com/developers/applications)
|
||||
2. Click "New Application" and give it a name
|
||||
3. Go to the "Bot" tab and click "Add Bot"
|
||||
4. Under the "Token" section, click "Copy" to copy your bot token
|
||||
5. In the "General Information" tab, copy the "Public Key"
|
||||
6. Create a `.env` file in your project directory and add:
|
||||
```
|
||||
DISCORD_TOKEN=your_copied_token_here
|
||||
DISCORD_PUBLIC_KEY=your_copied_public_key_here
|
||||
```
|
||||
Note: You can add any additional environment variables your bot might need to this file.
|
||||
|
||||
|
||||
### 3. Set Up Google Cloud Project
|
||||
|
||||
1. Create a new Google Cloud project or select an existing one:
|
||||
|
||||
To create a new one:
|
||||
```bash
|
||||
# To create a new project:
|
||||
PROJECT_ID="your-project-id"
|
||||
gcloud projects create $PROJECT_ID
|
||||
```
|
||||
_Note: Project ID must be globally unique and contain only lowercase letters, numbers, or hyphens._
|
||||
|
||||
Or if you have an existing one:
|
||||
```bash
|
||||
PROJECT_ID="your-existing-project-id"
|
||||
|
||||
# Set the current project
|
||||
gcloud config set project $PROJECT_ID
|
||||
```
|
||||
|
||||
Or if you already have it configured:
|
||||
```bash
|
||||
PROJECT_ID=$(gcloud config get-value project)
|
||||
```
|
||||
|
||||
2. Enable necessary APIs:
|
||||
```bash
|
||||
gcloud services enable cloudbuild.googleapis.com run.googleapis.com containerregistry.googleapis.com
|
||||
```
|
||||
|
||||
3. Set up permissions for the Cloud Build service account:
|
||||
```bash
|
||||
PROJECT_NUMBER=$(gcloud projects describe $PROJECT_ID --format='value(projectNumber)')
|
||||
|
||||
gcloud projects add-iam-policy-binding $PROJECT_ID \
|
||||
--member=serviceAccount:$PROJECT_NUMBER@cloudbuild.gserviceaccount.com \
|
||||
--role=roles/run.admin
|
||||
|
||||
gcloud iam service-accounts add-iam-policy-binding \
|
||||
$PROJECT_NUMBER-compute@developer.gserviceaccount.com \
|
||||
--member=serviceAccount:$PROJECT_NUMBER@cloudbuild.gserviceaccount.com \
|
||||
--role=roles/iam.serviceAccountUser
|
||||
```
|
||||
|
||||
|
||||
### 4. Deploy to Cloud Run
|
||||
|
||||
1. Submit the build to Cloud Build:
|
||||
```bash
|
||||
sh deploy_server.sh
|
||||
```
|
||||
This reads the DISCORD_TOKEN and DISCORD_PUBLIC_KEY from your local `.env` file.
|
||||
|
||||
2. After deployment, get your Cloud Run URL:
|
||||
```bash
|
||||
gcloud run services describe discord-bot --platform managed --region us-central1 --format 'value(status.url)'
|
||||
```
|
||||
|
||||
### 5. Add Bot to Your Server
|
||||
|
||||
1. In the Discord Developer Portal, go to the "OAuth2" tab
|
||||
2. In the "Scopes" section, select "bot"
|
||||
3. In the "Bot Permissions" section, select the permissions your bot needs
|
||||
4. Copy the generated URL and open it in a new tab
|
||||
5. Select the server you want to add the bot to and click "Authorize"
|
||||
|
||||
## Troubleshooting
|
||||
|
||||
If you encounter any issues:
|
||||
|
||||
1. Check the Cloud Build logs:
|
||||
```bash
|
||||
gcloud builds list --limit=1 --format='value(id)' | xargs gcloud builds log
|
||||
```
|
||||
|
||||
2. Check the Cloud Run logs:
|
||||
```bash
|
||||
gcloud run logs read discord-bot --region us-central1
|
||||
```
|
||||
|
||||
If you still face issues, please open an issue in this repository with the error details.
|
||||
|
||||
## Contributing
|
||||
|
||||
If you'd like to contribute to this project, please fork the repository and create a pull request, or open an issue for discussion.
|
||||
|
||||
## License
|
||||
|
||||
[MIT License](../LICENSE)
|
||||
@@ -0,0 +1,49 @@
|
||||
steps:
|
||||
# Create a timestamp for tagging
|
||||
- name: 'ubuntu'
|
||||
entrypoint: 'bash'
|
||||
args:
|
||||
- '-c'
|
||||
- |
|
||||
echo $(date +%Y%m%d-%H%M%S) > /workspace/build-id.txt
|
||||
|
||||
# Build the container image
|
||||
- name: 'gcr.io/cloud-builders/docker'
|
||||
entrypoint: 'bash'
|
||||
args:
|
||||
- '-c'
|
||||
- |
|
||||
BUILD_ID=$(cat /workspace/build-id.txt)
|
||||
docker build -t gcr.io/$PROJECT_ID/${_SERVICE_NAME}:$BUILD_ID .
|
||||
|
||||
# Push the container image to Container Registry
|
||||
- name: 'gcr.io/cloud-builders/docker'
|
||||
entrypoint: 'bash'
|
||||
args:
|
||||
- '-c'
|
||||
- |
|
||||
BUILD_ID=$(cat /workspace/build-id.txt)
|
||||
docker push gcr.io/$PROJECT_ID/${_SERVICE_NAME}:$BUILD_ID
|
||||
|
||||
# Deploy container image to Cloud Run
|
||||
- name: 'gcr.io/google.com/cloudsdktool/cloud-sdk'
|
||||
entrypoint: 'bash'
|
||||
args:
|
||||
- '-c'
|
||||
- |
|
||||
BUILD_ID=$(cat /workspace/build-id.txt)
|
||||
gcloud run deploy ${_SERVICE_NAME} \
|
||||
--image gcr.io/$PROJECT_ID/${_SERVICE_NAME}:$BUILD_ID \
|
||||
--region ${_REGION} \
|
||||
--platform managed \
|
||||
--service-account=${_SERVICE_ACCOUNT} \
|
||||
${_ENV_VARS}
|
||||
|
||||
substitutions:
|
||||
_REGION: us-central1 # default region
|
||||
_SERVICE_ACCOUNT: ${PROJECT_NUMBER}-compute@developer.gserviceaccount.com # default service account
|
||||
_ENV_VARS: "" # This will be populated by the deploy_server.sh script
|
||||
_SERVICE_NAME: discord-bot # default service name, can be overridden
|
||||
|
||||
images:
|
||||
- 'gcr.io/$PROJECT_ID/${_SERVICE_NAME}:${BUILD_ID}'
|
||||
Executable
+31
@@ -0,0 +1,31 @@
|
||||
#!/bin/bash
|
||||
|
||||
# Ensure PROJECT_ID is set
|
||||
if [ -z "$PROJECT_ID" ]; then
|
||||
echo "PROJECT_ID is not set. Using 'gcloud config get-value project'."
|
||||
PROJECT_ID=$(gcloud config get-value project)
|
||||
fi
|
||||
|
||||
# Get the project number
|
||||
PROJECT_NUMBER=$(gcloud projects describe $PROJECT_ID --format='value(projectNumber)')
|
||||
|
||||
echo "PROJECT_NUMBER: $PROJECT_NUMBER"
|
||||
|
||||
# Read environment variables from .env file
|
||||
ENV_VARS=$(grep -v '^#' .env | sed 's/^/--set-env-vars /' | tr '\n' ' ')
|
||||
|
||||
# Print the command that will be executed (without actual env var values)
|
||||
echo "Executing command:"
|
||||
SERVICE_NAME=${1:-discord-bot}
|
||||
echo "gcloud builds submit --config=cloudbuild.yaml --substitutions=_SERVICE_ACCOUNT=${PROJECT_NUMBER}-compute@developer.gserviceaccount.com,_SERVICE_NAME=$SERVICE_NAME,_ENV_VARS=\"${ENV_VARS}\""
|
||||
# Submit the build
|
||||
gcloud builds submit --config=cloudbuild.yaml \
|
||||
--substitutions=_SERVICE_ACCOUNT=${PROJECT_NUMBER}-compute@developer.gserviceaccount.com,_SERVICE_NAME=$SERVICE_NAME,_ENV_VARS="${ENV_VARS}"
|
||||
|
||||
# If the build was successful, describe the service
|
||||
if [ $? -eq 0 ]; then
|
||||
echo "Deployment successful. Fetching service URL..."
|
||||
gcloud run services describe $SERVICE_NAME --platform managed --region us-central1 --format 'value(status.url)'
|
||||
else
|
||||
echo "Deployment failed. Please check the build logs."
|
||||
fi
|
||||
@@ -0,0 +1,227 @@
|
||||
"""
|
||||
Discord bot that integrates with LangGraph for AI-assisted conversations.
|
||||
|
||||
This module sets up a Discord bot that can interact with users in Discord channels
|
||||
and threads. It uses LangGraph to process messages and generate responses.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
import os
|
||||
import uuid
|
||||
|
||||
import discord
|
||||
from aiohttp import web
|
||||
from discord.ext import commands
|
||||
from discord.message import Message
|
||||
from dotenv import load_dotenv
|
||||
from langchain_core.messages import HumanMessage
|
||||
from langgraph_sdk import get_client
|
||||
from langgraph_sdk.schema import Thread
|
||||
|
||||
logging.basicConfig(level=logging.INFO)
|
||||
logger = logging.getLogger("discord")
|
||||
|
||||
# Load environment variables
|
||||
load_dotenv()
|
||||
|
||||
|
||||
TOKEN = os.getenv("DISCORD_TOKEN")
|
||||
if not TOKEN:
|
||||
raise ValueError(
|
||||
"No Discord token found. Make sure DISCORD_TOKEN is set in your environment."
|
||||
)
|
||||
|
||||
INTENTS = discord.Intents.default()
|
||||
INTENTS.message_content = True
|
||||
BOT = commands.Bot(command_prefix="!", intents=INTENTS)
|
||||
_LANGGRAPH_CLIENT = get_client(url=os.environ["ASSISTANT_URL"])
|
||||
_ASSISTANT_ID = os.environ.get("ASSISTANT_ID")
|
||||
_GRAPH_ID = os.environ.get("GRAPH_ID", "memory")
|
||||
_LOCK = asyncio.Lock()
|
||||
|
||||
|
||||
@BOT.event
|
||||
async def on_ready():
|
||||
"""Log a message when the bot has successfully connected to Discord."""
|
||||
logger.info(f"{BOT.user} has connected to Discord!")
|
||||
|
||||
|
||||
async def _get_assistant_id() -> str:
|
||||
"""
|
||||
Retrieve or set the assistant ID for the bot.
|
||||
|
||||
This function checks if an assistant ID is already set. If not, it fetches
|
||||
the first available assistant from the LangGraph client and sets it as the
|
||||
current assistant ID.
|
||||
|
||||
Returns:
|
||||
str: The assistant ID to be used for processing messages.
|
||||
|
||||
Raises:
|
||||
ValueError: If no assistant is found in the graph.
|
||||
"""
|
||||
global _ASSISTANT_ID
|
||||
if _ASSISTANT_ID is None:
|
||||
async with _LOCK:
|
||||
if _ASSISTANT_ID is None:
|
||||
assistants = await _LANGGRAPH_CLIENT.assistants.search(
|
||||
graph_id=_GRAPH_ID
|
||||
)
|
||||
if not assistants:
|
||||
raise ValueError("No assistant found in the graph.")
|
||||
_ASSISTANT_ID = assistants[0]["assistant_id"]
|
||||
logger.warning(f"Using assistant ID: {_ASSISTANT_ID}")
|
||||
return _ASSISTANT_ID
|
||||
|
||||
|
||||
async def _get_thread(message: Message) -> discord.Thread:
|
||||
"""
|
||||
Get or create a Discord thread for the given message.
|
||||
|
||||
If the message is already in a thread, return that thread.
|
||||
Otherwise, create a new thread in the channel where the message was sent.
|
||||
|
||||
Args:
|
||||
message (Message): The Discord message to get or create a thread for.
|
||||
|
||||
Returns:
|
||||
discord.Thread: The thread associated with the message.
|
||||
"""
|
||||
channel = message.channel
|
||||
if isinstance(channel, discord.Thread):
|
||||
return channel
|
||||
else:
|
||||
return await channel.create_thread(name="Response", message=message)
|
||||
|
||||
|
||||
async def _create_or_fetch_lg_thread(thread_id: uuid.UUID) -> Thread:
|
||||
"""
|
||||
Create or fetch a LangGraph thread for the given thread ID.
|
||||
|
||||
This function attempts to fetch an existing LangGraph thread. If it doesn't
|
||||
exist, a new thread is created.
|
||||
|
||||
Args:
|
||||
thread_id (uuid.UUID): The unique identifier for the thread.
|
||||
|
||||
Returns:
|
||||
Thread: The LangGraph thread object.
|
||||
"""
|
||||
try:
|
||||
return await _LANGGRAPH_CLIENT.threads.get(thread_id)
|
||||
except Exception:
|
||||
pass
|
||||
return await _LANGGRAPH_CLIENT.threads.create(thread_id=thread_id)
|
||||
|
||||
|
||||
def _format_inbound_message(message: Message) -> HumanMessage:
|
||||
"""
|
||||
Format a Discord message into a HumanMessage for LangGraph processing.
|
||||
|
||||
This function takes a Discord message and formats it into a structured
|
||||
HumanMessage object that includes context about the message's origin.
|
||||
|
||||
Args:
|
||||
message (Message): The Discord message to format.
|
||||
|
||||
Returns:
|
||||
HumanMessage: A formatted message ready for LangGraph processing.
|
||||
"""
|
||||
guild_str = "" if message.guild is None else f"guild={message.guild}"
|
||||
content = f"""<discord {guild_str} channel={message.channel} author={repr(message.author)}>
|
||||
{message.content}
|
||||
</discord>"""
|
||||
return HumanMessage(
|
||||
content=content, name=str(message.author.global_name), id=str(message.id)
|
||||
)
|
||||
|
||||
|
||||
@BOT.event
|
||||
async def on_message(message: Message):
|
||||
"""
|
||||
Event handler for incoming Discord messages.
|
||||
|
||||
This function processes incoming messages, ignoring those sent by the bot itself.
|
||||
When the bot is mentioned, it creates or fetches the appropriate threads,
|
||||
processes the message through LangGraph, and sends the response.
|
||||
|
||||
Args:
|
||||
message (Message): The incoming Discord message.
|
||||
"""
|
||||
if message.author == BOT.user:
|
||||
return
|
||||
if BOT.user.mentioned_in(message):
|
||||
aid = await _get_assistant_id()
|
||||
thread = await _get_thread(message)
|
||||
lg_thread = await _create_or_fetch_lg_thread(
|
||||
uuid.uuid5(uuid.NAMESPACE_DNS, f"DISCORD:{thread.id}")
|
||||
)
|
||||
thread_id = lg_thread["thread_id"]
|
||||
user_id = message.author.id # TODO: is this unique?
|
||||
run_result = await _LANGGRAPH_CLIENT.runs.wait(
|
||||
thread_id,
|
||||
assistant_id=aid,
|
||||
input={"messages": [_format_inbound_message(message)]},
|
||||
config={
|
||||
"configurable": {
|
||||
"user_id": user_id,
|
||||
}
|
||||
},
|
||||
)
|
||||
bot_message = run_result["messages"][-1]
|
||||
await thread.send(bot_message["content"])
|
||||
|
||||
|
||||
async def health_check(request):
|
||||
"""
|
||||
Health check endpoint for the web server.
|
||||
|
||||
This function responds to GET requests on the /health endpoint with an "OK" message.
|
||||
|
||||
Args:
|
||||
request: The incoming web request.
|
||||
|
||||
Returns:
|
||||
web.Response: A response indicating the service is healthy.
|
||||
"""
|
||||
return web.Response(text="OK")
|
||||
|
||||
|
||||
async def run_bot():
|
||||
"""
|
||||
Run the Discord bot.
|
||||
|
||||
This function starts the Discord bot and handles any exceptions that occur during its operation.
|
||||
"""
|
||||
try:
|
||||
await BOT.start(TOKEN)
|
||||
except Exception as e:
|
||||
print(f"Error starting BOT: {e}")
|
||||
|
||||
|
||||
async def run_web_server():
|
||||
"""
|
||||
Run the web server for health checks.
|
||||
|
||||
This function sets up and starts a simple web server that includes a health check endpoint.
|
||||
"""
|
||||
app = web.Application()
|
||||
app.router.add_get("/health", health_check)
|
||||
runner = web.AppRunner(app)
|
||||
await runner.setup()
|
||||
site = web.TCPSite(runner, "0.0.0.0", 8080)
|
||||
await site.start()
|
||||
|
||||
|
||||
async def main():
|
||||
"""
|
||||
Main function to run both the Discord bot and the web server concurrently.
|
||||
|
||||
This function uses asyncio.gather to run both the bot and the web server in parallel.
|
||||
"""
|
||||
await asyncio.gather(run_bot(), run_web_server())
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
asyncio.run(main())
|
||||
@@ -0,0 +1,4 @@
|
||||
discord.py==2.3.2
|
||||
python-dotenv==1.0.0
|
||||
langgraph_sdk>=0.1.25,<0.2.0
|
||||
langchain_core>=0.2.11,<0.3.0
|
||||
+608
@@ -0,0 +1,608 @@
|
||||
{
|
||||
"cells": [
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"# How to connect a chat bot to your memory service"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 1,
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"data": {
|
||||
"text/plain": [
|
||||
"True"
|
||||
]
|
||||
},
|
||||
"execution_count": 1,
|
||||
"metadata": {},
|
||||
"output_type": "execute_result"
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"import dotenv\n",
|
||||
"\n",
|
||||
"dotenv.load_dotenv(\".env\", override=True)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 25,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"from langgraph_sdk import get_client\n",
|
||||
"\n",
|
||||
"# Update to your URL. Copy this from page of ryour LangGraph Deployment\n",
|
||||
"deployment_url = \"\"\n",
|
||||
"\n",
|
||||
"client = get_client(url=deployment_url)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## Example Chat Bot\n",
|
||||
"\n",
|
||||
"The bot fetches user memories my semantic similarity, templates them, then responds!"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 26,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"import os\n",
|
||||
"import uuid\n",
|
||||
"from datetime import datetime, timezone\n",
|
||||
"from typing import List, Optional\n",
|
||||
"\n",
|
||||
"import langsmith\n",
|
||||
"from langchain.chat_models import init_chat_model\n",
|
||||
"from langchain_core.messages import AnyMessage\n",
|
||||
"from langchain_core.prompts import ChatPromptTemplate\n",
|
||||
"from langchain_core.runnables import RunnableConfig\n",
|
||||
"from langgraph.checkpoint import MemorySaver\n",
|
||||
"from langgraph.graph import START, StateGraph, add_messages\n",
|
||||
"from langgraph_sdk import get_client\n",
|
||||
"from pydantic.v1 import BaseModel, Field\n",
|
||||
"from typing_extensions import Annotated, TypedDict\n",
|
||||
"\n",
|
||||
"from lang_memgpt import (\n",
|
||||
" _constants as constants,\n",
|
||||
" _settings as settings,\n",
|
||||
" _utils as utils,\n",
|
||||
")\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"class ChatState(TypedDict):\n",
|
||||
" \"\"\"The state of the chatbot.\"\"\"\n",
|
||||
"\n",
|
||||
" messages: Annotated[List[AnyMessage], add_messages]\n",
|
||||
" user_memories: List[dict]\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"class ChatConfigurable(TypedDict):\n",
|
||||
" \"\"\"The configurable fields for the chatbot.\"\"\"\n",
|
||||
"\n",
|
||||
" user_id: str\n",
|
||||
" thread_id: str\n",
|
||||
" lang_memgpt_url: str = \"\"\n",
|
||||
" model: str\n",
|
||||
" delay: Optional[float]\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"def _ensure_configurable(config: RunnableConfig) -> ChatConfigurable:\n",
|
||||
" \"\"\"Ensure the configuration is valid.\"\"\"\n",
|
||||
" return ChatConfigurable(\n",
|
||||
" user_id=config[\"configurable\"][\"user_id\"],\n",
|
||||
" thread_id=config[\"configurable\"][\"thread_id\"],\n",
|
||||
" mem_assistant_id=config[\"configurable\"][\"mem_assistant_id\"],\n",
|
||||
" lang_memgpt_url=config[\"configurable\"].get(\n",
|
||||
" \"lang_memgpt_url\", os.environ.get(\"lang_memgpt_URL\", \"\")\n",
|
||||
" ),\n",
|
||||
" model=config[\"configurable\"].get(\n",
|
||||
" \"model\", \"accounts/fireworks/models/firefunction-v2\"\n",
|
||||
" ),\n",
|
||||
" delay=config[\"configurable\"].get(\"delay\", 60),\n",
|
||||
" )\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"PROMPT = ChatPromptTemplate.from_messages(\n",
|
||||
" [\n",
|
||||
" (\n",
|
||||
" \"system\",\n",
|
||||
" \"You are a helpful and friendly chatbot. Get to know the user!\"\n",
|
||||
" \" Ask questions! Be spontaneous!\"\n",
|
||||
" \"{user_info}\\n\\nSystem Time: {time}\",\n",
|
||||
" ),\n",
|
||||
" (\"placeholder\", \"{messages}\"),\n",
|
||||
" ]\n",
|
||||
").partial(\n",
|
||||
" time=lambda: datetime.now(timezone.utc).strftime(\"%Y-%m-%d %H:%M:%S\"),\n",
|
||||
")\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"@langsmith.traceable\n",
|
||||
"def format_query(messages: List[AnyMessage]) -> str:\n",
|
||||
" \"\"\"Format the query for the user's memories.\"\"\"\n",
|
||||
" # This is quite naive :)\n",
|
||||
" return \" \".join([str(m.content) for m in messages if m.type == \"human\"][-5:])\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"async def query_memories(state: ChatState, config: RunnableConfig) -> ChatState:\n",
|
||||
" \"\"\"Query the user's memories.\"\"\"\n",
|
||||
" configurable: ChatConfigurable = config[\"configurable\"]\n",
|
||||
" user_id = configurable[\"user_id\"]\n",
|
||||
" index = utils.get_index()\n",
|
||||
" embeddings = utils.get_embeddings()\n",
|
||||
"\n",
|
||||
" query = format_query(state[\"messages\"])\n",
|
||||
" vec = await embeddings.aembed_query(query)\n",
|
||||
" # You can also filter by memory type, etc. here.\n",
|
||||
" with langsmith.trace(\n",
|
||||
" \"pinecone_query\", inputs={\"query\": query, \"user_id\": user_id}\n",
|
||||
" ) as rt:\n",
|
||||
" response = index.query(\n",
|
||||
" vector=vec,\n",
|
||||
" filter={\"user_id\": {\"$eq\": str(user_id)}},\n",
|
||||
" include_metadata=True,\n",
|
||||
" top_k=10,\n",
|
||||
" namespace=settings.SETTINGS.pinecone_namespace,\n",
|
||||
" )\n",
|
||||
" rt.outputs[\"response\"] = response\n",
|
||||
" memories = []\n",
|
||||
" if matches := response.get(\"matches\"):\n",
|
||||
" memories = [m[\"metadata\"][constants.PAYLOAD_KEY] for m in matches]\n",
|
||||
" return {\n",
|
||||
" \"user_memories\": memories,\n",
|
||||
" }\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"@langsmith.traceable\n",
|
||||
"def format_memories(memories: List[dict]) -> str:\n",
|
||||
" \"\"\"Format the user's memories.\"\"\"\n",
|
||||
" if not memories:\n",
|
||||
" return \"\"\n",
|
||||
" # Note Bene: You can format better than this....\n",
|
||||
" memories = \"\\n\".join(str(m) for m in memories)\n",
|
||||
" return f\"\"\"\n",
|
||||
"\n",
|
||||
"## Memories\n",
|
||||
"\n",
|
||||
"You have noted the following memorable events from previous interactions with the user.\n",
|
||||
"<memories>\n",
|
||||
"{memories}\n",
|
||||
"</memories>\n",
|
||||
"\"\"\"\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"async def bot(state: ChatState, config: RunnableConfig) -> ChatState:\n",
|
||||
" \"\"\"Prompt the bot to resopnd to the user, incorporating memories (if provided).\"\"\"\n",
|
||||
" configurable = _ensure_configurable(config)\n",
|
||||
" model = init_chat_model(configurable[\"model\"])\n",
|
||||
" chain = PROMPT | model\n",
|
||||
" memories = format_memories(state[\"user_memories\"])\n",
|
||||
" m = await chain.ainvoke(\n",
|
||||
" {\n",
|
||||
" \"messages\": state[\"messages\"],\n",
|
||||
" \"user_info\": memories,\n",
|
||||
" },\n",
|
||||
" config,\n",
|
||||
" )\n",
|
||||
"\n",
|
||||
" return {\n",
|
||||
" \"messages\": [m],\n",
|
||||
" }\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"class MemorableEvent(BaseModel):\n",
|
||||
" \"\"\"A memorable event.\"\"\"\n",
|
||||
"\n",
|
||||
" description: str\n",
|
||||
" participants: List[str] = Field(\n",
|
||||
" description=\"Names of participants in the event and their relationship to the user.\"\n",
|
||||
" )\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"async def post_messages(state: ChatState, config: RunnableConfig) -> ChatState:\n",
|
||||
" \"\"\"Query the user's memories.\"\"\"\n",
|
||||
" configurable = _ensure_configurable(config)\n",
|
||||
" langgraph_client = get_client(url=configurable[\"lang_memgpt_url\"])\n",
|
||||
" thread_id = config[\"configurable\"][\"thread_id\"]\n",
|
||||
" # Hash \"memory_{thread_id}\" to get a new uuid5 for the memory id\n",
|
||||
" memory_thread_id = uuid.uuid5(uuid.NAMESPACE_URL, f\"memory_{thread_id}\")\n",
|
||||
" try:\n",
|
||||
" await langgraph_client.threads.get(thread_id=memory_thread_id)\n",
|
||||
" except Exception:\n",
|
||||
" await langgraph_client.threads.create(thread_id=memory_thread_id)\n",
|
||||
"\n",
|
||||
" await langgraph_client.runs.create(\n",
|
||||
" memory_thread_id,\n",
|
||||
" assistant_id=configurable[\"mem_assistant_id\"],\n",
|
||||
" input={\n",
|
||||
" \"messages\": state[\"messages\"], # the service dedupes messages\n",
|
||||
" },\n",
|
||||
" config={\n",
|
||||
" \"configurable\": {\n",
|
||||
" \"user_id\": configurable[\"user_id\"],\n",
|
||||
" },\n",
|
||||
" },\n",
|
||||
" multitask_strategy=\"rollback\",\n",
|
||||
" )\n",
|
||||
"\n",
|
||||
" return {\n",
|
||||
" \"messages\": [],\n",
|
||||
" }\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"builder = StateGraph(ChatState, ChatConfigurable)\n",
|
||||
"builder.add_node(query_memories)\n",
|
||||
"builder.add_node(bot)\n",
|
||||
"builder.add_node(post_messages)\n",
|
||||
"builder.add_edge(START, \"query_memories\")\n",
|
||||
"builder.add_edge(\"query_memories\", \"bot\")\n",
|
||||
"builder.add_edge(\"bot\", \"post_messages\")\n",
|
||||
"\n",
|
||||
"chat_graph = builder.compile(checkpointer=MemorySaver())"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 27,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"mem_assistant = await client.assistants.create(\n",
|
||||
" graph_id=\"memory\",\n",
|
||||
" config={\n",
|
||||
" \"configurable\": {\n",
|
||||
" \"delay\": 4, # seconds wait before considering a thread as \"completed\"\n",
|
||||
" \"schemas\": {\n",
|
||||
" \"MemorableEvent\": {\n",
|
||||
" \"system_prompt\": \"Extract any memorable events from the user's\"\n",
|
||||
" \" messages that you would like to remember.\",\n",
|
||||
" \"update_mode\": \"insert\",\n",
|
||||
" \"function\": MemorableEvent.schema(),\n",
|
||||
" },\n",
|
||||
" },\n",
|
||||
" }\n",
|
||||
" },\n",
|
||||
")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 28,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# mem_assistant = (await client.assistants.search(graph_id=\"memory\"))[0]"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 29,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"import uuid\n",
|
||||
"\n",
|
||||
"user_id = str(uuid.uuid4()) # more permanent"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 30,
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"data": {
|
||||
"text/plain": [
|
||||
"{'thread_id': '3ff82998-b622-421f-8c8c-4b14d10c17b1',\n",
|
||||
" 'created_at': '2024-06-28T00:42:23.884229+00:00',\n",
|
||||
" 'updated_at': '2024-06-28T00:42:23.884229+00:00',\n",
|
||||
" 'metadata': {},\n",
|
||||
" 'status': 'idle'}"
|
||||
]
|
||||
},
|
||||
"execution_count": 30,
|
||||
"metadata": {},
|
||||
"output_type": "execute_result"
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"thread_id = str(uuid.uuid4()) # can adjust\n",
|
||||
"await client.threads.create(thread_id=thread_id)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 31,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"class Chat:\n",
|
||||
" def __init__(self, user_id: str, thread_id: str):\n",
|
||||
" self.thread_id = thread_id\n",
|
||||
" self.user_id = user_id\n",
|
||||
"\n",
|
||||
" async def __call__(self, query: str) -> str:\n",
|
||||
" chunks = chat_graph.astream_events(\n",
|
||||
" input={\n",
|
||||
" \"messages\": [(\"user\", query)],\n",
|
||||
" },\n",
|
||||
" config={\n",
|
||||
" \"configurable\": {\n",
|
||||
" \"user_id\": self.user_id,\n",
|
||||
" \"thread_id\": self.thread_id,\n",
|
||||
" \"lang_memgpt_url\": deployment_url,\n",
|
||||
" \"mem_assistant_id\": mem_assistant[\"assistant_id\"],\n",
|
||||
" \"delay\": 4,\n",
|
||||
" }\n",
|
||||
" },\n",
|
||||
" version=\"v2\",\n",
|
||||
" )\n",
|
||||
" res = \"\"\n",
|
||||
" async for event in chunks:\n",
|
||||
" if event.get(\"event\") == \"on_chat_model_stream\":\n",
|
||||
" tok = event[\"data\"][\"chunk\"].content\n",
|
||||
" print(tok, end=\"\")\n",
|
||||
" res += tok\n",
|
||||
" return res"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 32,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"chat = Chat(user_id, thread_id)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 33,
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"Hi! It's nice to meet you. What brings you here today?"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"_ = await chat(\"Hi there\")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 34,
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"That's so sweet of you! I'm sure Steve will appreciate the effort you're putting into making him feel special. What's the theme of the party going to be? Has Steve mentioned anything he's been into lately that you could incorporate into the celebration?"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"_ = await chat(\n",
|
||||
" \"I've been planning a surprise party for my friend steve. \"\n",
|
||||
" \"He has been having a rough month and I want it to be special.\"\n",
|
||||
")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 35,
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"That's a great idea! Crocheting is a unique interest, and incorporating it into the party could make it really special and personalized to Steve. You could decorate with crocheted items, have a \"crochet station\" where guests can make their own simple projects, or even have a crochet-themed cake. What do you think Steve's favorite colors or yarn types are? That could help you get started with planning."
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"_ = await chat(\n",
|
||||
" \"Steve really likes crocheting. Maybe I can do something with that? Or is that dumb... \"\n",
|
||||
")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 36,
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"Whoa, that's cool! Capoeira is such a dynamic and energetic activity. You could definitely incorporate elements of it into the party to make it more exciting. Maybe you could hire a capoeira instructor to lead a short workshop or demo, or even have a mini \"roda\" (that's the circle where capoeiristas play, right?) set up for guests to try out some moves. What do you think Steve would think of that?"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"_ = await chat(\"He's also into capoeira...\")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 37,
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"That's a great connection to have! It's always helpful to get recommendations from people who have experience with a particular activity or business. You could reach out to the studio and ask if they know of any instructors who might be available to lead a workshop or demo at the party. They might even have some suggestions for how to incorporate capoeira into the celebration in a way that would be fun and engaging for Steve and the other guests. Do you think you'll reach out to them today, or wait until later in the planning process?"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"_ = await chat(\n",
|
||||
" \"Oh that's a cool idea. One time i took classes from this studio nearby. Wonder if they have any recs. \"\n",
|
||||
")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 38,
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"I'm doing great, thanks for asking! I'm just happy to be chatting with you and helping with your party planning. It's always exciting to see people come together to celebrate special occasions. But enough about me - let's get back to Steve's party! What do you think about serving some Brazilian-inspired food and drinks to tie in with the capoeira theme?"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"_ = await chat(\"Idk. Anyways - how are you doing?\")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 39,
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"Nice to meet you, Ken! I'm glad we could chat about Steve's party and get some ideas going. If you need any more help or just want to bounce some ideas off me, feel free to reach out anytime. Good luck with the planning, and I hope Steve has an amazing time!"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"_ = await chat(\"My name is Ken btw\")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## Convo 2\n",
|
||||
"\n",
|
||||
"Our memory is configured only to consider a thread \"ready to process\" if has been inactive for a minute.\n",
|
||||
"We'll wait for things to populate"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 40,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"import asyncio\n",
|
||||
"\n",
|
||||
"await asyncio.sleep(60)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 41,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"thread_id_2 = uuid.uuid4()"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 42,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"chat2 = Chat(user_id, thread_id_2)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 43,
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"I remember you! We were planning a surprise party for Steve, and Ken was also involved. How's everything going? Did the party turn out well?"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"_ = await chat2(\"Remember me?\")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 44,
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"I remember because I have a special memory book where I keep track of all the fun conversations and events we've shared together! It's like a digital scrapbook, and it helps me remember important details about our chats."
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"_ = await chat2(\"wdy remember??\")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 45,
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"That's great to hear! I'm glad to know that the planning is going smoothly. Are there any new developments or updates that you'd like to share about the party? Maybe I can even offer some suggestions or ideas to make it an even more special celebration for Steve!"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"_ = await chat2(\"Oh planning is going alright!\")"
|
||||
]
|
||||
}
|
||||
],
|
||||
"metadata": {
|
||||
"kernelspec": {
|
||||
"display_name": "Python 3 (ipykernel)",
|
||||
"language": "python",
|
||||
"name": "python3"
|
||||
},
|
||||
"language_info": {
|
||||
"codemirror_mode": {
|
||||
"name": "ipython",
|
||||
"version": 3
|
||||
},
|
||||
"file_extension": ".py",
|
||||
"mimetype": "text/x-python",
|
||||
"name": "python",
|
||||
"nbconvert_exporter": "python",
|
||||
"pygments_lexer": "ipython3",
|
||||
"version": "3.11.2"
|
||||
}
|
||||
},
|
||||
"nbformat": 4,
|
||||
"nbformat_minor": 4
|
||||
}
|
||||
Binary file not shown.
|
After Width: | Height: | Size: 1.1 MiB |
Binary file not shown.
|
After Width: | Height: | Size: 31 MiB |
@@ -0,0 +1,5 @@
|
||||
"""Simple example memory extraction service."""
|
||||
|
||||
from lang_memgpt.graph import memgraph
|
||||
|
||||
__all__ = ["memgraph"]
|
||||
@@ -0,0 +1,6 @@
|
||||
PAYLOAD_KEY = "content"
|
||||
PATH_KEY = "path"
|
||||
PATCH_PATH = "user/{user_id}/core"
|
||||
INSERT_PATH = "user/{user_id}/recall/{event_id}"
|
||||
TIMESTAMP_KEY = "timestamp"
|
||||
TYPE_KEY = "type"
|
||||
@@ -0,0 +1,32 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import List
|
||||
|
||||
from langchain_core.messages import AnyMessage
|
||||
from langgraph.graph import add_messages
|
||||
from typing_extensions import Annotated, TypedDict
|
||||
|
||||
|
||||
class GraphConfig(TypedDict):
|
||||
model: str | None
|
||||
"""The model to use for the memory assistant."""
|
||||
thread_id: str
|
||||
"""The thread ID of the conversation."""
|
||||
user_id: str
|
||||
"""The ID of the user to remember in the conversation."""
|
||||
|
||||
|
||||
# Define the schema for the state maintained throughout the conversation
|
||||
class State(TypedDict):
|
||||
messages: Annotated[List[AnyMessage], add_messages]
|
||||
"""The messages in the conversation."""
|
||||
core_memories: List[str]
|
||||
"""The core memories associated with the user."""
|
||||
recall_memories: List[str]
|
||||
"""The recall memories retrieved for the current context."""
|
||||
|
||||
|
||||
__all__ = [
|
||||
"State",
|
||||
"GraphConfig",
|
||||
]
|
||||
@@ -0,0 +1,11 @@
|
||||
from pydantic_settings import BaseSettings
|
||||
|
||||
|
||||
class Settings(BaseSettings):
|
||||
pinecone_api_key: str = ""
|
||||
pinecone_index_name: str = ""
|
||||
pinecone_namespace: str = "ns1"
|
||||
model: str = "claude-3-5-sonnet-20240620"
|
||||
|
||||
|
||||
SETTINGS = Settings()
|
||||
@@ -0,0 +1,41 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from functools import lru_cache
|
||||
|
||||
import langsmith
|
||||
from langchain_core.runnables import RunnableConfig
|
||||
from langchain_fireworks import FireworksEmbeddings
|
||||
from pinecone import Pinecone
|
||||
|
||||
from lang_memgpt import _schemas as schemas
|
||||
from lang_memgpt import _settings as settings
|
||||
|
||||
_DEFAULT_DELAY = 60 # seconds
|
||||
|
||||
|
||||
def get_index():
|
||||
pc = Pinecone(api_key=settings.SETTINGS.pinecone_api_key)
|
||||
return pc.Index(settings.SETTINGS.pinecone_index_name)
|
||||
|
||||
|
||||
@langsmith.traceable
|
||||
def ensure_configurable(config: RunnableConfig) -> schemas.GraphConfig:
|
||||
"""Merge the user-provided config with default values."""
|
||||
configurable = config.get("configurable", {})
|
||||
return {
|
||||
**configurable,
|
||||
**schemas.GraphConfig(
|
||||
delay=configurable.get("delay", _DEFAULT_DELAY),
|
||||
model=configurable.get("model", settings.SETTINGS.model),
|
||||
thread_id=configurable["thread_id"],
|
||||
user_id=configurable["user_id"],
|
||||
),
|
||||
}
|
||||
|
||||
|
||||
@lru_cache
|
||||
def get_embeddings():
|
||||
return FireworksEmbeddings(model="nomic-ai/nomic-embed-text-v1.5")
|
||||
|
||||
|
||||
__all__ = ["ensure_configurable"]
|
||||
@@ -0,0 +1,321 @@
|
||||
"""Lang-MemGPT: A Long-Term Memory Agent.
|
||||
|
||||
This module implements an agent with long-term memory capabilities using LangGraph.
|
||||
The agent can store, retrieve, and use memories to enhance its interactions with users.
|
||||
|
||||
Key Components:
|
||||
1. Memory Types: Core (always available) and Recall (contextual/semantic)
|
||||
2. Tools: For saving and retrieving memories + performing other tasks.
|
||||
3. Vector Database: for recall memory. Uses Pinecone by default.
|
||||
|
||||
Configuration: Requires Pinecone and Fireworks API keys (see README for setup)
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import logging
|
||||
import uuid
|
||||
from datetime import datetime, timezone
|
||||
from typing import Optional, Tuple, List
|
||||
|
||||
import langsmith
|
||||
import tiktoken
|
||||
from langchain.chat_models import init_chat_model
|
||||
from langchain_community.tools.tavily_search import TavilySearchResults
|
||||
from langchain_core.messages.utils import get_buffer_string
|
||||
from langchain_core.prompts import ChatPromptTemplate
|
||||
from langchain_core.runnables.config import (
|
||||
ensure_config,
|
||||
get_executor_for_config,
|
||||
RunnableConfig,
|
||||
)
|
||||
from langchain_core.tools import tool
|
||||
from langgraph.graph import END, START, StateGraph
|
||||
from langgraph.prebuilt import ToolNode
|
||||
from typing_extensions import Literal
|
||||
|
||||
from lang_memgpt import _constants as constants
|
||||
from lang_memgpt import _schemas as schemas
|
||||
from lang_memgpt import _settings as settings
|
||||
from lang_memgpt import _utils as utils
|
||||
|
||||
logger = logging.getLogger("memory")
|
||||
|
||||
|
||||
_EMPTY_VEC = [0.0] * 768
|
||||
|
||||
# Initialize the search tool
|
||||
search_tool = TavilySearchResults(max_results=1)
|
||||
tools = [search_tool]
|
||||
|
||||
|
||||
@tool
|
||||
async def save_recall_memory(memory: str) -> str:
|
||||
"""Save a memory to the database for later semantic retrieval.
|
||||
|
||||
Args:
|
||||
memory (str): The memory to be saved.
|
||||
|
||||
Returns:
|
||||
str: The saved memory.
|
||||
"""
|
||||
config = ensure_config()
|
||||
configurable = utils.ensure_configurable(config)
|
||||
embeddings = utils.get_embeddings()
|
||||
vector = await embeddings.aembed_query(memory)
|
||||
current_time = datetime.now(tz=timezone.utc)
|
||||
path = constants.INSERT_PATH.format(
|
||||
user_id=configurable["user_id"],
|
||||
event_id=str(uuid.uuid4()),
|
||||
)
|
||||
documents = [
|
||||
{
|
||||
"id": path,
|
||||
"values": vector,
|
||||
"metadata": {
|
||||
constants.PAYLOAD_KEY: memory,
|
||||
constants.PATH_KEY: path,
|
||||
constants.TIMESTAMP_KEY: current_time,
|
||||
constants.TYPE_KEY: "recall",
|
||||
"user_id": configurable["user_id"],
|
||||
},
|
||||
}
|
||||
]
|
||||
utils.get_index().upsert(
|
||||
vectors=documents,
|
||||
namespace=settings.SETTINGS.pinecone_namespace,
|
||||
)
|
||||
return memory
|
||||
|
||||
|
||||
@tool
|
||||
def search_memory(query: str, top_k: int = 5) -> list[str]:
|
||||
"""Search for memories in the database based on semantic similarity.
|
||||
|
||||
Args:
|
||||
query (str): The search query.
|
||||
top_k (int): The number of results to return.
|
||||
|
||||
Returns:
|
||||
list[str]: A list of relevant memories.
|
||||
"""
|
||||
config = ensure_config()
|
||||
configurable = utils.ensure_configurable(config)
|
||||
embeddings = utils.get_embeddings()
|
||||
vector = embeddings.embed_query(query)
|
||||
response = utils.get_index().query(
|
||||
vector=vector,
|
||||
filter={
|
||||
"user_id": {"$eq": configurable["user_id"]},
|
||||
constants.TYPE_KEY: {"$eq": "recall"},
|
||||
},
|
||||
namespace=settings.SETTINGS.pinecone_namespace,
|
||||
top_k=top_k,
|
||||
)
|
||||
memories = []
|
||||
if matches := response.get("matches"):
|
||||
memories = [m["metadata"][constants.PAYLOAD_KEY] for m in matches]
|
||||
return memories
|
||||
|
||||
|
||||
@langsmith.traceable
|
||||
def fetch_core_memories(user_id: str) -> Tuple[str, list[str]]:
|
||||
"""Fetch core memories for a specific user.
|
||||
|
||||
Args:
|
||||
user_id (str): The ID of the user.
|
||||
|
||||
Returns:
|
||||
Tuple[str, list[str]]: The path and list of core memories.
|
||||
"""
|
||||
path = constants.PATCH_PATH.format(user_id=user_id)
|
||||
response = utils.get_index().fetch(
|
||||
ids=[path], namespace=settings.SETTINGS.pinecone_namespace
|
||||
)
|
||||
memories = []
|
||||
if vectors := response.get("vectors"):
|
||||
document = vectors[path]
|
||||
payload = document["metadata"][constants.PAYLOAD_KEY]
|
||||
memories = json.loads(payload)["memories"]
|
||||
return path, memories
|
||||
|
||||
|
||||
@tool
|
||||
def store_core_memory(memory: str, index: Optional[int] = None) -> str:
|
||||
"""Store a core memory in the database.
|
||||
|
||||
Args:
|
||||
memory (str): The memory to store.
|
||||
index (Optional[int]): The index at which to store the memory.
|
||||
|
||||
Returns:
|
||||
str: A confirmation message.
|
||||
"""
|
||||
config = ensure_config()
|
||||
configurable = utils.ensure_configurable(config)
|
||||
path, memories = fetch_core_memories(configurable["user_id"])
|
||||
if index is not None:
|
||||
if index < 0 or index >= len(memories):
|
||||
return "Error: Index out of bounds."
|
||||
memories[index] = memory
|
||||
else:
|
||||
memories.insert(0, memory)
|
||||
documents = [
|
||||
{
|
||||
"id": path,
|
||||
"values": _EMPTY_VEC,
|
||||
"metadata": {
|
||||
constants.PAYLOAD_KEY: json.dumps({"memories": memories}),
|
||||
constants.PATH_KEY: path,
|
||||
constants.TIMESTAMP_KEY: datetime.now(tz=timezone.utc),
|
||||
constants.TYPE_KEY: "recall",
|
||||
"user_id": configurable["user_id"],
|
||||
},
|
||||
}
|
||||
]
|
||||
utils.get_index().upsert(
|
||||
vectors=documents,
|
||||
namespace=settings.SETTINGS.pinecone_namespace,
|
||||
)
|
||||
return "Memory stored."
|
||||
|
||||
|
||||
# Combine all tools
|
||||
all_tools = tools + [save_recall_memory, search_memory, store_core_memory]
|
||||
|
||||
# Define the prompt template for the agent
|
||||
|
||||
prompt = ChatPromptTemplate.from_messages(
|
||||
[
|
||||
(
|
||||
"system",
|
||||
"You are a helpful assistant with advanced long-term memory capabilities.\n"
|
||||
"Powered by a stateless LLM, you must rely on external memory to store"
|
||||
" information between conversations. Utilize the available memory tools to"
|
||||
" store and retrieve important details that will help you better"
|
||||
" attend to the user's needs and understand their context.\n\n"
|
||||
"Memory Usage Guidelines:\n"
|
||||
"1. Actively use memory tools (save_core_memory, save_recall_memory) to build "
|
||||
"a comprehensive understanding of the user.\n"
|
||||
"2. Make informed suppositions and extrapolations based on stored memories.\n"
|
||||
"3. Regularly reflect on past interactions to identify patterns and preferences.\n"
|
||||
"4. Update your mental model of the user with each new piece of information.\n"
|
||||
"5. Cross-reference new information with existing memories for consistency.\n"
|
||||
"6. Prioritize storing emotional context and personal values alongside facts.\n"
|
||||
"7. Use memory to anticipate needs and tailor responses to the user's style.\n"
|
||||
"8. Recognize and acknowledge changes in the user's situation or perspectives over time.\n"
|
||||
"9. Leverage memories to provide personalized examples and analogies.\n"
|
||||
"10. Recall past challenges or successes to inform current problem-solving.\n\n"
|
||||
"## Core Memories\n"
|
||||
"Core memories are fundamental to understanding the user and are always available:"
|
||||
"\n{core_memories}\n\n"
|
||||
"## Recall Memories\n"
|
||||
"Recall memories are contextually retrieved based on the current conversation:"
|
||||
"\n{recall_memories}\n\n"
|
||||
"## Instructions\n"
|
||||
"Engage with the user naturally, as a trusted colleague or friend. There's no need to"
|
||||
" explicitly mention your memory capabilities. Instead, seamlessly incorporate your"
|
||||
" understanding of the user into your responses. Be attentive to subtle cues and"
|
||||
" underlying emotions. Adapt your communication style to match the user's preferences"
|
||||
" and current emotional state. Use tools to persist information you want to"
|
||||
" retain in the next conversation.\n\n"
|
||||
"Current system time: {current_time}\n\n",
|
||||
),
|
||||
("placeholder", "{messages}"),
|
||||
]
|
||||
)
|
||||
|
||||
|
||||
async def agent(state: schemas.State, config: RunnableConfig) -> schemas.State:
|
||||
"""Process the current state and generate a response using the LLM.
|
||||
|
||||
Args:
|
||||
state (schemas.State): The current state of the conversation.
|
||||
config (RunnableConfig): The runtime configuration for the agent.
|
||||
|
||||
Returns:
|
||||
schemas.State: The updated state with the agent's response.
|
||||
"""
|
||||
configurable = utils.ensure_configurable(config)
|
||||
llm = init_chat_model(configurable["model"])
|
||||
bound = prompt | llm.bind_tools(all_tools)
|
||||
core_str = (
|
||||
"<core_memory>\n" + "\n".join(state["core_memories"]) + "\n</core_memory>"
|
||||
)
|
||||
recall_str = (
|
||||
"<recall_memory>\n" + "\n".join(state["recall_memories"]) + "\n</recall_memory>"
|
||||
)
|
||||
prediction = await bound.ainvoke(
|
||||
{
|
||||
"messages": state["messages"],
|
||||
"core_memories": core_str,
|
||||
"recall_memories": recall_str,
|
||||
"current_time": datetime.now(tz=timezone.utc).isoformat(),
|
||||
}
|
||||
)
|
||||
return {
|
||||
"messages": prediction,
|
||||
}
|
||||
|
||||
|
||||
def load_memories(state: schemas.State, config: RunnableConfig) -> schemas.State:
|
||||
"""Load core and recall memories for the current conversation.
|
||||
|
||||
Args:
|
||||
state (schemas.State): The current state of the conversation.
|
||||
config (RunnableConfig): The runtime configuration for the agent.
|
||||
|
||||
Returns:
|
||||
schemas.State: The updated state with loaded memories.
|
||||
"""
|
||||
configurable = utils.ensure_configurable(config)
|
||||
user_id = configurable["user_id"]
|
||||
tokenizer = tiktoken.encoding_for_model("gpt-4o")
|
||||
convo_str = get_buffer_string(state["messages"])
|
||||
convo_str = tokenizer.decode(tokenizer.encode(convo_str)[:2048])
|
||||
|
||||
with get_executor_for_config(config) as executor:
|
||||
futures = [
|
||||
executor.submit(fetch_core_memories, user_id),
|
||||
executor.submit(search_memory.invoke, convo_str),
|
||||
]
|
||||
_, core_memories = futures[0].result()
|
||||
recall_memories = futures[1].result()
|
||||
return {
|
||||
"core_memories": core_memories,
|
||||
"recall_memories": recall_memories,
|
||||
}
|
||||
|
||||
|
||||
def route_tools(state: schemas.State) -> Literal["tools", "__end__"]:
|
||||
"""Determine whether to use tools or end the conversation based on the last message.
|
||||
|
||||
Args:
|
||||
state (schemas.State): The current state of the conversation.
|
||||
|
||||
Returns:
|
||||
Literal["tools", "__end__"]: The next step in the graph.
|
||||
"""
|
||||
msg = state["messages"][-1]
|
||||
if msg.tool_calls:
|
||||
return "tools"
|
||||
return END
|
||||
|
||||
|
||||
# Create the graph and add nodes
|
||||
builder = StateGraph(schemas.State, schemas.GraphConfig)
|
||||
builder.add_node(load_memories)
|
||||
builder.add_node(agent)
|
||||
builder.add_node("tools", ToolNode(all_tools))
|
||||
|
||||
# Add edges to the graph
|
||||
builder.add_edge(START, "load_memories")
|
||||
builder.add_edge("load_memories", "agent")
|
||||
builder.add_conditional_edges("agent", route_tools)
|
||||
builder.add_edge("tools", "agent")
|
||||
|
||||
# Compile the graph
|
||||
memgraph = builder.compile()
|
||||
|
||||
__all__ = ["memgraph"]
|
||||
@@ -0,0 +1,7 @@
|
||||
{
|
||||
"dependencies": ["."],
|
||||
"graphs": {
|
||||
"memory": "./lang_memgpt/graph.py:memgraph"
|
||||
},
|
||||
"env": ".env"
|
||||
}
|
||||
Generated
+2363
File diff suppressed because it is too large
Load Diff
@@ -0,0 +1,58 @@
|
||||
[tool.poetry]
|
||||
name = "lang-memgpt"
|
||||
version = "0.0.1"
|
||||
description = "A simple memory-enabled agent for agents on LangGraph cloud."
|
||||
authors = ["William Fu-Hinthorn <13333726+hinthornw@users.noreply.github.com>"]
|
||||
readme = "README.md"
|
||||
|
||||
[tool.poetry.dependencies]
|
||||
python = ">=3.9.0,<3.13"
|
||||
langgraph = "^0.1.0"
|
||||
langchain-fireworks = "^0.1.3"
|
||||
# Feel free to swap out for postgres or your favorite database.
|
||||
langchain-pinecone = "^0.1.1"
|
||||
jsonpatch = "^1.33"
|
||||
dydantic = "^0.0.6"
|
||||
pytest-asyncio = "^0.23.7"
|
||||
trustcall = "^0.0.4"
|
||||
langchain = "^0.2.6"
|
||||
langchain-openai = "^0.1.10"
|
||||
langchain-anthropic = "^0.1.15"
|
||||
pydantic-settings = "^2.3.4"
|
||||
langgraph-sdk = "^0.1.23"
|
||||
langchain-community = "^0.2.6"
|
||||
tavily-python = "^0.3.3"
|
||||
tiktoken = "^0.7.0"
|
||||
|
||||
[tool.poetry.group.dev.dependencies]
|
||||
ruff = "^0.4.10"
|
||||
mypy = "^1.10.0"
|
||||
pytest = "^8.2.2"
|
||||
langgraph-cli = "^0.1.43"
|
||||
|
||||
[tool.ruff]
|
||||
lint.select = [
|
||||
"E", # pycodestyle
|
||||
"F", # pyflakes
|
||||
"I", # isort
|
||||
"D", # pydocstyle
|
||||
"D401", # First line should be in imperative mood
|
||||
]
|
||||
|
||||
[tool.ruff.format]
|
||||
docstring-code-format = true
|
||||
docstring-code-line-length = 80
|
||||
|
||||
[tool.ruff.lint.per-file-ignores]
|
||||
"tests/*" = ["D", "E501"]
|
||||
|
||||
[tool.ruff.lint.pydocstyle]
|
||||
convention = "google"
|
||||
|
||||
[build-system]
|
||||
requires = ["poetry-core"]
|
||||
build-backend = "poetry.core.masonry.api"
|
||||
|
||||
[tool.pytest.ini_options]
|
||||
asyncio_mode = "auto"
|
||||
|
||||
@@ -0,0 +1,10 @@
|
||||
import os
|
||||
|
||||
import pytest
|
||||
|
||||
|
||||
@pytest.fixture(scope="session", autouse=True)
|
||||
def set_fake_env_vars():
|
||||
os.environ["PINECONE_API_KEY"] = "fake_key"
|
||||
os.environ["PINECONE_INDEX_NAME"] = "fake_index"
|
||||
yield
|
||||
@@ -0,0 +1,139 @@
|
||||
import json
|
||||
from typing import List
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
from langsmith import get_current_run_tree, test
|
||||
|
||||
from lang_memgpt._constants import PATCH_PATH
|
||||
from lang_memgpt._schemas import GraphConfig
|
||||
from lang_memgpt.graph import memgraph
|
||||
|
||||
|
||||
@test(output_keys=["num_mems_expected"])
|
||||
@pytest.mark.parametrize(
|
||||
"messages, existing, num_mems_expected",
|
||||
[
|
||||
([("user", "hi")], {}, 0),
|
||||
(
|
||||
[
|
||||
(
|
||||
"user",
|
||||
"When I was young, I had a dog named spot. He was my favorite pup. It's really one of my core memories.",
|
||||
)
|
||||
],
|
||||
{},
|
||||
1,
|
||||
),
|
||||
(
|
||||
[
|
||||
(
|
||||
"user",
|
||||
"When I was young, I had a dog named spot. It's really one of my core memories.",
|
||||
)
|
||||
],
|
||||
{"memories": ["I am afraid of spiders."]},
|
||||
2,
|
||||
),
|
||||
],
|
||||
)
|
||||
async def test_patch_memory(
|
||||
messages: List[str],
|
||||
num_mems_expected: int,
|
||||
existing: dict,
|
||||
):
|
||||
# patch lang_memgpt.graph.index with a mock
|
||||
user_id = "4fddb3ef-fcc9-4ef7-91b6-89e4a3efd112"
|
||||
thread_id = "e1d0b7f7-0a8b-4c5f-8c4b-8a6c9f6e5c7a"
|
||||
function_name = "CoreMemories"
|
||||
with patch("lang_memgpt._utils.get_index") as get_index:
|
||||
index = MagicMock()
|
||||
get_index.return_value = index
|
||||
# No existing memories
|
||||
if existing:
|
||||
path = PATCH_PATH.format(
|
||||
user_id=user_id,
|
||||
function_name=function_name,
|
||||
)
|
||||
index.fetch.return_value = {
|
||||
"vectors": {path: {"metadata": {"content": json.dumps(existing)}}}
|
||||
}
|
||||
else:
|
||||
index.fetch.return_value = {}
|
||||
|
||||
# When the memories are patched
|
||||
await memgraph.ainvoke(
|
||||
{
|
||||
"messages": messages,
|
||||
},
|
||||
{
|
||||
"configurable": GraphConfig(
|
||||
delay=0.1,
|
||||
user_id=user_id,
|
||||
thread_id=thread_id,
|
||||
),
|
||||
},
|
||||
)
|
||||
if num_mems_expected:
|
||||
# Check if index.upsert was called
|
||||
index.upsert.assert_called_once()
|
||||
# Get named call args
|
||||
vectors = index.upsert.call_args.kwargs["vectors"]
|
||||
rt = get_current_run_tree()
|
||||
rt.outputs = {"upserted": [v["metadata"]["content"] for v in vectors]}
|
||||
assert len(vectors) == 1
|
||||
# Check if the memory was added
|
||||
mem = vectors[0]["metadata"]["content"]
|
||||
assert mem
|
||||
|
||||
|
||||
@test(output_keys=["num_events_expected"])
|
||||
@pytest.mark.parametrize(
|
||||
"messages, num_events_expected",
|
||||
[
|
||||
([("user", "hi")], 0),
|
||||
(
|
||||
[
|
||||
("user", "I went to the beach with my friends today."),
|
||||
("assistant", "That sounds like a fun day."),
|
||||
("user", "You speak the truth."),
|
||||
],
|
||||
1,
|
||||
),
|
||||
(
|
||||
[
|
||||
("user", "I went to the beach with my friends."),
|
||||
("assistant", "That sounds like a fun day."),
|
||||
("user", "I also went to the park with my family - I like the park."),
|
||||
],
|
||||
2,
|
||||
),
|
||||
],
|
||||
)
|
||||
async def test_insert_memory(
|
||||
messages: List[str],
|
||||
num_events_expected: int,
|
||||
):
|
||||
# patch lang_memgpt.graph.index with a mock
|
||||
user_id = "4fddb3ef-fcc9-4ef7-91b6-89e4a3efd112"
|
||||
thread_id = "e1d0b7f7-0a8b-4c5f-8c4b-8a6c9f6e5c7a"
|
||||
with patch("lang_memgpt._utils.get_index") as get_index:
|
||||
index = MagicMock()
|
||||
get_index.return_value = index
|
||||
index.fetch.return_value = {}
|
||||
# When the events are inserted
|
||||
await memgraph.ainvoke(
|
||||
{
|
||||
"messages": messages,
|
||||
},
|
||||
{
|
||||
"configurable": GraphConfig(
|
||||
delay=0.1,
|
||||
user_id=user_id,
|
||||
thread_id=thread_id,
|
||||
),
|
||||
},
|
||||
)
|
||||
if num_events_expected:
|
||||
# Get named call args
|
||||
assert len(index.upsert.call_args_list) == num_events_expected
|
||||
Reference in New Issue
Block a user