From 2474ddc43f7cabfe4b932cd7e4851a9efcd780a2 Mon Sep 17 00:00:00 2001 From: Jacob Lee Date: Fri, 4 Oct 2024 10:13:44 -0700 Subject: [PATCH] Allow passing WebSocket in directly (#4) --- js_server/src/index.ts | 8 ++------ js_server/src/lib/langchain_openai_voice.ts | 10 ++++++++-- 2 files changed, 10 insertions(+), 8 deletions(-) diff --git a/js_server/src/index.ts b/js_server/src/index.ts index 12a7ae8..8f7320e 100644 --- a/js_server/src/index.ts +++ b/js_server/src/index.ts @@ -1,15 +1,14 @@ import "dotenv/config"; +import { type WebSocket } from "ws"; import { serve } from "@hono/node-server"; import { Hono } from "hono"; import { createNodeWebSocket } from "@hono/node-ws"; import { serveStatic } from "@hono/node-server/serve-static"; -import { WebSocket } from "ws"; import { OpenAIVoiceReactAgent } from "./lib/langchain_openai_voice"; import { INSTRUCTIONS } from "./prompt"; import { TOOLS } from "./tools"; -import { createStreamFromWebsocket } from "./lib/utils"; const app = new Hono(); @@ -30,10 +29,7 @@ app.get( tools: TOOLS, model: "gpt-4o-realtime-preview", }); - await agent.connect( - createStreamFromWebsocket(ws.raw as WebSocket), - ws.send.bind(ws) - ); + await agent.connect(ws.raw as WebSocket, ws.send.bind(ws)); }, onClose: () => { console.log("CLOSING"); diff --git a/js_server/src/lib/langchain_openai_voice.ts b/js_server/src/lib/langchain_openai_voice.ts index 03948df..71c2b78 100644 --- a/js_server/src/lib/langchain_openai_voice.ts +++ b/js_server/src/lib/langchain_openai_voice.ts @@ -198,13 +198,19 @@ export class OpenAIVoiceReactAgent { /** * Connect to the OpenAI API and send and receive messages. - * @param inputStream + * @param websocketOrStream * @param sendOutputChunk */ async connect( - inputStream: AsyncGenerator, + websocketOrStream: AsyncGenerator | WebSocket, sendOutputChunk: (chunk: string) => void | Promise ) { + let inputStream; + if ("next" in websocketOrStream) { + inputStream = websocketOrStream; + } else { + inputStream = createStreamFromWebsocket(websocketOrStream); + } const toolsByName = this.tools.reduce( (toolsByName: Record, tool) => { toolsByName[tool.name] = tool;