mirror of
https://github.com/Mintplex-Labs/abitat.git
synced 2026-07-01 10:05:27 -04:00
fix default model
This commit is contained in:
@@ -0,0 +1,6 @@
|
||||
---
|
||||
'aibitat': patch
|
||||
---
|
||||
|
||||
Fix an issue where the default model wasnt getting replaced when specifying on
|
||||
specific agent
|
||||
+20
-22
@@ -173,7 +173,7 @@ export type FunctionDefinition = {
|
||||
export class AIbitat {
|
||||
private emitter = new EventEmitter()
|
||||
|
||||
private defaultProvider
|
||||
private defaultProvider: ProviderConfig
|
||||
private defaultInterrupt
|
||||
private maxRounds
|
||||
private _chats
|
||||
@@ -194,10 +194,10 @@ export class AIbitat {
|
||||
this.defaultInterrupt = interrupt
|
||||
this.maxRounds = maxRounds
|
||||
|
||||
this.defaultProvider = this.getProviderForConfig({
|
||||
this.defaultProvider = {
|
||||
provider,
|
||||
...rest,
|
||||
})!
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
@@ -285,6 +285,8 @@ export class AIbitat {
|
||||
throw new Error(`Channel configuration "${channel}" not found`)
|
||||
}
|
||||
return {
|
||||
provider: 'openai' as const,
|
||||
model: 'gpt-4' as const,
|
||||
maxRounds: 10,
|
||||
role: 'Group chat manager.',
|
||||
...config,
|
||||
@@ -606,11 +608,7 @@ export class AIbitat {
|
||||
// get the provider that will be used for the manager
|
||||
// if the manager has a provider, use that otherwise
|
||||
// use the GPT-4 because it has a better reasoning
|
||||
const nodeProvider = this.getProviderForConfig(channelConfig)
|
||||
const provider =
|
||||
nodeProvider ||
|
||||
this.getProviderForConfig({provider: 'openai', model: 'gpt-4'})!
|
||||
|
||||
const provider = this.getProviderForConfig(channelConfig)
|
||||
const history = this.getHistory({to: channel})
|
||||
|
||||
// build the messages to send to the provider
|
||||
@@ -704,8 +702,7 @@ ${this.getHistory({to: route.to})
|
||||
?.map(name => this.functions.get(name))
|
||||
.filter(a => !!a) as FunctionDefinition[] | undefined
|
||||
|
||||
const nodeProvider = this.getProviderForConfig(fromConfig)
|
||||
const provider = nodeProvider || this.defaultProvider
|
||||
const provider = this.getProviderForConfig(fromConfig)
|
||||
|
||||
// get the chat completion
|
||||
const content = await provider.create(messages, functions)
|
||||
@@ -814,20 +811,21 @@ ${this.getHistory({to: route.to})
|
||||
* @param config The provider configuration.
|
||||
*/
|
||||
private getProviderForConfig(config: ProviderConfig) {
|
||||
if (typeof config.provider === 'string') {
|
||||
switch (config.provider) {
|
||||
case 'openai':
|
||||
return new OpenAIProvider({model: config.model})
|
||||
|
||||
default:
|
||||
throw new Error(
|
||||
`Unknown provider: ${config.provider}. Please use "openai"`,
|
||||
)
|
||||
}
|
||||
const x = {
|
||||
...this.defaultProvider,
|
||||
...config,
|
||||
}
|
||||
|
||||
if (config.provider) {
|
||||
return config.provider
|
||||
if (typeof x.provider === 'object') {
|
||||
return x.provider
|
||||
}
|
||||
|
||||
switch (x.provider) {
|
||||
case 'openai':
|
||||
return new OpenAIProvider({model: x.model})
|
||||
|
||||
default:
|
||||
throw new Error(`Unknown provider: ${x.provider}. Please use "openai"`)
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -231,6 +231,7 @@ export class OpenAIProvider extends AIProvider<OpenAI> {
|
||||
call: OpenAI.Chat.ChatCompletionMessage.FunctionCall,
|
||||
) {
|
||||
const funcToCall = functions.find(f => f.name === call.name)
|
||||
log(`calling function "${call.name}" with arguments: `, call.arguments)
|
||||
if (!funcToCall) {
|
||||
throw new Error(`Function '${call.name}' not found`)
|
||||
}
|
||||
@@ -245,7 +246,6 @@ export class OpenAIProvider extends AIProvider<OpenAI> {
|
||||
)
|
||||
}
|
||||
|
||||
log('calling function: ', funcToCall.name, 'with arguments: ', json)
|
||||
return await funcToCall.handler(json)
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user