Merge pull request #1 from run-llama/jerry/add_finetune_tutorial

[wip] add sql finetuning tutorial with modal
This commit is contained in:
Jerry Liu
2023-08-16 23:44:35 -07:00
committed by GitHub
14 changed files with 7251 additions and 1 deletions
+2
View File
@@ -0,0 +1,2 @@
.venv
*__pycache__*
+21
View File
@@ -0,0 +1,21 @@
MIT License
Copyright (c) 2023 Modal Labs
Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:
The above copyright notice and this permission notice shall be included in all
copies or substantial portions of the Software.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
SOFTWARE.
+34 -1
View File
@@ -1 +1,34 @@
# modal_finetune_sql
# Finetuning LLaMa + Text-to-SQL
This walkthrough shows you how to fine-tune LLaMa-7B on a Text-to-SQL dataset, and then use it for inference against
any database of structured data using LlamaIndex.
This code is taken and adapted from the Modal `doppel-bot` repo: https://github.com/modal-labs/doppel-bot.
### Stack
- LlamaIndex
- Modal
- Hugging Face datasets
- OpenLLaMa
- Peft
### Steps for running
Please see the notebook `tutorial.ipynb` for full instructions.
In the meantime you can run each step individually as below:
Loading data:
`modal run src.load_data_sql`
Finetuning:
`modal run --detach src.finetune_sql`
Inference:
`modal run src.inference_sql_llamaindex::main --query "Which city has the highest population?" --sqlite-file-path "nbs/cities.db"`
(Optional) Downloading model weights:
`modal run src.download_weights --output-dir out_model`
+13
View File
@@ -0,0 +1,13 @@
[tool.black]
line-length = 120
[tool.ruff]
ignore = [
'E501',
'E741',
]
select = [
'E',
'F',
'W',
]
+4
View File
@@ -0,0 +1,4 @@
modal-client==0.50.3044
llama-index==0.8.2.post1
datasets==2.14.4
peft
+82
View File
@@ -0,0 +1,82 @@
from modal import Image, Stub, NetworkFileSystem, Dict
import random
from typing import Optional
from pathlib import Path
VOL_MOUNT_PATH = Path("/vol")
WANDB_PROJECT = "test-finetune-modal"
MODEL_PATH = "/model"
def download_models():
from transformers import LlamaForCausalLM, LlamaTokenizer
model_name = "openlm-research/open_llama_7b"
model = LlamaForCausalLM.from_pretrained(model_name)
model.save_pretrained(MODEL_PATH)
tokenizer = LlamaTokenizer.from_pretrained(model_name)
tokenizer.save_pretrained(MODEL_PATH)
openllama_image = (
Image.micromamba()
.micromamba_install(
"cudatoolkit=11.7",
"cudnn=8.1.0",
"cuda-nvcc",
channels=["conda-forge", "nvidia"],
)
.apt_install("git")
.pip_install(
"accelerate==0.18.0",
"bitsandbytes==0.37.0",
"bitsandbytes-cuda117==0.26.0.post2",
"datasets==2.10.1",
"fire==0.5.0",
"gradio==3.23.0",
"peft @ git+https://github.com/huggingface/peft.git@e536616888d51b453ed354a6f1e243fecb02ea08",
"transformers @ git+https://github.com/huggingface/transformers.git@a92e0ad2e20ef4ce28410b5e05c5d63a5a304e65",
"torch==2.0.0",
"torchvision==0.15.1",
"sentencepiece==0.1.97",
"llama-index==0.8.1",
"sentence-transformers",
)
.run_function(download_models)
.pip_install("wandb==0.15.0")
)
stub = Stub(name="doppel-bot", image=openllama_image)
stub.model_dict = Dict.new()
stub.data_dict = Dict.new()
output_vol = NetworkFileSystem.new(cloud="gcp").persisted("doppelbot-vol")
def generate_prompt_sql(input, context, output=""):
return f"""You are a powerful text-to-SQL model. Your job is to answer questions about a database. You are given a question and context regarding one or more tables.
You must output the SQL query that answers the question.
### Input:
{input}
### Context:
{context}
### Response:
{output}"""
def get_data_path(data_dir: str = "data_sql") -> Path:
return VOL_MOUNT_PATH / data_dir / "data_sql.jsonl"
def get_model_path(data_dir: str = "data_sql", checkpoint: Optional[str] = None) -> Path:
path = VOL_MOUNT_PATH / data_dir
if checkpoint:
path = path / checkpoint
return path
+2
View File
@@ -0,0 +1,2 @@
{"input": "How many heads of the departments are older than 56 ?", "context": "CREATE TABLE head (age INTEGER)"}
{"input": "List the name, born state and age of the heads of departments ordered by age.", "context": "CREATE TABLE head (name VARCHAR, born_state VARCHAR, age VARCHAR)"}
+55
View File
@@ -0,0 +1,55 @@
"""Download weights."""
from .common import (
stub, output_vol, VOL_MOUNT_PATH, get_model_path
)
import os
import json
from pathlib import Path
@stub.function(
network_file_systems={VOL_MOUNT_PATH.as_posix(): output_vol},
cloud="gcp"
)
def load_model(model_dir: str = "data_sql"):
"""Load model."""
path = get_model_path(model_dir=model_dir)
config_path = path / "adapter_config.json"
model_path = path / "adapter_model.bin"
config_data = json.load(open(config_path))
with open(model_path, "rb") as f:
model_data = f.read()
print(f'loaded config, model data from {path}')
# read data, put this in `model_dict` on stub
stub.model_dict["config"] = config_data
stub.model_dict["model"] = model_data
@stub.local_entrypoint()
def main(output_dir: str, model_dir: str = "data_sql"):
# copy adapter_config.json and adapter_model.bin files into dict
load_model.call(model_dir=model_dir)
model_data = stub.model_dict["model"]
config_data = stub.model_dict["config"]
print(f"Loaded model data, storing in {output_dir}")
# store locally
if not os.path.exists(output_dir):
os.makedirs(output_dir)
out_model_path = Path(output_dir) / "adapter_model.bin"
out_config_path = Path(output_dir) / "adapter_config.json"
with open(out_model_path, "wb") as f:
f.write(model_data)
with open(out_config_path, "w") as f:
json.dump(config_data, f)
print("Done!")
+109
View File
@@ -0,0 +1,109 @@
from typing import Optional
from modal import gpu, method, Retries
from modal.cls import ClsMixin
import json
from .common import (
output_vol,
stub,
VOL_MOUNT_PATH,
get_data_path,
generate_prompt_sql
)
from .inference_utils import OpenLlamaLLM
@stub.function(
gpu="A100",
retries=Retries(
max_retries=3,
initial_delay=5.0,
backoff_coefficient=2.0,
),
timeout=60 * 60 * 2,
network_file_systems={VOL_MOUNT_PATH.as_posix(): output_vol},
cloud="gcp",
)
def run_evals(
sample_data,
model_dir: str = "data_sql",
use_finetuned_model: bool = True
):
llm = OpenLlamaLLM(
model_dir=model_dir, max_new_tokens=256, use_finetuned_model=use_finetuned_model
)
inputs_outputs = []
for row_dict in sample_data:
prompt = generate_prompt_sql(row_dict["input"], row_dict["context"])
completion = llm.complete(
prompt,
do_sample=True,
temperature=0.3,
top_p=0.85,
top_k=40,
num_beams=1,
max_new_tokens=600,
repetition_penalty=1.2,
)
inputs_outputs.append((row_dict, completion.text))
return inputs_outputs
@stub.function(
gpu="A100",
retries=Retries(
max_retries=3,
initial_delay=5.0,
backoff_coefficient=2.0,
),
timeout=60 * 60 * 2,
network_file_systems={VOL_MOUNT_PATH.as_posix(): output_vol},
cloud="gcp",
)
def run_evals_all(
data_dir: str = "data_sql",
model_dir: str = "data_sql",
num_samples: int = 10,
):
# evaluate a sample from the same training set
from datasets import load_dataset
data_path = get_data_path(data_dir).as_posix()
data = load_dataset("json", data_files=data_path)
# load sample data
sample_data = data["train"].shuffle().select(range(num_samples))
print('*** Running inference with finetuned model ***')
inputs_outputs_0 = run_evals(
sample_data=sample_data,
model_dir=model_dir,
use_finetuned_model=True
)
print('*** Running inference with base model ***')
input_outputs_1 = run_evals(
sample_data=sample_data,
model_dir=model_dir,
use_finetuned_model=False
)
return inputs_outputs_0, input_outputs_1
@stub.local_entrypoint()
def main(data_dir: str = "data_sql", model_dir: str = "data_sql", num_samples: int = 10):
"""Main function."""
inputs_outputs_0, input_outputs_1 = run_evals_all.call(
data_dir=data_dir,
model_dir=model_dir,
num_samples=num_samples
)
for idx, (row_dict, completion) in enumerate(inputs_outputs_0):
print('************ Row {idx} ************')
print(f"Input {idx}: " + str(row_dict))
print(f"Output {idx} (finetuned model): " + str(completion))
print(f"Output {idx} (base model): " + str(input_outputs_1[idx][1]))
print('***********************************')
+251
View File
@@ -0,0 +1,251 @@
from modal import Secret
from datetime import datetime
import os
from math import ceil
from .common import (
MODEL_PATH,
VOL_MOUNT_PATH,
WANDB_PROJECT,
output_vol,
stub,
get_data_path,
get_model_path,
generate_prompt_sql,
)
# This code is adapter from https://github.com/tloen/alpaca-lora/blob/65fb8225c09af81feb5edb1abb12560f02930703/finetune.py
# with modifications mainly to expose more parameters to the user.
def _train(
# model/data params
base_model: str,
data,
output_dir: str = "./lora-alpaca",
eval_steps: int = 20,
save_steps: int = 20,
# training hyperparams
batch_size: int = 128,
micro_batch_size: int = 32,
max_steps: int = 200,
learning_rate: float = 3e-4,
cutoff_len: int = 512,
val_set_size: int = 100,
# lora hyperparams
lora_r: int = 16,
lora_alpha: int = 16,
lora_dropout: float = 0.05,
lora_target_modules=[
"q_proj",
"k_proj",
"v_proj",
"o_proj",
],
# llm hyperparams
train_on_inputs: bool = True, # if False, masks out inputs in loss
add_eos_token: bool = True,
group_by_length: bool = True, # faster, but produces an odd training loss curve
# wandb params
wandb_project: str = "",
wandb_run_name: str = "",
wandb_watch: str = "", # options: false | gradients | all
wandb_log_model: str = "", # options: false | true
resume_from_checkpoint: str = None, # either training checkpoint or final adapter
):
import os
import sys
import torch
import transformers
from peft import (
LoraConfig,
get_peft_model,
get_peft_model_state_dict,
prepare_model_for_int8_training,
set_peft_model_state_dict,
)
from transformers import LlamaForCausalLM, LlamaTokenizer
gradient_accumulation_steps = batch_size // micro_batch_size
device_map = "auto"
world_size = int(os.environ.get("WORLD_SIZE", 1))
ddp = world_size != 1
if ddp:
device_map = {"": int(os.environ.get("LOCAL_RANK") or 0)}
gradient_accumulation_steps = gradient_accumulation_steps // world_size
# Check if parameter passed or if set within environ
use_wandb = len(wandb_project) > 0 or ("WANDB_PROJECT" in os.environ and len(os.environ["WANDB_PROJECT"]) > 0)
# Only overwrite environ if wandb param passed
if len(wandb_project) > 0:
os.environ["WANDB_PROJECT"] = wandb_project
if len(wandb_watch) > 0:
os.environ["WANDB_WATCH"] = wandb_watch
if len(wandb_log_model) > 0:
os.environ["WANDB_LOG_MODEL"] = wandb_log_model
model = LlamaForCausalLM.from_pretrained(
base_model,
load_in_8bit=True,
torch_dtype=torch.float16,
device_map=device_map,
)
tokenizer = LlamaTokenizer.from_pretrained(base_model, add_eos_token=True)
tokenizer.pad_token_id = 0 # unk. we want this to be different from the eos token
tokenizer.padding_side = "left" # Allow batched inference
def tokenize(prompt, add_eos_token=True):
# there's probably a way to do this with the tokenizer settings
# but again, gotta move fast
result = tokenizer(
prompt,
truncation=True,
max_length=cutoff_len,
padding=False,
return_tensors=None,
)
if (
result["input_ids"][-1] != tokenizer.eos_token_id
and len(result["input_ids"]) < cutoff_len
and add_eos_token
):
result["input_ids"].append(tokenizer.eos_token_id)
result["attention_mask"].append(1)
result["labels"] = result["input_ids"].copy()
return result
def generate_and_tokenize_prompt(data_point):
full_prompt = generate_prompt_sql(
data_point["input"],
data_point["context"],
data_point["output"],
)
tokenized_full_prompt = tokenize(full_prompt)
if not train_on_inputs:
raise NotImplementedError("not implemented yet")
return tokenized_full_prompt
model = prepare_model_for_int8_training(model)
config = LoraConfig(
r=lora_r,
lora_alpha=lora_alpha,
target_modules=lora_target_modules,
lora_dropout=lora_dropout,
bias="none",
task_type="CAUSAL_LM",
)
model = get_peft_model(model, config)
if resume_from_checkpoint:
# Check the available weights and load them
checkpoint_name = os.path.join(resume_from_checkpoint, "pytorch_model.bin") # Full checkpoint
if not os.path.exists(checkpoint_name):
checkpoint_name = os.path.join(
resume_from_checkpoint, "adapter_model.bin"
) # only LoRA model - LoRA config above has to fit
resume_from_checkpoint = False # So the trainer won't try loading its state
# The two files above have a different name depending on how they were saved, but are actually the same.
if os.path.exists(checkpoint_name):
print(f"Restarting from {checkpoint_name}")
adapters_weights = torch.load(checkpoint_name)
set_peft_model_state_dict(model, adapters_weights)
else:
print(f"Checkpoint {checkpoint_name} not found")
model.print_trainable_parameters() # Be more transparent about the % of trainable params.
if val_set_size > 0:
train_val = data["train"].train_test_split(test_size=val_set_size, shuffle=True, seed=42)
train_data = train_val["train"].shuffle().map(generate_and_tokenize_prompt)
val_data = train_val["test"].shuffle().map(generate_and_tokenize_prompt)
else:
train_data = data["train"].shuffle().map(generate_and_tokenize_prompt)
val_data = None
if not ddp and torch.cuda.device_count() > 1:
# keeps Trainer from trying its own DataParallelism when more than 1 gpu is available
model.is_parallelizable = True
model.model_parallel = True
trainer = transformers.Trainer(
model=model,
train_dataset=train_data,
eval_dataset=val_data,
args=transformers.TrainingArguments(
per_device_train_batch_size=micro_batch_size,
gradient_accumulation_steps=gradient_accumulation_steps,
warmup_steps=100,
max_steps=max_steps,
learning_rate=learning_rate,
fp16=True,
logging_steps=10,
optim="adamw_torch",
evaluation_strategy="steps" if val_set_size > 0 else "no",
save_strategy="steps",
eval_steps=eval_steps if val_set_size > 0 else None,
save_steps=save_steps,
output_dir=output_dir,
# save_total_limit=3,
load_best_model_at_end=False,
ddp_find_unused_parameters=False if ddp else None,
group_by_length=group_by_length,
report_to="wandb" if use_wandb else "none",
run_name=wandb_run_name if use_wandb else None,
),
data_collator=transformers.DataCollatorForSeq2Seq(
tokenizer, pad_to_multiple_of=8, return_tensors="pt", padding=True
),
)
model.config.use_cache = False
old_state_dict = model.state_dict
model.state_dict = (lambda self, *_, **__: get_peft_model_state_dict(self, old_state_dict())).__get__(
model, type(model)
)
if torch.__version__ >= "2" and sys.platform != "win32":
model = torch.compile(model)
trainer.train(resume_from_checkpoint=resume_from_checkpoint)
model.save_pretrained(output_dir)
print("\n If there's a warning about missing keys above, please disregard :)")
@stub.function(
gpu="A100",
# TODO: Modal should support optional secrets.
secret=Secret.from_name("my-wandb-secret") if WANDB_PROJECT else None,
timeout=60 * 60 * 2,
network_file_systems={VOL_MOUNT_PATH: output_vol},
cloud="oci",
allow_cross_region_volumes=True,
)
def finetune(data_dir: str = "data_sql", model_dir: str = "data_sql"):
from datasets import load_dataset
data_path = get_data_path(data_dir).as_posix()
data = load_dataset("json", data_files=data_path)
num_samples = len(data["train"])
val_set_size = ceil(0.1 * num_samples)
print(f"Loaded {num_samples} samples. ")
_train(
MODEL_PATH,
data,
val_set_size=val_set_size,
output_dir=get_model_path(model_dir).as_posix(),
wandb_project=WANDB_PROJECT,
wandb_run_name=f"openllama-{datetime.now().strftime('%Y-%m-%d-%H-%M')}",
)
# Delete scraped data after fine-tuning
os.remove(data_path)
+98
View File
@@ -0,0 +1,98 @@
from typing import Optional
from modal import gpu, method, Retries
from modal.cls import ClsMixin
import json
from .common import (
MODEL_PATH,
output_vol,
stub,
VOL_MOUNT_PATH,
get_model_path,
generate_prompt_sql
)
from .inference_utils import OpenLlamaLLM
from llama_index.callbacks import CallbackManager
from llama_index.llms import (
CustomLLM,
LLMMetadata,
CompletionResponse,
CompletionResponseGen,
)
from llama_index.llms.base import llm_completion_callback
from llama_index.indices.struct_store.sql_query import NLSQLTableQueryEngine
from llama_index import SQLDatabase, ServiceContext, Prompt
from typing import Any
@stub.function(
gpu="A100",
retries=Retries(
max_retries=3,
initial_delay=5.0,
backoff_coefficient=2.0,
),
timeout=60 * 60 * 2,
network_file_systems={VOL_MOUNT_PATH.as_posix(): output_vol},
cloud="gcp",
)
def run_query(query: str, model_dir: str = "data_sql", use_finetuned_model: bool = True):
"""Run query."""
import pandas as pd
from sqlalchemy import create_engine
# define SQL database
assert "sqlite_data" in stub.data_dict
with open(VOL_MOUNT_PATH / "test_data.db", "wb") as fp:
fp.write(stub.data_dict["sqlite_data"])
# define service context (containing custom LLM)
print('setting up service context')
# finetuned llama LLM
num_output = 256
llm = OpenLlamaLLM(
model_dir=model_dir, max_new_tokens=num_output, use_finetuned_model=use_finetuned_model
)
service_context = ServiceContext.from_defaults(llm=llm)
sql_path = VOL_MOUNT_PATH / "test_data.db"
engine = create_engine(f'sqlite:///{sql_path}', echo=True)
sql_database = SQLDatabase(engine)
# define custom text-to-SQL prompt with generate prompt
prompt_prefix = "Dialect: {dialect}\n\n"
prompt_suffix = generate_prompt_sql("{query_str}", "{schema}", output="")
sql_prompt = Prompt(prompt_prefix + prompt_suffix)
query_engine = NLSQLTableQueryEngine(
sql_database,
text_to_sql_prompt=sql_prompt,
service_context=service_context,
synthesize_response=False
)
response = query_engine.query(query)
print(
f'Model output: \n'
f'SQL Query: {str(response.metadata["sql_query"])}'
f"Response: {response.response}"
)
return response
@stub.local_entrypoint()
def main(query: str, sqlite_file_path: str, model_dir: str = "data_sql", use_finetuned_model: str = "True"):
"""Main function."""
fp = open(sqlite_file_path, "rb")
stub.data_dict["sqlite_data"] = fp.read()
if use_finetuned_model == "None":
# try both
run_query.call(query, model_dir=model_dir, use_finetuned_model=True)
run_query.call(query, model_dir=model_dir, use_finetuned_model=False)
else:
bool_toggle = use_finetuned_model == "True"
run_query.call(query, model_dir=model_dir, use_finetuned_model=bool_toggle)
+121
View File
@@ -0,0 +1,121 @@
"""Get inference utils."""
from typing import Optional
from modal import gpu
from modal.cls import ClsMixin
from .common import (
MODEL_PATH,
output_vol,
stub,
VOL_MOUNT_PATH,
get_model_path,
)
from llama_index.callbacks import CallbackManager
from llama_index.llms import (
CustomLLM,
LLMMetadata,
CompletionResponse,
CompletionResponseGen,
)
from llama_index.llms.base import llm_completion_callback
from typing import Any
@stub.cls(
gpu=gpu.A100(memory=20),
network_file_systems={VOL_MOUNT_PATH: output_vol},
)
class OpenLlamaLLM(CustomLLM, ClsMixin):
"""OpenLlamaLLM is a custom LLM that uses the OpenLlamaModel."""
def __init__(
self,
model_dir: str = "data_sql",
max_new_tokens: int = 128,
callback_manager: Optional[CallbackManager] = None,
use_finetuned_model: bool = True,
):
super().__init__(callback_manager=callback_manager)
import sys
import torch
from peft import PeftModel
from transformers import LlamaForCausalLM, LlamaTokenizer
CHECKPOINT = get_model_path(model_dir)
load_8bit = False
device = "cuda"
self.tokenizer = LlamaTokenizer.from_pretrained(MODEL_PATH)
model = LlamaForCausalLM.from_pretrained(
MODEL_PATH,
load_in_8bit=load_8bit,
torch_dtype=torch.float16,
device_map="auto",
)
if use_finetuned_model:
model = PeftModel.from_pretrained(
model,
CHECKPOINT,
torch_dtype=torch.float16,
)
if not load_8bit:
model.half() # seems to fix bugs for some users.
model.eval()
if torch.__version__ >= "2" and sys.platform != "win32":
model = torch.compile(model)
self.model = model
self.device = device
self._max_new_tokens = max_new_tokens
@property
def metadata(self) -> LLMMetadata:
"""Get LLM metadata."""
return LLMMetadata(
context_window=2048,
num_output=self._max_new_tokens,
model_name="finetuned_openllama_sql"
)
@llm_completion_callback()
def complete(self, prompt: str, **kwargs: Any) -> CompletionResponse:
import torch
from transformers import GenerationConfig
# TODO: TO fill
inputs = self.tokenizer(prompt, return_tensors="pt")
input_ids = inputs["input_ids"].to(self.device)
# tokens = self.tokenizer.convert_ids_to_tokens(input_ids[0])
# print(tokens)
generation_config = GenerationConfig(
**kwargs,
)
with torch.no_grad():
generation_output = self.model.generate(
input_ids=input_ids,
generation_config=generation_config,
return_dict_in_generate=True,
output_scores=True,
max_new_tokens=self._max_new_tokens,
)
s = generation_output.sequences[0]
output = self.tokenizer.decode(s, skip_special_tokens=True)
# NOTE: parsing response this way means that the model can mostly
# only be used for text-to-SQL, not other purposes
response_text = output.split("### Response:")[1].strip()
return CompletionResponse(text=response_text)
@llm_completion_callback()
def stream_complete(self, prompt: str, **kwargs: Any) -> CompletionResponseGen:
raise NotImplementedError()
+40
View File
@@ -0,0 +1,40 @@
import json
from modal import Retries
from .common import (
stub,
VOL_MOUNT_PATH,
output_vol,
get_data_path
)
@stub.function(
retries=Retries(
max_retries=3,
initial_delay=5.0,
backoff_coefficient=2.0,
),
timeout=60 * 60 * 2,
network_file_systems={VOL_MOUNT_PATH.as_posix(): output_vol},
cloud="gcp",
)
def load_data_sql(data_dir: str = "data_sql"):
from datasets import load_dataset
dataset = load_dataset("b-mc2/sql-create-context")
dataset_splits = {"train": dataset["train"]}
out_path = get_data_path(data_dir)
out_path.parent.mkdir(parents=True, exist_ok=True)
for key, ds in dataset_splits.items():
with open(out_path, "w") as f:
for item in ds:
newitem = {
"input": item["question"],
"context": item["context"],
"output": item["answer"],
}
f.write(json.dumps(newitem) + "\n")
+6419
View File
File diff suppressed because it is too large Load Diff