Add option to return used SQL in SqlDatabaseChain (#1313)

* Add option to return used SQL in SqlDatabaseChain

* Cleanup examples
This commit is contained in:
David Duong
2023-05-18 14:46:36 +02:00
committed by GitHub
parent a0f4ec4652
commit 0f62af5026
5 changed files with 106 additions and 18 deletions
@@ -1,5 +1,6 @@
import CodeBlock from "@theme/CodeBlock";
import SqlDBExample from "@examples/chains/sql_db.ts";
import SqlDBSqlOutputExample from "@examples/chains/sql_db_sql_output.ts";
# `SqlDatabaseChain`
@@ -35,3 +36,7 @@ const db = await SqlDatabase.fromDataSourceParams({
includesTables: ["Track"],
});
```
If desired, you can return the used SQL command when calling the chain.
<CodeBlock language="typescript">{SqlDBSqlOutputExample}</CodeBlock>
+14 -16
View File
@@ -8,22 +8,20 @@ import { SqlDatabaseChain } from "langchain/chains";
* To set it up follow the instructions on https://database.guide/2-sample-databases-sqlite/, placing the .db file
* in the examples folder.
*/
export const run = async () => {
const datasource = new DataSource({
type: "sqlite",
database: "Chinook.db",
});
const datasource = new DataSource({
type: "sqlite",
database: "Chinook.db",
});
const db = await SqlDatabase.fromDataSourceParams({
appDataSource: datasource,
});
const db = await SqlDatabase.fromDataSourceParams({
appDataSource: datasource,
});
const chain = new SqlDatabaseChain({
llm: new OpenAI({ temperature: 0 }),
database: db,
});
const chain = new SqlDatabaseChain({
llm: new OpenAI({ temperature: 0 }),
database: db,
});
const res = await chain.run("How many tracks are there?");
console.log(res);
// There are 3503 tracks.
};
const res = await chain.run("How many tracks are there?");
console.log(res);
// There are 3503 tracks.
+33
View File
@@ -0,0 +1,33 @@
import { DataSource } from "typeorm";
import { OpenAI } from "langchain/llms/openai";
import { SqlDatabase } from "langchain/sql_db";
import { SqlDatabaseChain } from "langchain/chains";
/**
* This example uses Chinook database, which is a sample database available for SQL Server, Oracle, MySQL, etc.
* To set it up follow the instructions on https://database.guide/2-sample-databases-sqlite/, placing the .db file
* in the examples folder.
*/
const datasource = new DataSource({
type: "sqlite",
database: "Chinook.db",
});
const db = await SqlDatabase.fromDataSourceParams({
appDataSource: datasource,
});
const chain = new SqlDatabaseChain({
llm: new OpenAI({ temperature: 0 }),
database: db,
sqlOutputKey: "sql",
});
const res = await chain.call({ query: "How many tracks are there?" });
/* Expected result:
* {
* result: ' There are 3503 tracks.',
* sql: ' SELECT COUNT(*) FROM "Track";'
* }
*/
console.log(res);
@@ -21,6 +21,7 @@ export interface SqlDatabaseChainInput extends ChainInputs {
topK?: number;
inputKey?: string;
outputKey?: string;
sqlOutputKey?: string;
prompt?: PromptTemplate;
}
@@ -41,6 +42,8 @@ export class SqlDatabaseChain extends BaseChain {
outputKey = "result";
sqlOutputKey: string | undefined = undefined;
// Whether to return the result of querying the SQL table directly.
returnDirect = false;
@@ -51,6 +54,7 @@ export class SqlDatabaseChain extends BaseChain {
this.topK = fields.topK ?? this.topK;
this.inputKey = fields.inputKey ?? this.inputKey;
this.outputKey = fields.outputKey ?? this.outputKey;
this.sqlOutputKey = fields.sqlOutputKey ?? this.sqlOutputKey;
this.prompt =
fields.prompt ??
getPromptTemplateFromDataSource(this.database.appDataSource);
@@ -114,6 +118,10 @@ export class SqlDatabaseChain extends BaseChain {
};
}
if (this.sqlOutputKey != null) {
finalResult[this.sqlOutputKey] = sqlCommand;
}
return finalResult;
}
@@ -126,6 +134,9 @@ export class SqlDatabaseChain extends BaseChain {
}
get outputKeys(): string[] {
if (this.sqlOutputKey != null) {
return [this.outputKey, this.sqlOutputKey];
}
return [this.outputKey];
}
@@ -37,12 +37,53 @@ test("Test SqlDatabaseChain", async () => {
expect(chain.prompt).toBe(SQL_SQLITE_PROMPT);
const res = await chain.run("How many users are there?");
console.log(res);
const run = await chain.run("How many users are there?");
console.log(run);
await datasource.destroy();
});
test("Test SqlDatabaseChain with sqlOutputKey", async () => {
const datasource = new DataSource({
type: "sqlite",
database: ":memory:",
synchronize: true,
});
await datasource.initialize();
await datasource.query(`
CREATE TABLE users (id INTEGER PRIMARY KEY AUTOINCREMENT, name TEXT, age INTEGER);
`);
await datasource.query(`
INSERT INTO users (name, age) VALUES ('Alice', 20);
`);
await datasource.query(`
INSERT INTO users (name, age) VALUES ('Bob', 21);
`);
await datasource.query(`
INSERT INTO users (name, age) VALUES ('Charlie', 22);
`);
const db = await SqlDatabase.fromDataSourceParams({
appDataSource: datasource,
});
const chain = new SqlDatabaseChain({
llm: new OpenAI({ temperature: 0 }),
database: db,
inputKey: "query",
sqlOutputKey: "sql",
});
expect(chain.prompt).toBe(SQL_SQLITE_PROMPT);
const run = await chain.call({ query: "How many users are there?" });
console.log(run);
expect(run).toHaveProperty("sql");
await datasource.destroy();
});
// We create this string to reach the token limit of the query built to describe the database and get the SQL query.
const veryLongString = `Lorem ipsum dolor sit amet, consectetur adipiscing elit. Aliquam orci nisi, vulputate ac pulvinar eu, maximus a tortor. Duis suscipit, nibh vel fermentum vehicula, mauris ante convallis metus, et feugiat turpis mauris non felis. Interdum et malesuada fames ac ante ipsum primis in faucibus. Maecenas efficitur nibh in nisi sagittis ultrices. Donec id velit nunc. Nam a lacus risus. Vestibulum molestie massa eget convallis pellentesque.
Mauris a nisl eget velit finibus blandit ac a odio. Sed sagittis consequat urna a egestas. Curabitur pretium convallis nibh, in ullamcorper odio tempus nec. Curabitur laoreet nec nisl sed accumsan. Sed elementum eleifend molestie. Aenean ullamcorper interdum risus, eget pharetra est volutpat ut. Aenean maximus consequat justo rutrum finibus. Mauris consequat facilisis consectetur. Vivamus rutrum dignissim libero, non aliquam lectus tempus id. In hac habitasse platea dictumst. Sed at magna dignissim, tincidunt lectus in, malesuada risus. Phasellus placerat blandit ligula. Integer posuere id elit at commodo. Sed consequat sagittis odio eget congue.