Add Aleph Alpha LLM (#1609)

* PR test

* AlephAlpha integration

* Eslint

* Some changes

* Update import map

---------

Co-authored-by: Nuno Campos <nuno@boringbits.io>
This commit is contained in:
paaatrrrick
2023-06-13 05:19:06 -07:00
committed by GitHub
parent 85de6e16b9
commit 1def2ba6a3
16 changed files with 383 additions and 0 deletions
@@ -159,6 +159,14 @@ import AI21Example from "@examples/models/llm/ai21.ts";
<CodeBlock language="typescript">{AI21Example}</CodeBlock>
## `AlephAlpha`
You can get started with AlephAlpha' Luminous family of models by signing up for an API key [on their website](https://www.aleph-alpha.com/).
import AlephAlphaExample from "@examples/models/llm/aleph_alpha.ts";
<CodeBlock language="typescript">{AlephAlphaExample}</CodeBlock>
## Additional LLM Implementations
### `PromptLayerOpenAI`
+15
View File
@@ -0,0 +1,15 @@
import { AlephAlpha } from "langchain/llms/aleph_alpha";
const model = new AlephAlpha({
aleph_alpha_api_key: "YOUR_ALEPH_ALPHA_API_KEY", // Or set as process.env.ALEPH_ALPHA_API_KEY
});
const res = await model.call(`Is cereal soup?`);
console.log({ res });
/*
{
res: "\nIs soup a cereal? I dont think so, but it is delicious."
}
*/
+3
View File
@@ -76,6 +76,9 @@ llms/openai.d.ts
llms/ai21.cjs
llms/ai21.js
llms/ai21.d.ts
llms/aleph_alpha.cjs
llms/aleph_alpha.js
llms/aleph_alpha.d.ts
llms/cohere.cjs
llms/cohere.js
llms/cohere.d.ts
+8
View File
@@ -88,6 +88,9 @@
"llms/ai21.cjs",
"llms/ai21.js",
"llms/ai21.d.ts",
"llms/aleph_alpha.cjs",
"llms/aleph_alpha.js",
"llms/aleph_alpha.d.ts",
"llms/cohere.cjs",
"llms/cohere.js",
"llms/cohere.d.ts",
@@ -873,6 +876,11 @@
"import": "./llms/ai21.js",
"require": "./llms/ai21.cjs"
},
"./llms/aleph_alpha": {
"types": "./llms/aleph_alpha.d.ts",
"import": "./llms/aleph_alpha.js",
"require": "./llms/aleph_alpha.cjs"
},
"./llms/cohere": {
"types": "./llms/cohere.d.ts",
"import": "./llms/cohere.js",
+1
View File
@@ -40,6 +40,7 @@ const entrypoints = {
"llms/base": "llms/base",
"llms/openai": "llms/openai",
"llms/ai21": "llms/ai21",
"llms/aleph_alpha": "llms/aleph_alpha",
"llms/cohere": "llms/cohere",
"llms/hf": "llms/hf",
"llms/replicate": "llms/replicate",
+285
View File
@@ -0,0 +1,285 @@
import { LLM, BaseLLMParams } from "./base.js";
import { getEnvironmentVariable } from "../util/env.js";
export interface AlephAlphaInput extends BaseLLMParams {
model: string;
maximum_tokens: number;
minimum_tokens?: number;
echo?: boolean;
temperature?: number;
top_k?: number;
top_p?: number;
presence_penalty?: number;
frequency_penalty?: number;
sequence_penalty?: number;
sequence_penalty_min_length?: number;
repetition_penalties_include_prompt?: boolean;
repetition_penalties_include_completion?: boolean;
use_multiplicative_presence_penalty?: boolean;
use_multiplicative_frequency_penalty?: boolean;
use_multiplicative_sequence_penalty?: boolean;
penalty_bias?: string;
penalty_exceptions?: string[];
penalty_exceptions_include_stop_sequences?: boolean;
best_of?: number;
n?: number;
logit_bias?: object;
log_probs?: number;
tokens?: boolean;
raw_completion: boolean;
disable_optimizations?: boolean;
completion_bias_inclusion?: string[];
completion_bias_inclusion_first_token_only: boolean;
completion_bias_exclusion?: string[];
completion_bias_exclusion_first_token_only: boolean;
contextual_control_threshold?: number;
control_log_additive: boolean;
stop?: string[];
aleph_alpha_api_key?: string;
base_url: string;
}
export class AlephAlpha extends LLM implements AlephAlphaInput {
model = "luminous-base";
maximum_tokens = 64;
minimum_tokens = 0;
echo: boolean;
temperature = 0.0;
top_k: number;
top_p = 0.0;
presence_penalty?: number;
frequency_penalty?: number;
sequence_penalty?: number;
sequence_penalty_min_length?: number;
repetition_penalties_include_prompt?: boolean;
repetition_penalties_include_completion?: boolean;
use_multiplicative_presence_penalty?: boolean;
use_multiplicative_frequency_penalty?: boolean;
use_multiplicative_sequence_penalty?: boolean;
penalty_bias?: string;
penalty_exceptions?: string[];
penalty_exceptions_include_stop_sequences?: boolean;
best_of?: number;
n?: number;
logit_bias?: object;
log_probs?: number;
tokens?: boolean;
raw_completion: boolean;
disable_optimizations?: boolean;
completion_bias_inclusion?: string[];
completion_bias_inclusion_first_token_only: boolean;
completion_bias_exclusion?: string[];
completion_bias_exclusion_first_token_only: boolean;
contextual_control_threshold?: number;
control_log_additive: boolean;
aleph_alpha_api_key? = getEnvironmentVariable("ALEPH_ALPHA_API_KEY");
stop?: string[];
base_url = "https://api.aleph-alpha.com/complete";
constructor(fields: Partial<AlephAlpha>) {
super(fields ?? {});
this.model = fields?.model ?? this.model;
this.temperature = fields?.temperature ?? this.temperature;
this.maximum_tokens = fields?.maximum_tokens ?? this.maximum_tokens;
this.minimum_tokens = fields?.minimum_tokens ?? this.minimum_tokens;
this.top_k = fields?.top_k ?? this.top_k;
this.top_p = fields?.top_p ?? this.top_p;
this.presence_penalty = fields?.presence_penalty ?? this.presence_penalty;
this.frequency_penalty =
fields?.frequency_penalty ?? this.frequency_penalty;
this.sequence_penalty = fields?.sequence_penalty ?? this.sequence_penalty;
this.sequence_penalty_min_length =
fields?.sequence_penalty_min_length ?? this.sequence_penalty_min_length;
this.repetition_penalties_include_prompt =
fields?.repetition_penalties_include_prompt ??
this.repetition_penalties_include_prompt;
this.repetition_penalties_include_completion =
fields?.repetition_penalties_include_completion ??
this.repetition_penalties_include_completion;
this.use_multiplicative_presence_penalty =
fields?.use_multiplicative_presence_penalty ??
this.use_multiplicative_presence_penalty;
this.use_multiplicative_frequency_penalty =
fields?.use_multiplicative_frequency_penalty ??
this.use_multiplicative_frequency_penalty;
this.use_multiplicative_sequence_penalty =
fields?.use_multiplicative_sequence_penalty ??
this.use_multiplicative_sequence_penalty;
this.penalty_bias = fields?.penalty_bias ?? this.penalty_bias;
this.penalty_exceptions =
fields?.penalty_exceptions ?? this.penalty_exceptions;
this.penalty_exceptions_include_stop_sequences =
fields?.penalty_exceptions_include_stop_sequences ??
this.penalty_exceptions_include_stop_sequences;
this.best_of = fields?.best_of ?? this.best_of;
this.n = fields?.n ?? this.n;
this.logit_bias = fields?.logit_bias ?? this.logit_bias;
this.log_probs = fields?.log_probs ?? this.log_probs;
this.tokens = fields?.tokens ?? this.tokens;
this.raw_completion = fields?.raw_completion ?? this.raw_completion;
this.disable_optimizations =
fields?.disable_optimizations ?? this.disable_optimizations;
this.completion_bias_inclusion =
fields?.completion_bias_inclusion ?? this.completion_bias_inclusion;
this.completion_bias_inclusion_first_token_only =
fields?.completion_bias_inclusion_first_token_only ??
this.completion_bias_inclusion_first_token_only;
this.completion_bias_exclusion =
fields?.completion_bias_exclusion ?? this.completion_bias_exclusion;
this.completion_bias_exclusion_first_token_only =
fields?.completion_bias_exclusion_first_token_only ??
this.completion_bias_exclusion_first_token_only;
this.contextual_control_threshold =
fields?.contextual_control_threshold ?? this.contextual_control_threshold;
this.control_log_additive =
fields?.control_log_additive ?? this.control_log_additive;
this.aleph_alpha_api_key =
fields?.aleph_alpha_api_key ?? this.aleph_alpha_api_key;
this.stop = fields?.stop ?? this.stop;
}
validateEnvironment() {
if (!this.aleph_alpha_api_key) {
throw new Error(
"Aleph Alpha API Key is missing in environment variables."
);
}
}
/** Get the default parameters for calling Aleph Alpha API. */
get defaultParams() {
return {
model: this.model,
temperature: this.temperature,
maximum_tokens: this.maximum_tokens,
minimum_tokens: this.minimum_tokens,
top_k: this.top_k,
top_p: this.top_p,
presence_penalty: this.presence_penalty,
frequency_penalty: this.frequency_penalty,
sequence_penalty: this.sequence_penalty,
sequence_penalty_min_length: this.sequence_penalty_min_length,
repetition_penalties_include_prompt:
this.repetition_penalties_include_prompt,
repetition_penalties_include_completion:
this.repetition_penalties_include_completion,
use_multiplicative_presence_penalty:
this.use_multiplicative_presence_penalty,
use_multiplicative_frequency_penalty:
this.use_multiplicative_frequency_penalty,
use_multiplicative_sequence_penalty:
this.use_multiplicative_sequence_penalty,
penalty_bias: this.penalty_bias,
penalty_exceptions: this.penalty_exceptions,
penalty_exceptions_include_stop_sequences:
this.penalty_exceptions_include_stop_sequences,
best_of: this.best_of,
n: this.n,
logit_bias: this.logit_bias,
log_probs: this.log_probs,
tokens: this.tokens,
raw_completion: this.raw_completion,
disable_optimizations: this.disable_optimizations,
completion_bias_inclusion: this.completion_bias_inclusion,
completion_bias_inclusion_first_token_only:
this.completion_bias_inclusion_first_token_only,
completion_bias_exclusion: this.completion_bias_exclusion,
completion_bias_exclusion_first_token_only:
this.completion_bias_exclusion_first_token_only,
contextual_control_threshold: this.contextual_control_threshold,
control_log_additive: this.control_log_additive,
};
}
/** Get the identifying parameters for this LLM. */
get identifyingParams() {
return { ...this.defaultParams };
}
/** Get the type of LLM. */
_llmType(): string {
return "aleph_alpha";
}
async _call(
prompt: string,
options: this["ParsedCallOptions"]
): Promise<string> {
let stop = options?.stop;
this.validateEnvironment();
if (this.stop && stop && this.stop.length > 0 && stop.length > 0) {
throw new Error("`stop` found in both the input and default params.");
}
stop = this.stop ?? stop ?? [];
const headers = {
Authorization: `Bearer ${this.aleph_alpha_api_key}`,
"Content-Type": "application/json",
Accept: "application/json",
};
const data = { prompt, stop_sequences: stop, ...this.defaultParams };
const responseData = await this.caller.call(async () => {
const response = await fetch(this.base_url, {
method: "POST",
headers,
body: JSON.stringify(data),
signal: options.signal,
});
if (!response.ok) {
// consume the response body to release the connection
// https://undici.nodejs.org/#/?id=garbage-collection
const text = await response.text();
const error = new Error(
`Aleph Alpha call failed with status ${response.status} and body ${text}`
);
// eslint-disable-next-line @typescript-eslint/no-explicit-any
(error as any).response = response;
throw error;
}
return response.json();
});
if (
!responseData.completions ||
responseData.completions.length === 0 ||
!responseData.completions[0].completion
) {
throw new Error("No completions found in response");
}
return responseData.completions[0].completion ?? "";
}
}
@@ -0,0 +1,54 @@
import { test, describe, expect } from "@jest/globals";
import { AlephAlpha } from "../aleph_alpha.js";
describe("AI21", () => {
test("test call", async () => {
const aleph_alpha = new AlephAlpha({});
const result = await aleph_alpha.call(
"What is a good name for a company that makes colorful socks?"
);
console.log({ result });
});
test("test translation call", async () => {
const aleph_alpha = new AlephAlpha({});
const result = await aleph_alpha.call(
`Translate "I love programming" into German.`
);
console.log({ result });
});
test("test JSON output call", async () => {
const aleph_alpha = new AlephAlpha({});
const result = await aleph_alpha.call(
`Output a JSON object with three string fields: "name", "birthplace", "bio".`
);
console.log({ result });
});
test("should abort the request", async () => {
const aleph_alpha = new AlephAlpha({});
const controller = new AbortController();
await expect(() => {
const ret = aleph_alpha.call(
"Respond with an extremely verbose response",
{
signal: controller.signal,
}
);
controller.abort();
return ret;
}).rejects.toThrow("AbortError: This operation was aborted");
});
test("throws an error when response status is not ok", async () => {
const aleph_alpha = new AlephAlpha({
aleph_alpha_api_key: "BAD_KEY",
});
await expect(aleph_alpha.call("Test prompt")).rejects.toThrow(
'Aleph Alpha call failed with status 401 and body {"error":"InvalidToken","code":"UNAUTHENTICATED"}'
);
});
});
+1
View File
@@ -11,6 +11,7 @@ export * as embeddings__openai from "../embeddings/openai.js";
export * as llms__base from "../llms/base.js";
export * as llms__openai from "../llms/openai.js";
export * as llms__ai21 from "../llms/ai21.js";
export * as llms__aleph_alpha from "../llms/aleph_alpha.js";
export * as prompts from "../prompts/index.js";
export * as vectorstores__base from "../vectorstores/base.js";
export * as vectorstores__memory from "../vectorstores/memory.js";
+1
View File
@@ -56,6 +56,7 @@
"src/llms/base.ts",
"src/llms/openai.ts",
"src/llms/ai21.ts",
"src/llms/aleph_alpha.ts",
"src/llms/cohere.ts",
"src/llms/hf.ts",
"src/llms/replicate.ts",
+1
View File
@@ -10,6 +10,7 @@ export * from "langchain/embeddings/openai";
export * from "langchain/llms/base";
export * from "langchain/llms/openai";
export * from "langchain/llms/ai21";
export * from "langchain/llms/aleph_alpha";
export * from "langchain/prompts";
export * from "langchain/vectorstores/base";
export * from "langchain/vectorstores/memory";
+1
View File
@@ -10,6 +10,7 @@ const embeddings_openai = require("langchain/embeddings/openai");
const llms_base = require("langchain/llms/base");
const llms_openai = require("langchain/llms/openai");
const llms_ai21 = require("langchain/llms/ai21");
const llms_aleph_alpha = require("langchain/llms/aleph_alpha");
const prompts = require("langchain/prompts");
const vectorstores_base = require("langchain/vectorstores/base");
const vectorstores_memory = require("langchain/vectorstores/memory");
+1
View File
@@ -10,6 +10,7 @@ export * from "langchain/embeddings/openai";
export * from "langchain/llms/base";
export * from "langchain/llms/openai";
export * from "langchain/llms/ai21";
export * from "langchain/llms/aleph_alpha";
export * from "langchain/prompts";
export * from "langchain/vectorstores/base";
export * from "langchain/vectorstores/memory";
+1
View File
@@ -10,6 +10,7 @@ import * as embeddings_openai from "langchain/embeddings/openai";
import * as llms_base from "langchain/llms/base";
import * as llms_openai from "langchain/llms/openai";
import * as llms_ai21 from "langchain/llms/ai21";
import * as llms_aleph_alpha from "langchain/llms/aleph_alpha";
import * as prompts from "langchain/prompts";
import * as vectorstores_base from "langchain/vectorstores/base";
import * as vectorstores_memory from "langchain/vectorstores/memory";
+1
View File
@@ -10,6 +10,7 @@ import * as embeddings_openai from "langchain/embeddings/openai";
import * as llms_base from "langchain/llms/base";
import * as llms_openai from "langchain/llms/openai";
import * as llms_ai21 from "langchain/llms/ai21";
import * as llms_aleph_alpha from "langchain/llms/aleph_alpha";
import * as prompts from "langchain/prompts";
import * as vectorstores_base from "langchain/vectorstores/base";
import * as vectorstores_memory from "langchain/vectorstores/memory";
+1
View File
@@ -10,6 +10,7 @@ export * from "langchain/embeddings/openai";
export * from "langchain/llms/base";
export * from "langchain/llms/openai";
export * from "langchain/llms/ai21";
export * from "langchain/llms/aleph_alpha";
export * from "langchain/prompts";
export * from "langchain/vectorstores/base";
export * from "langchain/vectorstores/memory";
+1
View File
@@ -10,6 +10,7 @@ export * from "langchain/embeddings/openai";
export * from "langchain/llms/base";
export * from "langchain/llms/openai";
export * from "langchain/llms/ai21";
export * from "langchain/llms/aleph_alpha";
export * from "langchain/prompts";
export * from "langchain/vectorstores/base";
export * from "langchain/vectorstores/memory";