Compare commits

...

5 Commits

Author SHA1 Message Date
Alex Yang 0f47d185c3 feat: code 2024-01-15 12:52:25 -06:00
Alex Yang 8e2beaddca Merge remote-tracking branch 'origin/main' into himself65/fix-type
# Conflicts:
#	packages/core/src/llm/LLM.ts
2024-01-15 12:49:51 -06:00
Alex Yang bd3a7fd450 feat: improve 2024-01-15 12:48:32 -06:00
Alex Yang 2f1894b251 docs(changeset): fix: openai type might be missing 2024-01-14 21:32:18 -06:00
Alex Yang 2f1afecea7 fix: abstract openai class 2024-01-14 21:30:54 -06:00
6 changed files with 134 additions and 101 deletions
+5
View File
@@ -0,0 +1,5 @@
---
"llamaindex": patch
---
fix: openai type might be missing
+37 -32
View File
@@ -13,8 +13,8 @@ export enum OpenAIEmbeddingModelType {
TEXT_EMBED_ADA_002 = "text-embedding-ada-002",
}
export class OpenAIEmbedding extends BaseEmbedding {
model: OpenAIEmbeddingModelType | string;
export abstract class OpenAIEmbeddingLike extends BaseEmbedding {
abstract model: string;
// OpenAI session params
apiKey?: string = undefined;
@@ -27,15 +27,47 @@ export class OpenAIEmbedding extends BaseEmbedding {
session: OpenAISession;
constructor(init?: Partial<OpenAIEmbedding> & { azure?: AzureOpenAIConfig }) {
constructor(init?: Partial<OpenAIEmbeddingLike>) {
super();
this.model = OpenAIEmbeddingModelType.TEXT_EMBED_ADA_002;
this.maxRetries = init?.maxRetries ?? 10;
this.timeout = init?.timeout ?? 60 * 1000; // Default is 60 seconds
this.additionalSessionOptions = init?.additionalSessionOptions;
this.apiKey = init?.apiKey ?? undefined;
this.session =
init?.session ??
getOpenAISession({
apiKey: this.apiKey,
maxRetries: this.maxRetries,
timeout: this.timeout,
...this.additionalSessionOptions,
});
}
private async getOpenAIEmbedding(input: string) {
const { data } = await this.session.openai.embeddings.create({
model: this.model,
input,
});
return data[0].embedding;
}
async getTextEmbedding(text: string): Promise<number[]> {
return this.getOpenAIEmbedding(text);
}
async getQueryEmbedding(query: string): Promise<number[]> {
return this.getOpenAIEmbedding(query);
}
}
export class OpenAIEmbedding extends OpenAIEmbeddingLike {
public override model: OpenAIEmbeddingModelType;
constructor(init?: Partial<OpenAIEmbedding> & { azure?: AzureOpenAIConfig }) {
super(init);
this.model = init?.model ?? OpenAIEmbeddingModelType.TEXT_EMBED_ADA_002;
if (init?.azure || shouldUseAzure()) {
const azureConfig = getAzureConfigFromEnv({
...init?.azure,
@@ -60,33 +92,6 @@ export class OpenAIEmbedding extends BaseEmbedding {
defaultQuery: { "api-version": azureConfig.apiVersion },
...this.additionalSessionOptions,
});
} else {
this.apiKey = init?.apiKey ?? undefined;
this.session =
init?.session ??
getOpenAISession({
apiKey: this.apiKey,
maxRetries: this.maxRetries,
timeout: this.timeout,
...this.additionalSessionOptions,
});
}
}
private async getOpenAIEmbedding(input: string) {
const { data } = await this.session.openai.embeddings.create({
model: this.model,
input,
});
return data[0].embedding;
}
async getTextEmbedding(text: string): Promise<number[]> {
return this.getOpenAIEmbedding(text);
}
async getQueryEmbedding(query: string): Promise<number[]> {
return this.getOpenAIEmbedding(query);
}
}
+3 -3
View File
@@ -1,8 +1,8 @@
import { OpenAIEmbedding } from "./OpenAIEmbedding";
import { OpenAIEmbeddingLike } from "./OpenAIEmbedding";
export class TogetherEmbedding extends OpenAIEmbedding {
export class TogetherEmbedding extends OpenAIEmbeddingLike {
override model: string;
constructor(init?: Partial<OpenAIEmbedding>) {
constructor(init?: Partial<TogetherEmbedding>) {
super({
apiKey: process.env.TOGETHER_API_KEY,
...init,
+59 -52
View File
@@ -184,17 +184,22 @@ export const GPT35_MODELS = {
/**
* We currently support GPT-3.5 and GPT-4 models
*/
export const ALL_AVAILABLE_OPENAI_MODELS = {
export const ALL_AVAILABLE_OPENAI_MODELS: Record<
string,
{
contextWindow: number;
}
> = {
...GPT4_MODELS,
...GPT35_MODELS,
};
/**
* OpenAI LLM implementation
*/
export class OpenAI extends BaseLLM {
// Per completion OpenAI params
model: keyof typeof ALL_AVAILABLE_OPENAI_MODELS | string;
type OpenAIModel = keyof typeof GPT4_MODELS | keyof typeof GPT35_MODELS;
export abstract class OpenAILike extends BaseLLM implements LLM {
hasStreaming: boolean = true;
abstract model: string;
temperature: number;
topP: number;
maxTokens?: number;
@@ -215,13 +220,8 @@ export class OpenAI extends BaseLLM {
callbackManager?: CallbackManager;
constructor(
init?: Partial<OpenAI> & {
azure?: AzureOpenAIConfig;
},
) {
constructor(init?: Partial<OpenAILike>) {
super();
this.model = init?.model ?? "gpt-3.5-turbo";
this.temperature = init?.temperature ?? 0.1;
this.topP = init?.topP ?? 1;
this.maxTokens = init?.maxTokens ?? undefined;
@@ -231,56 +231,26 @@ export class OpenAI extends BaseLLM {
this.additionalChatOptions = init?.additionalChatOptions;
this.additionalSessionOptions = init?.additionalSessionOptions;
if (init?.azure || shouldUseAzure()) {
const azureConfig = getAzureConfigFromEnv({
...init?.azure,
model: getAzureModel(this.model),
this.apiKey = init?.apiKey ?? undefined;
this.session =
init?.session ??
getOpenAISession({
apiKey: this.apiKey,
maxRetries: this.maxRetries,
timeout: this.timeout,
...this.additionalSessionOptions,
});
if (!azureConfig.apiKey) {
throw new Error(
"Azure API key is required for OpenAI Azure models. Please set the AZURE_OPENAI_KEY environment variable.",
);
}
this.apiKey = azureConfig.apiKey;
this.session =
init?.session ??
getOpenAISession({
azure: true,
apiKey: this.apiKey,
baseURL: getAzureBaseUrl(azureConfig),
maxRetries: this.maxRetries,
timeout: this.timeout,
defaultQuery: { "api-version": azureConfig.apiVersion },
...this.additionalSessionOptions,
});
} else {
this.apiKey = init?.apiKey ?? undefined;
this.session =
init?.session ??
getOpenAISession({
apiKey: this.apiKey,
maxRetries: this.maxRetries,
timeout: this.timeout,
...this.additionalSessionOptions,
});
}
this.callbackManager = init?.callbackManager;
}
get metadata() {
const contextWindow =
ALL_AVAILABLE_OPENAI_MODELS[
this.model as keyof typeof ALL_AVAILABLE_OPENAI_MODELS
]?.contextWindow ?? 1024;
return {
model: this.model,
temperature: this.temperature,
topP: this.topP,
maxTokens: this.maxTokens,
contextWindow,
contextWindow: ALL_AVAILABLE_OPENAI_MODELS[this.model].contextWindow,
tokenizer: Tokenizers.CL100K_BASE,
};
}
@@ -421,6 +391,43 @@ export class OpenAI extends BaseLLM {
}
}
export class OpenAI extends OpenAILike {
model: OpenAIModel;
constructor(
init?: Partial<OpenAI> & {
azure?: AzureOpenAIConfig;
},
) {
super(init);
this.model = init?.model ?? "gpt-3.5-turbo";
if (init?.azure || shouldUseAzure()) {
const azureConfig = getAzureConfigFromEnv({
...init?.azure,
model: getAzureModel(this.model),
});
if (!azureConfig.apiKey) {
throw new Error(
"Azure API key is required for OpenAI Azure models. Please set the AZURE_OPENAI_KEY environment variable.",
);
}
this.apiKey = azureConfig.apiKey;
this.session =
init?.session ??
getOpenAISession({
azure: true,
apiKey: this.apiKey,
baseURL: getAzureBaseUrl(azureConfig),
maxRetries: this.maxRetries,
timeout: this.timeout,
defaultQuery: { "api-version": azureConfig.apiVersion },
...this.additionalSessionOptions,
});
}
}
}
export const ALL_AVAILABLE_LLAMADEUCE_MODELS = {
"Llama-2-70b-chat-old": {
contextWindow: 4096,
+8 -11
View File
@@ -1,4 +1,3 @@
import _ from "lodash";
import OpenAI, { ClientOptions } from "openai";
export class AzureOpenAI extends OpenAI {
@@ -35,8 +34,10 @@ export class OpenAISession {
// I'm not 100% sure this is necessary vs. just starting a new session
// every time we make a call. They say they try to reuse connections
// so in theory this is more efficient, but we should test it in the future.
let defaultOpenAISession: { session: OpenAISession; options: ClientOptions }[] =
[];
let defaultOpenAISession: {
session: OpenAISession;
options: ClientOptions;
} | null = null;
/**
* Get a session for the OpenAI API. If one already exists with the same options,
@@ -47,14 +48,10 @@ let defaultOpenAISession: { session: OpenAISession; options: ClientOptions }[] =
export function getOpenAISession(
options: ClientOptions & { azure?: boolean } = {},
) {
let session = defaultOpenAISession.find((session) => {
return _.isEqual(session.options, options);
})?.session;
if (!session) {
session = new OpenAISession(options);
defaultOpenAISession.push({ session, options });
if (!defaultOpenAISession) {
const session = new OpenAISession(options);
defaultOpenAISession = { session, options };
}
return session;
return defaultOpenAISession.session;
}
+22 -3
View File
@@ -1,7 +1,13 @@
import { OpenAI } from "./LLM";
import { Tokenizers } from "../GlobalsHelper";
import { OpenAILike } from "./LLM";
export class TogetherLLM extends OpenAI {
constructor(init?: Partial<OpenAI>) {
export class TogetherLLM extends OpenAILike {
override model: string;
constructor(
init?: Partial<TogetherLLM> & {
model?: string;
},
) {
super({
...init,
apiKey: process.env.TOGETHER_API_KEY,
@@ -10,5 +16,18 @@ export class TogetherLLM extends OpenAI {
baseURL: "https://api.together.xyz/v1",
},
});
this.model = init?.model ?? '"mistralai/Mixtral-8x7B-Instruct-v0.1';
}
get metadata() {
return {
model: this.model,
temperature: this.temperature,
topP: this.topP,
maxTokens: this.maxTokens,
// todo: cannot find context window in documentation
contextWindow: 1024,
tokenizer: Tokenizers.CL100K_BASE,
};
}
}