Adds TS example agent (#2)

* Adds TS example agent

* Adds README

* Fix calculator
This commit is contained in:
Jacob Lee
2024-10-04 09:56:36 -07:00
committed by GitHub
parent 02360ed634
commit 5bec194aa0
16 changed files with 2990 additions and 3 deletions
+1
View File
@@ -2,3 +2,4 @@ __pycache__
.venv
.mypy_cache
.pytest_cache
.env
+24 -3
View File
@@ -8,6 +8,8 @@ Specifically, we enable this model to call tools by providing it a list of [Lang
## Installation
### Python
Make sure you're running Python 3.10 or later, then install `uv` to be able to run the project:
```bash
@@ -23,15 +25,34 @@ export TAVILY_API_KEY=your_tavily_api_key
Note: the Tavily API key is for the Tavily search engine, you can get an API key [here](https://app.tavily.com/). This is just an example tool, and if you do not want to use it you do not have to (see [Adding your own tools](#adding-your-own-tools))
### TypeScript
Navigate into the `js_server` folder, then install required dependencies with `yarn`:
```bash
yarn
```
You will also need to copy the provided `js_server/.env.example` file to `.env` and fill in your OpenAI and Tavily keys.
## Running the project
To run the project, execute the following command:
### Python
To run the project, execute the following commands:
```bash
cd server
uv run src/server/app.py
```
### TypeScript
```bash
cd js_server
yarn dev
```
## Open the browser
Now you can open the browser and navigate to `http://localhost:3000` to see the project running.
@@ -44,11 +65,11 @@ You may need to make sure that your browser can access your microphone.
## Adding your own tools
You can add your own tools by adding them to the `server/src/server/tools.py` file.
You can add your own tools by adding them to the `server/src/server/tools.py` file for Python or the `js_server/src/tools.ts` folder for TypeScript.
## Adding your own custom instructions
You can add your own custom instructions by adding them to the `server/src/server/prompt.py` file.
You can add your own custom instructions by adding them to the `server/src/server/prompt.py` file for Python or the `js_server/src/prompt.ts` folder for TypeScript.
## Next steps
+2
View File
@@ -0,0 +1,2 @@
OPENAI_API_KEY=
TAVILY_API_KEY=
+28
View File
@@ -0,0 +1,28 @@
# dev
.yarn/
!.yarn/releases
.vscode/*
!.vscode/launch.json
!.vscode/*.code-snippets
.idea/workspace.xml
.idea/usage.statistics.xml
.idea/shelf
# deps
node_modules/
# env
.env
.env.production
# logs
logs/
*.log
npm-debug.log*
yarn-debug.log*
yarn-error.log*
pnpm-debug.log*
lerna-debug.log*
# misc
.DS_Store
+47
View File
@@ -0,0 +1,47 @@
# 🦜🎤 Voice ReAct Agent (TypeScript)
This is an implementation of a [ReAct](https://arxiv.org/abs/2210.03629)-style agent that uses OpenAI's new [Realtime API](https://platform.openai.com/docs/guides/realtime). It is a light [Hono](https://hono.dev/) app that statically serves a simple frontend from `/src/static` as well as a websocket endpoint for handling streaming audio input and output.
Specifically, we enable this model to call tools by providing it a list of [LangChain tools](https://js.langchain.com/docs/how_to/custom_tools/). It is easy to write these custom tools, and you can easily pass these to the model.
![](../static/react.png)
## Installation
Install required dependencies with `yarn`:
```bash
yarn
```
You will also need to copy the provided `.env.example` file to `.env` and fill in your OpenAI and Tavily keys.
## Running the project
```bash
yarn dev
```
## Open the browser
Now you can open the browser and navigate to `http://localhost:3000` to see the project running.
### Enable microphone
You may need to make sure that your browser can access your microphone.
- [Chrome](http://0.0.0.0:3000/)
## Adding your own tools
You can add your own tools by adding them to the `/src/tools.ts` folder for TypeScript.
## Adding your own custom instructions
You can add your own custom instructions by adding them to the `/src/prompt.ts` folder for TypeScript.
## Next steps
- [ ] Enable interrupting the AI
- [ ] Enable changing of instructions/tools based on state
- [ ] Add auth middleware
+23
View File
@@ -0,0 +1,23 @@
{
"name": "js_server",
"scripts": {
"dev": "tsx watch src/index.ts"
},
"dependencies": {
"@hono/node-server": "^1.13.1",
"@hono/node-ws": "^1.0.4",
"@langchain/community": "^0.3.4",
"@langchain/core": "^0.3.7",
"dotenv": "^16.4.5",
"hono": "^4.6.3",
"ws": "^8.18.0",
"zod": "^3.23.8",
"zod-to-json-schema": "^3.23.3"
},
"devDependencies": {
"@types/node": "^20.11.17",
"@types/ws": "^8.5.12",
"tsx": "^4.7.1"
},
"packageManager": "yarn@3.5.1+sha512.8cd0e31bd60779ef4ca92b855fb3462c7ec35ce8b345752b7349a68239776417f46d41d79c8047242d9c93b48a1516f64c7444ebe747d9a02bf26868e6fa1f2b"
}
+53
View File
@@ -0,0 +1,53 @@
import "dotenv/config";
import { serve } from "@hono/node-server";
import { Hono } from "hono";
import { createNodeWebSocket } from "@hono/node-ws";
import { serveStatic } from "@hono/node-server/serve-static";
import { WebSocket } from "ws";
import { OpenAIVoiceReactAgent } from "./lib/langchain_openai_voice";
import { INSTRUCTIONS } from "./prompt";
import { TOOLS } from "./tools";
import { createStreamFromWebsocket } from "./lib/utils";
const app = new Hono();
const { injectWebSocket, upgradeWebSocket } = createNodeWebSocket({ app });
app.use("/", serveStatic({ path: "./static/index.html" }));
app.use("/static/*", serveStatic({ root: "./" }));
app.get(
"/ws",
upgradeWebSocket((c) => ({
onOpen: async (c, ws) => {
if (!process.env.OPENAI_API_KEY) {
return ws.close();
}
const agent = new OpenAIVoiceReactAgent({
instructions: INSTRUCTIONS,
tools: TOOLS,
model: "gpt-4o-realtime-preview",
});
await agent.connect(
createStreamFromWebsocket(ws.raw as WebSocket),
ws.send.bind(ws)
);
},
onClose: () => {
console.log("CLOSING");
},
}))
);
const port = 3000;
const server = serve({
fetch: app.fetch,
port,
});
injectWebSocket(server);
console.log(`Server is running on port ${port}`);
+279
View File
@@ -0,0 +1,279 @@
import WebSocket from "ws";
import { StructuredTool } from "@langchain/core/tools";
import { mergeStreams, createStreamFromWebsocket } from "./utils";
import { zodToJsonSchema } from "zod-to-json-schema";
const DEFAULT_MODEL = "gpt-4o-realtime-preview";
const DEFAULT_URL = "wss://api.openai.com/v1/realtime";
const EVENTS_TO_IGNORE = [
"response.function_call_arguments.delta",
"rate_limits.updated",
"response.audio_transcript.delta",
"response.created",
"response.content_part.added",
"response.content_part.done",
"conversation.item.created",
"response.audio.done",
"session.created",
"session.updated",
"response.done",
"response.output_item.done",
];
class OpenAIWebSocketConnection {
ws?: WebSocket;
url: string;
apiKey?: string;
model: string;
constructor(params: { url?: string; apiKey?: string; model?: string }) {
this.url = params.url ?? DEFAULT_URL;
this.model = params.model ?? DEFAULT_MODEL;
this.apiKey = params.apiKey ?? process.env.OPENAI_API_KEY;
}
async connect() {
const headers = {
Authorization: `Bearer ${this.apiKey}`,
"OpenAI-Beta": "realtime=v1",
};
const finalUrl = `${this.url}?model=${this.model}`;
this.ws = new WebSocket(finalUrl, { headers });
await new Promise<void>((resolve, reject) => {
const timeout = setTimeout(() => {
reject(new Error("Connection timed out after 10 seconds."));
}, 10000);
this.ws?.once("open", () => {
clearTimeout(timeout);
resolve();
});
this.ws?.once("error", (error) => {
clearTimeout(timeout);
reject(error);
});
});
}
sendEvent(event: Record<string, unknown>) {
const formattedEvent = JSON.stringify(event);
if (this.ws === undefined) {
throw new Error("Socket connection is not active, call .connect() first");
}
this.ws?.send(formattedEvent);
}
async *eventStream() {
if (!this.ws) {
throw new Error("Socket connection is not active, call .connect() first");
}
yield* createStreamFromWebsocket(this.ws);
}
}
/**
* Can accept function calls and emits function call outputs to a stream.
*/
class VoiceToolExecutor {
protected toolsByName: Record<string, StructuredTool>;
protected triggerPromise: Promise<any> | null = null;
protected triggerResolve: ((value: any) => void) | null = null;
protected lock: Promise<void> | null = null;
constructor(toolsByName: Record<string, StructuredTool>) {
this.toolsByName = toolsByName;
}
protected async triggerFunc(): Promise<any> {
if (!this.triggerPromise) {
this.triggerPromise = new Promise((resolve) => {
this.triggerResolve = resolve;
});
}
return this.triggerPromise;
}
async addToolCall(toolCall: any): Promise<void> {
while (this.lock) {
await this.lock;
}
this.lock = (async () => {
if (this.triggerResolve) {
this.triggerResolve(toolCall);
this.triggerPromise = null;
this.triggerResolve = null;
} else {
throw new Error("Tool call adding already in progress");
}
})();
await this.lock;
this.lock = null;
}
protected async createToolCallTask(toolCall: any): Promise<any> {
const tool = this.toolsByName[toolCall.name];
if (!tool) {
throw new Error(
`Tool ${toolCall.name} not found. Must be one of ${Object.keys(
this.toolsByName
)}`
);
}
let args;
try {
args = JSON.parse(toolCall.arguments);
} catch (error) {
throw new Error(
`Failed to parse arguments '${toolCall.arguments}'. Must be valid JSON.`
);
}
const result = await tool.call(args);
const resultStr =
typeof result === "string" ? result : JSON.stringify(result);
return {
type: "conversation.item.create",
item: {
id: toolCall.call_id,
call_id: toolCall.call_id,
type: "function_call_output",
output: resultStr,
},
};
}
async *outputIterator(): AsyncGenerator<any, void, unknown> {
while (true) {
const toolCall = await this.triggerFunc();
try {
const result = await this.createToolCallTask(toolCall);
yield result;
} catch (error: any) {
yield {
type: "conversation.item.create",
item: {
id: toolCall.call_id,
call_id: toolCall.call_id,
type: "function_call_output",
output: `Error: ${error.message}`,
},
};
}
}
}
}
export class OpenAIVoiceReactAgent {
protected connection: OpenAIWebSocketConnection;
protected instructions?: string;
protected tools: StructuredTool[];
constructor(params: {
instructions?: string;
tools?: StructuredTool[];
url?: string;
apiKey?: string;
model?: string;
}) {
this.connection = new OpenAIWebSocketConnection({
url: params.url,
apiKey: params.apiKey,
model: params.model,
});
this.instructions = params.instructions;
this.tools = params.tools ?? [];
}
/**
* Connect to the OpenAI API and send and receive messages.
* @param inputStream
* @param sendOutputChunk
*/
async connect(
inputStream: AsyncGenerator<string>,
sendOutputChunk: (chunk: string) => void | Promise<void>
) {
const toolsByName = this.tools.reduce(
(toolsByName: Record<string, StructuredTool>, tool) => {
toolsByName[tool.name] = tool;
return toolsByName;
},
{}
);
const toolExecutor = new VoiceToolExecutor(toolsByName);
await this.connection.connect();
const modelReceiveStream = this.connection.eventStream();
// Send tools and instructions with initial chunk
const toolDefs = Object.values(toolsByName).map((tool) => ({
type: "function",
name: tool.name,
description: tool.description,
parameters: zodToJsonSchema(tool.schema),
}));
this.connection.sendEvent({
type: "session.update",
session: {
instructions: this.instructions,
input_audio_transcription: {
model: "whisper-1",
},
tools: toolDefs,
},
});
for await (const [streamKey, dataRaw] of mergeStreams({
input_mic: inputStream,
output_speaker: modelReceiveStream,
tool_outputs: toolExecutor.outputIterator(),
})) {
let data: any;
try {
data = typeof dataRaw === "string" ? JSON.parse(dataRaw) : dataRaw;
} catch (error) {
console.error("Error decoding data:", dataRaw);
continue;
}
if (streamKey === "input_mic") {
this.connection.sendEvent(data);
} else if (streamKey === "tool_outputs") {
console.log("tool output", data);
this.connection.sendEvent(data);
this.connection.sendEvent({ type: "response.create", response: {} });
} else if (streamKey === "output_speaker") {
const { type } = data;
if (type === "response.audio.delta") {
sendOutputChunk(JSON.stringify(data));
} else if (type === "response.audio_buffer.speech_started") {
console.log("interrupt");
sendOutputChunk(JSON.stringify(data));
} else if (type === "error") {
console.error("error:", data);
} else if (type === "response.function_call_arguments.done") {
console.log("tool call", data);
toolExecutor.addToolCall(data);
} else if (type === "response.audio_transcript.done") {
console.log("model:", data.transcript);
} else if (
type === "conversation.item.input_audio_transcription.completed"
) {
console.log("user:", data.transcript);
} else if (!EVENTS_TO_IGNORE.includes(type)) {
console.log(type);
}
}
}
}
}
+77
View File
@@ -0,0 +1,77 @@
import WebSocket from "ws";
/**
* Merge multiple streams into one stream.
*/
export async function* mergeStreams<T>(
streams: Record<string, AsyncGenerator<T>>
): AsyncGenerator<[string, T]> {
// start the first iteration of each output iterator
const tasks = new Map(
Object.entries(streams).map(([key, stream], i) => {
return [key, stream.next().then((result) => ({ key, stream, result }))];
})
);
// yield chunks as they become available,
// starting new iterations as needed,
// until all iterators are done
while (tasks.size) {
const { key, result, stream } = await Promise.race(tasks.values());
tasks.delete(key);
if (!result.done) {
yield [key, result.value];
tasks.set(
key,
stream.next().then((result) => ({ key, stream, result }))
);
}
}
}
export async function* createStreamFromWebsocket(ws: WebSocket) {
const messageQueue: string[] = [];
let resolveMessage: ((value: string | PromiseLike<string>) => void) | null =
null;
let rejectMessage: ((reason?: any) => void) | null = null;
const onMessage = (data: WebSocket.Data) => {
const message = data.toString();
if (resolveMessage) {
resolveMessage(message);
resolveMessage = null;
rejectMessage = null;
} else {
messageQueue.push(message);
}
};
const onError = (error: Error) => {
if (rejectMessage) {
rejectMessage(error);
resolveMessage = null;
rejectMessage = null;
}
};
ws.on("message", onMessage);
ws.on("error", onError);
try {
while (ws.readyState === WebSocket.OPEN) {
let message: string;
if (messageQueue.length > 0) {
message = messageQueue.shift()!;
} else {
message = await new Promise<string>((resolve, reject) => {
resolveMessage = resolve;
rejectMessage = reject;
});
}
yield JSON.parse(message);
}
} finally {
ws.off("message", onMessage);
ws.off("error", onError);
}
}
+1
View File
@@ -0,0 +1 @@
export const INSTRUCTIONS = "You are a helpful assistant. Speak English.";
+30
View File
@@ -0,0 +1,30 @@
import { z } from "zod";
import { tool } from "@langchain/core/tools";
import { TavilySearchResults } from "@langchain/community/tools/tavily_search";
const add = tool(
async ({ a, b }) => {
return a + b;
},
{
name: "add",
description:
"Add two numbers. Please let the user know that you're adding the numbers BEFORE you call the tool",
schema: z.object({
a: z.number(),
b: z.number(),
}),
}
);
const tavilyTool = new TavilySearchResults({
maxResults: 5,
kwargs: {
includeAnswer: true,
},
});
tavilyTool.description = `This is a search tool for accessing the internet.\n\nLet the user know you're asking your friend Tavily for help before you call the tool.`;
export const TOOLS = [add, tavilyTool];
@@ -0,0 +1,34 @@
// source: https://github.com/Azure-Samples/aisearch-openai-rag-audio/blob/7f685a8969e3b63e8c3ef345326c21f5ab82b1c3/app/frontend/public/audio-playback-worklet.js
class AudioPlaybackWorklet extends AudioWorkletProcessor {
constructor() {
super();
this.port.onmessage = this.handleMessage.bind(this);
this.buffer = [];
}
handleMessage(event) {
if (event.data === null) {
this.buffer = [];
return;
}
this.buffer.push(...event.data);
}
process(inputs, outputs, parameters) {
const output = outputs[0];
const channel = output[0];
if (this.buffer.length > channel.length) {
const toProcess = this.buffer.slice(0, channel.length);
this.buffer = this.buffer.slice(channel.length);
channel.set(toProcess.map(v => v / 32768));
} else {
channel.set(this.buffer.map(v => v / 32768));
this.buffer = [];
}
return true;
}
}
registerProcessor("audio-playback-worklet", AudioPlaybackWorklet);
@@ -0,0 +1,31 @@
// source: https://github.com/Azure-Samples/aisearch-openai-rag-audio/blob/7f685a8969e3b63e8c3ef345326c21f5ab82b1c3/app/frontend/public/audio-processor-worklet.js
const MIN_INT16 = -0x8000;
const MAX_INT16 = 0x7fff;
class PCMAudioProcessor extends AudioWorkletProcessor {
constructor() {
super();
}
process(inputs, outputs, parameters) {
const input = inputs[0];
if (input.length > 0) {
const float32Buffer = input[0];
const int16Buffer = this.float32ToInt16(float32Buffer);
this.port.postMessage(int16Buffer);
}
return true;
}
float32ToInt16(float32Array) {
const int16Array = new Int16Array(float32Array.length);
for (let i = 0; i < float32Array.length; i++) {
let val = Math.floor(float32Array[i] * MAX_INT16);
val = Math.max(MIN_INT16, Math.min(MAX_INT16, val));
int16Array[i] = val;
}
return int16Array;
}
}
registerProcessor("audio-processor-worklet", PCMAudioProcessor);
+191
View File
@@ -0,0 +1,191 @@
<!DOCTYPE html>
<html lang="en">
<head>
<meta charset="UTF-8">
<meta name="viewport" content="width=device-width, initial-scale=1.0">
<title>Microphone to Speaker</title>
<style>
body {
font-family: Arial, sans-serif;
display: flex;
justify-content: center;
align-items: center;
height: 100vh;
margin: 0;
background-color: #f0f0f0;
}
#toggleAudio {
font-size: 18px;
padding: 10px 20px;
cursor: pointer;
background-color: #4CAF50;
color: white;
border: none;
border-radius: 5px;
transition: background-color 0.3s;
}
#toggleAudio:hover {
background-color: #45a049;
}
</style>
</head>
<body>
<button id="toggleAudio">Start Audio</button>
<script>
// Create audio context
const BUFFER_SIZE = 4800;
class Player {
constructor() {
this.playbackNode = null;
}
async init(sampleRate) {
const audioContext = new AudioContext({ sampleRate });
await audioContext.audioWorklet.addModule("/static/audio-playback-worklet.js");
this.playbackNode = new AudioWorkletNode(audioContext, "audio-playback-worklet");
this.playbackNode.connect(audioContext.destination);
}
play(buffer) {
if (this.playbackNode) {
this.playbackNode.port.postMessage(buffer);
}
}
stop() {
if (this.playbackNode) {
this.playbackNode.port.postMessage(null);
}
}
}
class Recorder {
constructor(onDataAvailable) {
this.onDataAvailable = onDataAvailable;
this.audioContext = null;
this.mediaStream = null;
this.mediaStreamSource = null;
this.workletNode = null;
}
async start(stream) {
try {
if (this.audioContext) {
await this.audioContext.close();
}
this.audioContext = new (window.AudioContext || window.webkitAudioContext)({ sampleRate: 24000 });
await this.audioContext.audioWorklet.addModule("/static/audio-processor-worklet.js");
this.mediaStream = stream;
this.mediaStreamSource = this.audioContext.createMediaStreamSource(this.mediaStream);
this.workletNode = new AudioWorkletNode(this.audioContext, "audio-processor-worklet");
this.workletNode.port.onmessage = event => {
this.onDataAvailable(event.data.buffer);
};
this.mediaStreamSource.connect(this.workletNode);
this.workletNode.connect(this.audioContext.destination);
} catch (error) {
this.stop();
}
}
async stop() {
if (this.mediaStream) {
this.mediaStream.getTracks().forEach(track => track.stop());
this.mediaStream = null;
}
if (this.audioContext) {
await this.audioContext.close();
this.audioContext = null;
}
this.mediaStreamSource = null;
this.workletNode = null;
}
}
// Function to get microphone input and send it to WebSocket
async function startAudio() {
try {
// handle output -> speaker stuff
const ws = new WebSocket("ws://localhost:3000/ws");
const audioPlayer = new Player();
audioPlayer.init(24000);
ws.onmessage = event => {
const data = JSON.parse(event.data);
if (data?.type !== 'response.audio.delta') return;
const binary = atob(data.delta);
const bytes = Uint8Array.from(binary, c => c.charCodeAt(0));
const pcmData = new Int16Array(bytes.buffer);
audioPlayer.play(pcmData);
};
let buffer = new Uint8Array();
const appendToBuffer = (newData) => {
const newBuffer = new Uint8Array(buffer.length + newData.length);
newBuffer.set(buffer);
newBuffer.set(newData, buffer.length);
buffer = newBuffer;
};
const handleAudioData = (data) => {
const uint8Array = new Uint8Array(data);
appendToBuffer(uint8Array);
if (buffer.length >= BUFFER_SIZE) {
const toSend = new Uint8Array(buffer.slice(0, BUFFER_SIZE));
buffer = new Uint8Array(buffer.slice(BUFFER_SIZE));
const regularArray = String.fromCharCode(...toSend);
const base64 = btoa(regularArray);
ws.send(JSON.stringify({type: 'input_audio_buffer.append', audio: base64}));
}
};
// handle microphone -> input websocket
const audioRecorder = new Recorder(handleAudioData);
const stream = await navigator.mediaDevices.getUserMedia({ audio: true });
await audioRecorder.start(stream);
} catch (error) {
console.error('Error accessing the microphone', error);
alert('Error accessing the microphone. Please check your settings and try again.');
}
}
// Button to toggle audio
const toggleButton = document.getElementById('toggleAudio');
let isAudioOn = false;
toggleButton.addEventListener('click', async () => {
if (!isAudioOn) {
await startAudio();
toggleButton.textContent = 'Stop Audio';
isAudioOn = true;
} else {
// audioContext.suspend();
toggleButton.textContent = 'Start Audio';
isAudioOn = false;
}
});
</script>
</body>
</html>
+14
View File
@@ -0,0 +1,14 @@
{
"compilerOptions": {
"target": "ESNext",
"module": "ESNext",
"moduleResolution": "Bundler",
"strict": true,
"skipLibCheck": true,
"types": [
"node"
],
"jsx": "react-jsx",
"jsxImportSource": "hono/jsx",
}
}
+2155
View File
File diff suppressed because it is too large Load Diff