Include placeholder value for all secrets used, not just those in kwargs (#1696)

* Include placeholder value for all secrets used, not just those in kwargs

* Fix test

* Fix test

* Add error handling for openai streaming when errors are sent as events in the SSE stream (#1698)

* Add error handling for openai streaming when errors are sent as events in the SSE stream

* Lint
This commit is contained in:
Nuno Campos
2023-06-19 15:41:15 +01:00
committed by GitHub
parent 542c2f1cf3
commit 276fd1ce75
9 changed files with 122 additions and 63 deletions
+6 -6
View File
@@ -124,7 +124,7 @@ export class ChatAnthropic extends BaseChatModel implements AnthropicInput {
lc_serializable = true;
apiKey?: string;
anthropicApiKey?: string;
apiUrl?: string;
@@ -153,9 +153,9 @@ export class ChatAnthropic extends BaseChatModel implements AnthropicInput {
constructor(fields?: Partial<AnthropicInput> & BaseChatModelParams) {
super(fields ?? {});
this.apiKey =
this.anthropicApiKey =
fields?.anthropicApiKey ?? getEnvironmentVariable("ANTHROPIC_API_KEY");
if (!this.apiKey) {
if (!this.anthropicApiKey) {
throw new Error("Anthropic API key not found");
}
@@ -266,14 +266,14 @@ export class ChatAnthropic extends BaseChatModel implements AnthropicInput {
options: { signal?: AbortSignal },
runManager?: CallbackManagerForLLMRun
): Promise<CompletionResponse> {
if (!this.apiKey) {
if (!this.anthropicApiKey) {
throw new Error("Missing Anthropic API key.");
}
let makeCompletionRequest;
if (request.stream) {
if (!this.streamingClient) {
const options = this.apiUrl ? { apiUrl: this.apiUrl } : undefined;
this.streamingClient = new AnthropicApi(this.apiKey, options);
this.streamingClient = new AnthropicApi(this.anthropicApiKey, options);
}
makeCompletionRequest = async () => {
let currentCompletion = "";
@@ -308,7 +308,7 @@ export class ChatAnthropic extends BaseChatModel implements AnthropicInput {
} else {
if (!this.batchClient) {
const options = this.apiUrl ? { apiUrl: this.apiUrl } : undefined;
this.batchClient = new AnthropicApi(this.apiKey, options);
this.batchClient = new AnthropicApi(this.anthropicApiKey, options);
}
makeCompletionRequest = async () =>
this.batchClient
+27 -20
View File
@@ -159,6 +159,8 @@ export class ChatOpenAI
maxTokens?: number;
openAIApiKey?: string;
azureOpenAIApiVersion?: string;
azureOpenAIApiKey?: string;
@@ -175,9 +177,6 @@ export class ChatOpenAI
fields?: Partial<OpenAIChatInput> &
Partial<AzureOpenAIInput> &
BaseChatModelParams & {
concurrency?: number;
cache?: boolean;
openAIApiKey?: string;
configuration?: ConfigurationParameters;
},
/** @deprecated */
@@ -185,25 +184,26 @@ export class ChatOpenAI
) {
super(fields ?? {});
const apiKey =
this.openAIApiKey =
fields?.openAIApiKey ?? getEnvironmentVariable("OPENAI_API_KEY");
const azureApiKey =
this.azureOpenAIApiKey =
fields?.azureOpenAIApiKey ??
getEnvironmentVariable("AZURE_OPENAI_API_KEY");
if (!azureApiKey && !apiKey) {
if (!this.azureOpenAIApiKey && !this.openAIApiKey) {
throw new Error("(Azure) OpenAI API key not found");
}
const azureApiInstanceName =
this.azureOpenAIApiInstanceName =
fields?.azureOpenAIApiInstanceName ??
getEnvironmentVariable("AZURE_OPENAI_API_INSTANCE_NAME");
const azureApiDeploymentName =
this.azureOpenAIApiDeploymentName =
fields?.azureOpenAIApiDeploymentName ??
getEnvironmentVariable("AZURE_OPENAI_API_DEPLOYMENT_NAME");
const azureApiVersion =
this.azureOpenAIApiVersion =
fields?.azureOpenAIApiVersion ??
getEnvironmentVariable("AZURE_OPENAI_API_VERSION");
@@ -222,11 +222,6 @@ export class ChatOpenAI
this.streaming = fields?.streaming ?? false;
this.azureOpenAIApiVersion = azureApiVersion;
this.azureOpenAIApiKey = azureApiKey;
this.azureOpenAIApiInstanceName = azureApiInstanceName;
this.azureOpenAIApiDeploymentName = azureApiDeploymentName;
if (this.streaming && this.n > 1) {
throw new Error("Cannot stream results when n > 1");
}
@@ -244,7 +239,7 @@ export class ChatOpenAI
}
this.clientConfig = {
apiKey,
apiKey: this.openAIApiKey,
...configuration,
...fields?.configuration,
};
@@ -327,18 +322,29 @@ export class ChatOpenAI
responseType: "stream",
onmessage: (event) => {
if (event.data?.trim?.() === "[DONE]") {
if (resolved) {
if (resolved || rejected) {
return;
}
resolved = true;
resolve(response);
} else {
const message = JSON.parse(event.data) as {
const data = JSON.parse(event.data);
if (data?.error) {
if (rejected) {
return;
}
rejected = true;
reject(data.error);
return;
}
const message = data as {
id: string;
object: string;
created: number;
model: string;
choices: Array<{
choices?: Array<{
index: number;
finish_reason: string | null;
delta: {
@@ -361,7 +367,7 @@ export class ChatOpenAI
}
// on all messages, update choice
for (const part of message.choices) {
for (const part of message.choices ?? []) {
if (part != null) {
let choice = response.choices.find(
(c) => c.index === part.index
@@ -414,7 +420,8 @@ export class ChatOpenAI
// when all messages are finished, resolve
if (
!resolved &&
message.choices.every((c) => c.finish_reason != null)
!rejected &&
message.choices?.every((c) => c.finish_reason != null)
) {
resolved = true;
resolve(response);
+27 -17
View File
@@ -100,6 +100,8 @@ export class OpenAIChat
streaming = false;
openAIApiKey?: string;
azureOpenAIApiVersion?: string;
azureOpenAIApiKey?: string;
@@ -116,7 +118,6 @@ export class OpenAIChat
fields?: Partial<OpenAIChatInput> &
Partial<AzureOpenAIInput> &
BaseLLMParams & {
openAIApiKey?: string;
configuration?: ConfigurationParameters;
},
/** @deprecated */
@@ -124,26 +125,28 @@ export class OpenAIChat
) {
super(fields ?? {});
const apiKey =
this.openAIApiKey =
fields?.openAIApiKey ?? getEnvironmentVariable("OPENAI_API_KEY");
const azureApiKey =
this.azureOpenAIApiKey =
fields?.azureOpenAIApiKey ??
getEnvironmentVariable("AZURE_OPENAI_API_KEY");
if (!azureApiKey && !apiKey) {
if (!this.azureOpenAIApiKey && !this.openAIApiKey) {
throw new Error("(Azure) OpenAI API key not found");
}
const azureApiInstanceName =
this.azureOpenAIApiInstanceName =
fields?.azureOpenAIApiInstanceName ??
getEnvironmentVariable("AZURE_OPENAI_API_INSTANCE_NAME");
const azureApiDeploymentName =
fields?.azureOpenAIApiDeploymentName ??
getEnvironmentVariable("AZURE_OPENAI_API_DEPLOYMENT_NAME");
this.azureOpenAIApiDeploymentName =
(fields?.azureOpenAIApiCompletionsDeploymentName ||
fields?.azureOpenAIApiDeploymentName) ??
(getEnvironmentVariable("AZURE_OPENAI_API_COMPLETIONS_DEPLOYMENT_NAME") ||
getEnvironmentVariable("AZURE_OPENAI_API_DEPLOYMENT_NAME"));
const azureApiVersion =
this.azureOpenAIApiVersion =
fields?.azureOpenAIApiVersion ??
getEnvironmentVariable("AZURE_OPENAI_API_VERSION");
@@ -163,11 +166,6 @@ export class OpenAIChat
this.streaming = fields?.streaming ?? false;
this.azureOpenAIApiVersion = azureApiVersion;
this.azureOpenAIApiKey = azureApiKey;
this.azureOpenAIApiInstanceName = azureApiInstanceName;
this.azureOpenAIApiDeploymentName = azureApiDeploymentName;
if (this.streaming && this.n > 1) {
throw new Error("Cannot stream results when n > 1");
}
@@ -185,7 +183,7 @@ export class OpenAIChat
}
this.clientConfig = {
apiKey,
apiKey: this.openAIApiKey,
...configuration,
...fields?.configuration,
};
@@ -266,13 +264,24 @@ export class OpenAIChat
responseType: "stream",
onmessage: (event) => {
if (event.data?.trim?.() === "[DONE]") {
if (resolved) {
if (resolved || rejected) {
return;
}
resolved = true;
resolve(response);
} else {
const message = JSON.parse(event.data) as {
const data = JSON.parse(event.data);
if (data?.error) {
if (rejected) {
return;
}
rejected = true;
reject(data.error);
return;
}
const message = data as {
id: string;
object: string;
created: number;
@@ -329,6 +338,7 @@ export class OpenAIChat
// when all messages are finished, resolve
if (
!resolved &&
!rejected &&
message.choices.every((c) => c.finish_reason != null)
) {
resolved = true;
+23 -15
View File
@@ -104,6 +104,8 @@ export class OpenAI extends BaseLLM implements OpenAIInput, AzureOpenAIInput {
streaming = false;
openAIApiKey?: string;
azureOpenAIApiVersion?: string;
azureOpenAIApiKey?: string;
@@ -120,7 +122,6 @@ export class OpenAI extends BaseLLM implements OpenAIInput, AzureOpenAIInput {
fields?: Partial<OpenAIInput> &
Partial<AzureOpenAIInput> &
BaseLLMParams & {
openAIApiKey?: string;
configuration?: ConfigurationParameters;
},
/** @deprecated */
@@ -136,28 +137,28 @@ export class OpenAI extends BaseLLM implements OpenAIInput, AzureOpenAIInput {
}
super(fields ?? {});
const apiKey =
this.openAIApiKey =
fields?.openAIApiKey ?? getEnvironmentVariable("OPENAI_API_KEY");
const azureApiKey =
this.azureOpenAIApiKey =
fields?.azureOpenAIApiKey ??
getEnvironmentVariable("AZURE_OPENAI_API_KEY");
if (!azureApiKey && !apiKey) {
if (!this.azureOpenAIApiKey && !this.openAIApiKey) {
throw new Error("(Azure) OpenAI API key not found");
}
const azureApiInstanceName =
this.azureOpenAIApiInstanceName =
fields?.azureOpenAIApiInstanceName ??
getEnvironmentVariable("AZURE_OPENAI_API_INSTANCE_NAME");
const azureApiDeploymentName =
this.azureOpenAIApiDeploymentName =
(fields?.azureOpenAIApiCompletionsDeploymentName ||
fields?.azureOpenAIApiDeploymentName) ??
(getEnvironmentVariable("AZURE_OPENAI_API_COMPLETIONS_DEPLOYMENT_NAME") ||
getEnvironmentVariable("AZURE_OPENAI_API_DEPLOYMENT_NAME"));
const azureApiVersion =
this.azureOpenAIApiVersion =
fields?.azureOpenAIApiVersion ??
getEnvironmentVariable("AZURE_OPENAI_API_VERSION");
@@ -178,11 +179,6 @@ export class OpenAI extends BaseLLM implements OpenAIInput, AzureOpenAIInput {
this.streaming = fields?.streaming ?? false;
this.azureOpenAIApiVersion = azureApiVersion;
this.azureOpenAIApiKey = azureApiKey;
this.azureOpenAIApiInstanceName = azureApiInstanceName;
this.azureOpenAIApiDeploymentName = azureApiDeploymentName;
if (this.streaming && this.n > 1) {
throw new Error("Cannot stream results when n > 1");
}
@@ -204,7 +200,7 @@ export class OpenAI extends BaseLLM implements OpenAIInput, AzureOpenAIInput {
}
this.clientConfig = {
apiKey,
apiKey: this.openAIApiKey,
...configuration,
...fields?.configuration,
};
@@ -310,7 +306,7 @@ export class OpenAI extends BaseLLM implements OpenAIInput, AzureOpenAIInput {
responseType: "stream",
onmessage: (event) => {
if (event.data?.trim?.() === "[DONE]") {
if (resolved) {
if (resolved || rejected) {
return;
}
resolved = true;
@@ -319,7 +315,18 @@ export class OpenAI extends BaseLLM implements OpenAIInput, AzureOpenAIInput {
choices,
});
} else {
const message = JSON.parse(event.data) as Omit<
const data = JSON.parse(event.data);
if (data?.error) {
if (rejected) {
return;
}
rejected = true;
reject(data.error);
return;
}
const message = data as Omit<
CreateCompletionResponse,
"usage"
>;
@@ -352,6 +359,7 @@ export class OpenAI extends BaseLLM implements OpenAIInput, AzureOpenAIInput {
// when all messages are finished, resolve
if (
!resolved &&
!rejected &&
choices.every((c) => c.finish_reason != null)
) {
resolved = true;
+9 -3
View File
@@ -8,6 +8,7 @@ import { optionalImportEntrypoints } from "./import_constants.js";
import * as importMap from "./import_map.js";
import { OptionalImportMap, SecretMap } from "./import_type.js";
import { SerializedFields, keyFromJson, mapKeys } from "./map_keys.js";
import { getEnvironmentVariable } from "../util/env.js";
function combineAliasesAndInvert(constructor: typeof Serializable) {
const aliases: { [key: string]: string } = {};
@@ -50,9 +51,14 @@ async function reviver(
if (key in secretsMap) {
return secretsMap[key as keyof SecretMap];
} else {
throw new Error(
`Missing key "${key}" for ${pathStr} in load(secretsMap={})`
);
const secretValueInEnv = getEnvironmentVariable(key);
if (secretValueInEnv) {
return secretValueInEnv;
} else {
throw new Error(
`Missing key "${key}" for ${pathStr} in load(secretsMap={})`
);
}
}
} else if (
typeof value === "object" &&
+9 -1
View File
@@ -128,12 +128,20 @@ export abstract class Serializable {
Object.assign(kwargs, Reflect.get(current, "lc_attributes", this));
}
// include all secrets used, even if not in kwargs,
// will be replaced with sentinel value in replaceSecrets
for (const key in secrets) {
if (key in this && this[key as keyof this] !== undefined) {
kwargs[key] = this[key as keyof this] || kwargs[key];
}
}
return {
lc: 1,
type: "constructor",
id: [...this.lc_namespace, this.constructor.name],
kwargs: mapKeys(
this.lc_secrets ? replaceSecrets(kwargs, secrets) : kwargs,
Object.keys(secrets).length ? replaceSecrets(kwargs, secrets) : kwargs,
keyToJson,
aliases
),
@@ -93,6 +93,11 @@ kwargs:
prefix_messages:
- role: system
content: You're a nice assistant
openai_api_key:
lc: 1
type: secret
id:
- OPENAI_API_KEY
prompt:
lc: 1
type: constructor
+10 -1
View File
@@ -113,10 +113,11 @@ test("serialize + deserialize custom classes", async () => {
});
test("serialize + deserialize llm", async () => {
// eslint-disable-next-line no-process-env
process.env.OPENAI_API_KEY = "openai-key";
const llm = new OpenAI({
temperature: 0.5,
modelName: "davinci",
openAIApiKey: "openai-key",
});
llm.temperature = 0.7;
const lc_argumentsBefore = llm.lc_kwargs;
@@ -126,11 +127,17 @@ test("serialize + deserialize llm", async () => {
expect(JSON.parse(str).kwargs.temperature).toBe(0.7);
expect(JSON.parse(str).kwargs.model).toBe("davinci");
expect(JSON.parse(str).kwargs.openai_api_key.type).toBe("secret");
// Accept secret in secret map
const llm2 = await load<OpenAI>(str, {
OPENAI_API_KEY: "openai-key",
});
expect(llm2).toBeInstanceOf(OpenAI);
expect(JSON.stringify(llm2, null, 2)).toBe(str);
// Accept secret as env var
const llm3 = await load<OpenAI>(str);
expect(llm3).toBeInstanceOf(OpenAI);
expect(llm.openAIApiKey).toBe(llm3.openAIApiKey);
expect(JSON.stringify(llm3, null, 2)).toBe(str);
});
test("serialize + deserialize llm with optional deps", async () => {
@@ -182,6 +189,8 @@ test("serialize + deserialize llm chain string prompt", async () => {
});
test("serialize + deserialize llm chain chat prompt", async () => {
// eslint-disable-next-line no-process-env
process.env.OPENAI_API_KEY = undefined;
const llm = new ChatOpenAI({
temperature: 0.5,
modelName: "gpt-4",
+6
View File
@@ -48,6 +48,12 @@ export declare interface OpenAIBaseInput {
* Timeout to use when making requests to OpenAI.
*/
timeout?: number;
/**
* API key to use when making requests to OpenAI. Defaults to the value of
* `OPENAI_API_KEY` environment variable.
*/
openAIApiKey?: string;
}
export interface OpenAICallOptions extends BaseLanguageModelCallOptions {