Clean up code

This commit is contained in:
Wlad Paiva
2023-10-17 11:06:01 -03:00
parent 5b38a65925
commit 937ac287d9
8 changed files with 214 additions and 109 deletions
+5
View File
@@ -0,0 +1,5 @@
---
'aibitat': patch
---
`.on` method is being replaced by especialized `.onXXX` methods
+4 -6
View File
@@ -55,6 +55,8 @@ setting them on the node config.
## Usage
> For a more complete example, check out the [examples](./examples) folder.
You can install the package:
```bash
@@ -88,9 +90,7 @@ const aibitat = new AIbitat({
},
})
aibitat.on('message', ({from, to, content}) =>
console.log(`${from}: ${content}`),
)
aibitat.onMessage(({from, to, content}) => console.log(`${from}: ${content}`))
// 🧑: How much is 2 + 2?
// 🐭: The sum of 2 + 2 is 4.
// 🦁: That is correct.
@@ -148,9 +148,7 @@ to each other. The `config` object is used to configure each node.
You can listen to events using the `on` method:
```ts
aibitat.on('message', ({from, to, content}) =>
console.log(`${from}: ${content}`),
)
aibitat.onMessage(({from, to, content}) => console.log(`${from}: ${content}`))
```
The following events are available:
+3 -3
View File
@@ -16,10 +16,10 @@ const aibitat = new AIbitat({
},
})
aibitat.on('message', terminal.print)
aibitat.on('terminate', () => console.timeEnd('🚀 chat finished'))
aibitat.onMessage(terminal.print)
aibitat.onTerminate(() => console.timeEnd('🚀 chat finished'))
aibitat.on('interrupt', async node => {
aibitat.onInterrupt(async node => {
const feedback = await terminal.askForFeedback(node)
await aibitat.continue(feedback)
})
+2 -2
View File
@@ -21,8 +21,8 @@ const aibitat = new AIbitat({
},
})
aibitat.on('message', terminal.print)
aibitat.on('terminate', () => console.timeEnd('🚀 chat finished'))
aibitat.onMessage(terminal.print)
aibitat.onTerminate(() => console.timeEnd('🚀 chat finished'))
// Ask for the mathematical problem of the chat before starting the conversation
const math = await input({
+3 -3
View File
@@ -61,9 +61,9 @@ const aibitat = new AIbitat({
},
})
aibitat.on('message', terminal.print)
aibitat.on('terminate', () => console.timeEnd('🚀 chat finished'))
aibitat.on('interrupt', async node => {
aibitat.onMessage(terminal.print)
aibitat.onTerminate(() => console.timeEnd('🚀 chat finished'))
aibitat.onInterrupt(async node => {
const feedback = await terminal.askForFeedback(node)
await aibitat.continue(feedback)
})
+3 -3
View File
@@ -25,10 +25,10 @@ const aibitat = new AIbitat({
},
})
aibitat.on('message', terminal.print)
aibitat.on('terminate', () => console.timeEnd('🚀 chat finished'))
aibitat.onMessage(terminal.print)
aibitat.onTerminate(() => console.timeEnd('🚀 chat finished'))
aibitat.on('interrupt', async node => {
aibitat.onInterrupt(async node => {
const feedback = await terminal.askForFeedback(node)
await aibitat.continue(feedback)
})
+5 -5
View File
@@ -137,7 +137,7 @@ describe('direct message', () => {
const aibitat = new AIbitat(defaultaibitat)
const callback = mock(() => {})
aibitat.on('message', callback)
aibitat.onMessage(callback)
await aibitat.start(defaultStart)
@@ -153,7 +153,7 @@ describe('direct message', () => {
})
const callback = mock(() => {})
aibitat.on('interrupt', callback)
aibitat.onInterrupt(callback)
await aibitat.start(defaultStart)
@@ -172,7 +172,7 @@ describe('direct message', () => {
})
const callback = mock(() => {})
aibitat.on('interrupt', callback)
aibitat.onInterrupt(callback)
await aibitat.start(defaultStart)
@@ -188,7 +188,7 @@ describe('direct message', () => {
// so I have to work around it.
// https://github.com/oven-sh/bun/issues/1825
const p = new Promise(async resolve => {
aibitat.on('interrupt', async () => {
aibitat.onInterrupt(async () => {
if (aibitat.chats.length < 4) {
await aibitat.continue()
} else {
@@ -218,7 +218,7 @@ describe('direct message', () => {
// so I have to work around it.
// https://github.com/oven-sh/bun/issues/1825
const p = new Promise(async resolve => {
aibitat.on('interrupt', a => {
aibitat.onInterrupt(a => {
if (aibitat.chats.length < 4) {
aibitat.continue('my feedback')
} else {
+189 -87
View File
@@ -43,6 +43,7 @@ export type BaseNodeConfig = ProviderConfig & {
* Agents are fully autonomous and can solve tasks with LLM.
*/
export type Agent = BaseNodeConfig & {
/** The type of the node */
type: 'agent'
}
@@ -50,6 +51,7 @@ export type Agent = BaseNodeConfig & {
* Managers are designed to take care of a group of agents.
*/
export type Manager = BaseNodeConfig & {
/** The type of the node */
type: 'manager'
/**
@@ -64,6 +66,7 @@ export type Manager = BaseNodeConfig & {
* when he understands that the task is completed.
*/
export type Assistant = BaseNodeConfig & {
/** The type of the node */
type: 'assistant'
}
@@ -105,7 +108,8 @@ type Chat = {
*/
type ChatState = Omit<Chat, 'content'> & {
content?: string
state: 'success' | 'loading' | 'error' | 'interrupt'
// state: 'success' | 'loading' | 'error' | 'interrupt'
state: 'success' | 'interrupt'
}
/**
@@ -117,7 +121,14 @@ type History = Array<ChatState>
* AIbitat props.
*/
export type AIbitatProps = ProviderConfig & {
/**
* The nodes and their connections.
*/
nodes: Nodes
/**
* The configuration for all nodes.
*/
config: Config
/**
@@ -140,10 +151,10 @@ export type AIbitatProps = ProviderConfig & {
}
/**
* AIbitat is a class that manages the aibitat of a chat.
* AIbitat is a class that manages the conversation between agents.
* It is designed to solve a task with LLM.
*
* Guiding the chat through a aibitat of nodes.
* Guiding the chat through a graph of nodes.
*/
export class AIbitat {
private emitter = new EventEmitter()
@@ -185,6 +196,133 @@ export class AIbitat {
return this._chats
}
/**
* Get the specific node configuration.
*
* @param node The name of the node.
* @throws When the node configuration is not found.
* @returns The node configuration.
*/
private getNodeConfig(node: string) {
const config = this.config[node]
if (!config) {
throw new Error(`Node configuration "${node}" not found`)
}
return config
}
/**
* Get the connections of a node.
*
* @param node The name of the node.
* @throws When the node connections are not found.
* @returns The node connections.
*/
private getNodeConnections(node: string) {
const connections = this.nodes[node]
if (!connections) {
throw new Error(`Node connections "${node}" not found`)
}
return connections
}
/**
* Get the members of a group.
* @throws When the group is not defined as an array in the connections.
* @param node The name of the group.
* @returns The members of the group.
*/
private getGroupMembers(node: string) {
const group = this.getNodeConnections(node)
if (!Array.isArray(group)) {
throw new Error(
`Group ${node} is not defined as an array in your connections`,
)
}
return group
}
/**
* Triggered when a chat is terminated. After this, the chat can't be continued.
*
* @param listener
* @returns
*/
public onTerminate(listener: (node: string) => void) {
this.emitter.on('terminate', listener)
return this
}
/**
* Terminate the chat. After this, the chat can't be continued.
*
* @param node Last node to chat with
*/
private terminate(node: string) {
this.emitter.emit('terminate', node)
}
/**
* Triggered when a chat is interrupted by a node.
*
* @param listener
* @returns
*/
public onInterrupt(listener: (chat: {from: string; to: string}) => void) {
this.emitter.on('interrupt', listener)
return this
}
/**
* Interruption the chat.
*
* @param chat The nodes that participated in the interruption.
* @returns
*/
private interrupt(chat: {from: string; to: string}) {
this._chats.push({
...chat,
state: 'interrupt',
})
this.emitter.emit('interrupt', chat)
}
/**
* Triggered when a message is added to the chat history.
* This can either be the first message or a reply to a message.
*
* @param listener
* @returns
*/
public onMessage(listener: (chat: ChatState) => void) {
this.emitter.on('message', listener)
return this
}
/**
* Register a new successful message in the chat history.
* This will trigger the `onMessage` event.
*
* @param message
*/
private newMessage(message: {from: string; to: string; content: string}) {
const chat = {
...message,
state: 'success' as const,
}
this._chats.push(chat)
this.emitter.emit('message', chat)
return this
}
/**
* Start a new chat.
*
* @param message The message to start the chat.
*/
async start(message: Chat) {
log(
`starting a chat from ${chalk.yellow(message.from)} to ${chalk.yellow(
@@ -192,17 +330,10 @@ export class AIbitat {
)} with ${chalk.green(message.content)}`,
)
const x = {
...message,
state: 'success' as const,
}
// register the message in the chat history
this.newMessage(message)
// chats have no state
this._chats.push(x)
this.emitter.emit('message', {
...x,
state: 'success',
})
// ask the node to reply
await this.chat({
to: message.from,
from: message.to,
@@ -227,14 +358,9 @@ export class AIbitat {
// check if the message is for a group
// if it is, select the next node to chat with from the group
// and then ask them to reply.
const fromNode = this.config[message.from]
if (!fromNode) {
throw new Error(`Node configuration "${message.from}" not found`)
}
const fromNode = this.getNodeConfig(message.from)
const isManager = fromNode.type === 'manager'
if (isManager) {
if (fromNode.type === 'manager') {
// select a node from the group
const nextNode = await this.selectNext(message.from)
@@ -242,25 +368,7 @@ export class AIbitat {
// TODO: should it throw an error or keep the chat alive when there is no node to chat with in the group?
// maybe it should wrap up the chat and reply to the original node
// For now, it will terminate the chat
this.emitter.emit('terminate', {...message, content: 'TERMINATE'})
return
}
const group = this.nodes[message.from]
if (!Array.isArray(group)) {
throw new Error(
`Group ${message.from} is not defined as an array in your nodes`,
)
}
const {maxRounds = 10} = fromNode
// get chats only from the group's nodes
const rounds = this.getHistory({to: message.from}).filter(chat =>
group.includes(chat.from),
).length
if (rounds >= maxRounds) {
this.emitter.emit('terminate', message.to)
this.terminate(message.from)
return
}
@@ -270,11 +378,19 @@ export class AIbitat {
}
if (this.shouldNodeInterrupt(nextNode)) {
this._chats.push({
...nextChat,
state: 'interrupt',
})
this.emitter.emit('interrupt', nextChat)
this.interrupt(nextChat)
return
}
// get chats only from the group's nodes
const history = this.getHistory({to: message.from})
const group = this.getGroupMembers(message.from)
const rounds = history.filter(chat => group.includes(chat.from)).length
// TODO: maybe this default should be defined somewhere else
const {maxRounds = 10} = fromNode
if (rounds >= maxRounds) {
this.terminate(message.to)
return
}
@@ -289,18 +405,14 @@ export class AIbitat {
reply === 'TERMINATE' ||
this.hasReachedMaximumRounds(message.from, message.to)
) {
this.emitter.emit('terminate', message.to)
this.terminate(message.to)
return
}
const newChat = {to: message.from, from: message.to}
if (reply === 'INTERRUPT' || this.shouldNodeInterrupt(message.to)) {
this._chats.push({
...newChat,
state: 'interrupt',
})
this.emitter.emit('interrupt', newChat)
this.interrupt(newChat)
return
}
@@ -327,17 +439,18 @@ export class AIbitat {
}
/**
* Select the next node to chat with from a group.
* Select the next node to chat with from a group. The node will be selected based on the history of chats.
* It will select the node that has not reached the maximum number of rounds yet and has not chatted with the manager in the last round.
* If it could not determine the next node, it will return a random node.
*
* @param manager The manager node.
* @returns The name of the node to chat with.
*/
private async selectNext(manager: string) {
// get all members of the group
const nodes = this.nodes[manager]
if (!nodes || !Array.isArray(nodes)) {
throw new Error(`Group ${manager} not found`)
}
const nodes = this.getGroupMembers(manager)
// TODO: move this to when the group is created
// warn if the group is underpopulated
if (nodes.length < 3) {
console.warn(
@@ -360,6 +473,7 @@ export class AIbitat {
}
if (!availableNodes.length) {
// TODO: what should it do when there is no node to chat with?
return
}
@@ -373,6 +487,7 @@ export class AIbitat {
const history = this.getHistory({to: manager})
// build the messages to send to the provider
const messages = [
{
role: 'system' as const,
@@ -396,6 +511,8 @@ Only return the role.
},
]
// ask the provider to select the next node to chat with
// and remove the brackets from the response
const name = (await provider.create(messages)).replace(/^\[|\]$/g, '')
if (this.config[name]) {
return name
@@ -420,20 +537,15 @@ Only return the role.
*/
private async reply({from, to}: {from: string; to: string}) {
// get the provider for the node that will reply
const nodeProvider = this.getProviderFromConfig(this.config[from])
const nodeProvider = this.getProviderFromConfig(this.getNodeConfig(from))
const provider = nodeProvider || this.defaultProvider
const newChat: ChatState = {
from,
to,
state: 'loading',
}
this._chats.push(newChat)
const isManager = this.config[to].type === 'manager'
const isManager = this.getNodeConfig(to).type === 'manager'
let chatHistory: Message[]
// if the node is a manager, send the group chat history to the provider
// otherwise, send the chat history between the two nodes
if (isManager) {
chatHistory = [
{
@@ -453,7 +565,7 @@ ${this.getHistory({to})
} else {
chatHistory = this.getHistory({from, to}).map(c => ({
content: c.content,
role: c.from == to ? ('user' as const) : ('assistant' as const),
role: c.from === to ? ('user' as const) : ('assistant' as const),
}))
}
@@ -470,14 +582,19 @@ ${this.getHistory({to})
// get the chat completion
const content = await provider.create(messages)
// TODO: add error handling
newChat.state = 'success'
newChat.content = content
this.emitter.emit('message', newChat)
this.newMessage({from, to, content})
return content
}
/**
* Continue the chat from the last interruption.
* If the last chat was not an interruption, it will throw an error.
* Provide a feedback where it was interrupted if you want to.
*
* @param feedback The feedback to the interruption if any.
* @returns
*/
public async continue(feedback?: string | null) {
const lastChat = this._chats.at(-1)
if (!lastChat || lastChat.state !== 'interrupt') {
@@ -537,26 +654,11 @@ ${this.getHistory({to})
}[]
}
public on(event: 'terminate', listener: (node: string) => void): this
public on(
event: 'interrupt',
listener: (chat: {from: string; to: string}) => void,
): this
public on(event: 'message', listener: (chat: ChatState) => void): this
/**
* Get provider based on configurations.
* If the provider is a string, it will return the default provider for that string.
*
* @param event
* @param listener
* @returns
*/
public on(event: string, listener: (...args: any[]) => void) {
this.emitter.on(event, listener)
return this
}
/**
* Get provider based on configurations
* @param config The provider configuration.
*/
private getProviderFromConfig(config: ProviderConfig) {
if (typeof config.provider === 'string') {