Files
finetune-embedding/generate_dataset.ipynb
Jerry Liu e1deb11272 cr
2023-08-24 22:49:12 -07:00

453 lines
12 KiB
Plaintext

{
"cells": [
{
"cell_type": "markdown",
"id": "e1bdd511-4fc0-4bbd-9f4b-df76bbcb756a",
"metadata": {},
"source": [
"# Generate Synthetic Dataset with LLM"
]
},
{
"cell_type": "markdown",
"id": "1986c191-629b-45a2-8edb-472c96daddeb",
"metadata": {},
"source": [
"In this notebook, we generate a synthetic dataset of (query, relevant documents) pairs from a corpus of documents *without labelers* by leveraging LLM."
]
},
{
"cell_type": "markdown",
"id": "d4b48882-dd24-4872-a42c-f21ac303e550",
"metadata": {},
"source": [
"### Generate Corpus"
]
},
{
"cell_type": "markdown",
"id": "7e1fba52-b566-42ed-91f9-3a20be04f4e2",
"metadata": {},
"source": [
"First, we create the corpus of text chunks by leveraging LlamaIndex to load some financial PDFs, and parsing/chunking into plain text chunks."
]
},
{
"cell_type": "code",
"execution_count": 57,
"id": "c262e939-9eef-421e-8a94-c1d8a6cf861d",
"metadata": {
"tags": []
},
"outputs": [],
"source": [
"import json\n",
"\n",
"from llama_index import SimpleDirectoryReader\n",
"from llama_index.node_parser import SimpleNodeParser\n",
"from llama_index.schema import MetadataMode"
]
},
{
"cell_type": "code",
"execution_count": 105,
"id": "32eaaf6b-e3dc-4838-aaf5-f1d6305e2426",
"metadata": {},
"outputs": [],
"source": [
"TRAIN_FILES = ['../llama_index/docs/examples/data/10k/lyft_2021.pdf']\n",
"VAL_FILES = ['../llama_index/docs/examples/data/10k/uber_2021.pdf']\n",
"\n",
"TRAIN_CORPUS_FPATH = './data/train_corpus.json'\n",
"VAL_CORPUS_FPATH = './data/val_corpus.json'"
]
},
{
"cell_type": "code",
"execution_count": 48,
"id": "f6ca90b0-eac9-420f-b6e9-a83749280b5e",
"metadata": {},
"outputs": [],
"source": [
"def load_corpus(files, verbose=False):\n",
" if verbose:\n",
" print(f\"Loading files {files}\")\n",
"\n",
" reader = SimpleDirectoryReader(input_files=files)\n",
" docs = reader.load_data()\n",
" if verbose:\n",
" print(f'Loaded {len(docs)} docs')\n",
" \n",
" parser = SimpleNodeParser.from_defaults()\n",
" nodes = parser.get_nodes_from_documents(docs, show_progress=verbose)\n",
"\n",
" if verbose:\n",
" print(f'Parsed {len(nodes)} nodes')\n",
"\n",
" corpus = {node.node_id: node.get_content(metadata_mode=MetadataMode.NONE) for node in nodes}\n",
" return corpus"
]
},
{
"cell_type": "markdown",
"id": "b516ecaa-56fc-443b-ac74-f6fb1fc760a3",
"metadata": {},
"source": [
"We do a very naive train/val split by having the Lyft corpus as the train dataset, and the Uber corpus as the val dataset."
]
},
{
"cell_type": "code",
"execution_count": 55,
"id": "15af0b0c-547c-493a-8fea-1d880c3c63fb",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Loaded 238 docs\n"
]
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "a1670320ee1a448dbab193affff76446",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"Parsing documents into nodes: 0%| | 0/238 [00:00<?, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Parsed 334 nodes\n",
"Loaded 307 docs\n"
]
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "38ccf19175c94fe58f399019c9d7b570",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"Parsing documents into nodes: 0%| | 0/307 [00:00<?, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Parsed 395 nodes\n"
]
}
],
"source": [
"train_corpus = load_corpus(TRAIN_FILES, verbose=True)\n",
"val_corpus = load_corpus(VAL_FILES, verbose=True)"
]
},
{
"cell_type": "code",
"execution_count": 68,
"id": "92f8838e-96d0-4def-978b-61bd1dce3217",
"metadata": {},
"outputs": [],
"source": [
"with open(TRAIN_CORPUS_FPATH, 'w+') as f:\n",
" json.dump(train_corpus, f)\n",
"\n",
"with open(VAL_CORPUS_FPATH, 'w+') as f:\n",
" json.dump(val_corpus, f)"
]
},
{
"cell_type": "markdown",
"id": "d08491e5-6273-4b09-b19f-bbf3276d562b",
"metadata": {},
"source": [
"### Generate synthetic queries"
]
},
{
"cell_type": "markdown",
"id": "c2e7f9f3-0aa2-4d04-bfb0-92d4eabc404b",
"metadata": {},
"source": [
"Now, we use an LLM (gpt-3.5-turbo) to generate questions using each text chunk in the corpus as context.\n",
"\n",
"Each pair of (generated question, text chunk used as context) becomes a datapoint in the finetuning dataset (either for training or evaluation)."
]
},
{
"cell_type": "code",
"execution_count": 71,
"id": "af162c19-4362-447c-9c57-e35d13cdb4ba",
"metadata": {
"tags": []
},
"outputs": [],
"source": [
"import re\n",
"import uuid\n",
"\n",
"from llama_index.llms import OpenAI\n",
"from llama_index.schema import MetadataMode\n",
"from tqdm.notebook import tqdm"
]
},
{
"cell_type": "code",
"execution_count": 106,
"id": "e48e338f-0cf0-497b-988b-5a7d112b3502",
"metadata": {},
"outputs": [],
"source": [
"TRAIN_QUERIES_FPATH = './data/train_queries.json'\n",
"TRAIN_RELEVANT_DOCS_FPATH = './data/train_relevant_docs.json'\n",
"\n",
"VAL_QUERIES_FPATH = './data/val_queries.json'\n",
"VAL_RELEVANT_DOCS_FPATH = './data/val_relevant_docs.json'"
]
},
{
"cell_type": "code",
"execution_count": 66,
"id": "4bfeb3cd-59fd-496c-a4f0-c269835123ed",
"metadata": {},
"outputs": [],
"source": [
"with open(TRAIN_CORPUS_FPATH, 'r+') as f:\n",
" train_corpus = json.load(f)\n",
"\n",
"with open(VAL_CORPUS_FPATH, 'r+') as f:\n",
" val_corpus = json.load(f)"
]
},
{
"cell_type": "code",
"execution_count": 107,
"id": "0c26a5cd-9ec4-4c7b-bc58-9349d83a248f",
"metadata": {},
"outputs": [],
"source": [
"def generate_queries(\n",
" corpus,\n",
" num_questions_per_chunk=2,\n",
" prompt_template=None,\n",
" verbose=False,\n",
"):\n",
" \"\"\"\n",
" Automatically generate hypothetical questions that could be answered with\n",
" doc in the corpus.\n",
" \"\"\"\n",
" llm = OpenAI(model='gpt-3.5-turbo')\n",
"\n",
" prompt_template = prompt_template or \"\"\"\\\n",
" Context information is below.\n",
" \n",
" ---------------------\n",
" {context_str}\n",
" ---------------------\n",
" \n",
" Given the context information and not prior knowledge.\n",
" generate only questions based on the below query.\n",
" \n",
" You are a Teacher/ Professor. Your task is to setup \\\n",
" {num_questions_per_chunk} questions for an upcoming \\\n",
" quiz/examination. The questions should be diverse in nature \\\n",
" across the document. Restrict the questions to the \\\n",
" context information provided.\"\n",
" \"\"\"\n",
"\n",
" queries = {}\n",
" relevant_docs = {}\n",
" for node_id, text in tqdm(corpus.items()):\n",
" query = prompt_template.format(context_str=text, num_questions_per_chunk=num_questions_per_chunk)\n",
" response = llm.complete(query)\n",
" \n",
" result = str(response).strip().split(\"\\n\")\n",
" questions = [\n",
" re.sub(r\"^\\d+[\\).\\s]\", \"\", question).strip() for question in result\n",
" ]\n",
" questions = [question for question in questions if len(question) > 0]\n",
" \n",
" for question in questions:\n",
" question_id = str(uuid.uuid4())\n",
" queries[question_id] = question\n",
" relevant_docs[question_id] = [node_id]\n",
" return queries, relevant_docs"
]
},
{
"cell_type": "code",
"execution_count": 81,
"id": "84780125-1c09-4904-bce1-23586d012c60",
"metadata": {
"scrolled": true
},
"outputs": [
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "31fa3a6fa5a448e1901b48f02574a85a",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
" 0%| | 0/334 [00:00<?, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"train_queries, train_relevant_docs = generate_queries(train_corpus)"
]
},
{
"cell_type": "code",
"execution_count": 91,
"id": "9452dcfc-7084-496f-87eb-8c7be6a6dc6a",
"metadata": {
"scrolled": true
},
"outputs": [
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "8cfb77b0f3844a5ea68a5bd02593b4e7",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
" 0%| | 0/395 [00:00<?, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"val_queries, val_relevant_docs = generate_queries(val_corpus)"
]
},
{
"cell_type": "code",
"execution_count": 94,
"id": "96087eb2-607b-4115-ab37-426bfcf6af1c",
"metadata": {
"tags": []
},
"outputs": [],
"source": [
"with open(TRAIN_QUERIES_FPATH, 'w+') as f:\n",
" json.dump(train_queries, f)\n",
"\n",
"with open(TRAIN_RELEVANT_DOCS_FPATH, 'w+') as f:\n",
" json.dump(train_relevant_docs, f)\n",
"\n",
"with open(VAL_QUERIES_FPATH, 'w+') as f:\n",
" json.dump(val_queries, f)\n",
"\n",
"with open(VAL_RELEVANT_DOCS_FPATH, 'w+') as f:\n",
" json.dump(val_relevant_docs, f)"
]
},
{
"cell_type": "markdown",
"id": "71453dc5-25e0-45bf-9d86-5e72b3a891d5",
"metadata": {},
"source": [
"### Merge data"
]
},
{
"cell_type": "markdown",
"id": "f252a7fe-c51d-47f5-ba94-45c8788a9674",
"metadata": {},
"source": [
"Finally, we do some minor re-organization to make it easier to access the dataset for training and evaluation."
]
},
{
"cell_type": "code",
"execution_count": 104,
"id": "3f465498-daa5-41b3-9ea3-8114254832b7",
"metadata": {},
"outputs": [],
"source": [
"TRAIN_DATASET_FPATH = './data/train_dataset.json'\n",
"VAL_DATASET_FPATH = './data/val_dataset.json'"
]
},
{
"cell_type": "code",
"execution_count": 102,
"id": "430e34b0-699d-4eec-a26d-6d100d81cca4",
"metadata": {},
"outputs": [],
"source": [
"train_dataset = {\n",
" 'queries': train_queries,\n",
" 'corpus': train_corpus,\n",
" 'relevant_docs': train_relevant_docs,\n",
"}\n",
"\n",
"val_dataset = {\n",
" 'queries': val_queries,\n",
" 'corpus': val_corpus,\n",
" 'relevant_docs': val_relevant_docs,\n",
"}"
]
},
{
"cell_type": "code",
"execution_count": 103,
"id": "b09071a2-6c32-408a-b971-39b5d6e42486",
"metadata": {},
"outputs": [],
"source": [
"with open(TRAIN_DATASET_FPATH, 'w+') as f:\n",
" json.dump(train_dataset, f)\n",
"\n",
"with open(VAL_DATASET_FPATH, 'w+') as f:\n",
" json.dump(val_dataset, f)"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.10.10"
}
},
"nbformat": 4,
"nbformat_minor": 5
}