This commit is contained in:
Wlad Paiva
2023-10-25 12:12:45 -03:00
parent 0aca38ea5d
commit 19ebcc47e3
2 changed files with 27 additions and 18 deletions
+6
View File
@@ -0,0 +1,6 @@
---
'aibitat': patch
---
Fix a bug where the provider given to AIbitat was not replacing the default
provider for channels
+21 -18
View File
@@ -285,8 +285,6 @@ export class AIbitat {
throw new Error(`Channel configuration "${channel}" not found`)
}
return {
provider: 'openai' as const,
model: 'gpt-4' as const,
maxRounds: 10,
role: 'Group chat manager.',
...config,
@@ -567,7 +565,7 @@ export class AIbitat {
/**
* Select the next node to chat with from a group. The node will be selected based on the history of chats.
* It will select the node that has not reached the maximum number of rounds yet and has not chatted with the manager in the last round.
* It will select the node that has not reached the maximum number of rounds yet and has not chatted with the channel in the last round.
* If it could not determine the next node, it will return a random node.
*
* @param channel The name of the group.
@@ -591,7 +589,7 @@ export class AIbitat {
node => !this.hasReachedMaximumRounds(channel, node),
)
// remove the last node that chatted with the manager so it doesn't chat again
// remove the last node that chatted with the channel so it doesn't chat again
const lastChat = this._chats.filter(c => c.to === channel).at(-1)
if (lastChat) {
const index = availableNodes.indexOf(lastChat.from)
@@ -605,10 +603,15 @@ export class AIbitat {
return
}
// get the provider that will be used for the manager
// if the manager has a provider, use that otherwise
// get the provider that will be used for the channel
// if the channel has a provider, use that otherwise
// use the GPT-4 because it has a better reasoning
const provider = this.getProviderForConfig(channelConfig)
const provider = this.getProviderForConfig({
// @ts-expect-error
model: 'gpt-4',
...this.defaultProvider,
...channelConfig,
})
const history = this.getHistory({to: channel})
// build the messages to send to the provider
@@ -702,7 +705,10 @@ ${this.getHistory({to: route.to})
?.map(name => this.functions.get(name))
.filter(a => !!a) as FunctionDefinition[] | undefined
const provider = this.getProviderForConfig(fromConfig)
const provider = this.getProviderForConfig({
...this.defaultProvider,
...fromConfig,
})
// get the chat completion
const content = await provider.create(messages, functions)
@@ -811,21 +817,18 @@ ${this.getHistory({to: route.to})
* @param config The provider configuration.
*/
private getProviderForConfig(config: ProviderConfig) {
const x = {
...this.defaultProvider,
...config,
if (typeof config.provider === 'object') {
return config.provider
}
if (typeof x.provider === 'object') {
return x.provider
}
switch (x.provider) {
switch (config.provider) {
case 'openai':
return new OpenAIProvider({model: x.model})
return new OpenAIProvider({model: config.model})
default:
throw new Error(`Unknown provider: ${x.provider}. Please use "openai"`)
throw new Error(
`Unknown provider: ${config.provider}. Please use "openai"`,
)
}
}